# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func
    def argsort1(var_probs: T.handle, var_argsort_gpu_v1: T.handle, var_value_buf: T.handle, var_value_swap_buf: T.handle, var_out_swap_buf: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(), T.int32()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size), offset_factor=1)
        out_buf = T.match_buffer(var_argsort_gpu_v1, (batch_size, vocab_size), "int32", align=8)
        value_buf = T.match_buffer(var_value_buf, (batch_size, vocab_size), align=8)
        value_swap_buf = T.match_buffer(var_value_swap_buf, (batch_size, vocab_size), align=8)
        out_swap_buf = T.match_buffer(var_out_swap_buf, (batch_size, vocab_size), "int32", align=8)
        with T.block("root"):
            T.reads()
            T.writes()
            with T.block("argsort_gpu"):
                T.reads()
                T.writes()
                if vocab_size > 0:
                    with T.launch_thread("threadIdx.x", 256) as threadIdx_x:
                        blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (vocab_size + 255) // 256))
                        blockIdx_y = T.launch_thread("blockIdx.y", T.max(1, batch_size))
                        if blockIdx_x * 256 + threadIdx_x < vocab_size:
                            value_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = probs[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) % vocab_size]
                            out_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * 256 + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = blockIdx_x * 256 + threadIdx_x
                    with T.attr(0, "hand_threaded", 0):
                        threadIdx_x = T.launch_thread("threadIdx.x", 64)
                        blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (vocab_size + 127) // 128))
                        blockIdx_y = T.launch_thread("blockIdx.y", T.max(1, batch_size))
                        temp_keys_swap = T.allocate([128], "float32", "shared")
                        temp_values_swap = T.allocate([128], "int32", "shared")
                        temp_keys = T.allocate([1], "float32", "local")
                        temp_values = T.allocate([1], "int32", "local")
                        temp_cond1 = T.allocate([1], "float32", "local")
                        temp_cond2 = T.allocate([1], "float32", "local")
                        temp_keys_swap_1 = T.Buffer((128,), data=temp_keys_swap, scope="shared")
                        temp_values_swap_1 = T.Buffer((128,), "int32", data=temp_values_swap, scope="shared")
                        for i in range(2):
                            if 2 * threadIdx_x + i + blockIdx_x * 128 < vocab_size:
                                temp_keys_swap_1[2 * threadIdx_x + i] = value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + i + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + i + blockIdx_x * 128)) % vocab_size]
                                temp_values_swap_1[2 * threadIdx_x + i] = out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + i + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + i + blockIdx_x * 128)) % vocab_size]
                        T.tvm_storage_sync("shared")
                        for j in range(T.min(128, vocab_size - blockIdx_x * 128)):
                            if 2 * threadIdx_x + (2 * threadIdx_x + j) % 2 < T.min(128, vocab_size - blockIdx_x * 128) - 1:
                                temp_cond1_1 = T.Buffer((1,), data=temp_cond1, scope="local")
                                temp_cond1_1[0] = temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2]
                                temp_cond2_1 = T.Buffer((1,), data=temp_cond2, scope="local")
                                temp_cond2_1[0] = temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2 + 1]
                                if temp_cond1_1[0] < temp_cond2_1[0]:
                                    temp_keys_1 = T.Buffer((1,), data=temp_keys, scope="local")
                                    temp_keys_1[0] = temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2]
                                    temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2] = temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2 + 1]
                                    temp_keys_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2 + 1] = temp_keys_1[0]
                                    temp_values_1 = T.Buffer((1,), "int32", data=temp_values, scope="local")
                                    temp_values_1[0] = temp_values_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2]
                                    temp_values_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2] = temp_values_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2 + 1]
                                    temp_values_swap_1[2 * threadIdx_x + (2 * threadIdx_x + j) % 2 + 1] = temp_values_1[0]
                            T.tvm_storage_sync("shared")
                        for k in range(2):
                            if 2 * threadIdx_x + k + blockIdx_x * 128 < vocab_size:
                                value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) % vocab_size] = temp_keys_swap_1[2 * threadIdx_x + k]
                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) % vocab_size] = temp_keys_swap_1[2 * threadIdx_x + k]
                                out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) % vocab_size] = temp_values_swap_1[2 * threadIdx_x + k]
                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (2 * threadIdx_x + k + blockIdx_x * 128)) % vocab_size] = temp_values_swap_1[2 * threadIdx_x + k]
                    for i_0 in range(T.if_then_else(T.bitwise_and(vocab_size, vocab_size - 1) == 0, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64) - 1, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64)) - (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)):
                        threadIdx_x = T.launch_thread("threadIdx.x", 256)
                        blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 1023) // 1024))
                        blockIdx_y = T.launch_thread("blockIdx.y", T.max(1, batch_size * ((vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) - 1)) // T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)))))
                        if T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) < vocab_size:
                            if (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 1023) // 1024 == 1:
                                if i_0 % 2 == 0:
                                    first = T.allocate([1], "int64", "local")
                                    mid = T.allocate([1], "int64", "local")
                                    last = T.allocate([1], "int64", "local")
                                    first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                    first_1[0] = T.Cast("int64", T.max(0, threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))))
                                    last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                    last_1[0] = T.Cast("int64", T.min(threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256), T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)))
                                    while first_1[0] < last_1[0]:
                                        if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                            first_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1)) + T.int64(1)
                                        else:
                                            last_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1))
                                    i = T.allocate([1], "int64", "local")
                                    j = T.allocate([1], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256)) - last_1[0]
                                    for i_1_1 in range(T.min(T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256), (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256)):
                                        if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size))) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))):
                                            if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                i_1[0] = i_1[0] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                j_1[0] = j_1[0] + T.int64(1)
                                        else:
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size))):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                i_1[0] = i_1[0] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_1_1)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                j_1[0] = j_1[0] + T.int64(1)
                                else:
                                    first = T.allocate([1], "int64", "local")
                                    mid = T.allocate([1], "int64", "local")
                                    last = T.allocate([1], "int64", "local")
                                    first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                    first_1[0] = T.Cast("int64", T.max(0, threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))))
                                    last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                    last_1[0] = T.Cast("int64", T.min(threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256), T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)))
                                    while first_1[0] < last_1[0]:
                                        if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                            first_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1)) + T.int64(1)
                                        else:
                                            last_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1))
                                    i = T.allocate([1], "int64", "local")
                                    j = T.allocate([1], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256)) - last_1[0]
                                    for i_2 in range(T.min(T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256), (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256)):
                                        if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size))) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))):
                                            if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                i_1[0] = i_1[0] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                j_1[0] = j_1[0] + T.int64(1)
                                        else:
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size))):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                i_1[0] = i_1[0] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) + 255) // 256) + i_2)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                j_1[0] = j_1[0] + T.int64(1)
                            else:
                                if i_0 % 2 == 0:
                                    first = T.allocate([1], "int64", "local")
                                    mid = T.allocate([1], "int64", "local")
                                    last = T.allocate([1], "int64", "local")
                                    first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                    first_1[0] = T.Cast("int64", T.max(0, blockIdx_x * 1024 - (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))))
                                    last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                    last_1[0] = T.Cast("int64", T.min(blockIdx_x * 1024, T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)))
                                    while first_1[0] < last_1[0]:
                                        if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024 - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024 - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                            first_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1)) + T.int64(1)
                                        else:
                                            last_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1))
                                    if i_0 % 2 == 0:
                                        first_2 = T.allocate([1], "int64", "local")
                                        mid_1 = T.allocate([1], "int64", "local")
                                        last_2 = T.allocate([1], "int64", "local")
                                        first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                        first_3[0] = T.max(T.int64(0), T.Cast("int64", threadIdx_x * 4) - T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)))
                                        last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                        last_3[0] = T.min(T.Cast("int64", threadIdx_x * 4), T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)))
                                        while first_3[0] < last_3[0]:
                                            if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                                first_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1)) + T.int64(1)
                                            else:
                                                last_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1))
                                        i = T.allocate([1], "int64", "local")
                                        j = T.allocate([1], "int64", "local")
                                        i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                        i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + first_3[0]
                                        j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                        j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - last_3[0]
                                        for i_3 in range(T.Cast("int32", T.min(T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)) - T.Cast("int64", threadIdx_x * 4), T.int64(4)))):
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)):
                                                if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                            else:
                                                if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)):
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_3)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                    else:
                                        first_2 = T.allocate([1], "int64", "local")
                                        mid_1 = T.allocate([1], "int64", "local")
                                        last_2 = T.allocate([1], "int64", "local")
                                        first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                        first_3[0] = T.max(T.int64(0), T.Cast("int64", threadIdx_x * 4) - T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)))
                                        last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                        last_3[0] = T.min(T.Cast("int64", threadIdx_x * 4), T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)))
                                        while first_3[0] < last_3[0]:
                                            if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                                first_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1)) + T.int64(1)
                                            else:
                                                last_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1))
                                        i = T.allocate([1], "int64", "local")
                                        j = T.allocate([1], "int64", "local")
                                        i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                        i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + first_3[0]
                                        j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                        j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - last_3[0]
                                        for i_4 in range(T.Cast("int32", T.min(T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)) - T.Cast("int64", threadIdx_x * 4), T.int64(4)))):
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)):
                                                if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                            else:
                                                if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)):
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_4)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                else:
                                    first = T.allocate([1], "int64", "local")
                                    mid = T.allocate([1], "int64", "local")
                                    last = T.allocate([1], "int64", "local")
                                    first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                    first_1[0] = T.Cast("int64", T.max(0, blockIdx_x * 1024 - (T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size) - T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size))))
                                    last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                    last_1[0] = T.Cast("int64", T.min(blockIdx_x * 1024, T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) - T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)))
                                    while first_1[0] < last_1[0]:
                                        if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024 - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024 - 1) - T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + T.shift_right(first_1[0] + last_1[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                            first_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1)) + T.int64(1)
                                        else:
                                            last_1[0] = T.shift_right(first_1[0] + last_1[0], T.int64(1))
                                    if i_0 % 2 == 0:
                                        first_2 = T.allocate([1], "int64", "local")
                                        mid_1 = T.allocate([1], "int64", "local")
                                        last_2 = T.allocate([1], "int64", "local")
                                        first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                        first_3[0] = T.max(T.int64(0), T.Cast("int64", threadIdx_x * 4) - T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)))
                                        last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                        last_3[0] = T.min(T.Cast("int64", threadIdx_x * 4), T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)))
                                        while first_3[0] < last_3[0]:
                                            if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                                first_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1)) + T.int64(1)
                                            else:
                                                last_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1))
                                        i = T.allocate([1], "int64", "local")
                                        j = T.allocate([1], "int64", "local")
                                        i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                        i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + first_3[0]
                                        j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                        j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - last_3[0]
                                        for i_5 in range(T.Cast("int32", T.min(T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)) - T.Cast("int64", threadIdx_x * 4), T.int64(4)))):
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)):
                                                if value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                            else:
                                                if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)):
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = value_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_5)) % vocab_size] = out_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                    else:
                                        first_2 = T.allocate([1], "int64", "local")
                                        mid_1 = T.allocate([1], "int64", "local")
                                        last_2 = T.allocate([1], "int64", "local")
                                        first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                        first_3[0] = T.max(T.int64(0), T.Cast("int64", threadIdx_x * 4) - T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)))
                                        last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                        last_3[0] = T.min(T.Cast("int64", threadIdx_x * 4), T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)))
                                        while first_3[0] < last_3[0]:
                                            if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - T.int64(1) - T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.shift_right(first_3[0] + last_3[0], T.int64(1)))) % T.Cast("int64", vocab_size)]:
                                                first_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1)) + T.int64(1)
                                            else:
                                                last_3[0] = T.shift_right(first_3[0] + last_3[0], T.int64(1))
                                        i = T.allocate([1], "int64", "local")
                                        j = T.allocate([1], "int64", "local")
                                        i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                        i_1[0] = T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + first_3[0]
                                        j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                        j_1[0] = T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.Cast("int64", threadIdx_x * 4) - last_3[0]
                                        for i_6 in range(T.Cast("int32", T.min(T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)) - T.Cast("int64", threadIdx_x * 4), T.int64(4)))):
                                            if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)) and j_1[0] < T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)), vocab_size)) - (T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size) + blockIdx_x * 1024) - last_1[0]), T.int64(1024)):
                                                if value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)] <= value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                                            else:
                                                if i_1[0] < T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0] + T.min(T.Cast("int64", T.min(T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) // 2, vocab_size)) - (T.Cast("int64", T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size)) + first_1[0]), T.int64(1024)):
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + i_1[0]) % T.Cast("int64", vocab_size)]
                                                    i_1[0] = i_1[0] + T.int64(1)
                                                else:
                                                    value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = value_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(2, i_0 + (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) * (blockIdx_y // batch_size) + blockIdx_x * 1024 + threadIdx_x * 4 + i_6)) % vocab_size] = out_swap_buf[(T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) // T.Cast("int64", vocab_size), (T.Cast("int64", blockIdx_y % batch_size * vocab_size) + j_1[0]) % T.Cast("int64", vocab_size)]
                                                    j_1[0] = j_1[0] + T.int64(1)
                    if T.if_then_else(T.bitwise_and(vocab_size, vocab_size - 1) == 0, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64) - 1, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64)) > 32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1 and (T.if_then_else(T.bitwise_and(vocab_size, vocab_size - 1) == 0, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64) - 1, 64 - (T.clz(vocab_size) - 32 + 64 - 64 + 64)) - (32 - (T.clz(128) - 32 + 64 - 64 + 32) - 1)) % 2 == 1:
                        threadIdx_x = T.launch_thread("threadIdx.x", 256)
                        blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (vocab_size + 255) // 256))
                        blockIdx_y = T.launch_thread("blockIdx.y", T.max(1, batch_size))
                        if blockIdx_x * 256 + threadIdx_x < vocab_size:
                            value_buf[(blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) % vocab_size] = value_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) % vocab_size]
                            out_buf[(blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) % vocab_size] = out_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * 256 + threadIdx_x)) % vocab_size]

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(child_token[0], tx):T.min(child_token[0], tx) + (T.max(child_token[0], (vocab_size + 1023) // 1024 * 1024 + tx - 1024) + 1 - T.min(child_token[0], tx))], child_token[0], draft_probs[child_ptr[0], T.min(child_token[0], tx):T.min(child_token[0], tx) + (T.max(child_token[0], (vocab_size + 1023) // 1024 * 1024 + tx - 1024) + 1 - T.min(child_token[0], tx))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], tx:tx + ((vocab_size + 1023) // 1024 * 1024 - 1023)], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + 1023) // 1024):
                                    if i * 1024 + tx < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * 1024 + tx]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * 1024 + tx]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + 1023) // 1024):
                                        if i * 1024 + tx < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * 1024 + tx]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * 1024 + tx]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * 1024 + tx] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int32(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
        for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0, ax1 in T.grid(1, 1):
                for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    for ax2_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                        with T.block("max"):
                            v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0)
                            v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
                            v2 = T.axis.reduce(4096, ax2_fused_0 * 64 + ax2_fused_1)
                            T.reads(temperature[v0], A[v0, v1 * 4096 + v2])
                            T.writes(temp_max_shared[v0, v1])
                            with T.init():
                                temp_max_shared[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                            temp_max_shared[v0, v1] = T.max(temp_max_shared[v0, v1], T.if_then_else(v1 * 4096 + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * 4096 + v2] / temperature[v0], A[v0, v1 * 4096 + v2]), T.float32(-340282346638528859811704183484516925440.0)))
            for ax0, ax1 in T.grid(1, 1):
                for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    for ax2_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                        with T.block("sum_exp"):
                            v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0)
                            v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
                            v2 = T.axis.reduce(4096, ax2_fused_0 * 64 + ax2_fused_1)
                            T.reads(temperature[v0], A[v0, v1 * 4096 + v2], temp_max_shared[v0, v1])
                            T.writes(temp_sum_shared[v0, v1])
                            with T.init():
                                temp_sum_shared[v0, v1] = T.float32(0.0)
                            temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(v1 * 4096 + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(T.if_then_else(v1 * 4096 + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * 4096 + v2] / temperature[v0], A[v0, v1 * 4096 + v2]), T.float32(-340282346638528859811704183484516925440.0)) - temp_max_shared[v0, v1]), T.Cast("float32", T.if_then_else(v1 * 4096 + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * 4096 + v2] / temperature[v0], A[v0, v1 * 4096 + v2]), T.float32(-340282346638528859811704183484516925440.0)) == temp_max_shared[v0, v1])), T.float32(0.0))
            for ax2_1 in T.thread_binding(64, thread="threadIdx.x"):
                for ax2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("log"):
                        v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks)
                        v2 = T.axis.spatial(1, ax2_0 * 64 + ax2_1)
                        T.where(ax2_0 * 64 + ax2_1 < 1)
                        T.reads(temperature[v0], temp_sum_shared[v0, v1], temp_max_shared[v0, v1])
                        T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                        chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum_shared[v0, v1]), temp_sum_shared[v0, v1])
                        chunked_max[v0, v1] = temp_max_shared[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding((batch_size * 256 + 1023) // 1024, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 256
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 2
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 2 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int32, tgt_page_id: T.int32, copy_length: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding((copy_length * 256 + 1023) // 1024, thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(2, (b * 1024 + t) // (copy_length * 128))
                    vp = T.axis.spatial(copy_length, (b * 1024 + t) % (copy_length * 128) // 128)
                    vd = T.axis.spatial(128, (b * 1024 + t) % 128)
                    T.where(b * 1024 + t < copy_length * 2 * 128)
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((batch_size + 1023) // 1024, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    v0 = T.axis.spatial(batch_size, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < batch_size)
                    T.reads()
                    T.writes(result[v0, 0])
                    result[v0, 0] = value

    @T.prim_func
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 2048), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 2048), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 2048), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 2048), "float16")
        B = T.match_buffer(pB, (1, seq_len, 2048), "float16")
        O = T.match_buffer(pO, (1, seq_len, 2048), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(2048, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func
    def fused_dequantize1_fused_NT_matmul10_add2(model_layers_0_self_attn_c_attn_q_weight2: T.Buffer((2560, 256), "uint32"), model_layers_0_self_attn_c_attn_q_scale2: T.Buffer((2560, 64), "float16"), rms_norm73: T.Buffer((1, 1, 2048), "float16"), model_layers_0_self_attn_c_attn_bias2: T.Buffer((2560,), "float16"), T_add_intermediate_intermediate: T.Buffer((1, 1, 2560), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 2560), "float16", scope="local")
        NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 2560), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 2560), "float16", scope="local")
        model_layers_0_self_attn_c_attn_q_weight2_local = T.alloc_buffer((2560, 256), "uint32", scope="local")
        rms_norm73_shared = T.alloc_buffer((1, 1, 2048), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(640, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(4, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("rms_norm73_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(2048, ax2_0 * 512 + ax2_1 * 128 + ax2_2 * 4 + ax2_3)
                                            T.reads(rms_norm73[v0, v1, v2])
                                            T.writes(rms_norm73_shared[v0, v1, v2])
                                            rms_norm73_shared[v0, v1, v2] = rms_norm73[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(1):
                            for ax0_ax1_fused_1 in T.vectorized(1):
                                with T.block("model_layers_0_self_attn_c_attn_q_weight2_local"):
                                    v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_self_attn_c_attn_q_weight2[v0, v1])
                                    T.writes(model_layers_0_self_attn_c_attn_q_weight2_local[v0, v1])
                                    model_layers_0_self_attn_c_attn_q_weight2_local[v0, v1] = model_layers_0_self_attn_c_attn_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], rms_norm73_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], model_layers_0_self_attn_c_attn_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], model_layers_0_self_attn_c_attn_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + rms_norm73_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0.0)
                            for ax1 in range(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_2 in range(1):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0)
                            v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(NT_matmul_intermediate_local[0, 0, v0])
                            with T.init():
                                NT_matmul_intermediate_local[0, 0, v0] = T.float16(0.0)
                            NT_matmul_intermediate_local[0, 0, v0] = NT_matmul_intermediate_local[0, 0, v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
            for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                for ax0_fused_2 in range(1):
                    with T.block("T_add"):
                        v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 4 + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2)
                        T.reads(NT_matmul_intermediate_local[0, 0, v0], model_layers_0_self_attn_c_attn_bias2[v0])
                        T.writes(T_add_intermediate_intermediate[0, 0, v0])
                        T_add_intermediate_intermediate[0, 0, v0] = NT_matmul_intermediate_local[0, 0, v0] + model_layers_0_self_attn_c_attn_bias2[v0]

    @T.prim_func
    def fused_dequantize1_fused_NT_matmul5_add1(model_layers_0_self_attn_c_attn_q_weight3: T.Buffer((2560, 256), "uint32"), model_layers_0_self_attn_c_attn_q_scale3: T.Buffer((2560, 64), "float16"), p_rms_norm146: T.handle, model_layers_0_self_attn_c_attn_bias3: T.Buffer((2560,), "float16"), p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        rms_norm146 = T.match_buffer(p_rms_norm146, (1, seq_len, 2048), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (1, seq_len, 2560), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2560, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 1) // 2 * 2, 2560), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 1) // 2 * 2, 2560), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 1) // 2 * 2, 2560), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(320, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // 8], model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], rms_norm146[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, rms_norm146[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[0, v0, v1], model_layers_0_self_attn_c_attn_bias3[v1])
                                        T.writes(T_add_intermediate_intermediate[0, v0, v1])
                                        T_add_intermediate_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + model_layers_0_self_attn_c_attn_bias3[v1]
        else:
            if T.tvm_thread_invariant(seq_len <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2560, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 3) // 4 * 4, 2560), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 3) // 4 * 4, 2560), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 3) // 4 * 4, 2560), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(320, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // 8], model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], rms_norm146[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, rms_norm146[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[0, v0, v1], model_layers_0_self_attn_c_attn_bias3[v1])
                                            T.writes(T_add_intermediate_intermediate[0, v0, v1])
                                            T_add_intermediate_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + model_layers_0_self_attn_c_attn_bias3[v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2560), "float16", scope="local")
                    rms_norm146_reindex_pad_shared = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2560, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(80, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("rms_norm146_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(rms_norm146[v0, v1, v2])
                                                                    T.writes(rms_norm146_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm146_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm146[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight3[v1, v2 // 8], model_layers_0_self_attn_c_attn_q_scale3[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], rms_norm146_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + rms_norm146_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], model_layers_0_self_attn_c_attn_bias3[v2])
                                                        T.writes(T_add_intermediate_intermediate[0, v1, v2])
                                                        T_add_intermediate_intermediate[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + model_layers_0_self_attn_c_attn_bias3[v2]

    @T.prim_func
    def fused_dequantize1_fused_NT_matmul_add(model_layers_0_self_attn_c_attn_q_weight4: T.Buffer((2560, 256), "uint32"), model_layers_0_self_attn_c_attn_q_scale4: T.Buffer((2560, 64), "float16"), p_rms_norm219: T.handle, model_layers_0_self_attn_c_attn_bias4: T.Buffer((2560,), "float16"), p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        rms_norm219 = T.match_buffer(p_rms_norm219, (batch_size, 1, 2048), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, 1, 2560), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2560, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 1) // 2 * 2, 1, 2560), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 1) // 2 * 2, 1, 2560), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 1) // 2 * 2, 1, 2560), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(320, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // 8], model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm219[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, rms_norm219[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1], model_layers_0_self_attn_c_attn_bias4[v1])
                                        T.writes(T_add_intermediate_intermediate[v0, 0, v1])
                                        T_add_intermediate_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + model_layers_0_self_attn_c_attn_bias4[v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2560, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 2560), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 3) // 4 * 4, 1, 2560), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 2560), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(320, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // 8], model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm219[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, rms_norm219[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2560, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1], model_layers_0_self_attn_c_attn_bias4[v1])
                                            T.writes(T_add_intermediate_intermediate[v0, 0, v1])
                                            T_add_intermediate_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + model_layers_0_self_attn_c_attn_bias4[v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2560), "float16", scope="local")
                    rms_norm219_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2560, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(80, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("rms_norm219_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(rms_norm219[v1, 0, v2])
                                                                    T.writes(rms_norm219_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm219_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm219[v1, 0, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_self_attn_c_attn_q_weight4[v1, v2 // 8], model_layers_0_self_attn_c_attn_q_scale4[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], rms_norm219_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + rms_norm219_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], model_layers_0_self_attn_c_attn_bias4[v2])
                                                        T.writes(T_add_intermediate_intermediate[v1, 0, v2])
                                                        T_add_intermediate_intermediate[v1, 0, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + model_layers_0_self_attn_c_attn_bias4[v2]

    @T.prim_func
    def fused_dequantize2_NT_matmul1(model_layers_0_self_attn_o_proj_q_weight4: T.Buffer((2048, 256), "uint32"), model_layers_0_self_attn_o_proj_q_scale4: T.Buffer((2048, 64), "float16"), p_reshape435: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        reshape435 = T.match_buffer(p_reshape435, (batch_size, 1, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2048, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // 8], model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], reshape435[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, reshape435[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        T.writes(NT_matmul_intermediate[v0, 0, v1])
                                        NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2048, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // 8], model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], reshape435[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, reshape435[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            T.writes(NT_matmul_intermediate[v0, 0, v1])
                                            NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="local")
                    reshape435_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2048, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("reshape435_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(reshape435[v1, 0, v2])
                                                                    T.writes(reshape435_reindex_pad_shared[v0, v1, v2])
                                                                    reshape435_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, reshape435[v1, 0, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight4[v1, v2 // 8], model_layers_0_self_attn_o_proj_q_scale4[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], reshape435_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + reshape435_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, 0, v2])
                                                        NT_matmul_intermediate[v1, 0, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize2_NT_matmul11(model_layers_0_self_attn_o_proj_q_weight2: T.Buffer((2048, 256), "uint32"), model_layers_0_self_attn_o_proj_q_scale2: T.Buffer((2048, 64), "float16"), lv218: T.Buffer((1, 1, 2048), "float16"), NT_matmul_intermediate: T.Buffer((1, 1, 2048), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 2048), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 2048), "float16", scope="local")
        model_layers_0_self_attn_o_proj_q_weight2_local = T.alloc_buffer((2048, 256), "uint32", scope="local")
        lv218_shared = T.alloc_buffer((1, 1, 2048), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("lv218_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(2048, ax2_0 * 1024 + ax2_1 * 128 + ax2_2 * 4 + ax2_3)
                                            T.reads(lv218[v0, v1, v2])
                                            T.writes(lv218_shared[v0, v1, v2])
                                            lv218_shared[v0, v1, v2] = lv218[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(1):
                            for ax0_ax1_fused_1 in T.vectorized(1):
                                with T.block("model_layers_0_self_attn_o_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_self_attn_o_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_self_attn_o_proj_q_weight2_local[v0, v1])
                                    model_layers_0_self_attn_o_proj_q_weight2_local[v0, v1] = model_layers_0_self_attn_o_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv218_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], model_layers_0_self_attn_o_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], model_layers_0_self_attn_o_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv218_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0.0)
                            for ax1 in range(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_2 in range(1):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0)
                            v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                NT_matmul_intermediate[0, 0, v0] = T.float16(0.0)
                            NT_matmul_intermediate[0, 0, v0] = NT_matmul_intermediate[0, 0, v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func
    def fused_dequantize2_NT_matmul6(model_layers_0_self_attn_o_proj_q_weight3: T.Buffer((2048, 256), "uint32"), model_layers_0_self_attn_o_proj_q_scale3: T.Buffer((2048, 64), "float16"), p_reshape291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        reshape291 = T.match_buffer(p_reshape291, (1, seq_len, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2048, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // 8], model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], reshape291[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, reshape291[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        T.writes(NT_matmul_intermediate[0, v0, v1])
                                        NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2048, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // 8], model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], reshape291[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, reshape291[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            T.writes(NT_matmul_intermediate[0, v0, v1])
                                            NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2048), "float16", scope="local")
                    reshape291_reindex_pad_shared = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2048, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("reshape291_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(reshape291[v0, v1, v2])
                                                                    T.writes(reshape291_reindex_pad_shared[v0, v1, v2])
                                                                    reshape291_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, reshape291[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_self_attn_o_proj_q_weight3[v1, v2 // 8], model_layers_0_self_attn_o_proj_q_scale3[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], reshape291_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + reshape291_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[0, v1, v2])
                                                        NT_matmul_intermediate[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize3_NT_matmul12(model_layers_0_mlp_gate_up_proj_q_weight2: T.Buffer((22016, 256), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale2: T.Buffer((22016, 64), "float16"), rms_norm74: T.Buffer((1, 1, 2048), "float16"), NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 22016), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 22016), "float16", scope="local")
        model_layers_0_mlp_gate_up_proj_q_weight2_local = T.alloc_buffer((22016, 256), "uint32", scope="local")
        rms_norm74_shared = T.alloc_buffer((1, 1, 2048), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(5504, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(4, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("rms_norm74_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(2048, ax2_0 * 512 + ax2_1 * 128 + ax2_2 * 4 + ax2_3)
                                            T.reads(rms_norm74[v0, v1, v2])
                                            T.writes(rms_norm74_shared[v0, v1, v2])
                                            rms_norm74_shared[v0, v1, v2] = rms_norm74[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(1):
                            for ax0_ax1_fused_1 in T.vectorized(1):
                                with T.block("model_layers_0_mlp_gate_up_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, v1])
                                    model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, v1] = model_layers_0_mlp_gate_up_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], rms_norm74_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], model_layers_0_mlp_gate_up_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + rms_norm74_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0.0)
                            for ax1 in range(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_2 in range(1):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0)
                            v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                NT_matmul_intermediate[0, 0, v0] = T.float16(0.0)
                            NT_matmul_intermediate[0, 0, v0] = NT_matmul_intermediate[0, 0, v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func
    def fused_dequantize3_NT_matmul2(model_layers_0_mlp_gate_up_proj_q_weight4: T.Buffer((22016, 256), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale4: T.Buffer((22016, 64), "float16"), p_rms_norm220: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        rms_norm220 = T.match_buffer(p_rms_norm220, (batch_size, 1, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 1, 22016), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((22016, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 1) // 2 * 2, 1, 22016), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 1) // 2 * 2, 1, 22016), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 1) // 2 * 2, 1, 22016), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // 8], model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm220[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, rms_norm220[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        T.writes(NT_matmul_intermediate[v0, 0, v1])
                                        NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((22016, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 22016), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 3) // 4 * 4, 1, 22016), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 22016), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // 8], model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm220[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, rms_norm220[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            T.writes(NT_matmul_intermediate[v0, 0, v1])
                                            NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 22016), "float16", scope="local")
                    rms_norm220_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 22016, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(688, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("rms_norm220_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(rms_norm220[v1, 0, v2])
                                                                    T.writes(rms_norm220_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm220_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm220[v1, 0, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v1, v2 // 8], model_layers_0_mlp_gate_up_proj_q_scale4[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], rms_norm220_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + rms_norm220_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, 0, v2])
                                                        NT_matmul_intermediate[v1, 0, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize3_NT_matmul7(model_layers_0_mlp_gate_up_proj_q_weight3: T.Buffer((22016, 256), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale3: T.Buffer((22016, 64), "float16"), p_rms_norm147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        rms_norm147 = T.match_buffer(p_rms_norm147, (1, seq_len, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (1, seq_len, 22016), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((22016, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 1) // 2 * 2, 22016), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 1) // 2 * 2, 22016), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 1) // 2 * 2, 22016), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // 8], model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], rms_norm147[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, rms_norm147[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        T.writes(NT_matmul_intermediate[0, v0, v1])
                                        NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((22016, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 3) // 4 * 4, 22016), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 3) // 4 * 4, 22016), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 3) // 4 * 4, 22016), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // 8], model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], rms_norm147[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, rms_norm147[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(22016, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            T.writes(NT_matmul_intermediate[0, v0, v1])
                                            NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 22016), "float16", scope="local")
                    rms_norm147_reindex_pad_shared = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 22016, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(688, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("rms_norm147_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(rms_norm147[v0, v1, v2])
                                                                    T.writes(rms_norm147_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm147_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm147[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v1, v2 // 8], model_layers_0_mlp_gate_up_proj_q_scale3[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], rms_norm147_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + rms_norm147_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[0, v1, v2])
                                                        NT_matmul_intermediate[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize4_NT_matmul13(model_layers_0_mlp_down_proj_q_weight2: T.Buffer((2048, 1376), "uint32"), model_layers_0_mlp_down_proj_q_scale2: T.Buffer((2048, 344), "float16"), lv219: T.Buffer((1, 1, 11008), "float16"), NT_matmul_intermediate: T.Buffer((1, 1, 2048), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 2048), "float16", scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 2048), "float16", scope="local")
        model_layers_0_mlp_down_proj_q_weight2_local = T.alloc_buffer((2048, 1376), "uint32", scope="local")
        lv219_shared = T.alloc_buffer((1, 1, 11008), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(43, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(1):
                                        with T.block("lv219_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(11008, ax2_0 * 256 + ax2_1 * 32 + ax2_2 + ax2_3)
                                            T.reads(lv219[v0, v1, v2])
                                            T.writes(lv219_shared[v0, v1, v2])
                                            lv219_shared[v0, v1, v2] = lv219[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(43, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(1):
                            for ax0_ax1_fused_1 in T.vectorized(1):
                                with T.block("model_layers_0_mlp_down_proj_q_weight2_local"):
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_layers_0_mlp_down_proj_q_weight2[v0, v1])
                                    T.writes(model_layers_0_mlp_down_proj_q_weight2_local[v0, v1])
                                    model_layers_0_mlp_down_proj_q_weight2_local[v0, v1] = model_layers_0_mlp_down_proj_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv219_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], model_layers_0_mlp_down_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], model_layers_0_mlp_down_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv219_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0.0)
                            for ax1 in range(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_2 in range(1):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0)
                            v0 = T.axis.spatial(2048, u_fused_ax0_fused_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                NT_matmul_intermediate[0, 0, v0] = T.float16(0.0)
                            NT_matmul_intermediate[0, 0, v0] = NT_matmul_intermediate[0, 0, v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func
    def fused_dequantize4_NT_matmul3(model_layers_0_mlp_down_proj_q_weight4: T.Buffer((2048, 1376), "uint32"), model_layers_0_mlp_down_proj_q_scale4: T.Buffer((2048, 344), "float16"), p_lv1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        lv1 = T.match_buffer(p_lv1, (batch_size, 1, 11008), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2048, 11008), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 1) // 2 * 2, 1, 2048), "float16", scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(43, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(11008, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // 8], model_layers_0_mlp_down_proj_q_scale4[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], lv1[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, lv1[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        T.writes(NT_matmul_intermediate[v0, 0, v1])
                                        NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2048, 11008), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 2048), "float16", scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(43, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(11008, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // 8], model_layers_0_mlp_down_proj_q_scale4[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], lv1[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.if_then_else(v0 < batch_size, lv1[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            T.writes(NT_matmul_intermediate[v0, 0, v1])
                                            NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="local")
                    lv1_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 11008), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2048, 11008), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(1376):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("lv1_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(11008, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(lv1[v1, 0, v2])
                                                                    T.writes(lv1_reindex_pad_shared[v0, v1, v2])
                                                                    lv1_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, lv1[v1, 0, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(11008, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_mlp_down_proj_q_weight4[v1, v2 // 8], model_layers_0_mlp_down_proj_q_scale4[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(11008, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv1_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv1_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, 0, v2])
                                                        NT_matmul_intermediate[v1, 0, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize4_NT_matmul8(model_layers_0_mlp_down_proj_q_weight3: T.Buffer((2048, 1376), "uint32"), model_layers_0_mlp_down_proj_q_scale3: T.Buffer((2048, 344), "float16"), p_lv73: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        lv73 = T.match_buffer(p_lv73, (1, seq_len, 11008), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        if T.tvm_thread_invariant(seq_len <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((2048, 11008), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 1) // 2 * 2, 2048), "float16", scope="local")
                for ax0_0 in T.thread_binding((seq_len + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                for ax2_fused_0 in T.serial(43, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(11008, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // 8], model_layers_0_mlp_down_proj_q_scale3[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], lv73[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, lv73[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((seq_len + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                        NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(seq_len, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (seq_len + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < seq_len)
                                        T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        T.writes(NT_matmul_intermediate[0, v0, v1])
                                        NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
        else:
            if T.tvm_thread_invariant(seq_len <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((2048, 11008), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (seq_len + 3) // 4 * 4, 2048), "float16", scope="local")
                    for ax0_0 in T.thread_binding((seq_len + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(128, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float16(0.0)
                                    for ax2_fused_0 in T.serial(43, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(11008, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // 8], model_layers_0_mlp_down_proj_q_scale3[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], lv73[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.if_then_else(v0 < seq_len, lv73[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0)) * dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4]
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float16(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((seq_len + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[0, v0, v1] = T.float16(0.0)
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(8, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(seq_len, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(2048, ax1_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (seq_len + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < seq_len)
                                            T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            T.writes(NT_matmul_intermediate[0, v0, v1])
                                            NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 2048), "float16", scope="local")
                    lv73_reindex_pad_shared = T.alloc_buffer((1, (seq_len + 31) // 32 * 32, 11008), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 2048, 11008), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((seq_len + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0.0)
                                            for ax3_0 in range(1376):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("lv73_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(11008, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(lv73[v0, v1, v2])
                                                                    T.writes(lv73_reindex_pad_shared[v0, v1, v2])
                                                                    lv73_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, lv73[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(11008, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_layers_0_mlp_down_proj_q_weight3[v1, v2 // 8], model_layers_0_mlp_down_proj_q_scale3[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(11008, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv73_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv73_reindex_pad_shared[0, v1, v3] * dequantize_intermediate_reindex_shared[0, v2, v3]
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((seq_len + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(2048, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < seq_len)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[0, v1, v2])
                                                        NT_matmul_intermediate[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize_NT_matmul14(model_embed_tokens_q_weight2: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale2: T.Buffer((151936, 64), "float16"), rms_norm145: T.Buffer((1, 1, 2048), "float16"), NT_matmul_intermediate: T.Buffer((1, 1, 151936), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 151936), scope="local")
        NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 151936), scope="local")
        model_embed_tokens_q_weight2_local = T.alloc_buffer((151936, 256), "uint32", scope="local")
        rms_norm145_shared = T.alloc_buffer((1, 1, 2048), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(37984, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(4, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("rms_norm145_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(2048, ax2_0 * 512 + ax2_1 * 128 + ax2_2 * 4 + ax2_3)
                                            T.reads(rms_norm145[v0, v1, v2])
                                            T.writes(rms_norm145_shared[v0, v1, v2])
                                            rms_norm145_shared[v0, v1, v2] = rms_norm145[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float32(0.0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(1):
                            for ax0_ax1_fused_1 in T.vectorized(1):
                                with T.block("model_embed_tokens_q_weight2_local"):
                                    v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                    T.reads(model_embed_tokens_q_weight2[v0, v1])
                                    T.writes(model_embed_tokens_q_weight2_local[v0, v1])
                                    model_embed_tokens_q_weight2_local[v0, v1] = model_embed_tokens_q_weight2[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], rms_norm145_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], model_embed_tokens_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], model_embed_tokens_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32])
                                    T.writes(NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + T.Cast("float32", rms_norm145_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight2_local[v0, vax1_0_fused_ax1_1_fused_0 * 32 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale2[v0, (vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) * T.Cast("float32", vax1_0_fused_ax1_1_fused_0 * 256 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) * T.float32(0.00048828125)
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float32(0.0)
                            for ax1 in range(4):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_2 in range(1):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0)
                            v0 = T.axis.spatial(151936, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                NT_matmul_intermediate[0, 0, v0] = T.float32(0.0)
                            NT_matmul_intermediate[0, 0, v0] = NT_matmul_intermediate[0, 0, v0] + NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func
    def fused_dequantize_NT_matmul4(model_embed_tokens_q_weight4: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale4: T.Buffer((151936, 64), "float16"), p_rms_norm291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        rms_norm291 = T.match_buffer(p_rms_norm291, (batch_size, 1, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 1, 151936))
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((151936, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 1) // 2 * 2, 1, 151936), scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 1) // 2 * 2, 1, 151936), scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 1) // 2 * 2, 1, 151936), scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(18992, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float32(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_embed_tokens_q_weight4[v0, v1 // 8], model_embed_tokens_q_scale4[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm291[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, rms_norm291[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float32(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float32(0.0)
                                        NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                        T.writes(NT_matmul_intermediate[v0, 0, v1])
                                        NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((151936, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 151936), scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, (batch_size + 3) // 4 * 4, 1, 151936), scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 151936), scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(18992, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float32(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_embed_tokens_q_weight4[v0, v1 // 8], model_embed_tokens_q_scale4[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], rms_norm291[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, rms_norm291[v0, 0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float32(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[v0, 0, v1] = T.float32(0.0)
                                            NT_matmul_intermediate_pad_local[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[v0, 0, v1])
                                            T.writes(NT_matmul_intermediate[v0, 0, v1])
                                            NT_matmul_intermediate[v0, 0, v1] = NT_matmul_intermediate_pad_local[v0, 0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 151936), scope="local")
                    rms_norm291_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 151936, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(4748, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float32(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("rms_norm291_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(rms_norm291[v1, 0, v2])
                                                                    T.writes(rms_norm291_reindex_pad_shared[v0, v1, v2])
                                                                    rms_norm291_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, rms_norm291[v1, 0, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_embed_tokens_q_weight4[v1, v2 // 8], model_embed_tokens_q_scale4[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale4[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], rms_norm291_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + T.Cast("float32", rms_norm291_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", dequantize_intermediate_reindex_shared[0, v2, v3])
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[v1, 0, v2])
                                                        NT_matmul_intermediate[v1, 0, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize_NT_matmul9(model_embed_tokens_q_weight3: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale3: T.Buffer((151936, 64), "float16"), p_take1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        take1 = T.match_buffer(p_take1, (1, batch_size, 2048), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (1, batch_size, 151936))
        # with T.block("root"):
        if T.tvm_thread_invariant(batch_size <= 2):
            with T.block("root"):
                T.reads()
                T.writes()
                dequantize_intermediate_local = T.alloc_buffer((151936, 2048), "float16", scope="local")
                NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (batch_size + 1) // 2 * 2, 151936), scope="local")
                NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (batch_size + 1) // 2 * 2, 151936), scope="local")
                NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (batch_size + 1) // 2 * 2, 151936), scope="local")
                for ax0_0 in T.thread_binding((batch_size + 1) // 2, thread="blockIdx.y"):
                    for ax1_fused_0 in T.thread_binding(18992, thread="blockIdx.x"):
                        for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                            for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_1_init, ax1_fused_2_init in T.grid(2, 2):
                                    for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                        with T.block("NT_matmul_rf_init"):
                                            vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                            v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1_init)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                            T.reads()
                                            T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                            NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float32(0.0)
                                for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax0_0_1, ax1 in T.grid(2, 8):
                                        for ax0_1 in T.vectorized(1):
                                            with T.block("dequantize"):
                                                v0 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                T.reads(model_embed_tokens_q_weight3[v0, v1 // 8], model_embed_tokens_q_scale3[v0, v1 // 32])
                                                T.writes(dequantize_intermediate_local[v0, v1])
                                                dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v0, v1 // 32]
                                    for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(2, 2, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                            with T.block("NT_matmul_rf_update"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax0_1)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], take1[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, take1[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                        for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                            for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                    for ax2 in range(2):
                                        for ax3_fused_2_1 in T.vectorized(2):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float32(0.0)
                                            for ax1 in range(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                    v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax2)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                        for ax2_fused_2, ax1 in T.grid(2, 2):
                            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    with T.block("NT_matmul"):
                                        vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                        v0 = T.axis.spatial((batch_size + 1) // 2 * 2, ax0_0 * 2 + ax1)
                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                        T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        with T.init():
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = T.float32(0.0)
                                        NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                        for ax0 in range(2):
                            for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax1_fused_2 in range(2):
                                    with T.block("NT_matmul_intermediate_pad"):
                                        v0 = T.axis.spatial(batch_size, ax0_0 * 2 + ax0)
                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                        T.where((ax0_0 - (batch_size + 1) // 2 < 0 or ax0_0 == 0) and ax0_0 * 2 + ax0 < batch_size)
                                        T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                        T.writes(NT_matmul_intermediate[0, v0, v1])
                                        NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
        else:
            if T.tvm_thread_invariant(batch_size <= 8):
                with T.block("root"):
                    T.reads()
                    T.writes()
                    dequantize_intermediate_local = T.alloc_buffer((151936, 2048), "float16", scope="local")
                    NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (batch_size + 3) // 4 * 4, 151936), scope="local")
                    NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((128, 1, (batch_size + 3) // 4 * 4, 151936), scope="local")
                    NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((32, 1, (batch_size + 3) // 4 * 4, 151936), scope="local")
                    for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"):
                        for ax1_fused_0 in T.thread_binding(18992, thread="blockIdx.x"):
                            for ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax0_1_init, ax1_fused_2_init in T.grid(4, 2):
                                        for ax2_fused_1_ax2_fused_3_fused_1_init in T.vectorized(4):
                                            with T.block("NT_matmul_rf_init"):
                                                vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_init)
                                                v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init)
                                                v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2_init)
                                                T.reads()
                                                T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = T.float32(0.0)
                                    for ax2_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax0_0_1, ax1 in T.grid(2, 8):
                                            for ax0_1 in T.vectorized(1):
                                                with T.block("dequantize"):
                                                    v0 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax0_0_1 + ax0_1)
                                                    v1 = T.axis.spatial(2048, ax2_fused_0 * 256 + ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax1)
                                                    T.reads(model_embed_tokens_q_weight3[v0, v1 // 8], model_embed_tokens_q_scale3[v0, v1 // 32])
                                                    T.writes(dequantize_intermediate_local[v0, v1])
                                                    dequantize_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v0, v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v0, v1 // 32]
                                        for ax0_1, ax1_fused_2, ax2_fused_2 in T.grid(4, 2, 2):
                                            for ax2_fused_1_ax2_fused_3_fused_1 in T.vectorized(4):
                                                with T.block("NT_matmul_rf_update"):
                                                    vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(128, ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_1 * 2 + ax1_fused_2)
                                                    vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2])
                                                    T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1], take1[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                                                    T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, 0, v0, v1] + T.Cast("float32", T.if_then_else(v0 < batch_size, take1[0, v0, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4], T.float16(0.0))) * T.Cast("float32", dequantize_intermediate_local[v1, vax2_fused_0 * 256 + vax2_fused_1_ax2_fused_3_fused // 4 * 8 + vax2_fused_2 * 4 + vax2_fused_1_ax2_fused_3_fused % 4])
                            for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                    for ax3_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                        for ax2 in range(4):
                                            for ax3_fused_2_1 in T.vectorized(2):
                                                with T.block("NT_matmul_rf_init"):
                                                    vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(32, ax0)
                                                    v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                    v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                    T.reads()
                                                    T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                    NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = T.float32(0.0)
                                                for ax1 in range(4):
                                                    with T.block("NT_matmul_rf_update"):
                                                        vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                                        v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2)
                                                        v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax3_fused_0_ax3_fused_1_fused * 2 + ax3_fused_2_0 * 2 + ax3_fused_2_1)
                                                        T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1])
                                                        T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                                        NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 4 + vax2_fused_1_ax2_fused_3_fused_1, 0, v0, v1]
                            for ax2_fused_2, ax1 in T.grid(2, 4):
                                for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax0 in T.thread_binding(32, thread="threadIdx.x"):
                                        with T.block("NT_matmul"):
                                            vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(32, ax0)
                                            v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax2_fused_0_ax2_fused_1_fused * 2 + ax2_fused_2)
                                            T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1])
                                            T.writes(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            with T.init():
                                                NT_matmul_intermediate_pad_local[0, v0, v1] = T.float32(0.0)
                                            NT_matmul_intermediate_pad_local[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, 0, v0, v1]
                            for ax0 in range(4):
                                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(4, thread="threadIdx.y"):
                                    for ax1_fused_2 in range(2):
                                        with T.block("NT_matmul_intermediate_pad"):
                                            v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                                            v1 = T.axis.spatial(151936, ax1_fused_0 * 8 + ax1_fused_0_ax1_fused_1_fused * 2 + ax1_fused_2)
                                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                                            T.reads(NT_matmul_intermediate_pad_local[0, v0, v1])
                                            T.writes(NT_matmul_intermediate[0, v0, v1])
                                            NT_matmul_intermediate[0, v0, v1] = NT_matmul_intermediate_pad_local[0, v0, v1]
            else:
                with T.block("root"):
                    T.reads()
                    T.writes()
                    NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 151936), scope="local")
                    take1_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 31) // 32 * 32, 2048), "float16", scope="shared")
                    dequantize_intermediate_reindex_shared = T.alloc_buffer((1, 151936, 2048), "float16", scope="shared")
                    for ax0_ax2_0_fused in T.thread_binding(4748, thread="blockIdx.y"):
                        for ax1_0 in T.thread_binding((batch_size + 31) // 32, thread="blockIdx.x"):
                            for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                                for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                                    for ax2_2 in T.thread_binding(8, thread="threadIdx.y"):
                                        for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                            for ax1_3_init, ax2_3_0_init in T.grid(4, 4):
                                                for ax2_3_1_init in T.vectorized(1):
                                                    with T.block("NT_matmul_init"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                                        v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0_init + ax2_3_1_init)
                                                        T.reads()
                                                        T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                        NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float32(0.0)
                                            for ax3_0 in range(256):
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("take1_reindex_pad_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(take1[v0, v1, v2])
                                                                    T.writes(take1_reindex_pad_shared[v0, v1, v2])
                                                                    take1_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, take1[v0, v1, v2], T.float16(0.0))
                                                for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"):
                                                    for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                        for ax0_ax1_ax2_fused_2 in range(4):
                                                            for ax0_ax1_ax2_fused_3 in T.vectorized(1):
                                                                with T.block("dequantize_intermediate_reindex_shared"):
                                                                    v0 = T.axis.spatial(1, 0)
                                                                    v1 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8)
                                                                    v2 = T.axis.spatial(2048, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8)
                                                                    T.reads(model_embed_tokens_q_weight3[v1, v2 // 8], model_embed_tokens_q_scale3[v1, v2 // 32])
                                                                    T.writes(dequantize_intermediate_reindex_shared[v0, v1, v2])
                                                                    dequantize_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale3[v1, v2 // 32]
                                                for ax3_1, ax1_3, ax2_3_0 in T.grid(8, 4, 4):
                                                    for ax2_3_1 in T.vectorized(1):
                                                        with T.block("NT_matmul_update"):
                                                            v0 = T.axis.spatial(1, 0)
                                                            v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                                            v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_0 + ax2_3_1)
                                                            v3 = T.axis.reduce(2048, ax3_0 * 8 + ax3_1)
                                                            T.reads(NT_matmul_intermediate_reindex_pad_local[0, v1, v2], take1_reindex_pad_shared[0, v1, v3], dequantize_intermediate_reindex_shared[0, v2, v3])
                                                            T.writes(NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                                            NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + T.Cast("float32", take1_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", dequantize_intermediate_reindex_shared[0, v2, v3])
                                            for ax0, ax1, ax2_0 in T.grid(1, 4, 4):
                                                for ax2_1_1 in T.vectorized(1):
                                                    with T.block("NT_matmul_intermediate_reindex_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial((batch_size + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                                        v2 = T.axis.spatial(151936, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1)
                                                        T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < batch_size)
                                                        T.reads(NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                                        T.writes(NT_matmul_intermediate[0, v1, v2])
                                                        NT_matmul_intermediate[0, v1, v2] = NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func
    def fused_dequantize_take1(model_embed_tokens_q_weight: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale: T.Buffer((151936, 64), "float16"), p_input_ids: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_ids = T.match_buffer(p_input_ids, (seq_len,), "int32")
        T_take_intermediate = T.match_buffer(p_output0, (seq_len, 2048), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(seq_len * 2, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2048)
                    v1 = T.axis.spatial(2048, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2048)
                    T.reads(input_ids[v0], model_embed_tokens_q_weight[input_ids[v0], v1 // 8], model_embed_tokens_q_scale[input_ids[v0], v1 // 32])
                    T.writes(T_take_intermediate[v0, v1])
                    T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[input_ids[v0], v1 // 8], T.Cast("uint32", v1 % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * model_embed_tokens_q_scale[input_ids[v0], v1 // 32]

    @T.prim_func
    def fused_reshape10_reshape11(lv184: T.Buffer((1, 16, 128), "float16"), T_reshape_intermediate_1: T.Buffer((1, 1, 2048), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(2, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape_1"):
                    v0 = T.axis.spatial(2048, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.reads(lv184[0, v0 // 128, v0 % 128])
                    T.writes(T_reshape_intermediate_1[0, 0, v0])
                    T_reshape_intermediate_1[0, 0, v0] = lv184[0, v0 // 128, v0 % 128]

    @T.prim_func
    def fused_reshape8_reshape9(add108: T.Buffer((1, 1, 2560), "float16"), T_reshape_intermediate_1: T.Buffer((1, 20, 128), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape_1"):
                    v0 = T.axis.spatial(20, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 128)
                    v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 128)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560)
                    T.reads(add108[0, 0, v0 * 128 + v1])
                    T.writes(T_reshape_intermediate_1[0, v0, v1])
                    T_reshape_intermediate_1[0, v0, v1] = add108[0, 0, v0 * 128 + v1]

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 20, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 2, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("llama_fused_rope"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < seq_len * 2560)
                    T.reads(position_map[v0], qkv[v0, v1, v2 + -64:v2 + -64 + 129])
                    T.writes(q[v0, v1, v2], k[v0, v1 + -16, v2], v[v0, v1 + -18, v2])
                    if v1 < 16:
                        freq = T.float32()
                        q[v0, v1, v2] = T.if_then_else(apply_rope > 0 and v2 < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[v0, v1, v2]) + T.sin(freq) * T.Cast("float32", T.if_then_else(v2 < 64, qkv[v0, v1, v2 + 64] * T.float16(-1.0), qkv[v0, v1, v2 + -64]))), where={freq: T.Cast("float32", position_map[v0]) / T.pow(T.float32(1000000.0), T.Cast("float32", v2 * 2 % 128) / T.float32(128.0))}), qkv[v0, v1, v2])
                    else:
                        if v1 < 18:
                            freq = T.float32()
                            k[v0, v1 + -16, v2] = T.if_then_else(apply_rope > 0 and v2 < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[v0, v1, v2]) + T.sin(freq) * T.Cast("float32", T.if_then_else(v2 < 64, qkv[v0, v1, v2 + 64] * T.float16(-1.0), qkv[v0, v1, v2 + -64]))), where={freq: T.Cast("float32", position_map[v0]) / T.pow(T.float32(1000000.0), T.Cast("float32", v2 * 2 % 128) / T.float32(128.0))}), qkv[v0, v1, v2])
                        else:
                            v[v0, v1 + -18, v2] = qkv[v0, v1, v2]

    @T.prim_func
    def fused_split1_silu1_multiply1(p_lv147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        lv147 = T.match_buffer(p_lv147, (1, seq_len, 22016), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (1, seq_len, 11008), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((seq_len * 11008 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 11008)
                    v1 = T.axis.spatial(11008, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 11008)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < seq_len * 11008)
                    T.reads(lv147[0, v0, v1:v1 + 11009])
                    T.writes(T_multiply_intermediate_1[0, v0, v1])
                    T_multiply_intermediate_1[0, v0, v1] = lv147[0, v0, v1] * T.sigmoid(lv147[0, v0, v1]) * lv147[0, v0, v1 + 11008]

    @T.prim_func
    def fused_split2_silu2_multiply2(lv437: T.Buffer((1, 1, 22016), "float16"), T_multiply_intermediate_1: T.Buffer((1, 1, 11008), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(11, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(11008, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 11008)
                    T.reads(lv437[0, 0, v0:v0 + 11009])
                    T.writes(T_multiply_intermediate_1[0, 0, v0])
                    T_multiply_intermediate_1[0, 0, v0] = lv437[0, 0, v0] * T.sigmoid(lv437[0, 0, v0]) * lv437[0, 0, v0 + 11008]

    @T.prim_func
    def fused_split_silu_multiply(p_lv2: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        lv2 = T.match_buffer(p_lv2, (batch_size, 1, 22016), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size, 1, 11008), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * 11008 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 11008)
                    v1 = T.axis.spatial(11008, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 11008)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * 11008)
                    T.reads(lv2[v0, 0, v1:v1 + 11009])
                    T.writes(T_multiply_intermediate_1[v0, 0, v1])
                    T_multiply_intermediate_1[v0, 0, v1] = lv2[v0, 0, v1] * T.sigmoid(lv2[v0, 0, v1]) * lv2[v0, 0, v1 + 11008]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("gather_2d"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n)
                    v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n)
                    T.reads(src[indices[v0], v1], indices[v0])
                    T.writes(dst[v0, v1])
                    dst[v0, v1] = src[indices[v0], v1]

    @T.prim_func
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int32(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((out_batch * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_get_index_from_sorted"):
                    v0 = T.axis.spatial(out_batch, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (vocab_size * out_batch) // vocab_size)
                    v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % vocab_size)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < out_batch * vocab_size)
                    T.reads(usample[v0, 0], cumsum_sorted[sample_indices[v0, 0], v1 - 1:v1 - 1 + 2], sample_indices[v0, 0], renorm_prob[sample_indices[v0, 0], 0], indices[sample_indices[v0, 0], T.min(0, v1):T.min(0, v1) + (v1 + 1)])
                    T.writes(output_index[v0, 0])
                    if usample[v0, 0] < cumsum_sorted[sample_indices[v0, 0], v1] / renorm_prob[sample_indices[v0, 0], 0] or v1 + 1 == vocab_size:
                        if v1 == 0:
                            output_index[v0, 0] = indices[sample_indices[v0, 0], 0]
                        else:
                            if usample[v0, 0] >= cumsum_sorted[sample_indices[v0, 0], v1 - 1] / renorm_prob[sample_indices[v0, 0], 0]:
                                output_index[v0, 0] = indices[sample_indices[v0, 0], v1]

    @T.prim_func
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_get_renorm_prob"):
                    v0 = T.axis.spatial(batch, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (vocab_size * batch) // vocab_size)
                    v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % vocab_size)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch * vocab_size)
                    T.reads(cumsum_sorted[v0, T.min(T.min(0, v1), v1 + 1):T.min(T.min(0, v1), v1 + 1) + (v1 + 2)], top_p[v0, 0], top_k[v0, 0])
                    T.writes(renorm_prob[v0, 0])
                    if not (cumsum_sorted[v0, 0] < top_p[v0, 0] and top_k[v0, 0] > 1):
                        renorm_prob[v0, 0] = cumsum_sorted[v0, 0]
                    else:
                        if cumsum_sorted[v0, v1] < top_p[v0, 0] and v1 + 1 < top_k[v0, 0]:
                            if v1 + 1 == vocab_size:
                                renorm_prob[v0, 0] = cumsum_sorted[v0, v1]
                            else:
                                if not (cumsum_sorted[v0, v1 + 1] < top_p[v0, 0] and v1 + 1 + 1 < top_k[v0, 0]):
                                    renorm_prob[v0, 0] = cumsum_sorted[v0, v1 + 1]

    @T.prim_func
    def gpu_2d_continuous_cumsum1(var_a: T.handle, var_out: T.handle, var_Tmp: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        m, n = T.int32(), T.int32()
        A = T.match_buffer(var_a, (m, n))
        Out = T.match_buffer(var_out, (m, n))
        Tmp = T.match_buffer(var_Tmp, (m, n))
        # with T.block("root"):
        ceil_log2: T.int32 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", n))))
        total_rounds: T.int32 = ceil_log2 // 9
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + 511) // 512, thread="blockIdx.x"):
                with T.block(""):
                    T.reads(A[by, bx * 512:bx * 512 + 512])
                    T.writes(Out[by, bx * 512:bx * 512 + 512], Tmp[by, bx])
                    local_buf = T.alloc_buffer((4,), scope="local")
                    shared_buf = T.alloc_buffer((512,), scope="shared")
                    for ty in T.thread_binding(4, thread="threadIdx.y"):
                        for tx in T.thread_binding(32, thread="threadIdx.x"):
                            tx_idx: T.int32 = bx * 512 + ty * 128 + tx * 4
                            for i in T.vectorized(4):
                                local_buf[i] = T.if_then_else(tx_idx + i < n, T.Cast("float32", A[by, tx_idx + i]), T.Cast("float32", 0))
                            for i in T.unroll(1, 4):
                                local_buf[i] = local_buf[i] + local_buf[i - 1]
                            for i in T.vectorized(4):
                                shared_buf[ty * 128 + tx * 4 + i] = local_buf[i]
                            for i in T.unroll(5):
                                for j in T.vectorized(4):
                                    idx: T.int32 = ty * 128 + tx * 4
                                    if tx >= T.shift_left(1, i):
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(1, i) * 4 + 4 - 1]
                            for i in T.unroll(1, 4):
                                for j in T.vectorized(4):
                                    if ty == 0:
                                        idx: T.int32 = i * 128 + tx * 4
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i * 128 - 1]
                            for i in T.vectorized(4):
                                idx: T.int32 = ty * 128 + tx * 4 + i
                                if bx * 512 + idx < n:
                                    Out[by, bx * 512 + idx] = shared_buf[idx]
                            if tx == 0 and ty == 0:
                                for i in T.vectorized(4):
                                    Tmp[by, bx] = shared_buf[511]
        for i in range(total_rounds):
            cur_len: T.int32 = (n + T.shift_left(1, 9 * (i + 1)) - 1) // T.shift_left(1, 9 * (i + 1))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + 511) // 512, thread="blockIdx.x"):
                    with T.block(""):
                        T.reads(Tmp[by, bx * 512 + i * ((n + 511) // 512):bx * 512 + i * ((n + 511) // 512) + 512])
                        T.writes(Tmp[by, T.min(bx * 512 + i * ((n + 511) // 512), (i + 1) * ((n + 511) // 512) + bx):T.min(bx * 512 + i * ((n + 511) // 512), (i + 1) * ((n + 511) // 512) + bx) + (T.max(bx * 512 + i * ((n + 511) // 512) + 511, (i + 1) * ((n + 511) // 512) + bx) + 1 - T.min(bx * 512 + i * ((n + 511) // 512), (i + 1) * ((n + 511) // 512) + bx))])
                        local_buf = T.alloc_buffer((4,), scope="local")
                        shared_buf = T.alloc_buffer((512,), scope="shared")
                        for ty in T.thread_binding(4, thread="threadIdx.y"):
                            for tx in T.thread_binding(32, thread="threadIdx.x"):
                                tx_idx: T.int32 = bx * 512 + ty * 128 + tx * 4
                                for i_1 in T.vectorized(4):
                                    local_buf[i_1] = T.if_then_else(tx_idx + i_1 < cur_len, T.Cast("float32", Tmp[by, i * ((n + 512 - 1) // 512) + tx_idx + i_1]), T.Cast("float32", 0))
                                for i_1 in T.unroll(1, 4):
                                    local_buf[i_1] = local_buf[i_1] + local_buf[i_1 - 1]
                                for i_1 in T.vectorized(4):
                                    shared_buf[ty * 128 + tx * 4 + i_1] = local_buf[i_1]
                                for i_1 in T.unroll(5):
                                    for j in T.vectorized(4):
                                        idx: T.int32 = ty * 128 + tx * 4
                                        if tx >= T.shift_left(1, i_1):
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(1, i_1) * 4 + 4 - 1]
                                for i_1 in T.unroll(1, 4):
                                    for j in T.vectorized(4):
                                        if ty == 0:
                                            idx: T.int32 = i_1 * 128 + tx * 4
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i_1 * 128 - 1]
                                for i_1 in T.vectorized(4):
                                    idx: T.int32 = ty * 128 + tx * 4 + i_1
                                    if bx * 512 + idx < cur_len:
                                        Tmp[by, i * ((n + 512 - 1) // 512) + bx * 512 + idx] = shared_buf[idx]
                                if tx == 0 and ty == 0:
                                    for i_1 in T.vectorized(4):
                                        Tmp[by, (i + 1) * ((n + 512 - 1) // 512) + bx] = shared_buf[511]
        for i in range(total_rounds - 1):
            real_idx: T.int32 = total_rounds - 1 - i - 1
            cur_len: T.int32 = (n + T.shift_left(1, 9 * (real_idx + 1)) - 1) // T.shift_left(1, 9 * (real_idx + 1))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + 511) // 512, thread="blockIdx.x"):
                    for ty in T.thread_binding(4, thread="threadIdx.y"):
                        for tx in T.thread_binding(32, thread="threadIdx.x"):
                            for i_1 in range(4):
                                idx: T.int32 = bx * 512 + ty * 128 + i_1 * 32 + tx
                                if idx < cur_len:
                                    Tmp[by, real_idx * ((n + 512 - 1) // 512) + idx] = Tmp[by, real_idx * ((n + 512 - 1) // 512) + idx] + T.if_then_else(bx > 0, Tmp[by, (real_idx + 1) * ((n + 512 - 1) // 512) + bx - 1], T.float32(0.0))
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + 511) // 512, thread="blockIdx.x"):
                for ty in T.thread_binding(4, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for i in range(4):
                            idx: T.int32 = bx * 512 + ty * 128 + i * 32 + tx
                            if idx < n:
                                Out[by, idx] = Out[by, idx] + T.if_then_else(bx > 0, Tmp[by, bx - 1], T.float32(0.0))

    @T.prim_func
    def index(var_rms_norm72: T.handle, index: T.Buffer((1, 1, 2048), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        rms_norm72 = T.match_buffer(var_rms_norm72, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(2, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("index"):
                    v0 = T.axis.spatial(2048, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.reads(rms_norm72[0, seq_len - 1, v0])
                    T.writes(index[0, 0, v0])
                    index[0, 0, v0] = rms_norm72[0, seq_len - 1, v0]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        n, vocab_size = T.int32(), T.int32()
        prob = T.match_buffer(var_prob, (n, vocab_size))
        batch_size = T.int32()
        uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1))
        row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32")
        token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32")
        # with T.block("root"):
        aggregate = T.alloc_buffer((), scope="local")
        sample_id_local = T.alloc_buffer((), "int32", scope="local")
        step_iter = T.alloc_buffer((), "int32", scope="local")
        for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            row_idx: T.int32 = row_indices[bx, 0]
            for ty in T.thread_binding(4, thread="threadIdx.y"):
                for tx in T.thread_binding(32, thread="threadIdx.x"):
                    u: T.float32 = uniform_samples[bx, 0]
                    aggregate[()] = T.Cast("float32", 0)
                    step_iter[()] = 0
                    while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < T.Cast("int64", (vocab_size + 512 - 1) // 512)):
                        with T.block(""):
                            T.reads(step_iter[()], prob[row_idx, step_iter[()] * 512 + ty * 128 + tx * 4:step_iter[()] * 512 + ty * 128 + tx * 4 + 4], aggregate[()])
                            T.writes(sample_id_local[()], aggregate[()])
                            prob_gt_threshold = T.alloc_buffer((4,), scope="local")
                            cumsum = T.alloc_buffer((512,), scope="shared")
                            greater_than_u = T.alloc_buffer((4,), "bool", scope="local")
                            mask = T.alloc_buffer((4,), "bool", scope="local")
                            valid = T.alloc_buffer((4,), "bool", scope="local")
                            indices = T.alloc_buffer((4,), "int32", scope="local")
                            step_aggregate = T.alloc_buffer((), scope="local")
                            for v in T.unroll(4):
                                idx: T.int32 = step_iter[()] * 512 + ty * 128 + tx * 4 + v
                                prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0))
                                prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0.0), prob_local, T.Cast("float32", 0))
                                valid[v] = prob_local > T.float32(0.0) and idx < vocab_size
                            with T.block(""):
                                T.reads(prob_gt_threshold[0:4])
                                T.writes(step_aggregate[()])
                                local_sum = T.alloc_buffer((), scope="local")
                                shared_buf = T.alloc_buffer((128,), scope="shared")
                                idx: T.int32 = ty * 32 + tx
                                local_sum[()] = T.Cast("float32", 0)
                                for i in T.unroll(4):
                                    local_sum[()] = local_sum[()] + prob_gt_threshold[i]
                                shared_buf[idx] = local_sum[()]
                                for i in T.unroll(7):
                                    if idx % T.shift_left(1, i + 1) == 0:
                                        shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(1, i)]
                                step_aggregate[()] = shared_buf[0]
                            if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)):
                                for i in T.unroll(1, 4):
                                    prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - 1]
                                for i in T.vectorized(4):
                                    cumsum[ty * 128 + tx * 4 + i] = prob_gt_threshold[i]
                                for i in T.unroll(5):
                                    for j in T.vectorized(4):
                                        idx: T.int32 = ty * 128 + tx * 4
                                        if tx >= T.shift_left(1, i):
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(1, i) * 4 + 4 - 1]
                                for i in T.unroll(1, 4):
                                    for j in T.vectorized(4):
                                        if ty == 0:
                                            idx: T.int32 = i * 128 + tx * 4
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[i * 128 - 1]
                                for v in T.unroll(4):
                                    greater_than_u[v] = cumsum[ty * 128 + tx * 4 + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07)
                                with T.block(""):
                                    T.reads(greater_than_u[0:4])
                                    T.writes(mask[0:4])
                                    shared_buf = T.alloc_buffer((128,), "bool", scope="shared")
                                    tx_idx: T.int32 = ty * 32 + tx
                                    shared_buf[tx_idx] = greater_than_u[3]
                                    mask[0] = T.if_then_else(tx_idx != 0, T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - 1]), greater_than_u[0])
                                    for i in T.unroll(1, 4):
                                        mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - 1])
                                for v in T.unroll(4):
                                    mask[v] = mask[v] and valid[v]
                                    indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + T.Cast("int64", ty * 128) + T.Cast("int64", tx * 4) + T.Cast("int64", v))
                                with T.block(""):
                                    T.reads(mask[0:4], indices[0:4])
                                    T.writes(sample_id_local[()])
                                    local_sum = T.alloc_buffer((), "int32", scope="local")
                                    shared_buf = T.alloc_buffer((128,), "int32", scope="shared")
                                    idx: T.int32 = ty * 32 + tx
                                    local_sum[()] = T.Cast("int32", vocab_size - 1)
                                    for i in T.unroll(4):
                                        if mask[i]:
                                            local_sum[()] = T.min(local_sum[()], indices[i])
                                    shared_buf[idx] = local_sum[()]
                                    for i in T.unroll(7):
                                        if idx % T.shift_left(1, i + 1) == 0:
                                            shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(1, i)])
                                    sample_id_local[()] = shared_buf[0]
                            aggregate[()] = aggregate[()] + step_aggregate[()]
                        step_iter[()] = step_iter[()] + 1
                    if tx == 0 and ty == 0:
                        token_ids[bx, 0] = sample_id_local[()]

    @T.prim_func
    def reshape(var_add324: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        add324 = T.match_buffer(var_add324, (batch_size, 1, 2560), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, 1, 20, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < batch_size * 2560)
                    T.reads(add324[v0, 0, v1 * 128 + v2])
                    T.writes(T_reshape[v0, 0, v1, v2])
                    T_reshape[v0, 0, v1, v2] = add324[v0, 0, v1 * 128 + v2]

    @T.prim_func
    def reshape1(var_reshape432: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        reshape432 = T.match_buffer(var_reshape432, (batch_size, 1, 20, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, 20, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < batch_size * 2560)
                    T.reads(reshape432[v0, 0, v1, v2])
                    T.writes(T_reshape[v0, v1, v2])
                    T_reshape[v0, v1, v2] = reshape432[v0, 0, v1, v2]

    @T.prim_func
    def reshape2(var_lv546: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        lv546 = T.match_buffer(var_lv546, (batch_size, 16, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, 1, 16, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * 2, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048)
                    v1 = T.axis.spatial(16, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2048 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(lv546[v0, v1, v2])
                    T.writes(T_reshape[v0, 0, v1, v2])
                    T_reshape[v0, 0, v1, v2] = lv546[v0, v1, v2]

    @T.prim_func
    def reshape3(var_reshape434: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        reshape434 = T.match_buffer(var_reshape434, (batch_size, 1, 16, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(batch_size * 2, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2048)
                    v1 = T.axis.spatial(2048, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2048)
                    T.reads(reshape434[v0, 0, v1 // 128, v1 % 128])
                    T.writes(T_reshape[v0, 0, v1])
                    T_reshape[v0, 0, v1] = reshape434[v0, 0, v1 // 128, v1 % 128]

    @T.prim_func
    def reshape4(var_add216: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        add216 = T.match_buffer(var_add216, (1, seq_len, 2560), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, seq_len, 20, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < seq_len * 2560)
                    T.reads(add216[0, v0, v1 * 128 + v2])
                    T.writes(T_reshape[0, v0, v1, v2])
                    T_reshape[0, v0, v1, v2] = add216[0, v0, v1 * 128 + v2]

    @T.prim_func
    def reshape5(var_reshape288: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        reshape288 = T.match_buffer(var_reshape288, (1, seq_len, 20, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, 20, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560)
                    v1 = T.axis.spatial(20, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < seq_len * 2560)
                    T.reads(reshape288[0, v0, v1, v2])
                    T.writes(T_reshape[v0, v1, v2])
                    T_reshape[v0, v1, v2] = reshape288[0, v0, v1, v2]

    @T.prim_func
    def reshape6(var_lv365: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        lv365 = T.match_buffer(var_lv365, (seq_len, 16, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, seq_len, 16, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(seq_len * 2, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048)
                    v1 = T.axis.spatial(16, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2048 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(lv365[v0, v1, v2])
                    T.writes(T_reshape[0, v0, v1, v2])
                    T_reshape[0, v0, v1, v2] = lv365[v0, v1, v2]

    @T.prim_func
    def reshape7(var_reshape290: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        reshape290 = T.match_buffer(var_reshape290, (1, seq_len, 16, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(seq_len * 2, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2048)
                    v1 = T.axis.spatial(2048, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2048)
                    T.reads(reshape290[0, v0, v1 // 128, v1 % 128])
                    T.writes(T_reshape[0, v0, v1])
                    T_reshape[0, v0, v1] = reshape290[0, v0, v1 // 128, v1 % 128]

    @T.prim_func
    def rms_norm(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight4: T.Buffer((2048,), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        input_embeds = T.match_buffer(var_input_embeds, (batch_size, 1, 2048), "float16")
        T_cast = T.match_buffer(var_T_cast, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((64, batch_size, 1), scope="local")
        for ax0_fused in T.thread_binding(batch_size, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, v0, 0])
                    T_multiply_red_rf_local[vax1_fused_1, v0, 0] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(32, 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, v0, 0], input_embeds[v0, 0, vax1_fused_0 * 64 + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, v0, 0])
                        T_multiply_red_rf_local[vax1_fused_1, v0, 0] = T_multiply_red_rf_local[vax1_fused_1, v0, 0] + T.Cast("float32", input_embeds[v0, 0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", input_embeds[v0, 0, vax1_fused_0 * 64 + vax1_fused_1])
            for ax1_fused in range(1):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, v0, 0])
                        T.writes(T_multiply_red_shared[v0, 0])
                        with T.init():
                            T_multiply_red_shared[v0, 0] = T.float32(0.0)
                        T_multiply_red_shared[v0, 0] = T_multiply_red_shared[v0, 0] + T_multiply_red_rf_local[vax1_fused_1, v0, 0]
            for ax0_ax1_fused_0 in range(32):
                for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(batch_size, ax0_fused)
                        v1 = T.axis.spatial(2048, ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1)
                        T.reads(T_multiply_red_shared[v0, 0], input_embeds[v0, 0, v1], model_layers_0_input_layernorm_weight4[v1])
                        T.writes(T_cast[v0, 0, v1])
                        T_cast[v0, 0, v1] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[v0, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embeds[v0, 0, v1]) * T.Cast("float32", model_layers_0_input_layernorm_weight4[v1]))

    @T.prim_func
    def rms_norm1(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight3: T.Buffer((2048,), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_embeds = T.match_buffer(var_input_embeds, (1, seq_len, 2048), "float16")
        T_cast = T.match_buffer(var_T_cast, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((1, seq_len), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((64, 1, seq_len), scope="local")
        for ax0_fused in T.thread_binding(seq_len, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, 0, v0])
                    T_multiply_red_rf_local[vax1_fused_1, 0, v0] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(32, 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, 0, v0], input_embeds[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, 0, v0])
                        T_multiply_red_rf_local[vax1_fused_1, 0, v0] = T_multiply_red_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", input_embeds[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", input_embeds[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
            for ax1_fused in range(1):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, 0, v0])
                        T.writes(T_multiply_red_shared[0, v0])
                        with T.init():
                            T_multiply_red_shared[0, v0] = T.float32(0.0)
                        T_multiply_red_shared[0, v0] = T_multiply_red_shared[0, v0] + T_multiply_red_rf_local[vax1_fused_1, 0, v0]
            for ax0_ax1_fused_0 in range(32):
                for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(seq_len, ax0_fused)
                        v1 = T.axis.spatial(2048, ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1)
                        T.reads(T_multiply_red_shared[0, v0], input_embeds[0, v0, v1], model_layers_0_input_layernorm_weight3[v1])
                        T.writes(T_cast[0, v0, v1])
                        T_cast[0, v0, v1] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[0, v0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embeds[0, v0, v1]) * T.Cast("float32", model_layers_0_input_layernorm_weight3[v1]))

    @T.prim_func
    def rms_norm2(input_embed: T.Buffer((1, 1, 2048), "float16"), model_layers_0_input_layernorm_weight2: T.Buffer((2048,), "float16"), T_cast: T.Buffer((1, 1, 2048), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_multiply_red_shared = T.alloc_buffer((1, 1), scope="shared")
        T_multiply_red_rf_local = T.alloc_buffer((64, 1, 1), scope="local")
        for ax0_fused in T.thread_binding(1, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("T_multiply_red_rf_init"):
                    vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
                    v0 = T.axis.spatial(1, 0)
                    T.reads()
                    T.writes(T_multiply_red_rf_local[vax1_fused_1, 0, 0])
                    T_multiply_red_rf_local[vax1_fused_1, 0, 0] = T.float32(0.0)
                for ax1_fused_0, u in T.grid(32, 1):
                    with T.block("T_multiply_red_rf_update"):
                        vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
                        v0 = T.axis.spatial(1, 0)
                        vax1_fused_0 = T.axis.reduce(32, ax1_fused_0)
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, 0, 0], input_embed[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
                        T.writes(T_multiply_red_rf_local[vax1_fused_1, 0, 0])
                        T_multiply_red_rf_local[vax1_fused_1, 0, 0] = T_multiply_red_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", input_embed[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", input_embed[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
            for ax1_fused in range(1):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        vax1_fused_1 = T.axis.reduce(64, ax0)
                        v0 = T.axis.spatial(1, 0)
                        T.reads(T_multiply_red_rf_local[vax1_fused_1, 0, 0])
                        T.writes(T_multiply_red_shared[0, 0])
                        with T.init():
                            T_multiply_red_shared[0, 0] = T.float32(0.0)
                        T_multiply_red_shared[0, 0] = T_multiply_red_shared[0, 0] + T_multiply_red_rf_local[vax1_fused_1, 0, 0]
            for ax0_fused_0 in range(32):
                for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        v0 = T.axis.spatial(2048, ax0_fused_0 * 64 + ax0_fused_1)
                        T.reads(T_multiply_red_shared[0, 0], input_embed[0, 0, v0], model_layers_0_input_layernorm_weight2[v0])
                        T.writes(T_cast[0, 0, v0])
                        T_cast[0, 0, v0] = T.Cast("float16", T.rsqrt(T_multiply_red_shared[0, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", input_embed[0, 0, v0]) * T.Cast("float32", model_layers_0_input_layernorm_weight2[v0]))

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((num_positions + num_samples + 1023) // 1024, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    v0 = T.axis.spatial(num_positions + num_samples, ax0_fused_0 * 1024 + ax0_fused_1)
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < num_positions + num_samples)
                    T.reads(top_prob_offsets[v0], sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], unsorted_probs[T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]):T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]) + (T.max(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions]) + 1 - T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions])), T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]):T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]) + (T.max(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]))], sample_indices[v0 + (0 - num_positions)], sampling_results[v0 + (0 - num_positions)])
                    T.writes(top_prob_indices[v0], top_prob_probs[v0], sampled_values[v0 + (0 - num_positions)])
                    if v0 < num_positions:
                        row: T.int32 = top_prob_offsets[v0] // vocab_size
                        col: T.int32 = top_prob_offsets[v0] % vocab_size
                        top_prob_indices[v0] = sorted_indices[row, col]
                        top_prob_probs[v0] = unsorted_probs[row, sorted_indices[row, col]]
                    else:
                        vj: T.int32 = v0 - num_positions
                        sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("scatter_2d"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n)
                    v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n)
                    T.reads(src[v0, v1], indices[v0])
                    T.writes(dst[indices[v0], v1])
                    dst[indices[v0], v1] = src[v0, v1]

    @T.prim_func
    def shape_func(H: T.Buffer((T.int64(5),), "int64")):
        T.func_attr({"tir.is_host_func": 1})
        H[T.int64(3)] = (H[T.int64(1)] + T.int64(4096) - T.int64(1)) // T.int64(4096)
        H[T.int64(2)] = T.int64(320) * ((H[T.int64(1)] + T.int64(4096) - T.int64(1)) // T.int64(4096))
        H[T.int64(4)] = T.int64(320) * H[T.int64(1)]

    @T.prim_func
    def shape_func1(H: T.Buffer((T.int64(4),), "int64")):
        T.func_attr({"tir.is_host_func": 1})
        H[T.int64(3)] = T.int64(320) * H[T.int64(1)]

    @T.prim_func
    def shape_func2(H: T.Buffer((T.int64(3),), "int64")):
        T.func_attr({"tir.is_host_func": 1})
        H[T.int64(2)] = T.int64(80) * H[T.int64(1)] * T.int64(4)

    @T.prim_func
    def shape_func3(H: T.Buffer((T.int64(4),), "int64")):
        T.func_attr({"tir.is_host_func": 1})
        H[T.int64(2)] = T.int64(80) * H[T.int64(1)] * T.int64(4)
        H[T.int64(3)] = T.int64(320) * H[T.int64(1)]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int32(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(32, thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + 31) // 32, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * 32 + ax0_1)
                        T.where(ax0_0 * 32 + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(32, thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + 31) // 32, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * 32 + ax0_1)
                        T.where(ax0_0 * 32 + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(32, thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(4096, l2_0 * 1024 + l2_1 * 32 + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * 4096 + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * 4096 + v2])
                            if v1 * 4096 + v2 < vocab_size:
                                softmax[v0, v1 * 4096 + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * 4096 + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * 4096 + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func
    def take(var_rms_norm218: T.handle, var_logit_positions: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        rms_norm218 = T.match_buffer(var_rms_norm218, (1, seq_len, 2048), "float16")
        batch_size = T.int32()
        logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32")
        T_take = T.match_buffer(var_T_take, (1, batch_size, 2048), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(batch_size * 2, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2048)
                    v1 = T.axis.spatial(2048, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2048)
                    T.reads(rms_norm218[0, logit_positions[v0], v1], logit_positions[v0])
                    T.writes(T_take[0, v0, v1])
                    T_take[0, v0, v1] = rms_norm218[0, logit_positions[v0], v1]

    @T.prim_func
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(), T.int32()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int32(is_size_var=True), T.int32(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((batch_size_1 * vocab_size_1 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("take_sorted_probs"):
                    v0 = T.axis.spatial(batch_size_1, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (vocab_size_1 * batch_size_1) // vocab_size_1)
                    v1 = T.axis.spatial(vocab_size_1, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % vocab_size_1)
                    T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size_1 * vocab_size_1)
                    T.reads(probs[v0, lv1[v0, v1]], lv1[v0, v1])
                    T.writes(take_sorted_probs[v0, v1])
                    take_sorted_probs[v0, v1] = probs[v0, lv1[v0, v1]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int32(), T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int32(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (36, seqlen, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (36, seqlen, 2, 128), "float16")
        # with T.block("root"):
        for p_h_d_fused_0 in T.thread_binding((seqlen * 256 + 1023) // 1024, thread="blockIdx.x"):
            for p_h_d_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy0"):
                    vp = T.axis.spatial(seqlen, (p_h_d_fused_0 * 1024 + p_h_d_fused_1) // 256)
                    vh = T.axis.spatial(2, (p_h_d_fused_0 * 1024 + p_h_d_fused_1) % 256 // 128)
                    vd = T.axis.spatial(128, (p_h_d_fused_0 * 1024 + p_h_d_fused_1) % 128)
                    T.where(p_h_d_fused_0 * 1024 + p_h_d_fused_1 < seqlen * 256)
                    T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd])
                    T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                    position: T.int32 = position_map[vp]
                    k_data[layer_id, vp, vh, vd] = pages[position // page_size, 0, vh, position % page_size, vd]
                    v_data[layer_id, vp, vh, vd] = pages[position // page_size, 1, vh, position % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        ntoken = T.int32(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 2, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos_h_f_fused_0 in T.thread_binding((ntoken * 256 + 1023) // 1024, thread="blockIdx.x"):
            for global_pos_h_f_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                if position_map[(global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) // 256] != -1:
                    with T.block("k_transpose_append"):
                        vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) // 256)
                        vh = T.axis.spatial(2, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) % 256 // 128)
                        vf = T.axis.spatial(128, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) % 128)
                        T.where(global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1 < ntoken * 256)
                        T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                        T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                        position: T.int32 = position_map[vgpos]
                        pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                    with T.block("v_transpose_append"):
                        vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) // 256)
                        vh = T.axis.spatial(2, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) % 256 // 128)
                        vf = T.axis.spatial(128, (global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1) % 128)
                        T.where(global_pos_h_f_fused_0 * 1024 + global_pos_h_f_fused_1 < ntoken * 256)
                        T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                        T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                        position: T.int32 = position_map[vgpos]
                        pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def _metadata() -> R.Object:
        shape_heap: R.Object = R.null_value()
        return R.str("{\"model_type\": \"qwen2\", \"quantization\": \"q4f16_1\", \"context_window_size\": 32768, \"sliding_window_size\": -1, \"attention_sink_size\": -1, \"prefill_chunk_size\": 2048, \"tensor_parallel_shards\": 1, \"pipeline_parallel_stages\": 1, \"disaggregation\": false, \"kv_state_kind\": \"kv_cache\", \"max_batch_size\": 80, \"params\": [{\"name\": \"model.embed_tokens.q_weight\", \"shape\": [151936, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.embed_tokens.q_scale\", \"shape\": [151936, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.0.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.1.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.2.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.3.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.4.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.5.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.6.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.7.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.8.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.9.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.10.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.11.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.12.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.13.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.14.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.15.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.16.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.17.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.18.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.19.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.20.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.21.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.22.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.23.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.24.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.25.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.26.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.27.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.28.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.29.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.30.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.31.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.32.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.33.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.34.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.self_attn.c_attn.q_weight\", \"shape\": [2560, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.self_attn.c_attn.q_scale\", \"shape\": [2560, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.self_attn.c_attn.bias\", \"shape\": [2560], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.self_attn.o_proj.q_weight\", \"shape\": [2048, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.self_attn.o_proj.q_scale\", \"shape\": [2048, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.mlp.gate_up_proj.q_weight\", \"shape\": [22016, 256], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.mlp.gate_up_proj.q_scale\", \"shape\": [22016, 64], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.mlp.down_proj.q_weight\", \"shape\": [2048, 1376], \"dtype\": \"uint32\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.mlp.down_proj.q_scale\", \"shape\": [2048, 344], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.input_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.layers.35.post_attention_layernorm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}, {\"name\": \"model.norm.weight\", \"shape\": [2048], \"dtype\": \"float16\", \"preprocs\": [], \"pipeline_stages\": [0]}], \"kv_cache\": {\"num_hidden_layers\": 36, \"num_attention_heads\": 16, \"num_key_value_heads\": 2, \"head_dim\": 128}, \"memory_usage\": {\"alloc_embedding_tensor\": 8388608, \"argsort_probs\": 0, \"batch_decode\": 57753600, \"batch_prefill\": 282779648, \"batch_verify\": 1478492160, \"create_tir_paged_kv_cache\": 0, \"decode\": 721920, \"embed\": 8388608, \"multinomial_from_uniform\": 320, \"prefill\": 234444288, \"renormalize_by_top_p\": 640, \"sample_with_top_p\": 640, \"sampler_take_probs\": 4160, \"sampler_verify_draft_tokens\": 0, \"softmax_with_temperature\": 0}}")

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"):
        R.func_attr({"relax.force_pure": True})
        shape_heap: R.Object = R.null_value()
        storage: R.Object = R.vm.alloc_storage(R.shape([8388608]), R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv: R.Tensor((2048, 2048), dtype="float16") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([2048, 2048]), R.dtype("float16"))
        R.vm.kill_object(storage)
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.force_pure": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        shape_heap: R.Tensor(dtype="int64", ndim=1) = R.call_builtin_with_ctx("vm.builtin.alloc_shape_heap", (R.prim_value(4),), sinfo_args=(R.Tensor(dtype="int64", ndim=1),))
        R.call_packed("vm.builtin.check_tensor_info", probs, R.prim_value(2), R.dtype("float32"), R.str("ErrorContext(fn=argsort_probs, loc=param[0], param=probs, annotation=R.Tensor((batch_size, vocab_size), dtype=\"float32\")) "), sinfo_args=(R.Tuple,))
        R.call_packed("vm.builtin.match_shape", probs, shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), R.str("ErrorContext(fn=argsort_probs, loc=param[0], param=probs, annotation=R.Tensor((batch_size, vocab_size), dtype=\"float32\")) "), sinfo_args=(R.Tuple,))
        cls.shape_func3(shape_heap)
        gv1051: R.Shape(ndim=1) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(1), R.prim_value(1), R.prim_value(2), sinfo_args=(R.Shape(ndim=1),))
        storage30: R.Object = R.vm.alloc_storage(gv1051, R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1052: R.Shape(ndim=2) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), sinfo_args=(R.Shape(ndim=2),))
        alloc1104: R.Tensor(dtype="int32", ndim=2) = R.vm.alloc_tensor(storage30, R.prim_value(0), gv1052, R.dtype("int32"))
        R.vm.kill_object(storage30)
        gv1053: R.Shape(ndim=1) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(1), R.prim_value(1), R.prim_value(3), sinfo_args=(R.Shape(ndim=1),))
        storage31: R.Object = R.vm.alloc_storage(gv1053, R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1054: R.Shape(ndim=2) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), sinfo_args=(R.Shape(ndim=2),))
        alloc1105: R.Tensor(dtype="float32", ndim=2) = R.vm.alloc_tensor(storage31, R.prim_value(0), gv1054, R.dtype("float32"))
        R.vm.kill_object(storage31)
        gv1055: R.Shape(ndim=1) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(1), R.prim_value(1), R.prim_value(3), sinfo_args=(R.Shape(ndim=1),))
        storage32: R.Object = R.vm.alloc_storage(gv1055, R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1056: R.Shape(ndim=2) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), sinfo_args=(R.Shape(ndim=2),))
        alloc1106: R.Tensor(dtype="float32", ndim=2) = R.vm.alloc_tensor(storage32, R.prim_value(0), gv1056, R.dtype("float32"))
        R.vm.kill_object(storage32)
        gv1057: R.Shape(ndim=1) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(1), R.prim_value(1), R.prim_value(3), sinfo_args=(R.Shape(ndim=1),))
        storage33: R.Object = R.vm.alloc_storage(gv1057, R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1058: R.Shape(ndim=2) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), sinfo_args=(R.Shape(ndim=2),))
        alloc1107: R.Tensor(dtype="int32", ndim=2) = R.vm.alloc_tensor(storage33, R.prim_value(0), gv1058, R.dtype("int32"))
        R.vm.kill_object(storage33)
        cls.argsort1(probs, alloc1104, alloc1105, alloc1106, alloc1107)
        lv: R.Tuple(R.Tensor(dtype="int32", ndim=2), R.Tensor(dtype="float32", ndim=2), R.Tensor(dtype="float32", ndim=2), R.Tensor(dtype="int32", ndim=2)) = alloc1104, alloc1105, alloc1106, alloc1107
        R.vm.kill_object(alloc1105)
        R.vm.kill_object(alloc1106)
        R.vm.kill_object(alloc1107)
        gv1059: R.Shape(ndim=1) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(1), R.prim_value(1), R.prim_value(2), sinfo_args=(R.Shape(ndim=1),))
        storage34: R.Object = R.vm.alloc_storage(gv1059, R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1060: R.Shape(ndim=2) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(2), R.prim_value(1), R.prim_value(0), R.prim_value(1), R.prim_value(1), sinfo_args=(R.Shape(ndim=2),))
        alloc1108: R.Tensor(dtype="float32", ndim=2) = R.vm.alloc_tensor(storage34, R.prim_value(0), gv1060, R.dtype("float32"))
        R.vm.kill_object(storage34)
        cls.take_sorted_probs(probs, alloc1104, alloc1108)
        gv1: R.Tuple(R.Tensor(dtype="float32", ndim=2), R.Tensor(dtype="int32", ndim=2)) = alloc1108, alloc1104
        R.vm.kill_object(alloc1104)
        R.vm.kill_object(alloc1108)
        gv1061: R.Tensor(dtype="float32", ndim=2) = gv1[0]
        R.call_packed("vm.builtin.match_shape", gv1061, shape_heap, R.prim_value(2), R.prim_value(3), R.prim_value(0), R.prim_value(3), R.prim_value(1), R.str("ErrorContext(fn=argsort_probs, loc=return, annotation=R.Tuple(R.Tensor((batch_size, vocab_size), dtype=\"float32\"), R.Tensor((batch_size, vocab_size), dtype=\"int32\"))) "), sinfo_args=(R.Tuple,))
        gv1062: R.Tensor(dtype="int32", ndim=2) = gv1[1]
        R.call_packed("vm.builtin.match_shape", gv1062, shape_heap, R.prim_value(2), R.prim_value(3), R.prim_value(0), R.prim_value(3), R.prim_value(1), R.str("ErrorContext(fn=argsort_probs, loc=return, annotation=R.Tuple(R.Tensor((batch_size, vocab_size), dtype=\"float32\"), R.Tensor((batch_size, vocab_size), dtype=\"int32\"))) "), sinfo_args=(R.Tuple,))
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.force_pure": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        shape_heap: R.Tensor(dtype="int64", ndim=1) = R.call_builtin_with_ctx("vm.builtin.alloc_shape_heap", (R.prim_value(2),), sinfo_args=(R.Tensor(dtype="int64", ndim=1),))
        R.call_packed("vm.builtin.check_tensor_info", input_embeds, R.prim_value(3), R.dtype("float16"), R.str("ErrorContext(fn=batch_decode, loc=param[0], param=input_embeds, annotation=R.Tensor((batch_size, 1, 2048), dtype=\"float16\")) "), sinfo_args=(R.Tuple,))
        R.call_packed("vm.builtin.check_tuple_info", packed_params, R.prim_value(399), R.str("ErrorContext(fn=batch_decode, loc=param[2], param=packed_params, annotation=R.Tuple(R.Tensor((151936, 256), dtype=\"uint32\"), R.Tensor((151936, 64), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 256), dtype=\"uint32\"), R.Tensor((2560, 64), dtype=\"float16\"), R.Tensor((2560,), dtype=\"float16\"), R.Tensor((2048, 256), dtype=\"uint32\"), R.Tensor((2048, 64), dtype=\"float16\"), R.Tensor((22016, 256), dtype=\"uint32\"), R.Tensor((22016, 64), dtype=\"float16\"), R.Tensor((2048, 1376), dtype=\"uint32\"), R.Tensor((2048, 344), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"))) "), sinfo_args=(R.Tuple,))
        R.call_packed("vm.builtin.match_shape", input_embeds, shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), R.str("ErrorContext(fn=batch_decode, loc=param[0], param=input_embeds, annotation=R.Tensor((batch_size, 1, 2048), dtype=\"float16\")) "), sinfo_args=(R.Tuple,))
        model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[11]
        storage35: R.Object = R.vm.alloc_storage(R.shape([3522560]), R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1063: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1109: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1063, R.dtype("float16"))
        cls.rms_norm(input_embeds, model_layers_0_input_layernorm_weight4, alloc1109)
        R.vm.kill_object(model_layers_0_input_layernorm_weight4)
        model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
        model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
        model_layers_0_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[4]
        storage36: R.Object = R.vm.alloc_storage(R.shape([3522560]), R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1064: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1110: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1064, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4, alloc1109, model_layers_0_self_attn_c_attn_bias4, alloc1110)
        R.vm.kill_object(alloc1109)
        R.vm.kill_object(model_layers_0_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_0_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_0_self_attn_c_attn_bias4)
        gv1065: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape432: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1110, gv1065, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1110)
        gv1066: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape433: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape432, gv1066, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape432)
        gv1067: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1111: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1067, R.dtype("float16"))
        _888: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape433, alloc1111)
        R.vm.kill_object(reshape433)
        gv1068: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape434: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1111, gv1068, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1111)
        gv1069: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape435: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape434, gv1069, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape434)
        model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
        model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
        gv1070: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1112: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1070, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4, reshape435, alloc1112)
        R.vm.kill_object(reshape435)
        R.vm.kill_object(model_layers_0_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_0_self_attn_o_proj_q_scale4)
        model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[12]
        gv1071: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1113: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1071, R.dtype("float16"))
        storage37: R.Object = R.vm.alloc_storage(R.shape([1761280]), R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1072: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1114: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1072, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1112, input_embeds, model_layers_0_post_attention_layernorm_weight4, alloc1113, alloc1114)
        R.vm.kill_object(alloc1112)
        R.vm.kill_object(model_layers_0_post_attention_layernorm_weight4)
        lv: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1113, alloc1114
        model_layers_0_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
        model_layers_0_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
        gv1073: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1115: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1073, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_0_mlp_gate_up_proj_q_weight4, model_layers_0_mlp_gate_up_proj_q_scale4, alloc1113, alloc1115)
        R.vm.kill_object(alloc1113)
        R.vm.kill_object(model_layers_0_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_0_mlp_gate_up_proj_q_scale4)
        gv1074: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1116: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1074, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1115, alloc1116)
        R.vm.kill_object(alloc1115)
        model_layers_0_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
        model_layers_0_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
        gv1075: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1117: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1075, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_0_mlp_down_proj_q_weight4, model_layers_0_mlp_down_proj_q_scale4, alloc1116, alloc1117)
        R.vm.kill_object(alloc1116)
        R.vm.kill_object(model_layers_0_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_0_mlp_down_proj_q_scale4)
        model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[22]
        gv1076: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1118: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1076, R.dtype("float16"))
        storage38: R.Object = R.vm.alloc_storage(R.shape([327680]), R.prim_value(0), R.dtype("uint8"), R.str("global"))
        gv1077: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1119: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1077, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1117, alloc1114, model_layers_1_input_layernorm_weight4, alloc1118, alloc1119)
        R.vm.kill_object(alloc1114)
        R.vm.kill_object(alloc1117)
        R.vm.kill_object(model_layers_1_input_layernorm_weight4)
        lv2: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1118, alloc1119
        model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
        model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
        model_layers_1_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[15]
        gv1078: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1120: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1078, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4, alloc1118, model_layers_1_self_attn_c_attn_bias4, alloc1120)
        R.vm.kill_object(alloc1118)
        R.vm.kill_object(model_layers_1_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_1_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_1_self_attn_c_attn_bias4)
        gv1079: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape436: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1120, gv1079, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1120)
        gv1080: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape437: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape436, gv1080, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape436)
        gv1081: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1121: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1081, R.dtype("float16"))
        _896: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape437, alloc1121)
        R.vm.kill_object(reshape437)
        gv1082: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape438: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1121, gv1082, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1121)
        gv1083: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape439: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape438, gv1083, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape438)
        model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
        model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
        gv1084: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1122: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1084, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4, reshape439, alloc1122)
        R.vm.kill_object(reshape439)
        R.vm.kill_object(model_layers_1_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_1_self_attn_o_proj_q_scale4)
        model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[23]
        gv1085: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1123: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1085, R.dtype("float16"))
        gv1086: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1124: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1086, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1122, alloc1119, model_layers_1_post_attention_layernorm_weight4, alloc1123, alloc1124)
        R.vm.kill_object(alloc1119)
        R.vm.kill_object(alloc1122)
        R.vm.kill_object(model_layers_1_post_attention_layernorm_weight4)
        lv4: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1123, alloc1124
        model_layers_1_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
        model_layers_1_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
        gv1087: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1125: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1087, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_1_mlp_gate_up_proj_q_weight4, model_layers_1_mlp_gate_up_proj_q_scale4, alloc1123, alloc1125)
        R.vm.kill_object(alloc1123)
        R.vm.kill_object(model_layers_1_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_1_mlp_gate_up_proj_q_scale4)
        gv1088: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1126: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1088, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1125, alloc1126)
        R.vm.kill_object(alloc1125)
        model_layers_1_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
        model_layers_1_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
        gv1089: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1127: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1089, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_1_mlp_down_proj_q_weight4, model_layers_1_mlp_down_proj_q_scale4, alloc1126, alloc1127)
        R.vm.kill_object(alloc1126)
        R.vm.kill_object(model_layers_1_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_1_mlp_down_proj_q_scale4)
        model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[33]
        gv1090: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1128: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1090, R.dtype("float16"))
        gv1091: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1129: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1091, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1127, alloc1124, model_layers_2_input_layernorm_weight4, alloc1128, alloc1129)
        R.vm.kill_object(alloc1124)
        R.vm.kill_object(alloc1127)
        R.vm.kill_object(model_layers_2_input_layernorm_weight4)
        lv6: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1128, alloc1129
        model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
        model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
        model_layers_2_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[26]
        gv1092: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1130: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1092, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4, alloc1128, model_layers_2_self_attn_c_attn_bias4, alloc1130)
        R.vm.kill_object(alloc1128)
        R.vm.kill_object(model_layers_2_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_2_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_2_self_attn_c_attn_bias4)
        gv1093: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape440: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1130, gv1093, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1130)
        gv1094: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape441: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape440, gv1094, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape440)
        gv1095: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1131: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1095, R.dtype("float16"))
        _904: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape441, alloc1131)
        R.vm.kill_object(reshape441)
        gv1096: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape442: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1131, gv1096, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1131)
        gv1097: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape443: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape442, gv1097, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape442)
        model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
        model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
        gv1098: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1132: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1098, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4, reshape443, alloc1132)
        R.vm.kill_object(reshape443)
        R.vm.kill_object(model_layers_2_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_2_self_attn_o_proj_q_scale4)
        model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
        gv1099: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1133: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1099, R.dtype("float16"))
        gv1100: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1134: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1100, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1132, alloc1129, model_layers_2_post_attention_layernorm_weight4, alloc1133, alloc1134)
        R.vm.kill_object(alloc1129)
        R.vm.kill_object(alloc1132)
        R.vm.kill_object(model_layers_2_post_attention_layernorm_weight4)
        lv8: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1133, alloc1134
        model_layers_2_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
        model_layers_2_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
        gv1101: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1135: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1101, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_2_mlp_gate_up_proj_q_weight4, model_layers_2_mlp_gate_up_proj_q_scale4, alloc1133, alloc1135)
        R.vm.kill_object(alloc1133)
        R.vm.kill_object(model_layers_2_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_2_mlp_gate_up_proj_q_scale4)
        gv1102: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1136: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1102, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1135, alloc1136)
        R.vm.kill_object(alloc1135)
        model_layers_2_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
        model_layers_2_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
        gv1103: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1137: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1103, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_2_mlp_down_proj_q_weight4, model_layers_2_mlp_down_proj_q_scale4, alloc1136, alloc1137)
        R.vm.kill_object(alloc1136)
        R.vm.kill_object(model_layers_2_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_2_mlp_down_proj_q_scale4)
        model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[44]
        gv1104: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1138: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1104, R.dtype("float16"))
        gv1105: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1139: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1105, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1137, alloc1134, model_layers_3_input_layernorm_weight4, alloc1138, alloc1139)
        R.vm.kill_object(alloc1134)
        R.vm.kill_object(alloc1137)
        R.vm.kill_object(model_layers_3_input_layernorm_weight4)
        lv10: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1138, alloc1139
        model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
        model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
        model_layers_3_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[37]
        gv1106: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1140: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1106, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4, alloc1138, model_layers_3_self_attn_c_attn_bias4, alloc1140)
        R.vm.kill_object(alloc1138)
        R.vm.kill_object(model_layers_3_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_3_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_3_self_attn_c_attn_bias4)
        gv1107: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape444: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1140, gv1107, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1140)
        gv1108: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape445: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape444, gv1108, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape444)
        gv1109: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1141: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1109, R.dtype("float16"))
        _912: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape445, alloc1141)
        R.vm.kill_object(reshape445)
        gv1110: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape446: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1141, gv1110, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1141)
        gv1111: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape447: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape446, gv1111, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape446)
        model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
        model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
        gv1112: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1142: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1112, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4, reshape447, alloc1142)
        R.vm.kill_object(reshape447)
        R.vm.kill_object(model_layers_3_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_3_self_attn_o_proj_q_scale4)
        model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[45]
        gv1113: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1143: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1113, R.dtype("float16"))
        gv1114: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1144: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1114, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1142, alloc1139, model_layers_3_post_attention_layernorm_weight4, alloc1143, alloc1144)
        R.vm.kill_object(alloc1139)
        R.vm.kill_object(alloc1142)
        R.vm.kill_object(model_layers_3_post_attention_layernorm_weight4)
        lv12: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1143, alloc1144
        model_layers_3_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
        model_layers_3_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
        gv1115: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1145: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1115, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_3_mlp_gate_up_proj_q_weight4, model_layers_3_mlp_gate_up_proj_q_scale4, alloc1143, alloc1145)
        R.vm.kill_object(alloc1143)
        R.vm.kill_object(model_layers_3_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_3_mlp_gate_up_proj_q_scale4)
        gv1116: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1146: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1116, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1145, alloc1146)
        R.vm.kill_object(alloc1145)
        model_layers_3_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
        model_layers_3_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
        gv1117: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1147: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1117, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_3_mlp_down_proj_q_weight4, model_layers_3_mlp_down_proj_q_scale4, alloc1146, alloc1147)
        R.vm.kill_object(alloc1146)
        R.vm.kill_object(model_layers_3_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_3_mlp_down_proj_q_scale4)
        model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[55]
        gv1118: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1148: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1118, R.dtype("float16"))
        gv1119: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1149: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1119, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1147, alloc1144, model_layers_4_input_layernorm_weight4, alloc1148, alloc1149)
        R.vm.kill_object(alloc1144)
        R.vm.kill_object(alloc1147)
        R.vm.kill_object(model_layers_4_input_layernorm_weight4)
        lv14: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1148, alloc1149
        model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
        model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
        model_layers_4_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[48]
        gv1120: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1150: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1120, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4, alloc1148, model_layers_4_self_attn_c_attn_bias4, alloc1150)
        R.vm.kill_object(alloc1148)
        R.vm.kill_object(model_layers_4_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_4_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_4_self_attn_c_attn_bias4)
        gv1121: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape448: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1150, gv1121, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1150)
        gv1122: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape449: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape448, gv1122, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape448)
        gv1123: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1151: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1123, R.dtype("float16"))
        _920: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape449, alloc1151)
        R.vm.kill_object(reshape449)
        gv1124: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape450: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1151, gv1124, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1151)
        gv1125: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape451: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape450, gv1125, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape450)
        model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
        model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
        gv1126: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1152: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1126, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4, reshape451, alloc1152)
        R.vm.kill_object(reshape451)
        R.vm.kill_object(model_layers_4_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_4_self_attn_o_proj_q_scale4)
        model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[56]
        gv1127: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1153: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1127, R.dtype("float16"))
        gv1128: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1154: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1128, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1152, alloc1149, model_layers_4_post_attention_layernorm_weight4, alloc1153, alloc1154)
        R.vm.kill_object(alloc1149)
        R.vm.kill_object(alloc1152)
        R.vm.kill_object(model_layers_4_post_attention_layernorm_weight4)
        lv16: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1153, alloc1154
        model_layers_4_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
        model_layers_4_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
        gv1129: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1155: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1129, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_4_mlp_gate_up_proj_q_weight4, model_layers_4_mlp_gate_up_proj_q_scale4, alloc1153, alloc1155)
        R.vm.kill_object(alloc1153)
        R.vm.kill_object(model_layers_4_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_4_mlp_gate_up_proj_q_scale4)
        gv1130: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1156: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1130, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1155, alloc1156)
        R.vm.kill_object(alloc1155)
        model_layers_4_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
        model_layers_4_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
        gv1131: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1157: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1131, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_4_mlp_down_proj_q_weight4, model_layers_4_mlp_down_proj_q_scale4, alloc1156, alloc1157)
        R.vm.kill_object(alloc1156)
        R.vm.kill_object(model_layers_4_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_4_mlp_down_proj_q_scale4)
        model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66]
        gv1132: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1158: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1132, R.dtype("float16"))
        gv1133: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1159: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1133, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1157, alloc1154, model_layers_5_input_layernorm_weight4, alloc1158, alloc1159)
        R.vm.kill_object(alloc1154)
        R.vm.kill_object(alloc1157)
        R.vm.kill_object(model_layers_5_input_layernorm_weight4)
        lv18: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1158, alloc1159
        model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
        model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
        model_layers_5_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[59]
        gv1134: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1160: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1134, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4, alloc1158, model_layers_5_self_attn_c_attn_bias4, alloc1160)
        R.vm.kill_object(alloc1158)
        R.vm.kill_object(model_layers_5_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_5_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_5_self_attn_c_attn_bias4)
        gv1135: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape452: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1160, gv1135, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1160)
        gv1136: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape453: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape452, gv1136, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape452)
        gv1137: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1161: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1137, R.dtype("float16"))
        _928: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape453, alloc1161)
        R.vm.kill_object(reshape453)
        gv1138: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape454: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1161, gv1138, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1161)
        gv1139: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape455: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape454, gv1139, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape454)
        model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
        model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
        gv1140: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1162: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1140, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4, reshape455, alloc1162)
        R.vm.kill_object(reshape455)
        R.vm.kill_object(model_layers_5_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_5_self_attn_o_proj_q_scale4)
        model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[67]
        gv1141: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1163: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1141, R.dtype("float16"))
        gv1142: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1164: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1142, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1162, alloc1159, model_layers_5_post_attention_layernorm_weight4, alloc1163, alloc1164)
        R.vm.kill_object(alloc1159)
        R.vm.kill_object(alloc1162)
        R.vm.kill_object(model_layers_5_post_attention_layernorm_weight4)
        lv20: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1163, alloc1164
        model_layers_5_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
        model_layers_5_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
        gv1143: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1165: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1143, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_5_mlp_gate_up_proj_q_weight4, model_layers_5_mlp_gate_up_proj_q_scale4, alloc1163, alloc1165)
        R.vm.kill_object(alloc1163)
        R.vm.kill_object(model_layers_5_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_5_mlp_gate_up_proj_q_scale4)
        gv1144: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1166: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1144, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1165, alloc1166)
        R.vm.kill_object(alloc1165)
        model_layers_5_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
        model_layers_5_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
        gv1145: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1167: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1145, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_5_mlp_down_proj_q_weight4, model_layers_5_mlp_down_proj_q_scale4, alloc1166, alloc1167)
        R.vm.kill_object(alloc1166)
        R.vm.kill_object(model_layers_5_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_5_mlp_down_proj_q_scale4)
        model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[77]
        gv1146: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1168: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1146, R.dtype("float16"))
        gv1147: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1169: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1147, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1167, alloc1164, model_layers_6_input_layernorm_weight4, alloc1168, alloc1169)
        R.vm.kill_object(alloc1164)
        R.vm.kill_object(alloc1167)
        R.vm.kill_object(model_layers_6_input_layernorm_weight4)
        lv22: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1168, alloc1169
        model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
        model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
        model_layers_6_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[70]
        gv1148: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1170: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1148, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4, alloc1168, model_layers_6_self_attn_c_attn_bias4, alloc1170)
        R.vm.kill_object(alloc1168)
        R.vm.kill_object(model_layers_6_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_6_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_6_self_attn_c_attn_bias4)
        gv1149: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape456: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1170, gv1149, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1170)
        gv1150: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape457: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape456, gv1150, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape456)
        gv1151: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1171: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1151, R.dtype("float16"))
        _936: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape457, alloc1171)
        R.vm.kill_object(reshape457)
        gv1152: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape458: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1171, gv1152, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1171)
        gv1153: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape459: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape458, gv1153, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape458)
        model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
        model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
        gv1154: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1172: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1154, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4, reshape459, alloc1172)
        R.vm.kill_object(reshape459)
        R.vm.kill_object(model_layers_6_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_6_self_attn_o_proj_q_scale4)
        model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[78]
        gv1155: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1173: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1155, R.dtype("float16"))
        gv1156: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1174: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1156, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1172, alloc1169, model_layers_6_post_attention_layernorm_weight4, alloc1173, alloc1174)
        R.vm.kill_object(alloc1169)
        R.vm.kill_object(alloc1172)
        R.vm.kill_object(model_layers_6_post_attention_layernorm_weight4)
        lv24: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1173, alloc1174
        model_layers_6_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
        model_layers_6_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
        gv1157: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1175: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1157, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_6_mlp_gate_up_proj_q_weight4, model_layers_6_mlp_gate_up_proj_q_scale4, alloc1173, alloc1175)
        R.vm.kill_object(alloc1173)
        R.vm.kill_object(model_layers_6_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_6_mlp_gate_up_proj_q_scale4)
        gv1158: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1176: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1158, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1175, alloc1176)
        R.vm.kill_object(alloc1175)
        model_layers_6_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
        model_layers_6_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
        gv1159: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1177: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1159, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_6_mlp_down_proj_q_weight4, model_layers_6_mlp_down_proj_q_scale4, alloc1176, alloc1177)
        R.vm.kill_object(alloc1176)
        R.vm.kill_object(model_layers_6_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_6_mlp_down_proj_q_scale4)
        model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[88]
        gv1160: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1178: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1160, R.dtype("float16"))
        gv1161: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1179: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1161, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1177, alloc1174, model_layers_7_input_layernorm_weight4, alloc1178, alloc1179)
        R.vm.kill_object(alloc1174)
        R.vm.kill_object(alloc1177)
        R.vm.kill_object(model_layers_7_input_layernorm_weight4)
        lv26: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1178, alloc1179
        model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
        model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
        model_layers_7_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[81]
        gv1162: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1180: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1162, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4, alloc1178, model_layers_7_self_attn_c_attn_bias4, alloc1180)
        R.vm.kill_object(alloc1178)
        R.vm.kill_object(model_layers_7_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_7_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_7_self_attn_c_attn_bias4)
        gv1163: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape460: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1180, gv1163, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1180)
        gv1164: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape461: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape460, gv1164, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape460)
        gv1165: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1181: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1165, R.dtype("float16"))
        _944: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape461, alloc1181)
        R.vm.kill_object(reshape461)
        gv1166: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape462: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1181, gv1166, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1181)
        gv1167: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape463: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape462, gv1167, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape462)
        model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
        model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
        gv1168: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1182: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1168, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4, reshape463, alloc1182)
        R.vm.kill_object(reshape463)
        R.vm.kill_object(model_layers_7_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_7_self_attn_o_proj_q_scale4)
        model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[89]
        gv1169: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1183: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1169, R.dtype("float16"))
        gv1170: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1184: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1170, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1182, alloc1179, model_layers_7_post_attention_layernorm_weight4, alloc1183, alloc1184)
        R.vm.kill_object(alloc1179)
        R.vm.kill_object(alloc1182)
        R.vm.kill_object(model_layers_7_post_attention_layernorm_weight4)
        lv28: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1183, alloc1184
        model_layers_7_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
        model_layers_7_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
        gv1171: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1185: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1171, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_7_mlp_gate_up_proj_q_weight4, model_layers_7_mlp_gate_up_proj_q_scale4, alloc1183, alloc1185)
        R.vm.kill_object(alloc1183)
        R.vm.kill_object(model_layers_7_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_7_mlp_gate_up_proj_q_scale4)
        gv1172: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1186: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1172, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1185, alloc1186)
        R.vm.kill_object(alloc1185)
        model_layers_7_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
        model_layers_7_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
        gv1173: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1187: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1173, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_7_mlp_down_proj_q_weight4, model_layers_7_mlp_down_proj_q_scale4, alloc1186, alloc1187)
        R.vm.kill_object(alloc1186)
        R.vm.kill_object(model_layers_7_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_7_mlp_down_proj_q_scale4)
        model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[99]
        gv1174: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1188: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1174, R.dtype("float16"))
        gv1175: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1189: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1175, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1187, alloc1184, model_layers_8_input_layernorm_weight4, alloc1188, alloc1189)
        R.vm.kill_object(alloc1184)
        R.vm.kill_object(alloc1187)
        R.vm.kill_object(model_layers_8_input_layernorm_weight4)
        lv30: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1188, alloc1189
        model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
        model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
        model_layers_8_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[92]
        gv1176: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1190: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1176, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4, alloc1188, model_layers_8_self_attn_c_attn_bias4, alloc1190)
        R.vm.kill_object(alloc1188)
        R.vm.kill_object(model_layers_8_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_8_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_8_self_attn_c_attn_bias4)
        gv1177: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape464: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1190, gv1177, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1190)
        gv1178: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape465: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape464, gv1178, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape464)
        gv1179: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1191: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1179, R.dtype("float16"))
        _952: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape465, alloc1191)
        R.vm.kill_object(reshape465)
        gv1180: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape466: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1191, gv1180, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1191)
        gv1181: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape467: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape466, gv1181, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape466)
        model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
        model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
        gv1182: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1192: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1182, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4, reshape467, alloc1192)
        R.vm.kill_object(reshape467)
        R.vm.kill_object(model_layers_8_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_8_self_attn_o_proj_q_scale4)
        model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100]
        gv1183: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1193: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1183, R.dtype("float16"))
        gv1184: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1194: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1184, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1192, alloc1189, model_layers_8_post_attention_layernorm_weight4, alloc1193, alloc1194)
        R.vm.kill_object(alloc1189)
        R.vm.kill_object(alloc1192)
        R.vm.kill_object(model_layers_8_post_attention_layernorm_weight4)
        lv32: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1193, alloc1194
        model_layers_8_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
        model_layers_8_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
        gv1185: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1195: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1185, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_8_mlp_gate_up_proj_q_weight4, model_layers_8_mlp_gate_up_proj_q_scale4, alloc1193, alloc1195)
        R.vm.kill_object(alloc1193)
        R.vm.kill_object(model_layers_8_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_8_mlp_gate_up_proj_q_scale4)
        gv1186: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1196: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1186, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1195, alloc1196)
        R.vm.kill_object(alloc1195)
        model_layers_8_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
        model_layers_8_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
        gv1187: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1197: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1187, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_8_mlp_down_proj_q_weight4, model_layers_8_mlp_down_proj_q_scale4, alloc1196, alloc1197)
        R.vm.kill_object(alloc1196)
        R.vm.kill_object(model_layers_8_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_8_mlp_down_proj_q_scale4)
        model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[110]
        gv1188: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1198: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1188, R.dtype("float16"))
        gv1189: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1199: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1189, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1197, alloc1194, model_layers_9_input_layernorm_weight4, alloc1198, alloc1199)
        R.vm.kill_object(alloc1194)
        R.vm.kill_object(alloc1197)
        R.vm.kill_object(model_layers_9_input_layernorm_weight4)
        lv34: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1198, alloc1199
        model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
        model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
        model_layers_9_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[103]
        gv1190: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1200: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1190, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4, alloc1198, model_layers_9_self_attn_c_attn_bias4, alloc1200)
        R.vm.kill_object(alloc1198)
        R.vm.kill_object(model_layers_9_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_9_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_9_self_attn_c_attn_bias4)
        gv1191: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape468: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1200, gv1191, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1200)
        gv1192: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape469: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape468, gv1192, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape468)
        gv1193: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1201: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1193, R.dtype("float16"))
        _960: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape469, alloc1201)
        R.vm.kill_object(reshape469)
        gv1194: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape470: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1201, gv1194, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1201)
        gv1195: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape471: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape470, gv1195, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape470)
        model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
        model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
        gv1196: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1202: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1196, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4, reshape471, alloc1202)
        R.vm.kill_object(reshape471)
        R.vm.kill_object(model_layers_9_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_9_self_attn_o_proj_q_scale4)
        model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[111]
        gv1197: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1203: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1197, R.dtype("float16"))
        gv1198: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1204: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1198, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1202, alloc1199, model_layers_9_post_attention_layernorm_weight4, alloc1203, alloc1204)
        R.vm.kill_object(alloc1199)
        R.vm.kill_object(alloc1202)
        R.vm.kill_object(model_layers_9_post_attention_layernorm_weight4)
        lv36: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1203, alloc1204
        model_layers_9_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
        model_layers_9_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
        gv1199: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1205: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1199, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_9_mlp_gate_up_proj_q_weight4, model_layers_9_mlp_gate_up_proj_q_scale4, alloc1203, alloc1205)
        R.vm.kill_object(alloc1203)
        R.vm.kill_object(model_layers_9_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_9_mlp_gate_up_proj_q_scale4)
        gv1200: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1206: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1200, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1205, alloc1206)
        R.vm.kill_object(alloc1205)
        model_layers_9_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
        model_layers_9_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
        gv1201: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1207: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1201, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_9_mlp_down_proj_q_weight4, model_layers_9_mlp_down_proj_q_scale4, alloc1206, alloc1207)
        R.vm.kill_object(alloc1206)
        R.vm.kill_object(model_layers_9_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_9_mlp_down_proj_q_scale4)
        model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[121]
        gv1202: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1208: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1202, R.dtype("float16"))
        gv1203: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1209: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1203, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1207, alloc1204, model_layers_10_input_layernorm_weight4, alloc1208, alloc1209)
        R.vm.kill_object(alloc1204)
        R.vm.kill_object(alloc1207)
        R.vm.kill_object(model_layers_10_input_layernorm_weight4)
        lv38: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1208, alloc1209
        model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
        model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
        model_layers_10_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[114]
        gv1204: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1210: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1204, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4, alloc1208, model_layers_10_self_attn_c_attn_bias4, alloc1210)
        R.vm.kill_object(alloc1208)
        R.vm.kill_object(model_layers_10_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_10_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_10_self_attn_c_attn_bias4)
        gv1205: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape472: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1210, gv1205, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1210)
        gv1206: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape473: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape472, gv1206, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape472)
        gv1207: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1211: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1207, R.dtype("float16"))
        _968: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape473, alloc1211)
        R.vm.kill_object(reshape473)
        gv1208: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape474: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1211, gv1208, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1211)
        gv1209: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape475: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape474, gv1209, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape474)
        model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
        model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
        gv1210: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1212: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1210, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4, reshape475, alloc1212)
        R.vm.kill_object(reshape475)
        R.vm.kill_object(model_layers_10_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_10_self_attn_o_proj_q_scale4)
        model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[122]
        gv1211: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1213: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1211, R.dtype("float16"))
        gv1212: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1214: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1212, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1212, alloc1209, model_layers_10_post_attention_layernorm_weight4, alloc1213, alloc1214)
        R.vm.kill_object(alloc1209)
        R.vm.kill_object(alloc1212)
        R.vm.kill_object(model_layers_10_post_attention_layernorm_weight4)
        lv40: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1213, alloc1214
        model_layers_10_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
        model_layers_10_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
        gv1213: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1215: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1213, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_10_mlp_gate_up_proj_q_weight4, model_layers_10_mlp_gate_up_proj_q_scale4, alloc1213, alloc1215)
        R.vm.kill_object(alloc1213)
        R.vm.kill_object(model_layers_10_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_10_mlp_gate_up_proj_q_scale4)
        gv1214: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1216: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1214, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1215, alloc1216)
        R.vm.kill_object(alloc1215)
        model_layers_10_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
        model_layers_10_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
        gv1215: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1217: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1215, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_10_mlp_down_proj_q_weight4, model_layers_10_mlp_down_proj_q_scale4, alloc1216, alloc1217)
        R.vm.kill_object(alloc1216)
        R.vm.kill_object(model_layers_10_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_10_mlp_down_proj_q_scale4)
        model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132]
        gv1216: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1218: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1216, R.dtype("float16"))
        gv1217: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1219: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1217, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1217, alloc1214, model_layers_11_input_layernorm_weight4, alloc1218, alloc1219)
        R.vm.kill_object(alloc1214)
        R.vm.kill_object(alloc1217)
        R.vm.kill_object(model_layers_11_input_layernorm_weight4)
        lv42: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1218, alloc1219
        model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
        model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
        model_layers_11_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[125]
        gv1218: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1220: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1218, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4, alloc1218, model_layers_11_self_attn_c_attn_bias4, alloc1220)
        R.vm.kill_object(alloc1218)
        R.vm.kill_object(model_layers_11_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_11_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_11_self_attn_c_attn_bias4)
        gv1219: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape476: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1220, gv1219, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1220)
        gv1220: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape477: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape476, gv1220, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape476)
        gv1221: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1221: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1221, R.dtype("float16"))
        _976: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape477, alloc1221)
        R.vm.kill_object(reshape477)
        gv1222: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape478: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1221, gv1222, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1221)
        gv1223: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape479: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape478, gv1223, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape478)
        model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
        model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
        gv1224: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1222: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1224, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4, reshape479, alloc1222)
        R.vm.kill_object(reshape479)
        R.vm.kill_object(model_layers_11_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_11_self_attn_o_proj_q_scale4)
        model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[133]
        gv1225: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1223: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1225, R.dtype("float16"))
        gv1226: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1224: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1226, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1222, alloc1219, model_layers_11_post_attention_layernorm_weight4, alloc1223, alloc1224)
        R.vm.kill_object(alloc1219)
        R.vm.kill_object(alloc1222)
        R.vm.kill_object(model_layers_11_post_attention_layernorm_weight4)
        lv44: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1223, alloc1224
        model_layers_11_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
        model_layers_11_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
        gv1227: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1225: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1227, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_11_mlp_gate_up_proj_q_weight4, model_layers_11_mlp_gate_up_proj_q_scale4, alloc1223, alloc1225)
        R.vm.kill_object(alloc1223)
        R.vm.kill_object(model_layers_11_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_11_mlp_gate_up_proj_q_scale4)
        gv1228: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1226: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1228, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1225, alloc1226)
        R.vm.kill_object(alloc1225)
        model_layers_11_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
        model_layers_11_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
        gv1229: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1227: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1229, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_11_mlp_down_proj_q_weight4, model_layers_11_mlp_down_proj_q_scale4, alloc1226, alloc1227)
        R.vm.kill_object(alloc1226)
        R.vm.kill_object(model_layers_11_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_11_mlp_down_proj_q_scale4)
        model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[143]
        gv1230: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1228: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1230, R.dtype("float16"))
        gv1231: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1229: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1231, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1227, alloc1224, model_layers_12_input_layernorm_weight4, alloc1228, alloc1229)
        R.vm.kill_object(alloc1224)
        R.vm.kill_object(alloc1227)
        R.vm.kill_object(model_layers_12_input_layernorm_weight4)
        lv46: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1228, alloc1229
        model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
        model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
        model_layers_12_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[136]
        gv1232: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1230: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1232, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4, alloc1228, model_layers_12_self_attn_c_attn_bias4, alloc1230)
        R.vm.kill_object(alloc1228)
        R.vm.kill_object(model_layers_12_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_12_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_12_self_attn_c_attn_bias4)
        gv1233: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape480: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1230, gv1233, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1230)
        gv1234: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape481: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape480, gv1234, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape480)
        gv1235: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1231: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1235, R.dtype("float16"))
        _984: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape481, alloc1231)
        R.vm.kill_object(reshape481)
        gv1236: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape482: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1231, gv1236, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1231)
        gv1237: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape483: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape482, gv1237, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape482)
        model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
        model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
        gv1238: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1232: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1238, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4, reshape483, alloc1232)
        R.vm.kill_object(reshape483)
        R.vm.kill_object(model_layers_12_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_12_self_attn_o_proj_q_scale4)
        model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[144]
        gv1239: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1233: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1239, R.dtype("float16"))
        gv1240: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1234: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1240, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1232, alloc1229, model_layers_12_post_attention_layernorm_weight4, alloc1233, alloc1234)
        R.vm.kill_object(alloc1229)
        R.vm.kill_object(alloc1232)
        R.vm.kill_object(model_layers_12_post_attention_layernorm_weight4)
        lv48: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1233, alloc1234
        model_layers_12_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
        model_layers_12_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
        gv1241: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1235: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1241, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_12_mlp_gate_up_proj_q_weight4, model_layers_12_mlp_gate_up_proj_q_scale4, alloc1233, alloc1235)
        R.vm.kill_object(alloc1233)
        R.vm.kill_object(model_layers_12_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_12_mlp_gate_up_proj_q_scale4)
        gv1242: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1236: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1242, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1235, alloc1236)
        R.vm.kill_object(alloc1235)
        model_layers_12_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
        model_layers_12_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
        gv1243: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1237: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1243, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_12_mlp_down_proj_q_weight4, model_layers_12_mlp_down_proj_q_scale4, alloc1236, alloc1237)
        R.vm.kill_object(alloc1236)
        R.vm.kill_object(model_layers_12_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_12_mlp_down_proj_q_scale4)
        model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
        gv1244: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1238: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1244, R.dtype("float16"))
        gv1245: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1239: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1245, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1237, alloc1234, model_layers_13_input_layernorm_weight4, alloc1238, alloc1239)
        R.vm.kill_object(alloc1234)
        R.vm.kill_object(alloc1237)
        R.vm.kill_object(model_layers_13_input_layernorm_weight4)
        lv50: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1238, alloc1239
        model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
        model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
        model_layers_13_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[147]
        gv1246: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1240: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1246, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4, alloc1238, model_layers_13_self_attn_c_attn_bias4, alloc1240)
        R.vm.kill_object(alloc1238)
        R.vm.kill_object(model_layers_13_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_13_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_13_self_attn_c_attn_bias4)
        gv1247: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape484: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1240, gv1247, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1240)
        gv1248: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape485: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape484, gv1248, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape484)
        gv1249: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1241: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1249, R.dtype("float16"))
        _992: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape485, alloc1241)
        R.vm.kill_object(reshape485)
        gv1250: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape486: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1241, gv1250, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1241)
        gv1251: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape487: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape486, gv1251, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape486)
        model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
        model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
        gv1252: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1242: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1252, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4, reshape487, alloc1242)
        R.vm.kill_object(reshape487)
        R.vm.kill_object(model_layers_13_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_13_self_attn_o_proj_q_scale4)
        model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[155]
        gv1253: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1243: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1253, R.dtype("float16"))
        gv1254: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1244: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1254, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1242, alloc1239, model_layers_13_post_attention_layernorm_weight4, alloc1243, alloc1244)
        R.vm.kill_object(alloc1239)
        R.vm.kill_object(alloc1242)
        R.vm.kill_object(model_layers_13_post_attention_layernorm_weight4)
        lv52: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1243, alloc1244
        model_layers_13_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
        model_layers_13_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
        gv1255: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1245: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1255, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_13_mlp_gate_up_proj_q_weight4, model_layers_13_mlp_gate_up_proj_q_scale4, alloc1243, alloc1245)
        R.vm.kill_object(alloc1243)
        R.vm.kill_object(model_layers_13_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_13_mlp_gate_up_proj_q_scale4)
        gv1256: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1246: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1256, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1245, alloc1246)
        R.vm.kill_object(alloc1245)
        model_layers_13_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
        model_layers_13_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
        gv1257: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1247: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1257, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_13_mlp_down_proj_q_weight4, model_layers_13_mlp_down_proj_q_scale4, alloc1246, alloc1247)
        R.vm.kill_object(alloc1246)
        R.vm.kill_object(model_layers_13_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_13_mlp_down_proj_q_scale4)
        model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[165]
        gv1258: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1248: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1258, R.dtype("float16"))
        gv1259: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1249: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1259, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1247, alloc1244, model_layers_14_input_layernorm_weight4, alloc1248, alloc1249)
        R.vm.kill_object(alloc1244)
        R.vm.kill_object(alloc1247)
        R.vm.kill_object(model_layers_14_input_layernorm_weight4)
        lv54: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1248, alloc1249
        model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
        model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
        model_layers_14_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[158]
        gv1260: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1250: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1260, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4, alloc1248, model_layers_14_self_attn_c_attn_bias4, alloc1250)
        R.vm.kill_object(alloc1248)
        R.vm.kill_object(model_layers_14_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_14_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_14_self_attn_c_attn_bias4)
        gv1261: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape488: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1250, gv1261, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1250)
        gv1262: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape489: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape488, gv1262, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape488)
        gv1263: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1251: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1263, R.dtype("float16"))
        _1000: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape489, alloc1251)
        R.vm.kill_object(reshape489)
        gv1264: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape490: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1251, gv1264, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1251)
        gv1265: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape491: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape490, gv1265, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape490)
        model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
        model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
        gv1266: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1252: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1266, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4, reshape491, alloc1252)
        R.vm.kill_object(reshape491)
        R.vm.kill_object(model_layers_14_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_14_self_attn_o_proj_q_scale4)
        model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[166]
        gv1267: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1253: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1267, R.dtype("float16"))
        gv1268: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1254: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1268, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1252, alloc1249, model_layers_14_post_attention_layernorm_weight4, alloc1253, alloc1254)
        R.vm.kill_object(alloc1249)
        R.vm.kill_object(alloc1252)
        R.vm.kill_object(model_layers_14_post_attention_layernorm_weight4)
        lv56: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1253, alloc1254
        model_layers_14_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
        model_layers_14_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
        gv1269: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1255: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1269, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_14_mlp_gate_up_proj_q_weight4, model_layers_14_mlp_gate_up_proj_q_scale4, alloc1253, alloc1255)
        R.vm.kill_object(alloc1253)
        R.vm.kill_object(model_layers_14_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_14_mlp_gate_up_proj_q_scale4)
        gv1270: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1256: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1270, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1255, alloc1256)
        R.vm.kill_object(alloc1255)
        model_layers_14_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
        model_layers_14_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
        gv1271: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1257: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1271, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_14_mlp_down_proj_q_weight4, model_layers_14_mlp_down_proj_q_scale4, alloc1256, alloc1257)
        R.vm.kill_object(alloc1256)
        R.vm.kill_object(model_layers_14_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_14_mlp_down_proj_q_scale4)
        model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[176]
        gv1272: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1258: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1272, R.dtype("float16"))
        gv1273: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1259: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1273, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1257, alloc1254, model_layers_15_input_layernorm_weight4, alloc1258, alloc1259)
        R.vm.kill_object(alloc1254)
        R.vm.kill_object(alloc1257)
        R.vm.kill_object(model_layers_15_input_layernorm_weight4)
        lv58: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1258, alloc1259
        model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
        model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
        model_layers_15_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[169]
        gv1274: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1260: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1274, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4, alloc1258, model_layers_15_self_attn_c_attn_bias4, alloc1260)
        R.vm.kill_object(alloc1258)
        R.vm.kill_object(model_layers_15_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_15_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_15_self_attn_c_attn_bias4)
        gv1275: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape492: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1260, gv1275, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1260)
        gv1276: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape493: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape492, gv1276, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape492)
        gv1277: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1261: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1277, R.dtype("float16"))
        _1008: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape493, alloc1261)
        R.vm.kill_object(reshape493)
        gv1278: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape494: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1261, gv1278, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1261)
        gv1279: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape495: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape494, gv1279, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape494)
        model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
        model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
        gv1280: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1262: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1280, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4, reshape495, alloc1262)
        R.vm.kill_object(reshape495)
        R.vm.kill_object(model_layers_15_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_15_self_attn_o_proj_q_scale4)
        model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[177]
        gv1281: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1263: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1281, R.dtype("float16"))
        gv1282: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1264: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1282, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1262, alloc1259, model_layers_15_post_attention_layernorm_weight4, alloc1263, alloc1264)
        R.vm.kill_object(alloc1259)
        R.vm.kill_object(alloc1262)
        R.vm.kill_object(model_layers_15_post_attention_layernorm_weight4)
        lv60: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1263, alloc1264
        model_layers_15_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
        model_layers_15_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
        gv1283: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1265: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1283, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_15_mlp_gate_up_proj_q_weight4, model_layers_15_mlp_gate_up_proj_q_scale4, alloc1263, alloc1265)
        R.vm.kill_object(alloc1263)
        R.vm.kill_object(model_layers_15_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_15_mlp_gate_up_proj_q_scale4)
        gv1284: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1266: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1284, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1265, alloc1266)
        R.vm.kill_object(alloc1265)
        model_layers_15_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
        model_layers_15_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
        gv1285: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1267: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1285, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_15_mlp_down_proj_q_weight4, model_layers_15_mlp_down_proj_q_scale4, alloc1266, alloc1267)
        R.vm.kill_object(alloc1266)
        R.vm.kill_object(model_layers_15_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_15_mlp_down_proj_q_scale4)
        model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
        gv1286: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1268: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1286, R.dtype("float16"))
        gv1287: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1269: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1287, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1267, alloc1264, model_layers_16_input_layernorm_weight4, alloc1268, alloc1269)
        R.vm.kill_object(alloc1264)
        R.vm.kill_object(alloc1267)
        R.vm.kill_object(model_layers_16_input_layernorm_weight4)
        lv62: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1268, alloc1269
        model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
        model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
        model_layers_16_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[180]
        gv1288: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1270: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1288, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4, alloc1268, model_layers_16_self_attn_c_attn_bias4, alloc1270)
        R.vm.kill_object(alloc1268)
        R.vm.kill_object(model_layers_16_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_16_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_16_self_attn_c_attn_bias4)
        gv1289: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape496: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1270, gv1289, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1270)
        gv1290: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape497: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape496, gv1290, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape496)
        gv1291: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1271: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1291, R.dtype("float16"))
        _1016: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape497, alloc1271)
        R.vm.kill_object(reshape497)
        gv1292: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape498: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1271, gv1292, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1271)
        gv1293: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape499: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape498, gv1293, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape498)
        model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
        model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
        gv1294: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1272: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1294, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4, reshape499, alloc1272)
        R.vm.kill_object(reshape499)
        R.vm.kill_object(model_layers_16_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_16_self_attn_o_proj_q_scale4)
        model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
        gv1295: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1273: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1295, R.dtype("float16"))
        gv1296: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1274: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1296, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1272, alloc1269, model_layers_16_post_attention_layernorm_weight4, alloc1273, alloc1274)
        R.vm.kill_object(alloc1269)
        R.vm.kill_object(alloc1272)
        R.vm.kill_object(model_layers_16_post_attention_layernorm_weight4)
        lv64: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1273, alloc1274
        model_layers_16_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
        model_layers_16_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
        gv1297: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1275: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1297, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_16_mlp_gate_up_proj_q_weight4, model_layers_16_mlp_gate_up_proj_q_scale4, alloc1273, alloc1275)
        R.vm.kill_object(alloc1273)
        R.vm.kill_object(model_layers_16_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_16_mlp_gate_up_proj_q_scale4)
        gv1298: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1276: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1298, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1275, alloc1276)
        R.vm.kill_object(alloc1275)
        model_layers_16_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
        model_layers_16_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
        gv1299: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1277: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1299, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_16_mlp_down_proj_q_weight4, model_layers_16_mlp_down_proj_q_scale4, alloc1276, alloc1277)
        R.vm.kill_object(alloc1276)
        R.vm.kill_object(model_layers_16_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_16_mlp_down_proj_q_scale4)
        model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[198]
        gv1300: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1278: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1300, R.dtype("float16"))
        gv1301: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1279: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1301, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1277, alloc1274, model_layers_17_input_layernorm_weight4, alloc1278, alloc1279)
        R.vm.kill_object(alloc1274)
        R.vm.kill_object(alloc1277)
        R.vm.kill_object(model_layers_17_input_layernorm_weight4)
        lv66: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1278, alloc1279
        model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
        model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
        model_layers_17_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[191]
        gv1302: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1280: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1302, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4, alloc1278, model_layers_17_self_attn_c_attn_bias4, alloc1280)
        R.vm.kill_object(alloc1278)
        R.vm.kill_object(model_layers_17_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_17_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_17_self_attn_c_attn_bias4)
        gv1303: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape500: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1280, gv1303, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1280)
        gv1304: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape501: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape500, gv1304, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape500)
        gv1305: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1281: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1305, R.dtype("float16"))
        _1024: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape501, alloc1281)
        R.vm.kill_object(reshape501)
        gv1306: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape502: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1281, gv1306, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1281)
        gv1307: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape503: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape502, gv1307, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape502)
        model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
        model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
        gv1308: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1282: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1308, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4, reshape503, alloc1282)
        R.vm.kill_object(reshape503)
        R.vm.kill_object(model_layers_17_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_17_self_attn_o_proj_q_scale4)
        model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[199]
        gv1309: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1283: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1309, R.dtype("float16"))
        gv1310: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1284: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1310, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1282, alloc1279, model_layers_17_post_attention_layernorm_weight4, alloc1283, alloc1284)
        R.vm.kill_object(alloc1279)
        R.vm.kill_object(alloc1282)
        R.vm.kill_object(model_layers_17_post_attention_layernorm_weight4)
        lv68: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1283, alloc1284
        model_layers_17_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
        model_layers_17_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
        gv1311: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1285: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1311, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_17_mlp_gate_up_proj_q_weight4, model_layers_17_mlp_gate_up_proj_q_scale4, alloc1283, alloc1285)
        R.vm.kill_object(alloc1283)
        R.vm.kill_object(model_layers_17_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_17_mlp_gate_up_proj_q_scale4)
        gv1312: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1286: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1312, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1285, alloc1286)
        R.vm.kill_object(alloc1285)
        model_layers_17_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
        model_layers_17_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
        gv1313: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1287: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1313, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_17_mlp_down_proj_q_weight4, model_layers_17_mlp_down_proj_q_scale4, alloc1286, alloc1287)
        R.vm.kill_object(alloc1286)
        R.vm.kill_object(model_layers_17_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_17_mlp_down_proj_q_scale4)
        model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[209]
        gv1314: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1288: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1314, R.dtype("float16"))
        gv1315: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1289: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1315, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1287, alloc1284, model_layers_18_input_layernorm_weight4, alloc1288, alloc1289)
        R.vm.kill_object(alloc1284)
        R.vm.kill_object(alloc1287)
        R.vm.kill_object(model_layers_18_input_layernorm_weight4)
        lv70: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1288, alloc1289
        model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
        model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
        model_layers_18_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[202]
        gv1316: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1290: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1316, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4, alloc1288, model_layers_18_self_attn_c_attn_bias4, alloc1290)
        R.vm.kill_object(alloc1288)
        R.vm.kill_object(model_layers_18_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_18_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_18_self_attn_c_attn_bias4)
        gv1317: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape504: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1290, gv1317, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1290)
        gv1318: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape505: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape504, gv1318, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape504)
        gv1319: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1291: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1319, R.dtype("float16"))
        _1032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape505, alloc1291)
        R.vm.kill_object(reshape505)
        gv1320: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape506: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1291, gv1320, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1291)
        gv1321: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape507: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape506, gv1321, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape506)
        model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
        model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
        gv1322: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1292: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1322, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4, reshape507, alloc1292)
        R.vm.kill_object(reshape507)
        R.vm.kill_object(model_layers_18_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_18_self_attn_o_proj_q_scale4)
        model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210]
        gv1323: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1293: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1323, R.dtype("float16"))
        gv1324: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1294: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1324, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1292, alloc1289, model_layers_18_post_attention_layernorm_weight4, alloc1293, alloc1294)
        R.vm.kill_object(alloc1289)
        R.vm.kill_object(alloc1292)
        R.vm.kill_object(model_layers_18_post_attention_layernorm_weight4)
        lv72: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1293, alloc1294
        model_layers_18_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
        model_layers_18_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
        gv1325: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1295: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1325, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_18_mlp_gate_up_proj_q_weight4, model_layers_18_mlp_gate_up_proj_q_scale4, alloc1293, alloc1295)
        R.vm.kill_object(alloc1293)
        R.vm.kill_object(model_layers_18_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_18_mlp_gate_up_proj_q_scale4)
        gv1326: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1296: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1326, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1295, alloc1296)
        R.vm.kill_object(alloc1295)
        model_layers_18_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
        model_layers_18_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
        gv1327: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1297: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1327, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_18_mlp_down_proj_q_weight4, model_layers_18_mlp_down_proj_q_scale4, alloc1296, alloc1297)
        R.vm.kill_object(alloc1296)
        R.vm.kill_object(model_layers_18_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_18_mlp_down_proj_q_scale4)
        model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[220]
        gv1328: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1298: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1328, R.dtype("float16"))
        gv1329: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1299: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1329, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1297, alloc1294, model_layers_19_input_layernorm_weight4, alloc1298, alloc1299)
        R.vm.kill_object(alloc1294)
        R.vm.kill_object(alloc1297)
        R.vm.kill_object(model_layers_19_input_layernorm_weight4)
        lv74: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1298, alloc1299
        model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
        model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
        model_layers_19_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[213]
        gv1330: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1300: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1330, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4, alloc1298, model_layers_19_self_attn_c_attn_bias4, alloc1300)
        R.vm.kill_object(alloc1298)
        R.vm.kill_object(model_layers_19_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_19_self_attn_c_attn_q_scale4)
        R.vm.kill_object(model_layers_19_self_attn_c_attn_bias4)
        gv1331: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape508: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1300, gv1331, sinfo_args=(R.Tensor((batch_size, 1, 20, 128), dtype="float16"),))
        R.vm.kill_object(alloc1300)
        gv1332: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(20), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        reshape509: R.Tensor((batch_size, 20, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape508, gv1332, sinfo_args=(R.Tensor((batch_size, 20, 128), dtype="float16"),))
        R.vm.kill_object(reshape508)
        gv1333: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=3),))
        alloc1301: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1333, R.dtype("float16"))
        _1040: R.Object = R.call_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape509, alloc1301)
        R.vm.kill_object(reshape509)
        gv1334: R.Shape(ndim=4) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(4), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(16), R.prim_value(0), R.prim_value(128), sinfo_args=(R.Shape(ndim=4),))
        reshape510: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.call_packed("vm.builtin.reshape", alloc1301, gv1334, sinfo_args=(R.Tensor((batch_size, 1, 16, 128), dtype="float16"),))
        R.vm.kill_object(alloc1301)
        gv1335: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        reshape511: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.call_packed("vm.builtin.reshape", reshape510, gv1335, sinfo_args=(R.Tensor((batch_size, 1, 2048), dtype="float16"),))
        R.vm.kill_object(reshape510)
        model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
        model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
        gv1336: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1302: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1336, R.dtype("float16"))
        cls.fused_dequantize2_NT_matmul1(model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4, reshape511, alloc1302)
        R.vm.kill_object(reshape511)
        R.vm.kill_object(model_layers_19_self_attn_o_proj_q_weight4)
        R.vm.kill_object(model_layers_19_self_attn_o_proj_q_scale4)
        model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
        gv1337: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1303: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1337, R.dtype("float16"))
        gv1338: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1304: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1338, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1302, alloc1299, model_layers_19_post_attention_layernorm_weight4, alloc1303, alloc1304)
        R.vm.kill_object(alloc1299)
        R.vm.kill_object(alloc1302)
        R.vm.kill_object(model_layers_19_post_attention_layernorm_weight4)
        lv76: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1303, alloc1304
        model_layers_19_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
        model_layers_19_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
        gv1339: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(22016), sinfo_args=(R.Shape(ndim=3),))
        alloc1305: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1339, R.dtype("float16"))
        cls.fused_dequantize3_NT_matmul2(model_layers_19_mlp_gate_up_proj_q_weight4, model_layers_19_mlp_gate_up_proj_q_scale4, alloc1303, alloc1305)
        R.vm.kill_object(alloc1303)
        R.vm.kill_object(model_layers_19_mlp_gate_up_proj_q_weight4)
        R.vm.kill_object(model_layers_19_mlp_gate_up_proj_q_scale4)
        gv1340: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(11008), sinfo_args=(R.Shape(ndim=3),))
        alloc1306: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1340, R.dtype("float16"))
        cls.fused_split_silu_multiply(alloc1305, alloc1306)
        R.vm.kill_object(alloc1305)
        model_layers_19_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
        model_layers_19_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
        gv1341: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1307: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage38, R.prim_value(0), gv1341, R.dtype("float16"))
        cls.fused_dequantize4_NT_matmul3(model_layers_19_mlp_down_proj_q_weight4, model_layers_19_mlp_down_proj_q_scale4, alloc1306, alloc1307)
        R.vm.kill_object(alloc1306)
        R.vm.kill_object(model_layers_19_mlp_down_proj_q_weight4)
        R.vm.kill_object(model_layers_19_mlp_down_proj_q_scale4)
        model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[231]
        gv1342: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1308: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage37, R.prim_value(0), gv1342, R.dtype("float16"))
        gv1343: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2048), sinfo_args=(R.Shape(ndim=3),))
        alloc1309: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage35, R.prim_value(0), gv1343, R.dtype("float16"))
        cls.fuse_add_norm_decode(alloc1307, alloc1304, model_layers_20_input_layernorm_weight4, alloc1308, alloc1309)
        R.vm.kill_object(alloc1304)
        R.vm.kill_object(alloc1307)
        R.vm.kill_object(model_layers_20_input_layernorm_weight4)
        lv78: R.Tuple(R.Tensor(dtype="float16", ndim=3), R.Tensor(dtype="float16", ndim=3)) = alloc1308, alloc1309
        model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
        model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
        model_layers_20_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[224]
        gv1344: R.Shape(ndim=3) = R.call_packed("vm.builtin.make_shape", shape_heap, R.prim_value(3), R.prim_value(1), R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0), R.prim_value(2560), sinfo_args=(R.Shape(ndim=3),))
        alloc1310: R.Tensor(dtype="float16", ndim=3) = R.vm.alloc_tensor(storage36, R.prim_value(0), gv1344, R.dtype("float16"))
        cls.fused_dequantize1_fused_NT_matmul_add(model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4, alloc1308, model_layers_20_self_attn_c_attn_bias4, alloc1310)
        R.vm.kill_object(alloc1308)
        R.vm.kill_object(model_layers_20_self_attn_c_attn_q_weight4)
        R.vm.kill_object(model_layers_20_self_attn_c_attn_q_scale4)
        R.vm.kill