# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func(private=True)
    def argsort(var_probs: T.handle, var_argsort_gpu_v1: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size), offset_factor=1)
        out_buf = T.match_buffer(var_argsort_gpu_v1, (batch_size, vocab_size), "int32", align=8)
        # with T.block("root"):
        value_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        value_swap_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        out_swap_buf = T.alloc_buffer((batch_size, vocab_size), "int32", align=8)
        with T.block("argsort_gpu"):
            T.reads()
            T.writes()
            if vocab_size > T.int64(0):
                with T.launch_thread("threadIdx.x", T.int64(256)) as threadIdx_x:
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(255)) // T.int64(256)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(256) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = probs[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size]
                        out_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = T.Cast("int32", blockIdx_x * T.int64(256) + threadIdx_x)
                with T.attr(0, "hand_threaded", 0):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(64))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(127)) // T.int64(128)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    temp_keys_swap = T.allocate([T.int64(128)], "float32", "shared")
                    temp_values_swap = T.allocate([T.int64(128)], "int32", "shared")
                    temp_keys = T.allocate([T.int64(1)], "float32", "local")
                    temp_values = T.allocate([T.int64(1)], "int32", "local")
                    temp_cond1 = T.allocate([T.int64(1)], "float32", "local")
                    temp_cond2 = T.allocate([T.int64(1)], "float32", "local")
                    temp_keys_swap_1 = T.Buffer((128,), data=temp_keys_swap, scope="shared")
                    temp_values_swap_1 = T.Buffer((128,), "int32", data=temp_values_swap, scope="shared")
                    for i in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128) < vocab_size:
                            temp_keys_swap_1[T.int64(2) * threadIdx_x + i] = value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                            temp_values_swap_1[T.int64(2) * threadIdx_x + i] = out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                    T.tvm_storage_sync("shared")
                    for j in range(T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128))):
                        if T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) < T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128)) - T.int64(1):
                            temp_cond1_1 = T.Buffer((1,), data=temp_cond1, scope="local")
                            temp_cond1_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                            temp_cond2_1 = T.Buffer((1,), data=temp_cond2, scope="local")
                            temp_cond2_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                            if temp_cond1_1[T.int64(0)] < temp_cond2_1[T.int64(0)]:
                                temp_keys_1 = T.Buffer((1,), data=temp_keys, scope="local")
                                temp_keys_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_keys_1[T.int64(0)]
                                temp_values_1 = T.Buffer((1,), "int32", data=temp_values, scope="local")
                                temp_values_1[T.int64(0)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_values_1[T.int64(0)]
                        T.tvm_storage_sync("shared")
                    for k in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128) < vocab_size:
                            value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                for i_0 in range(T.if_then_else(T.bitwise_and(vocab_size, vocab_size - T.int64(1)) == T.int64(0), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64) - T.int64(1), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64)) - (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(256))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(1023)) // T.int64(1024)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size * ((vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) - T.int64(1))) // T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))))))
                    if T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) < vocab_size:
                        if (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(1023)) // T.int64(1024) == T.int64(1):
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)), T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - last_1[T.int64(0)]
                                for i_1_1 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)), (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)):
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size)):
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)), T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) - last_1[T.int64(0)]
                                for i_2 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)), (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)):
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size)):
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) + T.int64(255)) // T.int64(256)) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                        else:
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(1024) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(1024), T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_3 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_4 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(1024) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(1024), T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_5 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_6 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(1024) - last_1[T.int64(0)]), T.int64(1024)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(1024)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(1024) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                if T.if_then_else(T.bitwise_and(vocab_size, vocab_size - T.int64(1)) == T.int64(0), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64) - T.int64(1), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64)) > T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1) and (T.if_then_else(T.bitwise_and(vocab_size, vocab_size - T.int64(1)) == T.int64(0), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64) - T.int64(1), T.int64(64) - T.Cast("int64", T.clz(vocab_size) - 64 + 64)) - (T.int64(32) - T.Cast("int64", T.clz(T.int64(128)) - 64 + 32) - T.int64(1))) % T.int64(2) == T.int64(1):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(256))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(255)) // T.int64(256)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(256) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) % vocab_size] = value_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) % vocab_size]
                        out_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) % vocab_size] = out_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(256) + threadIdx_x)) % vocab_size]

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
        for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                for ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    for ax2_fused_0 in T.serial(T.int64(64), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                        with T.block("max"):
                            v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0)
                            v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
                            v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(64) + ax2_fused_1)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                            T.writes(temp_max_shared[v0, v1])
                            with T.init():
                                temp_max_shared[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                            temp_max_shared[v0, v1] = T.max(temp_max_shared[v0, v1], T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0)))
            for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                for ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    for ax2_fused_0 in T.serial(T.int64(64), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                        with T.block("sum_exp"):
                            v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0)
                            v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
                            v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(64) + ax2_fused_1)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1])
                            T.writes(temp_sum_shared[v0, v1])
                            with T.init():
                                temp_sum_shared[v0, v1] = T.float32(0.0)
                            temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0)) - temp_max_shared[v0, v1]), T.Cast("float32", T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0)) == temp_max_shared[v0, v1])), T.float32(0.0))
            for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                for ax2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("log"):
                        v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks)
                        v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(64) + ax2_1)
                        T.where(ax2_0 * T.int64(64) + ax2_1 < T.int64(1))
                        T.reads(temperature[v0], temp_sum_shared[v0, v1], temp_max_shared[v0, v1])
                        T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                        chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum_shared[v0, v1]), temp_sum_shared[v0, v1])
                        chunked_max[v0, v1] = temp_max_shared[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding((batch_size * 256 + 1023) // 1024, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 256
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 2
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 2 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding((copy_length * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(2, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(128))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(128)) // T.int64(128))
                    vd = T.axis.spatial(128, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(128)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(2) * T.int64(128))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((batch_size + 1023) // 1024, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    v0 = T.axis.spatial(batch_size, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < batch_size)
                    T.reads()
                    T.writes(result[v0, 0])
                    result[v0, 0] = value

    @T.prim_func(private=True)
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 2048), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 2048), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 2048), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func(private=True)
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 2048), "float16")
        B = T.match_buffer(pB, (1, seq_len, 2048), "float16")
        O = T.match_buffer(pO, (1, seq_len, 2048), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(2048, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul10_add2(model_layers_0_self_attn_c_attn_q_weight2: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale2: T.Buffer((T.int64(2560), T.int64(64)), "float16"), rms_norm73: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_self_attn_c_attn_bias2: T.Buffer((T.int64(2560),), "float16"), T_add_intermediate_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16", scope="local")
        NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(2560)), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(2560)), "float16", scope="local")
        model_layers_0_self_attn_c_attn_q_weight2_local = T.alloc_buffer((T.int64(2560), T.int64(256)), "uint32", scope="local")
        rms_norm73_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(640), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(4), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(4)):
                                        with T.block("rms_norm73_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(2048), ax2_0 * T.int64(512) + ax2_1 * T.int64(128) + ax2_2 * T.int64(4) + ax2_3)
                                            T.reads(rms_norm73[v0, v1, v2])
                                            T.writes(rms_norm73_shared[v0, v1, v2])
                                            rms_norm73_shared[v0, v1, v2] = rms_norm73[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(1)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                with T.block("model_layers_0_self_attn_c_attn_q_weight2_local"):
                                    v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(256), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_self_attn_c_attn_q_weight2[v0, v1])
                                    T.writes(model_layers_0_self_attn_c_attn_q_weight2_local[v0, v1])
                                    model_layers_0_self_attn_c_attn_q_weight2_local[v0, v1] = model_layers_0_self_attn_c_attn_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], rms_norm73_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], model_layers_0_self_attn_c_attn_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], model_layers_0_self_attn_c_attn_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + rms_norm73_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                            v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]
            for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax0_fused_2 in range(T.int64(1)):
                    with T.block("T_add"):
                        v0 = T.axis.spatial(T.int64(2560), u_fused_ax0_fused_fused_0 * T.int64(4) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2)
                        T.reads(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_layers_0_self_attn_c_attn_bias2[v0])
                        T.writes(T_add_intermediate_intermediate[T.int64(0), T.int64(0), v0])
                        T_add_intermediate_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_layers_0_self_attn_c_attn_bias2[v0]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul5_add1(model_layers_0_self_attn_c_attn_q_weight3: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale3: T.Buffer((T.int64(2560), T.int64(64)), "float16"), p_rms_norm146: T.handle, model_layers_0_self_attn_c_attn_bias3: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm146 = T.match_buffer(p_rms_norm146, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2560)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2560)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2560)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2560)), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], rms_norm146[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, rms_norm146[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1], model_layers_0_self_attn_c_attn_bias3[v1])
                                        T.writes(T_add_intermediate_intermediate[T.int64(0), v0, v1])
                                        T_add_intermediate_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + model_layers_0_self_attn_c_attn_bias3[v1]
        else:
            if T.tvm_thread_invariant(seq_len <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2560)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2560)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2560)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], rms_norm146[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, rms_norm146[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1], model_layers_0_self_attn_c_attn_bias3[v1])
                                            T.writes(T_add_intermediate_intermediate[T.int64(0), v0, v1])
                                            T_add_intermediate_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + model_layers_0_self_attn_c_attn_bias3[v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local")
                    rms_norm146_reindex_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2560), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(80), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("rms_norm146_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(rms_norm146[v0, v1, v2])
                                                                    T.writes(rms_norm146_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm146_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm146[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight3[v1, v2 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale3[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], rms_norm146_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + rms_norm146_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], model_layers_0_self_attn_c_attn_bias3[v2])
                                                        T.writes(T_add_intermediate_intermediate[T.int64(0), v1, v2])
                                                        T_add_intermediate_intermediate[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + model_layers_0_self_attn_c_attn_bias3[v2]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul_add(model_layers_0_self_attn_c_attn_q_weight4: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale4: T.Buffer((T.int64(2560), T.int64(64)), "float16"), p_rms_norm219: T.handle, model_layers_0_self_attn_c_attn_bias4: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm219 = T.match_buffer(p_rms_norm219, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2560)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2560)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2560)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2560)), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm219[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, rms_norm219[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1], model_layers_0_self_attn_c_attn_bias4[v1])
                                        T.writes(T_add_intermediate_intermediate[v0, T.int64(0), v1])
                                        T_add_intermediate_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + model_layers_0_self_attn_c_attn_bias4[v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2560)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2560)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2560)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm219[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, rms_norm219[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2560), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1], model_layers_0_self_attn_c_attn_bias4[v1])
                                            T.writes(T_add_intermediate_intermediate[v0, T.int64(0), v1])
                                            T_add_intermediate_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + model_layers_0_self_attn_c_attn_bias4[v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local")
                    rms_norm219_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2560), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(80), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("rms_norm219_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(rms_norm219[v1, T.int64(0), v2])
                                                                    T.writes(rms_norm219_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm219_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm219[v1, T.int64(0), v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight4[v1, v2 // T.int64(8)], model_layers_0_self_attn_c_attn_q_scale4[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], rms_norm219_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + rms_norm219_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2560), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], model_layers_0_self_attn_c_attn_bias4[v2])
                                                        T.writes(T_add_intermediate_intermediate[v1, T.int64(0), v2])
                                                        T_add_intermediate_intermediate[v1, T.int64(0), v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + model_layers_0_self_attn_c_attn_bias4[v2]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul1(model_layers_0_self_attn_o_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape435: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape435 = T.match_buffer(p_reshape435, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], reshape435[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, reshape435[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                        NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], reshape435[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, reshape435[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                            NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local")
                    reshape435_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("reshape435_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(reshape435[v1, T.int64(0), v2])
                                                                    T.writes(reshape435_reindex_pad_shared[v0, v1, v2])
                                                                    reshape435_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, reshape435[v1, T.int64(0), v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight4[v1, v2 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale4[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], reshape435_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + reshape435_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, T.int64(0), v2])
                                                        NT_matmul_intermediate[v1, T.int64(0), v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul11(model_layers_0_self_attn_o_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(64)), "float16"), lv218: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="local")
        model_layers_0_self_attn_o_proj_q_weight2_local = T.alloc_buffer((T.int64(2048), T.int64(256)), "uint32", scope="local")
        lv218_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(2), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(4)):
                                        with T.block("lv218_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(2048), ax2_0 * T.int64(1024) + ax2_1 * T.int64(128) + ax2_2 * T.int64(4) + ax2_3)
                                            T.reads(lv218[v0, v1, v2])
                                            T.writes(lv218_shared[v0, v1, v2])
                                            lv218_shared[v0, v1, v2] = lv218[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(1)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                with T.block("model_layers_0_self_attn_o_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(256), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_self_attn_o_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_self_attn_o_proj_q_weight2_local[v0, v1])
                                    model_layers_0_self_attn_o_proj_q_weight2_local[v0, v1] = model_layers_0_self_attn_o_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], lv218_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], model_layers_0_self_attn_o_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], model_layers_0_self_attn_o_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + lv218_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                            v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul_intermediate[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul6(model_layers_0_self_attn_o_proj_q_weight3: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale3: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape291 = T.match_buffer(p_reshape291, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], reshape291[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, reshape291[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                        NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], reshape291[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, reshape291[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                            NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local")
                    reshape291_reindex_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("reshape291_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(reshape291[v0, v1, v2])
                                                                    T.writes(reshape291_reindex_pad_shared[v0, v1, v2])
                                                                    reshape291_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, reshape291[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight3[v1, v2 // T.int64(8)], model_layers_0_self_attn_o_proj_q_scale3[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], reshape291_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + reshape291_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul12(model_layers_0_mlp_gate_up_proj_q_weight2: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale2: T.Buffer((T.int64(22016), T.int64(64)), "float16"), rms_norm74: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(22016)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local")
        model_layers_0_mlp_gate_up_proj_q_weight2_local = T.alloc_buffer((T.int64(22016), T.int64(256)), "uint32", scope="local")
        rms_norm74_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(5504), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(4), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(4)):
                                        with T.block("rms_norm74_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(2048), ax2_0 * T.int64(512) + ax2_1 * T.int64(128) + ax2_2 * T.int64(4) + ax2_3)
                                            T.reads(rms_norm74[v0, v1, v2])
                                            T.writes(rms_norm74_shared[v0, v1, v2])
                                            rms_norm74_shared[v0, v1, v2] = rms_norm74[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(1)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                with T.block("model_layers_0_mlp_gate_up_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(256), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, v1])
                                    model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, v1] = model_layers_0_mlp_gate_up_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], rms_norm74_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], model_layers_0_mlp_gate_up_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + rms_norm74_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                            v0 = T.axis.spatial(T.int64(22016), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul_intermediate[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul2(model_layers_0_mlp_gate_up_proj_q_weight4: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale4: T.Buffer((T.int64(22016), T.int64(64)), "float16"), p_rms_norm220: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm220 = T.match_buffer(p_rms_norm220, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(22016)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(22016)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(22016)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(22016)), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(2752), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm220[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, rms_norm220[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                        NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(22016)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(22016)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(22016)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(2752), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm220[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, rms_norm220[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                            NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(22016)), "float16", scope="local")
                    rms_norm220_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(22016), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(688), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("rms_norm220_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(rms_norm220[v1, T.int64(0), v2])
                                                                    T.writes(rms_norm220_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm220_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm220[v1, T.int64(0), v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v1, v2 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale4[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], rms_norm220_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + rms_norm220_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, T.int64(0), v2])
                                                        NT_matmul_intermediate[v1, T.int64(0), v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul7(model_layers_0_mlp_gate_up_proj_q_weight3: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale3: T.Buffer((T.int64(22016), T.int64(64)), "float16"), p_rms_norm147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm147 = T.match_buffer(p_rms_norm147, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(22016)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(22016)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(22016)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(22016)), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(2752), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], rms_norm147[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, rms_norm147[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                        NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(22016)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(22016)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(22016)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(2752), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], rms_norm147[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, rms_norm147[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(22016), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                            NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(22016)), "float16", scope="local")
                    rms_norm147_reindex_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(22016), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(688), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("rms_norm147_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(rms_norm147[v0, v1, v2])
                                                                    T.writes(rms_norm147_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm147_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm147[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v1, v2 // T.int64(8)], model_layers_0_mlp_gate_up_proj_q_scale3[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], rms_norm147_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + rms_norm147_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(22016), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul13(model_layers_0_mlp_down_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(344)), "float16"), lv219: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="local")
        model_layers_0_mlp_down_proj_q_weight2_local = T.alloc_buffer((T.int64(2048), T.int64(1376)), "uint32", scope="local")
        lv219_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(43), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(1)):
                                        with T.block("lv219_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(11008), ax2_0 * T.int64(256) + ax2_1 * T.int64(32) + ax2_2 + ax2_3)
                                            T.reads(lv219[v0, v1, v2])
                                            T.writes(lv219_shared[v0, v1, v2])
                                            lv219_shared[v0, v1, v2] = lv219[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(1)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                with T.block("model_layers_0_mlp_down_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_mlp_down_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_mlp_down_proj_q_weight2_local[v0, v1])
                                    model_layers_0_mlp_down_proj_q_weight2_local[v0, v1] = model_layers_0_mlp_down_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], lv219_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], model_layers_0_mlp_down_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], model_layers_0_mlp_down_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + lv219_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                            v0 = T.axis.spatial(T.int64(2048), u_fused_ax0_fused_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul_intermediate[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0)
                            NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul3(model_layers_0_mlp_down_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(344)), "float16"), p_lv1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv1 = T.match_buffer(p_lv1, (batch_size, T.int64(1), T.int64(11008)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(2048)), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(11008), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale4[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], lv1[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv1[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                        NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(2048)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(11008), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale4[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], lv1[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv1[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                            NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local")
                    lv1_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(11008)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(11008)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(1376)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("lv1_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(11008), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(lv1[v1, T.int64(0), v2])
                                                                    T.writes(lv1_reindex_pad_shared[v0, v1, v2])
                                                                    lv1_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, lv1[v1, T.int64(0), v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(11008), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_mlp_down_proj_q_weight4[v1, v2 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale4[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(11008), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv1_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + lv1_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, T.int64(0), v2])
                                                        NT_matmul_intermediate[v1, T.int64(0), v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul8(model_layers_0_mlp_down_proj_q_weight3: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale3: T.Buffer((T.int64(2048), T.int64(344)), "float16"), p_lv73: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv73 = T.match_buffer(p_lv73, (T.int64(1), seq_len, T.int64(11008)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(2048)), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(11008), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale3[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], lv73[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, lv73[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                        NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (seq_len + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(2048)), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(11008), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale3[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], lv73[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.if_then_else(v0 < seq_len, lv73[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float16(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((seq_len + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(2048), ax1_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                            NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local")
                    lv73_reindex_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(11008)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(11008)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(T.int64(1376)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("lv73_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(11008), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(lv73[v0, v1, v2])
                                                                    T.writes(lv73_reindex_pad_shared[v0, v1, v2])
                                                                    lv73_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, lv73[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(11008), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_layers_0_mlp_down_proj_q_weight3[v1, v2 // T.int64(8)], model_layers_0_mlp_down_proj_q_scale3[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(11008), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv73_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + lv73_reindex_pad_shared[T.int64(0), v1, v3] * dequantize_intermediate_reindex_shared[T.int64(0), v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul14(model_embed_tokens_q_weight2: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale2: T.Buffer((T.int64(151936), T.int64(64)), "float16"), rms_norm145: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(151936)), scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(151936)), scope="local")
        model_embed_tokens_q_weight2_local = T.alloc_buffer((T.int64(151936), T.int64(256)), "uint32", scope="local")
        rms_norm145_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(37984), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(4), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(4)):
                                        with T.block("rms_norm145_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(2048), ax2_0 * T.int64(512) + ax2_1 * T.int64(128) + ax2_2 * T.int64(4) + ax2_3)
                                            T.reads(rms_norm145[v0, v1, v2])
                                            T.writes(rms_norm145_shared[v0, v1, v2])
                                            rms_norm145_shared[v0, v1, v2] = rms_norm145[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float32(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(1)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                with T.block("model_embed_tokens_q_weight2_local"):
                                    v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(256), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_embed_tokens_q_weight2[v0, v1])
                                    T.writes(model_embed_tokens_q_weight2_local[v0, v1])
                                    model_embed_tokens_q_weight2_local[v0, v1] = model_embed_tokens_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], rms_norm145_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], model_embed_tokens_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], model_embed_tokens_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + T.Cast("float32", rms_norm145_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)]) * T.Cast("float32", vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) * T.float32(0.00048828125)
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float32(0.0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                            v0 = T.axis.spatial(T.int64(151936), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul_intermediate[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float32(0.0)
                            NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul4(model_embed_tokens_q_weight4: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale4: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_rms_norm291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm291 = T.match_buffer(p_rms_norm291, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(151936)))
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(151936)), scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(151936)), scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(1), T.int64(151936)), scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(18992), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float32(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_embed_tokens_q_weight4[v0, v1 // T.int64(8)], model_embed_tokens_q_scale4[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm291[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, rms_norm291[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float32(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float32(0.0)
                                        NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                        T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                        NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(151936)), scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(151936)), scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(151936)), scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(18992), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = T.float32(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_embed_tokens_q_weight4[v0, v1 // T.int64(8)], model_embed_tokens_q_scale4[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1], rms_norm291[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, T.int64(0), v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, rms_norm291[v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = T.float32(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, v0, T.int64(0), v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float32(0.0)
                                            NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, T.int64(0), v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
                                            T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
                                            NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(151936)), scope="local")
                    rms_norm291_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(151936), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(4748), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("rms_norm291_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(rms_norm291[v1, T.int64(0), v2])
                                                                    T.writes(rms_norm291_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm291_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm291[v1, T.int64(0), v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_embed_tokens_q_weight4[v1, v2 // T.int64(8)], model_embed_tokens_q_scale4[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], rms_norm291_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + T.Cast("float32", rms_norm291_reindex_pad_shared[T.int64(0), v1, v3]) * T.Cast("float32", dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, T.int64(0), v2])
                                                        NT_matmul_intermediate[v1, T.int64(0), v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul9(model_embed_tokens_q_weight3: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale3: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_take1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        take1 = T.match_buffer(p_take1, (T.int64(1), batch_size, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), batch_size, T.int64(151936)))
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= T.int64(2)):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(151936)), scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(151936)), scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (batch_size + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(151936)), scope="local")
                for ax0_0 in T.thread_binding((batch_size + T.int64(1)) // T.int64(2), thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(T.int64(18992), thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(2), T.int64(2)):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1_init)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float32(0.0)
                                for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                        for ax0_1 in T.vectorized(T.int64(1)):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                T.reads(model_embed_tokens_q_weight3[v0, v1 // T.int64(8)], model_embed_tokens_q_scale3[v0, v1 // T.int64(32)])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v0, v1 // T.int64(32)]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(2), T.int64(2), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax0_1)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], take1[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, take1[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(T.int64(2)):
                                        for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float32(0.0)
                                            for ax1 in range(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax2)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                        for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(2)):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                        v0 = T.axis.spatial((batch_size + T.int64(1)) // T.int64(2) * T.int64(2), ax0_0 * T.int64(2) + ax1)
                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float32(0.0)
                                        NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                        for ax0 in range(T.int64(2)):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax1_fused_2 in range(T.int64(2)):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(2) + ax0)
                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + T.int64(1)) // T.int64(2) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(2) + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                        T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                        NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= T.int64(8)):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(151936)), scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(151936)), scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(151936)), scope="local")
                    for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(T.int64(18992), thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(T.int64(4), T.int64(2)):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
                                                v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = T.float32(0.0)
                                    for ax2_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
                                            for ax0_1 in T.vectorized(T.int64(1)):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(2048), ax2_fused_0 * T.int64(256) + ax2_fused_1_ax2_fused_3_fused_0 * T.int64(8) + ax1)
                                                    T.reads(model_embed_tokens_q_weight3[v0, v1 // T.int64(8)], model_embed_tokens_q_scale3[v0, v1 // T.int64(32)])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v0, v1 // T.int64(32)]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_1 * T.int64(2) + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1], take1[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, T.int64(0), v0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, take1[T.int64(0), v0, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * T.int64(256) + vax2_fused_1_ax2_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_2 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused % T.int64(4)])
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(T.int64(4)):
                                            for ax3_fused_2_1 in T.vectorized(T.int64(2)):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
                                                    v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                    v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = T.float32(0.0)
                                                for ax1 in range(T.int64(4)):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
                                                        v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * T.int64(4) + vax2_fused_1_ax2_fused_3_fused_1, T.int64(0), v0, v1]
                            for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
                                            v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = T.float32(0.0)
                                            NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, T.int64(0), v0, v1]
                            for ax0 in range(T.int64(4)):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                    for ax1_fused_2 in range(T.int64(2)):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
                                            v1 = T.axis.spatial(T.int64(151936), ax1_fused_0 * T.int64(8) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[T.int64(0), v0, v1])
                                            T.writes(NT_matmul_intermediate[T.int64(0), v0, v1])
                                            NT_matmul_intermediate[T.int64(0), v0, v1] = NT_matmul_intermediate_pad_local[T.int64(0), v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(151936)), scope="local")
                    take1_reindex_pad_shared = T.alloc_buffer((T.int64(1), (batch_size + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(151936), T.int64(2048)), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(T.int64(4748), thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
                                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(4)):
                                                for ax2_3_1_init in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                                        v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0.0)
                                            for ax3_0 in range(T.int64(256)):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("take1_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(take1[v0, v1, v2])
                                                                    T.writes(take1_reindex_pad_shared[v0, v1, v2])
                                                                    take1_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, take1[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(1)):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                                    v1 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // T.int64(8))
                                                                    v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(8) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % T.int64(8))
                                                                    T.reads(model_embed_tokens_q_weight3[v1, v2 // T.int64(8)], model_embed_tokens_q_scale3[v1, v2 // T.int64(32)])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v1, v2 // T.int64(8)], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v1, v2 // T.int64(32)]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(8), T.int64(4), T.int64(4)):
                                                    for ax2_3_1 in T.vectorized(T.int64(1)):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                            v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                            v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_1 * T.int64(32) + ax2_2 * T.int64(4) + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(8) + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], take1_reindex_pad_shared[T.int64(0), v1, v3], dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + T.Cast("float32", take1_reindex_pad_shared[T.int64(0), v1, v3]) * T.Cast("float32", dequantize_intermediate_reindex_shared[T.int64(0), v2, v3])
                                            for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(4)):
                                                for ax2_1_1 in T.vectorized(T.int64(1)):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(T.int64(1), ax0)
                                                        v1 = T.axis.spatial((batch_size + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
                                                        v2 = T.axis.spatial(T.int64(151936), ax0_ax2_0_fused * T.int64(32) + ax2_2 * T.int64(4) + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[T.int64(0), v1, v2])
                                                        NT_matmul_intermediate[T.int64(0), v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_dequantize_take1(model_embed_tokens_q_weight: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale: T.Buffer((151936, 64), "float16"), p_input_ids: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_ids = T.match_buffer(p_input_ids, (seq_len,), "int32")
        T_take_intermediate = T.match_buffer(p_output0, (seq_len, 2048), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(seq_len * 2, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2048)
                    v1 = T.axis.spatial(2048, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2048)
                    T.reads(input_ids[v0], model_embed_tokens_q_weight[input_ids[v0], v1 // 8], model_embed_tokens_q_scale[input_ids[v0], v1 // 32])
                    T.writes(T_take_intermediate[v0, v1])
                    T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[input_ids[v0], v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale[input_ids[v0], v1 // 32]

    @T.prim_func(private=True)
    def fused_reshape10_reshape11(lv184: T.Buffer((T.int64(1), T.int64(16), T.int64(128)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape_1"):
                    v0 = T.axis.spatial(T.int64(2048), ax0_fused_0 * T.int64(1024) + ax0_fused_1)
                    T.reads(lv184[T.int64(0), v0 // T.int64(128), v0 % T.int64(128)])
                    T.writes(T_reshape_intermediate_1[T.int64(0), T.int64(0), v0])
                    T_reshape_intermediate_1[T.int64(0), T.int64(0), v0] = lv184[T.int64(0), v0 // T.int64(128), v0 % T.int64(128)]

    @T.prim_func(private=True)
    def fused_reshape8_reshape9(add108: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(20), T.int64(128)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape_1"):
                    v0 = T.axis.spatial(T.int64(20), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(128))
                    v1 = T.axis.spatial(T.int64(128), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(128))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
                    T.reads(add108[T.int64(0), T.int64(0), v0 * T.int64(128) + v1])
                    T.writes(T_reshape_intermediate_1[T.int64(0), v0, v1])
                    T_reshape_intermediate_1[T.int64(0), v0, v1] = add108[T.int64(0), T.int64(0), v0 * T.int64(128) + v1]

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 20, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 2, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("llama_fused_rope"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < seq_len * 2560)
                    T.reads(position_map[v0], qkv[v0, v1, v2 + -64:v2 + -64 + 129])
                    T.writes(q[v0, v1, v2], k[v0, v1 + -16, v2], v[v0, v1 + -18, v2])
                    if v1 < 16:
                        freq = T.float32()
                        q[v0, v1, v2] = T.if_then_else(apply_rope > 0 and v2 < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[v0, v1, v2]) + T.sin(freq) * T.Cast("float32", T.if_then_else(v2 < 64, qkv[v0, v1, v2 + 64] * T.float16(-1.0), qkv[v0, v1, v2 + -64]))), where={freq: T.Cast("float32", position_map[v0]) / T.pow(T.float32(1000000.0), T.Cast("float32", v2 * 2 % 128) / T.float32(128.0))}), qkv[v0, v1, v2])
                    else:
                        if v1 < 18:
                            freq = T.float32()
                            k[v0, v1 + -16, v2] = T.if_then_else(apply_rope > 0 and v2 < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[v0, v1, v2]) + T.sin(freq) * T.Cast("float32", T.if_then_else(v2 < 64, qkv[v0, v1, v2 + 64] * T.float16(-1.0), qkv[v0, v1, v2 + -64]))), where={freq: T.Cast("float32", position_map[v0]) / T.pow(T.float32(1000000.0), T.Cast("float32", v2 * 2 % 128) / T.float32(128.0))}), qkv[v0, v1, v2])
                        else:
                            v[v0, v1 + -18, v2] = qkv[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_split1_silu1_multiply1(p_lv147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv147 = T.match_buffer(p_lv147, (T.int64(1), seq_len, T.int64(22016)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(11008)), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(11008) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(11008))
                    v1 = T.axis.spatial(T.int64(11008), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(11008))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(11008))
                    T.reads(lv147[T.int64(0), v0, v1:v1 + T.int64(11009)])
                    T.writes(T_multiply_intermediate_1[T.int64(0), v0, v1])
                    T_multiply_intermediate_1[T.int64(0), v0, v1] = lv147[T.int64(0), v0, v1] * T.sigmoid(lv147[T.int64(0), v0, v1]) * lv147[T.int64(0), v0, v1 + T.int64(11008)]

    @T.prim_func(private=True)
    def fused_split2_silu2_multiply2(lv437: T.Buffer((T.int64(1), T.int64(1), T.int64(22016)), "float16"), T_multiply_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(T.int64(11), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(T.int64(11008), ax0_fused_0 * T.int64(1024) + ax0_fused_1)
                    T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(11008))
                    T.reads(lv437[T.int64(0), T.int64(0), v0:v0 + T.int64(11009)])
                    T.writes(T_multiply_intermediate_1[T.int64(0), T.int64(0), v0])
                    T_multiply_intermediate_1[T.int64(0), T.int64(0), v0] = lv437[T.int64(0), T.int64(0), v0] * T.sigmoid(lv437[T.int64(0), T.int64(0), v0]) * lv437[T.int64(0), T.int64(0), v0 + T.int64(11008)]

    @T.prim_func(private=True)
    def fused_split_silu_multiply(p_lv2: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv2 = T.match_buffer(p_lv2, (batch_size, T.int64(1), T.int64(22016)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(11008)), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(11008) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(11008))
                    v1 = T.axis.spatial(T.int64(11008), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(11008))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(11008))
                    T.reads(lv2[v0, T.int64(0), v1:v1 + T.int64(11009)])
                    T.writes(T_multiply_intermediate_1[v0, T.int64(0), v1])
                    T_multiply_intermediate_1[v0, T.int64(0), v1] = lv2[v0, T.int64(0), v1] * T.sigmoid(lv2[v0, T.int64(0), v1]) * lv2[v0, T.int64(0), v1 + T.int64(11008)]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("gather_2d"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n)
                    v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n)
                    T.reads(src[indices[v0], v1], indices[v0])
                    T.writes(dst[v0, v1])
                    dst[v0, v1] = src[indices[v0], v1]

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((out_batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_get_index_from_sorted"):
                    v0 = T.axis.spatial(out_batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * out_batch) // vocab_size)
                    v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size)
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < out_batch * vocab_size)
                    T.reads(usample[v0, T.int64(0)], cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1):v1 - T.int64(1) + T.int64(2)], sample_indices[v0, T.int64(0)], renorm_prob[sample_indices[v0, T.int64(0)], 0], indices[sample_indices[v0, T.int64(0)], T.min(T.int64(0), v1):T.min(T.int64(0), v1) + (v1 + T.int64(1))])
                    T.writes(output_index[v0, 0])
                    if usample[v0, T.int64(0)] < cumsum_sorted[sample_indices[v0, T.int64(0)], v1] / renorm_prob[sample_indices[v0, T.int64(0)], 0] or v1 + T.int64(1) == vocab_size:
                        if v1 == T.int64(0):
                            output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], 0]
                        else:
                            if usample[v0, T.int64(0)] >= cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1)] / renorm_prob[sample_indices[v0, T.int64(0)], 0]:
                                output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], v1]

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_get_renorm_prob"):
                    v0 = T.axis.spatial(batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * batch) // vocab_size)
                    v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size)
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch * vocab_size)
                    T.reads(cumsum_sorted[v0, T.min(T.min(T.int64(0), v1), v1 + T.int64(1)):T.min(T.min(T.int64(0), v1), v1 + T.int64(1)) + (v1 + T.int64(2))], top_p[v0, 0], top_k[v0, 0])
                    T.writes(renorm_prob[v0, 0])
                    if not (cumsum_sorted[v0, 0] < top_p[v0, 0] and top_k[v0, 0] > 1):
                        renorm_prob[v0, 0] = cumsum_sorted[v0, 0]
                    else:
                        if cumsum_sorted[v0, v1] < top_p[v0, 0] and v1 + T.int64(1) < T.Cast("int64", top_k[v0, 0]):
                            if v1 + T.int64(1) == vocab_size:
                                renorm_prob[v0, 0] = cumsum_sorted[v0, v1]
                            else:
                                if not (cumsum_sorted[v0, v1 + T.int64(1)] < top_p[v0, 0] and v1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v0, 0])):
                                    renorm_prob[v0, 0] = cumsum_sorted[v0, v1 + T.int64(1)]

    @T.prim_func(private=True)
    def gpu_2d_continuous_cumsum(var_a: T.handle, var_out: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        m, n = T.int64(), T.int64()
        A = T.match_buffer(var_a, (m, n))
        Out = T.match_buffer(var_out, (m, n))
        # with T.block("root"):
        Tmp = T.alloc_buffer((m, n))
        ceil_log2: T.int64 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
        total_rounds: T.int64 = ceil_log2 // T.int64(9)
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                with T.block(""):
                    T.reads(A[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)])
                    T.writes(Out[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)], Tmp[by, bx])
                    local_buf = T.alloc_buffer((4,), scope="local")
                    shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                            for i in T.vectorized(T.int64(4)):
                                local_buf[i] = T.if_then_else(tx_idx + i < n, T.Cast("float32", A[by, tx_idx + i]), T.Cast("float32", 0))
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                local_buf[i] = local_buf[i] + local_buf[i - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                shared_buf[ty * T.int64(128) + tx * T.int64(4) + i] = local_buf[i]
                            for i in T.unroll(T.int64(5)):
                                for j in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                    if tx >= T.shift_left(T.int64(1), i):
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                for j in T.vectorized(T.int64(4)):
                                    if ty == T.int64(0):
                                        idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i * T.int64(128) - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i
                                if bx * T.int64(512) + idx < n:
                                    Out[by, bx * T.int64(512) + idx] = shared_buf[idx]
                            if tx == T.int64(0) and ty == T.int64(0):
                                for i in T.vectorized(T.int64(4)):
                                    Tmp[by, bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds):
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    with T.block(""):
                        T.reads(Tmp[by, bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)):bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(512)])
                        T.writes(Tmp[by, T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx):T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + (T.max(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(511), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + T.int64(1) - T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx))])
                        local_buf = T.alloc_buffer((4,), scope="local")
                        shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                        for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                                for i_1 in T.vectorized(T.int64(4)):
                                    local_buf[i_1] = T.if_then_else(tx_idx + i_1 < cur_len, T.Cast("float32", Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + tx_idx + i_1]), T.Cast("float32", 0))
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    local_buf[i_1] = local_buf[i_1] + local_buf[i_1 - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    shared_buf[ty * T.int64(128) + tx * T.int64(4) + i_1] = local_buf[i_1]
                                for i_1 in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i_1):
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i_1) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i_1 * T.int64(128) + tx * T.int64(4)
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i_1 * T.int64(128) - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i_1
                                    if bx * T.int64(512) + idx < cur_len:
                                        Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx * T.int64(512) + idx] = shared_buf[idx]
                                if tx == T.int64(0) and ty == T.int64(0):
                                    for i_1 in T.vectorized(T.int64(4)):
                                        Tmp[by, (i + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds - T.int64(1)):
            real_idx: T.int64 = total_rounds - T.int64(1) - i - T.int64(1)
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            for i_1 in range(T.int64(4)):
                                idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i_1 * T.int64(32) + tx
                                if idx < cur_len:
                                    Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] = Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] + T.if_then_else(bx > T.int64(0), Tmp[by, (real_idx + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx - T.int64(1)], T.float32(0.0))
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        for i in range(T.int64(4)):
                            idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i * T.int64(32) + tx
                            if idx < n:
                                Out[by, idx] = Out[by, idx] + T.if_then_else(bx > T.int64(0), Tmp[by, bx - T.int64(1)], T.float32(0.0))

    @T.prim_func(private=True)
    def index(var_rms_norm72: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm72 = T.match_buffer(var_rms_norm72, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("index"):
                    v0 = T.axis.spatial(T.int64(2048), ax0_fused_0 * T.int64(1024) + ax0_fused_1)
                    T.reads(rms_norm72[T.int64(0), seq_len - T.int64(1), v0])
                    T.writes(index[T.int64(0), T.int64(0), v0])
                    index[T.int64(0), T.int64(0), v0] = rms_norm72[T.int64(0), seq_len - T.int64(1), v0]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        n, vocab_size = T.int64(), T.int64()
        prob = T.match_buffer(var_prob, (n, vocab_size))
        batch_size = T.int64()
        uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1))
        row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32")
        token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32")
        # with T.block("root"):
        aggregate = T.alloc_buffer((), scope="local")
        sample_id_local = T.alloc_buffer((), "int32", scope="local")
        step_iter = T.alloc_buffer((), "int32", scope="local")
        for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            row_idx: T.int32 = row_indices[bx, 0]
            for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    u: T.float32 = uniform_samples[bx, 0]
                    aggregate[()] = T.Cast("float32", 0)
                    step_iter[()] = 0
                    while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
                        with T.block(""):
                            T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()])
                            T.writes(sample_id_local[()], aggregate[()])
                            prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local")
                            cumsum = T.alloc_buffer((T.int64(512),), scope="shared")
                            greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            mask = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            valid = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            indices = T.alloc_buffer((T.int64(4),), "int32", scope="local")
                            step_aggregate = T.alloc_buffer((), scope="local")
                            for v in T.unroll(T.int64(4)):
                                idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v
                                prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0))
                                prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0.0), prob_local, T.Cast("float32", 0))
                                valid[v] = prob_local > T.float32(0.0) and idx < vocab_size
                            with T.block(""):
                                T.reads(prob_gt_threshold[T.int64(0):T.int64(4)])
                                T.writes(step_aggregate[()])
                                local_sum = T.alloc_buffer((), scope="local")
                                shared_buf = T.alloc_buffer((T.int64(128),), scope="shared")
                                idx: T.int64 = ty * T.int64(32) + tx
                                local_sum[()] = T.Cast("float32", 0)
                                for i in T.unroll(T.int64(4)):
                                    local_sum[()] = local_sum[()] + prob_gt_threshold[i]
                                shared_buf[idx] = local_sum[()]
                                for i in T.unroll(T.int64(7)):
                                    if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                        shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)]
                                step_aggregate[()] = shared_buf[0]
                            if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)):
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)]
                                for i in T.vectorized(T.int64(4)):
                                    cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i]
                                for i in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i):
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)]
                                for v in T.unroll(T.int64(4)):
                                    greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07)
                                with T.block(""):
                                    T.reads(greater_than_u[T.int64(0):T.int64(4)])
                                    T.writes(mask[T.int64(0):T.int64(4)])
                                    shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared")
                                    tx_idx: T.int64 = ty * T.int64(32) + tx
                                    shared_buf[tx_idx] = greater_than_u[T.int64(3)]
                                    mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0])
                                    for i in T.unroll(T.int64(1), T.int64(4)):
                                        mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)])
                                for v in T.unroll(T.int64(4)):
                                    mask[v] = mask[v] and valid[v]
                                    indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v)
                                with T.block(""):
                                    T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)])
                                    T.writes(sample_id_local[()])
                                    local_sum = T.alloc_buffer((), "int32", scope="local")
                                    shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared")
                                    idx: T.int64 = ty * T.int64(32) + tx
                                    local_sum[()] = T.Cast("int32", vocab_size - T.int64(1))
                                    for i in T.unroll(T.int64(4)):
                                        if mask[i]:
                                            local_sum[()] = T.min(local_sum[()], indices[i])
                                    shared_buf[idx] = local_sum[()]
                                    for i in T.unroll(T.int64(7)):
                                        if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                            shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)])
                                    sample_id_local[()] = shared_buf[0]
                            aggregate[()] = aggregate[()] + step_aggregate[()]
                        step_iter[()] = step_iter[()] + 1
                    if tx == T.int64(0) and ty == T.int64(0):
                        token_ids[bx, 0] = sample_id_local[()]

    @T.prim_func(private=True)
    def reshape(var_add324: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        add324 = T.match_buffer(var_add324, (batch_size, T.int64(1), T.int64(2560)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(2560) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2560))
                    v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2560) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(2560))
                    T.reads(add324[v0, T.int64(0), v1 * T.int64(128) + v2])
                    T.writes(T_reshape[v0, T.int64(0), v1, v2])
                    T_reshape[v0, T.int64(0), v1, v2] = add324[v0, T.int64(0), v1 * T.int64(128) + v2]

    @T.prim_func(private=True)
    def reshape1(var_reshape432: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape432 = T.match_buffer(var_reshape432, (batch_size, T.int64(1), T.int64(20), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(2560) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2560))
                    v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2560) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(2560))
                    T.reads(reshape432[v0, T.int64(0), v1, v2])
                    T.writes(T_reshape[v0, v1, v2])
                    T_reshape[v0, v1, v2] = reshape432[v0, T.int64(0), v1, v2]

    @T.prim_func(private=True)
    def reshape2(var_lv546: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv546 = T.match_buffer(var_lv546, (batch_size, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2048))
                    v1 = T.axis.spatial(T.int64(16), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2048) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.reads(lv546[v0, v1, v2])
                    T.writes(T_reshape[v0, T.int64(0), v1, v2])
                    T_reshape[v0, T.int64(0), v1, v2] = lv546[v0, v1, v2]

    @T.prim_func(private=True)
    def reshape3(var_reshape434: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape434 = T.match_buffer(var_reshape434, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(batch_size * T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(2048))
                    v1 = T.axis.spatial(T.int64(2048), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(2048))
                    T.reads(reshape434[v0, T.int64(0), v1 // T.int64(128), v1 % T.int64(128)])
                    T.writes(T_reshape[v0, T.int64(0), v1])
                    T_reshape[v0, T.int64(0), v1] = reshape434[v0, T.int64(0), v1 // T.int64(128), v1 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape4(var_add216: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        add216 = T.match_buffer(var_add216, (T.int64(1), seq_len, T.int64(2560)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(2560) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2560))
                    v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2560) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(2560))
                    T.reads(add216[T.int64(0), v0, v1 * T.int64(128) + v2])
                    T.writes(T_reshape[T.int64(0), v0, v1, v2])
                    T_reshape[T.int64(0), v0, v1, v2] = add216[T.int64(0), v0, v1 * T.int64(128) + v2]

    @T.prim_func(private=True)
    def reshape5(var_reshape288: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape288 = T.match_buffer(var_reshape288, (T.int64(1), seq_len, T.int64(20), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(2560) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2560))
                    v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2560) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(2560))
                    T.reads(reshape288[T.int64(0), v0, v1, v2])
                    T.writes(T_reshape[v0, v1, v2])
                    T_reshape[v0, v1, v2] = reshape288[T.int64(0), v0, v1, v2]

    @T.prim_func(private=True)
    def reshape6(var_lv365: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv365 = T.match_buffer(var_lv365, (seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(seq_len * T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(2048))
                    v1 = T.axis.spatial(T.int64(16), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(2048) // T.int64(128))
                    v2 = T.axis.spatial(T.int64(128), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(128))
                    T.reads(lv365[v0, v1, v2])
                    T.writes(T_reshape[T.int64(0), v0, v1, v2])
                    T_reshape[T.int64(0), v0, v1, v2] = lv365[v0, v1, v2]

    @T.prim_func(private=True)
    def reshape7(var_reshape290: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape290 = T.match_buffer(var_reshape290, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(seq_len * T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(2048))
                    v1 = T.axis.spatial(T.int64(2048), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(2048))
                    T.reads(reshape290[T.int64(0), v0, v1 // T.int64(128), v1 % T.int64(128)])
                    T.writes(T_reshape[T.int64(0), v0, v1])
                    T_reshape[T.int64(0), v0, v1] = reshape290[T.int64(0), v0, v1 // T.int64(128), v1 % T.int64(128)]

    @T.prim_func(private=True)
    def rms_norm(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight4: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((batch_size, T.int64(1)), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((T.int64(64), batch_size, T.int64(1)), scope="local")
        for ax0_fused in T.thread_binding(batch_size, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)])
                    T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(T.int64(32), 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)], input_embeds[v0, T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)])
                        T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)] = T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)] + T.Cast("float32", input_embeds[v0, T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1]) * T.Cast("float32", input_embeds[v0, T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1])
            for ax1_fused in range(T.int64(1)):
                for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)])
                        T.writes(T_multiply_red_shared[v0, T.int64(0)])
                        with T.init():
                            T_multiply_red_shared[v0, T.int64(0)] = T.float32(0.0)
                        T_multiply_red_shared[v0, T.int64(0)] = T_multiply_red_shared[v0, T.int64(0)] + T_multiply_red_rf_local[vax1_fused_1, v0, T.int64(0)]
            for ax0_ax1_fused_0 in range(T.int64(32)):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(batch_size, ax0_fused)
                        v1 = T.axis.spatial(T.int64(2048), ax0_ax1_fused_0 * T.int64(64) + ax0_ax1_fused_1)
                        T.reads(T_multiply_red_shared[v0, T.int64(0)], input_embeds[v0, T.int64(0), v1], model_layers_0_input_layernorm_weight4[v1])
                        T.writes(T_cast[v0, T.int64(0), v1])
                        T_cast[v0, T.int64(0), v1] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[v0, T.int64(0)] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embeds[v0, T.int64(0), v1]) * T.Cast("float32", model_layers_0_input_layernorm_weight4[v1]))

    @T.prim_func(private=True)
    def rms_norm1(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight3: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((T.int64(1), seq_len), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((T.int64(64), T.int64(1), seq_len), scope="local")
        for ax0_fused in T.thread_binding(seq_len, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0])
                    T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(T.int64(32), 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0], input_embeds[T.int64(0), v0, vax1_fused_0 * T.int64(64) + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0])
                        T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0] = T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0] + T.Cast("float32", input_embeds[T.int64(0), v0, vax1_fused_0 * T.int64(64) + vax1_fused_1]) * T.Cast("float32", input_embeds[T.int64(0), v0, vax1_fused_0 * T.int64(64) + vax1_fused_1])
            for ax1_fused in range(T.int64(1)):
                for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(T_multiply_red_shared[T.int64(0), v0])
                        with T.init():
                            T_multiply_red_shared[T.int64(0), v0] = T.float32(0.0)
                        T_multiply_red_shared[T.int64(0), v0] = T_multiply_red_shared[T.int64(0), v0] + T_multiply_red_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_ax1_fused_0 in range(T.int64(32)):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(seq_len, ax0_fused)
                        v1 = T.axis.spatial(T.int64(2048), ax0_ax1_fused_0 * T.int64(64) + ax0_ax1_fused_1)
                        T.reads(T_multiply_red_shared[T.int64(0), v0], input_embeds[T.int64(0), v0, v1], model_layers_0_input_layernorm_weight3[v1])
                        T.writes(T_cast[T.int64(0), v0, v1])
                        T_cast[T.int64(0), v0, v1] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[T.int64(0), v0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embeds[T.int64(0), v0, v1]) * T.Cast("float32", model_layers_0_input_layernorm_weight3[v1]))

    @T.prim_func(private=True)
    def rms_norm2(input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_input_layernorm_weight2: T.Buffer((T.int64(2048),), "float16"), T_cast: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1)), scope="local")
        for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1 = T.axis.spatial(T.int64(64), ax1_fused_1)
                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)])
                    T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(T.int64(32), 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1 = T.axis.spatial(T.int64(64), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                        vax1_fused_0 = T.axis.reduce(T.int64(32), ax1_fused_0)
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)], input_embed[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)])
                        T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] + T.Cast("float32", input_embed[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1]) * T.Cast("float32", input_embed[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(64) + vax1_fused_1])
            for ax1_fused in range(T.int64(1)):
                for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1 = T.axis.reduce(T.int64(64), ax0)
                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)])
                        T.writes(T_multiply_red_shared[T.int64(0), T.int64(0)])
                        with T.init():
                            T_multiply_red_shared[T.int64(0), T.int64(0)] = T.float32(0.0)
                        T_multiply_red_shared[T.int64(0), T.int64(0)] = T_multiply_red_shared[T.int64(0), T.int64(0)] + T_multiply_red_rf_local[vax1_fused_1, T.int64(0), T.int64(0)]
            for ax0_fused_0 in range(T.int64(32)):
                for ax0_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(T.int64(2048), ax0_fused_0 * T.int64(64) + ax0_fused_1)
                        T.reads(T_multiply_red_shared[T.int64(0), T.int64(0)], input_embed[T.int64(0), T.int64(0), v0], model_layers_0_input_layernorm_weight2[v0])
                        T.writes(T_cast[T.int64(0), T.int64(0), v0])
                        T_cast[T.int64(0), T.int64(0), v0] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[T.int64(0), T.int64(0)] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embed[T.int64(0), T.int64(0), v0]) * T.Cast("float32", model_layers_0_input_layernorm_weight2[v0]))

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((num_positions + num_samples + 1023) // 1024, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    v0 = T.axis.spatial(num_positions + num_samples, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < num_positions + num_samples)
                    T.reads(top_prob_offsets[v0], sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], unsorted_probs[T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]):T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]) + (T.max(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions]) + 1 - T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions])), T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]):T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]) + (T.max(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]))], sample_indices[v0 + (0 - num_positions)], sampling_results[v0 + (0 - num_positions)])
                    T.writes(top_prob_indices[v0], top_prob_probs[v0], sampled_values[v0 + (0 - num_positions)])
                    if v0 < num_positions:
                        row: T.int32 = top_prob_offsets[v0] // vocab_size
                        col: T.int32 = top_prob_offsets[v0] % vocab_size
                        top_prob_indices[v0] = sorted_indices[row, col]
                        top_prob_probs[v0] = unsorted_probs[row, sorted_indices[row, col]]
                    else:
                        vj: T.int32 = v0 - num_positions
                        sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("scatter_2d"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n)
                    v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n)
                    T.reads(src[v0, v1], indices[v0])
                    T.writes(dst[indices[v0], v1])
                    dst[indices[v0], v1] = src[v0, v1]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take(var_rms_norm218: T.handle, var_logit_positions: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm218 = T.match_buffer(var_rms_norm218, (T.int64(1), seq_len, T.int64(2048)), "float16")
        batch_size = T.int64()
        logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32")
        T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(batch_size * T.int64(2), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(2048))
                    v1 = T.axis.spatial(T.int64(2048), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(2048))
                    T.reads(rms_norm218[T.int64(0), logit_positions[v0], v1], logit_positions[v0])
                    T.writes(T_take[T.int64(0), v0, v1])
                    T_take[T.int64(0), v0, v1] = rms_norm218[T.int64(0), logit_positions[v0], v1]

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size_1 * vocab_size_1 + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("take_sorted_probs"):
                    v0 = T.axis.spatial(batch_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size_1 * batch_size_1) // vocab_size_1)
                    v1 = T.axis.spatial(vocab_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size_1)
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size_1 * vocab_size_1)
                    T.reads(probs[v0, lv1[v0, v1]], lv1[v0, v1])
                    T.writes(take_sorted_probs[v0, v1])
                    take_sorted_probs[v0, v1] = probs[v0, lv1[v0, v1]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (36, seqlen, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (36, seqlen, 2, 128), "float16")
        # with T.block("root"):
        for p_h_d_fused_0 in T.thread_binding((seqlen * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for p_h_d_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("copy0"):
                    vp = T.axis.spatial(seqlen, (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) // T.int64(256))
                    vh = T.axis.spatial(2, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(256) // T.int64(128)))
                    vd = T.axis.spatial(128, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(128)))
                    T.where(p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1 < seqlen * T.int64(256))
                    T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                    T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                    position: T.int32 = position_map[vp]
                    k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                    v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 2, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos_h_f_fused_0 in T.thread_binding((ntoken * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for global_pos_h_f_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                if position_map[(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(256)] != -1:
                    with T.block("k_transpose_append"):
                        vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(256))
                        vh = T.axis.spatial(2, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(256) // T.int64(128)))
                        vf = T.axis.spatial(128, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(128)))
                        T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(256))
                        T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                        T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                        position: T.int32 = position_map[vgpos]
                        pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                    with T.block("v_transpose_append"):
                        vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(256))
                        vh = T.axis.spatial(2, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(256) // T.int64(128)))
                        vf = T.axis.spatial(128, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(128)))
                        T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(256))
                        T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                        T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                        position: T.int32 = position_map[vgpos]
                        pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv1 = R.call_tir(cls.argsort, (probs,), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32"))
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight4: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale4: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm219 = R.call_tir(cls.rms_norm, (input_embeds, model_layers_0_input_layernorm_weight4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4, rms_norm219, model_layers_0_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape432 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape433 = R.call_tir(cls.reshape1, (reshape432,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv546 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape433), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape434 = R.call_tir(cls.reshape2, (lv546,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape435 = R.call_tir(cls.reshape3, (reshape434,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4, reshape435), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_2 = R.call_tir(cls.fuse_add_norm_decode, (lv_1, input_embeds, model_layers_0_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[1]
            rms_norm220: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[0]
            lv1_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_0_mlp_gate_up_proj_q_weight4, model_layers_0_mlp_gate_up_proj_q_scale4, rms_norm220), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv1_2 = R.call_tir(cls.fused_split_silu_multiply, (lv1_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv2 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_0_mlp_down_proj_q_weight4, model_layers_0_mlp_down_proj_q_scale4, lv1_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv2_1 = R.call_tir(cls.fuse_add_norm_decode, (lv2, lv1, model_layers_1_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv3: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[1]
            rms_norm221: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[0]
            lv1_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4, rms_norm221, model_layers_1_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape436 = R.call_tir(cls.reshape, (lv1_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape437 = R.call_tir(cls.reshape1, (reshape436,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv551 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape437), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape438 = R.call_tir(cls.reshape2, (lv551,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape439 = R.call_tir(cls.reshape3, (reshape438,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv3_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4, reshape439), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv4 = R.call_tir(cls.fuse_add_norm_decode, (lv3_1, lv3, model_layers_1_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv5: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4[1]
            rms_norm222: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4[0]
            lv4_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_1_mlp_gate_up_proj_q_weight4, model_layers_1_mlp_gate_up_proj_q_scale4, rms_norm222), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv3_2 = R.call_tir(cls.fused_split_silu_multiply, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv5_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_1_mlp_down_proj_q_weight4, model_layers_1_mlp_down_proj_q_scale4, lv3_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6 = R.call_tir(cls.fuse_add_norm_decode, (lv5_1, lv5, model_layers_2_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv7: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6[1]
            rms_norm223: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6[0]
            lv2_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4, rms_norm223, model_layers_2_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape440 = R.call_tir(cls.reshape, (lv2_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape441 = R.call_tir(cls.reshape1, (reshape440,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv556 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape441), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape442 = R.call_tir(cls.reshape2, (lv556,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape443 = R.call_tir(cls.reshape3, (reshape442,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4, reshape443), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv8 = R.call_tir(cls.fuse_add_norm_decode, (lv6_1, lv7, model_layers_2_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv9: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[1]
            rms_norm224: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[0]
            lv7_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_2_mlp_gate_up_proj_q_weight4, model_layers_2_mlp_gate_up_proj_q_scale4, rms_norm224), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv5_2 = R.call_tir(cls.fused_split_silu_multiply, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv8_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_2_mlp_down_proj_q_weight4, model_layers_2_mlp_down_proj_q_scale4, lv5_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv10 = R.call_tir(cls.fuse_add_norm_decode, (lv8_1, lv9, model_layers_3_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv11: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10[1]
            rms_norm225: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10[0]
            lv3_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4, rms_norm225, model_layers_3_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape444 = R.call_tir(cls.reshape, (lv3_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape445 = R.call_tir(cls.reshape1, (reshape444,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv561 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape446 = R.call_tir(cls.reshape2, (lv561,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape447 = R.call_tir(cls.reshape3, (reshape446,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv9_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4, reshape447), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12 = R.call_tir(cls.fuse_add_norm_decode, (lv9_1, lv11, model_layers_3_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv13: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12[1]
            rms_norm226: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12[0]
            lv10_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_3_mlp_gate_up_proj_q_weight4, model_layers_3_mlp_gate_up_proj_q_scale4, rms_norm226), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv7_2 = R.call_tir(cls.fused_split_silu_multiply, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv11_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_3_mlp_down_proj_q_weight4, model_layers_3_mlp_down_proj_q_scale4, lv7_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv14 = R.call_tir(cls.fuse_add_norm_decode, (lv11_1, lv13, model_layers_4_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv15: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14[1]
            rms_norm227: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14[0]
            lv4_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4, rms_norm227, model_layers_4_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape448 = R.call_tir(cls.reshape, (lv4_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape449 = R.call_tir(cls.reshape1, (reshape448,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv566 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape449), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape450 = R.call_tir(cls.reshape2, (lv566,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape451 = R.call_tir(cls.reshape3, (reshape450,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4, reshape451), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv16 = R.call_tir(cls.fuse_add_norm_decode, (lv12_1, lv15, model_layers_4_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv17: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16[1]
            rms_norm228: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16[0]
            lv13_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_4_mlp_gate_up_proj_q_weight4, model_layers_4_mlp_gate_up_proj_q_scale4, rms_norm228), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv9_2 = R.call_tir(cls.fused_split_silu_multiply, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv14_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_4_mlp_down_proj_q_weight4, model_layers_4_mlp_down_proj_q_scale4, lv9_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18 = R.call_tir(cls.fuse_add_norm_decode, (lv14_1, lv17, model_layers_5_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv19: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[1]
            rms_norm229: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[0]
            lv5_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4, rms_norm229, model_layers_5_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape452 = R.call_tir(cls.reshape, (lv5_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape453 = R.call_tir(cls.reshape1, (reshape452,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv571 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape453), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape454 = R.call_tir(cls.reshape2, (lv571,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape455 = R.call_tir(cls.reshape3, (reshape454,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv15_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4, reshape455), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv20 = R.call_tir(cls.fuse_add_norm_decode, (lv15_1, lv19, model_layers_5_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv21: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20[1]
            rms_norm230: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20[0]
            lv16_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_5_mlp_gate_up_proj_q_weight4, model_layers_5_mlp_gate_up_proj_q_scale4, rms_norm230), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv11_2 = R.call_tir(cls.fused_split_silu_multiply, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv17_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_5_mlp_down_proj_q_weight4, model_layers_5_mlp_down_proj_q_scale4, lv11_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv22 = R.call_tir(cls.fuse_add_norm_decode, (lv17_1, lv21, model_layers_6_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv23: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22[1]
            rms_norm231: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22[0]
            lv6_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4, rms_norm231, model_layers_6_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape456 = R.call_tir(cls.reshape, (lv6_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape457 = R.call_tir(cls.reshape1, (reshape456,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv576 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape457), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape458 = R.call_tir(cls.reshape2, (lv576,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape459 = R.call_tir(cls.reshape3, (reshape458,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4, reshape459), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24 = R.call_tir(cls.fuse_add_norm_decode, (lv18_1, lv23, model_layers_6_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv25: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24[1]
            rms_norm232: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24[0]
            lv19_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_6_mlp_gate_up_proj_q_weight4, model_layers_6_mlp_gate_up_proj_q_scale4, rms_norm232), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv13_2 = R.call_tir(cls.fused_split_silu_multiply, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv20_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_6_mlp_down_proj_q_weight4, model_layers_6_mlp_down_proj_q_scale4, lv13_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv26 = R.call_tir(cls.fuse_add_norm_decode, (lv20_1, lv25, model_layers_7_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv27: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[1]
            rms_norm233: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[0]
            lv7_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4, rms_norm233, model_layers_7_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape460 = R.call_tir(cls.reshape, (lv7_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape461 = R.call_tir(cls.reshape1, (reshape460,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape461), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape462 = R.call_tir(cls.reshape2, (lv581,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape463 = R.call_tir(cls.reshape3, (reshape462,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv21_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4, reshape463), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv28 = R.call_tir(cls.fuse_add_norm_decode, (lv21_1, lv27, model_layers_7_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv29: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28[1]
            rms_norm234: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28[0]
            lv22_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_7_mlp_gate_up_proj_q_weight4, model_layers_7_mlp_gate_up_proj_q_scale4, rms_norm234), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv15_2 = R.call_tir(cls.fused_split_silu_multiply, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv23_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_7_mlp_down_proj_q_weight4, model_layers_7_mlp_down_proj_q_scale4, lv15_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30 = R.call_tir(cls.fuse_add_norm_decode, (lv23_1, lv29, model_layers_8_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv31: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30[1]
            rms_norm235: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30[0]
            lv8_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4, rms_norm235, model_layers_8_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape464 = R.call_tir(cls.reshape, (lv8_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape465 = R.call_tir(cls.reshape1, (reshape464,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv586 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape465), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape466 = R.call_tir(cls.reshape2, (lv586,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape467 = R.call_tir(cls.reshape3, (reshape466,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4, reshape467), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv32 = R.call_tir(cls.fuse_add_norm_decode, (lv24_1, lv31, model_layers_8_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv33: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32[1]
            rms_norm236: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32[0]
            lv25_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_8_mlp_gate_up_proj_q_weight4, model_layers_8_mlp_gate_up_proj_q_scale4, rms_norm236), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv17_2 = R.call_tir(cls.fused_split_silu_multiply, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv26_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_8_mlp_down_proj_q_weight4, model_layers_8_mlp_down_proj_q_scale4, lv17_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv34 = R.call_tir(cls.fuse_add_norm_decode, (lv26_1, lv33, model_layers_9_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv35: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34[1]
            rms_norm237: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34[0]
            lv9_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4, rms_norm237, model_layers_9_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape468 = R.call_tir(cls.reshape, (lv9_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape469 = R.call_tir(cls.reshape1, (reshape468,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv591 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape469), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape470 = R.call_tir(cls.reshape2, (lv591,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape471 = R.call_tir(cls.reshape3, (reshape470,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv27_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4, reshape471), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36 = R.call_tir(cls.fuse_add_norm_decode, (lv27_1, lv35, model_layers_9_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv37: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[1]
            rms_norm238: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[0]
            lv28_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_9_mlp_gate_up_proj_q_weight4, model_layers_9_mlp_gate_up_proj_q_scale4, rms_norm238), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv19_2 = R.call_tir(cls.fused_split_silu_multiply, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv29_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_9_mlp_down_proj_q_weight4, model_layers_9_mlp_down_proj_q_scale4, lv19_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv38 = R.call_tir(cls.fuse_add_norm_decode, (lv29_1, lv37, model_layers_10_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv39: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38[1]
            rms_norm239: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38[0]
            lv10_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4, rms_norm239, model_layers_10_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape472 = R.call_tir(cls.reshape, (lv10_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape473 = R.call_tir(cls.reshape1, (reshape472,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv596 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape473), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape474 = R.call_tir(cls.reshape2, (lv596,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape475 = R.call_tir(cls.reshape3, (reshape474,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4, reshape475), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv40 = R.call_tir(cls.fuse_add_norm_decode, (lv30_1, lv39, model_layers_10_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv41: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40[1]
            rms_norm240: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40[0]
            lv31_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_10_mlp_gate_up_proj_q_weight4, model_layers_10_mlp_gate_up_proj_q_scale4, rms_norm240), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv21_2 = R.call_tir(cls.fused_split_silu_multiply, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv32_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_10_mlp_down_proj_q_weight4, model_layers_10_mlp_down_proj_q_scale4, lv21_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42 = R.call_tir(cls.fuse_add_norm_decode, (lv32_1, lv41, model_layers_11_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv43: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42[1]
            rms_norm241: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42[0]
            lv11_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4, rms_norm241, model_layers_11_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape476 = R.call_tir(cls.reshape, (lv11_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape477 = R.call_tir(cls.reshape1, (reshape476,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv601 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape477), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape478 = R.call_tir(cls.reshape2, (lv601,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape479 = R.call_tir(cls.reshape3, (reshape478,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv33_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4, reshape479), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv44 = R.call_tir(cls.fuse_add_norm_decode, (lv33_1, lv43, model_layers_11_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv45: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[1]
            rms_norm242: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[0]
            lv34_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_11_mlp_gate_up_proj_q_weight4, model_layers_11_mlp_gate_up_proj_q_scale4, rms_norm242), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv23_2 = R.call_tir(cls.fused_split_silu_multiply, (lv34_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv35_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_11_mlp_down_proj_q_weight4, model_layers_11_mlp_down_proj_q_scale4, lv23_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv46 = R.call_tir(cls.fuse_add_norm_decode, (lv35_1, lv45, model_layers_12_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv47: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46[1]
            rms_norm243: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46[0]
            lv12_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4, rms_norm243, model_layers_12_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape480 = R.call_tir(cls.reshape, (lv12_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape481 = R.call_tir(cls.reshape1, (reshape480,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv606 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape482 = R.call_tir(cls.reshape2, (lv606,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape483 = R.call_tir(cls.reshape3, (reshape482,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4, reshape483), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48 = R.call_tir(cls.fuse_add_norm_decode, (lv36_1, lv47, model_layers_12_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv49: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48[1]
            rms_norm244: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48[0]
            lv37_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_12_mlp_gate_up_proj_q_weight4, model_layers_12_mlp_gate_up_proj_q_scale4, rms_norm244), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv25_2 = R.call_tir(cls.fused_split_silu_multiply, (lv37_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv38_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_12_mlp_down_proj_q_weight4, model_layers_12_mlp_down_proj_q_scale4, lv25_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv50 = R.call_tir(cls.fuse_add_norm_decode, (lv38_1, lv49, model_layers_13_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv51: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50[1]
            rms_norm245: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50[0]
            lv13_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4, rms_norm245, model_layers_13_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape484 = R.call_tir(cls.reshape, (lv13_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape485 = R.call_tir(cls.reshape1, (reshape484,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv611 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape485), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape486 = R.call_tir(cls.reshape2, (lv611,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape487 = R.call_tir(cls.reshape3, (reshape486,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv39_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4, reshape487), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv52 = R.call_tir(cls.fuse_add_norm_decode, (lv39_1, lv51, model_layers_13_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv53: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52[1]
            rms_norm246: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52[0]
            lv40_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_13_mlp_gate_up_proj_q_weight4, model_layers_13_mlp_gate_up_proj_q_scale4, rms_norm246), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv27_2 = R.call_tir(cls.fused_split_silu_multiply, (lv40_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv41_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_13_mlp_down_proj_q_weight4, model_layers_13_mlp_down_proj_q_scale4, lv27_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54 = R.call_tir(cls.fuse_add_norm_decode, (lv41_1, lv53, model_layers_14_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv55: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[1]
            rms_norm247: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[0]
            lv14_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4, rms_norm247, model_layers_14_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape488 = R.call_tir(cls.reshape, (lv14_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape489 = R.call_tir(cls.reshape1, (reshape488,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv616 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape489), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape490 = R.call_tir(cls.reshape2, (lv616,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape491 = R.call_tir(cls.reshape3, (reshape490,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4, reshape491), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv56 = R.call_tir(cls.fuse_add_norm_decode, (lv42_1, lv55, model_layers_14_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv57: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56[1]
            rms_norm248: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56[0]
            lv43_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_14_mlp_gate_up_proj_q_weight4, model_layers_14_mlp_gate_up_proj_q_scale4, rms_norm248), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv29_2 = R.call_tir(cls.fused_split_silu_multiply, (lv43_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv44_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_14_mlp_down_proj_q_weight4, model_layers_14_mlp_down_proj_q_scale4, lv29_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv58 = R.call_tir(cls.fuse_add_norm_decode, (lv44_1, lv57, model_layers_15_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv59: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58[1]
            rms_norm249: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58[0]
            lv15_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4, rms_norm249, model_layers_15_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape492 = R.call_tir(cls.reshape, (lv15_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape493 = R.call_tir(cls.reshape1, (reshape492,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv621 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape493), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape494 = R.call_tir(cls.reshape2, (lv621,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape495 = R.call_tir(cls.reshape3, (reshape494,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv45_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4, reshape495), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60 = R.call_tir(cls.fuse_add_norm_decode, (lv45_1, lv59, model_layers_15_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv61: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60[1]
            rms_norm250: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60[0]
            lv46_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_15_mlp_gate_up_proj_q_weight4, model_layers_15_mlp_gate_up_proj_q_scale4, rms_norm250), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv31_2 = R.call_tir(cls.fused_split_silu_multiply, (lv46_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv47_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_15_mlp_down_proj_q_weight4, model_layers_15_mlp_down_proj_q_scale4, lv31_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv62 = R.call_tir(cls.fuse_add_norm_decode, (lv47_1, lv61, model_layers_16_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv63: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[1]
            rms_norm251: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[0]
            lv16_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4, rms_norm251, model_layers_16_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape496 = R.call_tir(cls.reshape, (lv16_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape497 = R.call_tir(cls.reshape1, (reshape496,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv626 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape497), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape498 = R.call_tir(cls.reshape2, (lv626,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape499 = R.call_tir(cls.reshape3, (reshape498,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4, reshape499), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv64 = R.call_tir(cls.fuse_add_norm_decode, (lv48_1, lv63, model_layers_16_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv65: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64[1]
            rms_norm252: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64[0]
            lv49_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_16_mlp_gate_up_proj_q_weight4, model_layers_16_mlp_gate_up_proj_q_scale4, rms_norm252), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv33_2 = R.call_tir(cls.fused_split_silu_multiply, (lv49_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv50_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_16_mlp_down_proj_q_weight4, model_layers_16_mlp_down_proj_q_scale4, lv33_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66 = R.call_tir(cls.fuse_add_norm_decode, (lv50_1, lv65, model_layers_17_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv67: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66[1]
            rms_norm253: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66[0]
            lv17_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4, rms_norm253, model_layers_17_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape500 = R.call_tir(cls.reshape, (lv17_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape501 = R.call_tir(cls.reshape1, (reshape500,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv631 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape501), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape502 = R.call_tir(cls.reshape2, (lv631,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape503 = R.call_tir(cls.reshape3, (reshape502,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv51_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4, reshape503), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv68 = R.call_tir(cls.fuse_add_norm_decode, (lv51_1, lv67, model_layers_17_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv69: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68[1]
            rms_norm254: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68[0]
            lv52_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_17_mlp_gate_up_proj_q_weight4, model_layers_17_mlp_gate_up_proj_q_scale4, rms_norm254), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv35_2 = R.call_tir(cls.fused_split_silu_multiply, (lv52_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv53_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_17_mlp_down_proj_q_weight4, model_layers_17_mlp_down_proj_q_scale4, lv35_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv70 = R.call_tir(cls.fuse_add_norm_decode, (lv53_1, lv69, model_layers_18_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv71: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70[1]
            rms_norm255: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70[0]
            lv18_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4, rms_norm255, model_layers_18_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape504 = R.call_tir(cls.reshape, (lv18_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape505 = R.call_tir(cls.reshape1, (reshape504,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv636 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape505), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape506 = R.call_tir(cls.reshape2, (lv636,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape507 = R.call_tir(cls.reshape3, (reshape506,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4, reshape507), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.fuse_add_norm_decode, (lv54_1, lv71, model_layers_18_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv73: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[1]
            rms_norm256: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[0]
            lv55_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_18_mlp_gate_up_proj_q_weight4, model_layers_18_mlp_gate_up_proj_q_scale4, rms_norm256), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv37_2 = R.call_tir(cls.fused_split_silu_multiply, (lv55_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv56_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_18_mlp_down_proj_q_weight4, model_layers_18_mlp_down_proj_q_scale4, lv37_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv74 = R.call_tir(cls.fuse_add_norm_decode, (lv56_1, lv73, model_layers_19_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv75: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74[1]
            rms_norm257: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74[0]
            lv19_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4, rms_norm257, model_layers_19_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape508 = R.call_tir(cls.reshape, (lv19_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape509 = R.call_tir(cls.reshape1, (reshape508,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape509), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape510 = R.call_tir(cls.reshape2, (lv641,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape511 = R.call_tir(cls.reshape3, (reshape510,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv57_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4, reshape511), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv76 = R.call_tir(cls.fuse_add_norm_decode, (lv57_1, lv75, model_layers_19_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv77: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76[1]
            rms_norm258: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76[0]
            lv58_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_19_mlp_gate_up_proj_q_weight4, model_layers_19_mlp_gate_up_proj_q_scale4, rms_norm258), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv39_2 = R.call_tir(cls.fused_split_silu_multiply, (lv58_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv59_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_19_mlp_down_proj_q_weight4, model_layers_19_mlp_down_proj_q_scale4, lv39_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78 = R.call_tir(cls.fuse_add_norm_decode, (lv59_1, lv77, model_layers_20_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv79: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78[1]
            rms_norm259: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78[0]
            lv20_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4, rms_norm259, model_layers_20_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape512 = R.call_tir(cls.reshape, (lv20_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape513 = R.call_tir(cls.reshape1, (reshape512,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv646 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape513), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape514 = R.call_tir(cls.reshape2, (lv646,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape515 = R.call_tir(cls.reshape3, (reshape514,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_20_self_attn_o_proj_q_weight4, model_layers_20_self_attn_o_proj_q_scale4, reshape515), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv80 = R.call_tir(cls.fuse_add_norm_decode, (lv60_1, lv79, model_layers_20_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv81: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[1]
            rms_norm260: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[0]
            lv61_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_20_mlp_gate_up_proj_q_weight4, model_layers_20_mlp_gate_up_proj_q_scale4, rms_norm260), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv41_2 = R.call_tir(cls.fused_split_silu_multiply, (lv61_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv62_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_20_mlp_down_proj_q_weight4, model_layers_20_mlp_down_proj_q_scale4, lv41_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv82 = R.call_tir(cls.fuse_add_norm_decode, (lv62_1, lv81, model_layers_21_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv83: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82[1]
            rms_norm261: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82[0]
            lv21_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_21_self_attn_c_attn_q_weight4, model_layers_21_self_attn_c_attn_q_scale4, rms_norm261, model_layers_21_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape516 = R.call_tir(cls.reshape, (lv21_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape517 = R.call_tir(cls.reshape1, (reshape516,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv651 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape518 = R.call_tir(cls.reshape2, (lv651,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape519 = R.call_tir(cls.reshape3, (reshape518,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv63_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_21_self_attn_o_proj_q_weight4, model_layers_21_self_attn_o_proj_q_scale4, reshape519), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84 = R.call_tir(cls.fuse_add_norm_decode, (lv63_1, lv83, model_layers_21_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv85: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84[1]
            rms_norm262: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84[0]
            lv64_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_21_mlp_gate_up_proj_q_weight4, model_layers_21_mlp_gate_up_proj_q_scale4, rms_norm262), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv43_2 = R.call_tir(cls.fused_split_silu_multiply, (lv64_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv65_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_21_mlp_down_proj_q_weight4, model_layers_21_mlp_down_proj_q_scale4, lv43_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv86 = R.call_tir(cls.fuse_add_norm_decode, (lv65_1, lv85, model_layers_22_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv87: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86[1]
            rms_norm263: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86[0]
            lv22_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_22_self_attn_c_attn_q_weight4, model_layers_22_self_attn_c_attn_q_scale4, rms_norm263, model_layers_22_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape520 = R.call_tir(cls.reshape, (lv22_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape521 = R.call_tir(cls.reshape1, (reshape520,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv656 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape521), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape522 = R.call_tir(cls.reshape2, (lv656,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape523 = R.call_tir(cls.reshape3, (reshape522,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_22_self_attn_o_proj_q_weight4, model_layers_22_self_attn_o_proj_q_scale4, reshape523), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv88 = R.call_tir(cls.fuse_add_norm_decode, (lv66_1, lv87, model_layers_22_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv89: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88[1]
            rms_norm264: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88[0]
            lv67_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_22_mlp_gate_up_proj_q_weight4, model_layers_22_mlp_gate_up_proj_q_scale4, rms_norm264), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv45_2 = R.call_tir(cls.fused_split_silu_multiply, (lv67_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv68_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_22_mlp_down_proj_q_weight4, model_layers_22_mlp_down_proj_q_scale4, lv45_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90 = R.call_tir(cls.fuse_add_norm_decode, (lv68_1, lv89, model_layers_23_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv91: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[1]
            rms_norm265: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[0]
            lv23_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_23_self_attn_c_attn_q_weight4, model_layers_23_self_attn_c_attn_q_scale4, rms_norm265, model_layers_23_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape524 = R.call_tir(cls.reshape, (lv23_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape525 = R.call_tir(cls.reshape1, (reshape524,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv661 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape525), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape526 = R.call_tir(cls.reshape2, (lv661,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape527 = R.call_tir(cls.reshape3, (reshape526,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv69_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_23_self_attn_o_proj_q_weight4, model_layers_23_self_attn_o_proj_q_scale4, reshape527), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv92 = R.call_tir(cls.fuse_add_norm_decode, (lv69_1, lv91, model_layers_23_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv93: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92[1]
            rms_norm266: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92[0]
            lv70_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_23_mlp_gate_up_proj_q_weight4, model_layers_23_mlp_gate_up_proj_q_scale4, rms_norm266), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv47_2 = R.call_tir(cls.fused_split_silu_multiply, (lv70_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv71_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_23_mlp_down_proj_q_weight4, model_layers_23_mlp_down_proj_q_scale4, lv47_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv94 = R.call_tir(cls.fuse_add_norm_decode, (lv71_1, lv93, model_layers_24_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv95: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94[1]
            rms_norm267: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94[0]
            lv24_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_24_self_attn_c_attn_q_weight4, model_layers_24_self_attn_c_attn_q_scale4, rms_norm267, model_layers_24_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape528 = R.call_tir(cls.reshape, (lv24_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape529 = R.call_tir(cls.reshape1, (reshape528,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv666 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape529), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape530 = R.call_tir(cls.reshape2, (lv666,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape531 = R.call_tir(cls.reshape3, (reshape530,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_24_self_attn_o_proj_q_weight4, model_layers_24_self_attn_o_proj_q_scale4, reshape531), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv96 = R.call_tir(cls.fuse_add_norm_decode, (lv72_1, lv95, model_layers_24_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv97: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96[1]
            rms_norm268: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96[0]
            lv73_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_24_mlp_gate_up_proj_q_weight4, model_layers_24_mlp_gate_up_proj_q_scale4, rms_norm268), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv49_2 = R.call_tir(cls.fused_split_silu_multiply, (lv73_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv74_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_24_mlp_down_proj_q_weight4, model_layers_24_mlp_down_proj_q_scale4, lv49_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv98 = R.call_tir(cls.fuse_add_norm_decode, (lv74_1, lv97, model_layers_25_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv99: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98[1]
            rms_norm269: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98[0]
            lv25_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_25_self_attn_c_attn_q_weight4, model_layers_25_self_attn_c_attn_q_scale4, rms_norm269, model_layers_25_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape532 = R.call_tir(cls.reshape, (lv25_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape533 = R.call_tir(cls.reshape1, (reshape532,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv671 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape533), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape534 = R.call_tir(cls.reshape2, (lv671,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape535 = R.call_tir(cls.reshape3, (reshape534,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv75_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_25_self_attn_o_proj_q_weight4, model_layers_25_self_attn_o_proj_q_scale4, reshape535), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.fuse_add_norm_decode, (lv75_1, lv99, model_layers_25_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv101: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100[1]
            rms_norm270: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100[0]
            lv76_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_25_mlp_gate_up_proj_q_weight4, model_layers_25_mlp_gate_up_proj_q_scale4, rms_norm270), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv51_2 = R.call_tir(cls.fused_split_silu_multiply, (lv76_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv77_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_25_mlp_down_proj_q_weight4, model_layers_25_mlp_down_proj_q_scale4, lv51_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv102 = R.call_tir(cls.fuse_add_norm_decode, (lv77_1, lv101, model_layers_26_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv103: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102[1]
            rms_norm271: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102[0]
            lv26_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_26_self_attn_c_attn_q_weight4, model_layers_26_self_attn_c_attn_q_scale4, rms_norm271, model_layers_26_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape536 = R.call_tir(cls.reshape, (lv26_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape537 = R.call_tir(cls.reshape1, (reshape536,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv676 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape537), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape538 = R.call_tir(cls.reshape2, (lv676,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape539 = R.call_tir(cls.reshape3, (reshape538,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_26_self_attn_o_proj_q_weight4, model_layers_26_self_attn_o_proj_q_scale4, reshape539), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv104 = R.call_tir(cls.fuse_add_norm_decode, (lv78_1, lv103, model_layers_26_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv105: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104[1]
            rms_norm272: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104[0]
            lv79_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_26_mlp_gate_up_proj_q_weight4, model_layers_26_mlp_gate_up_proj_q_scale4, rms_norm272), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv53_2 = R.call_tir(cls.fused_split_silu_multiply, (lv79_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv80_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_26_mlp_down_proj_q_weight4, model_layers_26_mlp_down_proj_q_scale4, lv53_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv106 = R.call_tir(cls.fuse_add_norm_decode, (lv80_1, lv105, model_layers_27_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv107: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106[1]
            rms_norm273: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106[0]
            lv27_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_27_self_attn_c_attn_q_weight4, model_layers_27_self_attn_c_attn_q_scale4, rms_norm273, model_layers_27_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape540 = R.call_tir(cls.reshape, (lv27_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape541 = R.call_tir(cls.reshape1, (reshape540,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv681 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape541), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape542 = R.call_tir(cls.reshape2, (lv681,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape543 = R.call_tir(cls.reshape3, (reshape542,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv81_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_27_self_attn_o_proj_q_weight4, model_layers_27_self_attn_o_proj_q_scale4, reshape543), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv108 = R.call_tir(cls.fuse_add_norm_decode, (lv81_1, lv107, model_layers_27_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv109: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108[1]
            rms_norm274: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108[0]
            lv82_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_27_mlp_gate_up_proj_q_weight4, model_layers_27_mlp_gate_up_proj_q_scale4, rms_norm274), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv55_2 = R.call_tir(cls.fused_split_silu_multiply, (lv82_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv83_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_27_mlp_down_proj_q_weight4, model_layers_27_mlp_down_proj_q_scale4, lv55_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv110 = R.call_tir(cls.fuse_add_norm_decode, (lv83_1, lv109, model_layers_28_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv111: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110[1]
            rms_norm275: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110[0]
            lv28_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_28_self_attn_c_attn_q_weight4, model_layers_28_self_attn_c_attn_q_scale4, rms_norm275, model_layers_28_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape544 = R.call_tir(cls.reshape, (lv28_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape545 = R.call_tir(cls.reshape1, (reshape544,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv686 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape545), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape546 = R.call_tir(cls.reshape2, (lv686,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape547 = R.call_tir(cls.reshape3, (reshape546,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_28_self_attn_o_proj_q_weight4, model_layers_28_self_attn_o_proj_q_scale4, reshape547), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv112 = R.call_tir(cls.fuse_add_norm_decode, (lv84_1, lv111, model_layers_28_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv113: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112[1]
            rms_norm276: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112[0]
            lv85_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_28_mlp_gate_up_proj_q_weight4, model_layers_28_mlp_gate_up_proj_q_scale4, rms_norm276), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv57_2 = R.call_tir(cls.fused_split_silu_multiply, (lv85_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv86_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_28_mlp_down_proj_q_weight4, model_layers_28_mlp_down_proj_q_scale4, lv57_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv114 = R.call_tir(cls.fuse_add_norm_decode, (lv86_1, lv113, model_layers_29_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv115: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114[1]
            rms_norm277: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114[0]
            lv29_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_29_self_attn_c_attn_q_weight4, model_layers_29_self_attn_c_attn_q_scale4, rms_norm277, model_layers_29_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape548 = R.call_tir(cls.reshape, (lv29_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape549 = R.call_tir(cls.reshape1, (reshape548,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv691 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape549), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape550 = R.call_tir(cls.reshape2, (lv691,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape551 = R.call_tir(cls.reshape3, (reshape550,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv87_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_29_self_attn_o_proj_q_weight4, model_layers_29_self_attn_o_proj_q_scale4, reshape551), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv116 = R.call_tir(cls.fuse_add_norm_decode, (lv87_1, lv115, model_layers_29_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv117: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116[1]
            rms_norm278: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116[0]
            lv88_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_29_mlp_gate_up_proj_q_weight4, model_layers_29_mlp_gate_up_proj_q_scale4, rms_norm278), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv59_2 = R.call_tir(cls.fused_split_silu_multiply, (lv88_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv89_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_29_mlp_down_proj_q_weight4, model_layers_29_mlp_down_proj_q_scale4, lv59_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.fuse_add_norm_decode, (lv89_1, lv117, model_layers_30_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv119: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118[1]
            rms_norm279: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118[0]
            lv30_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_30_self_attn_c_attn_q_weight4, model_layers_30_self_attn_c_attn_q_scale4, rms_norm279, model_layers_30_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape552 = R.call_tir(cls.reshape, (lv30_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape553 = R.call_tir(cls.reshape1, (reshape552,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv696 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape554 = R.call_tir(cls.reshape2, (lv696,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape555 = R.call_tir(cls.reshape3, (reshape554,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_30_self_attn_o_proj_q_weight4, model_layers_30_self_attn_o_proj_q_scale4, reshape555), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv120 = R.call_tir(cls.fuse_add_norm_decode, (lv90_1, lv119, model_layers_30_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv121: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120[1]
            rms_norm280: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120[0]
            lv91_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_30_mlp_gate_up_proj_q_weight4, model_layers_30_mlp_gate_up_proj_q_scale4, rms_norm280), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv61_2 = R.call_tir(cls.fused_split_silu_multiply, (lv91_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv92_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_30_mlp_down_proj_q_weight4, model_layers_30_mlp_down_proj_q_scale4, lv61_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv122 = R.call_tir(cls.fuse_add_norm_decode, (lv92_1, lv121, model_layers_31_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv123: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122[1]
            rms_norm281: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122[0]
            lv31_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_31_self_attn_c_attn_q_weight4, model_layers_31_self_attn_c_attn_q_scale4, rms_norm281, model_layers_31_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape556 = R.call_tir(cls.reshape, (lv31_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape557 = R.call_tir(cls.reshape1, (reshape556,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv701 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape557), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape558 = R.call_tir(cls.reshape2, (lv701,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape559 = R.call_tir(cls.reshape3, (reshape558,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv93_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_31_self_attn_o_proj_q_weight4, model_layers_31_self_attn_o_proj_q_scale4, reshape559), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv124 = R.call_tir(cls.fuse_add_norm_decode, (lv93_1, lv123, model_layers_31_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv125: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124[1]
            rms_norm282: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124[0]
            lv94_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_31_mlp_gate_up_proj_q_weight4, model_layers_31_mlp_gate_up_proj_q_scale4, rms_norm282), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv63_2 = R.call_tir(cls.fused_split_silu_multiply, (lv94_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv95_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_31_mlp_down_proj_q_weight4, model_layers_31_mlp_down_proj_q_scale4, lv63_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv126 = R.call_tir(cls.fuse_add_norm_decode, (lv95_1, lv125, model_layers_32_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv127: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126[1]
            rms_norm283: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126[0]
            lv32_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_32_self_attn_c_attn_q_weight4, model_layers_32_self_attn_c_attn_q_scale4, rms_norm283, model_layers_32_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape560 = R.call_tir(cls.reshape, (lv32_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape561 = R.call_tir(cls.reshape1, (reshape560,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv706 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape561), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape562 = R.call_tir(cls.reshape2, (lv706,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape563 = R.call_tir(cls.reshape3, (reshape562,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv96_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_32_self_attn_o_proj_q_weight4, model_layers_32_self_attn_o_proj_q_scale4, reshape563), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv128 = R.call_tir(cls.fuse_add_norm_decode, (lv96_1, lv127, model_layers_32_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv129: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128[1]
            rms_norm284: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128[0]
            lv97_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_32_mlp_gate_up_proj_q_weight4, model_layers_32_mlp_gate_up_proj_q_scale4, rms_norm284), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv65_2 = R.call_tir(cls.fused_split_silu_multiply, (lv97_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv98_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_32_mlp_down_proj_q_weight4, model_layers_32_mlp_down_proj_q_scale4, lv65_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv130 = R.call_tir(cls.fuse_add_norm_decode, (lv98_1, lv129, model_layers_33_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv131: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130[1]
            rms_norm285: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130[0]
            lv33_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_33_self_attn_c_attn_q_weight4, model_layers_33_self_attn_c_attn_q_scale4, rms_norm285, model_layers_33_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape564 = R.call_tir(cls.reshape, (lv33_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape565 = R.call_tir(cls.reshape1, (reshape564,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv711 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape565), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape566 = R.call_tir(cls.reshape2, (lv711,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape567 = R.call_tir(cls.reshape3, (reshape566,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv99_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_33_self_attn_o_proj_q_weight4, model_layers_33_self_attn_o_proj_q_scale4, reshape567), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv132 = R.call_tir(cls.fuse_add_norm_decode, (lv99_1, lv131, model_layers_33_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv133: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132[1]
            rms_norm286: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132[0]
            lv100_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_33_mlp_gate_up_proj_q_weight4, model_layers_33_mlp_gate_up_proj_q_scale4, rms_norm286), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv67_2 = R.call_tir(cls.fused_split_silu_multiply, (lv100_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv101_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_33_mlp_down_proj_q_weight4, model_layers_33_mlp_down_proj_q_scale4, lv67_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv134 = R.call_tir(cls.fuse_add_norm_decode, (lv101_1, lv133, model_layers_34_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv135: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134[1]
            rms_norm287: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134[0]
            lv34_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_34_self_attn_c_attn_q_weight4, model_layers_34_self_attn_c_attn_q_scale4, rms_norm287, model_layers_34_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape568 = R.call_tir(cls.reshape, (lv34_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape569 = R.call_tir(cls.reshape1, (reshape568,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv716 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape569), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape570 = R.call_tir(cls.reshape2, (lv716,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape571 = R.call_tir(cls.reshape3, (reshape570,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv102_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_34_self_attn_o_proj_q_weight4, model_layers_34_self_attn_o_proj_q_scale4, reshape571), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.fuse_add_norm_decode, (lv102_1, lv135, model_layers_34_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv137: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136[1]
            rms_norm288: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136[0]
            lv103_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_34_mlp_gate_up_proj_q_weight4, model_layers_34_mlp_gate_up_proj_q_scale4, rms_norm288), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv69_2 = R.call_tir(cls.fused_split_silu_multiply, (lv103_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv104_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_34_mlp_down_proj_q_weight4, model_layers_34_mlp_down_proj_q_scale4, lv69_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv138 = R.call_tir(cls.fuse_add_norm_decode, (lv104_1, lv137, model_layers_35_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv139: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138[1]
            rms_norm289: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138[0]
            lv35_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_35_self_attn_c_attn_q_weight4, model_layers_35_self_attn_c_attn_q_scale4, rms_norm289, model_layers_35_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape572 = R.call_tir(cls.reshape, (lv35_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape573 = R.call_tir(cls.reshape1, (reshape572,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv721 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape573), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape574 = R.call_tir(cls.reshape2, (lv721,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape575 = R.call_tir(cls.reshape3, (reshape574,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv105_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_35_self_attn_o_proj_q_weight4, model_layers_35_self_attn_o_proj_q_scale4, reshape575), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv140 = R.call_tir(cls.fuse_add_norm_decode, (lv105_1, lv139, model_layers_35_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv141: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140[1]
            rms_norm290: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140[0]
            lv106_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_35_mlp_gate_up_proj_q_weight4, model_layers_35_mlp_gate_up_proj_q_scale4, rms_norm290), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv71_2 = R.call_tir(cls.fused_split_silu_multiply, (lv106_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv107_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_35_mlp_down_proj_q_weight4, model_layers_35_mlp_down_proj_q_scale4, lv71_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv142 = R.call_tir(cls.fuse_add_norm_decode, (lv107_1, lv141, model_norm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            rms_norm291: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv142[0]
            lv108_1 = R.call_tir(cls.fused_dequantize_NT_matmul4, (model_embed_tokens_q_weight4, model_embed_tokens_q_scale4, rms_norm291), out_sinfo=R.Tensor((batch_size, 1, 151936), dtype="float32"))
            gv4: R.Tuple(R.Tensor((batch_size, 1, 151936), dtype="float32"), R.Object) = lv108_1, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight3: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale3: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_