# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def NT_matmul(var_rms_norm219: T.handle, lv545: T.Buffer((T.int64(2560), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm219 = T.match_buffer(var_rms_norm219, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(2560)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm219[v_i0, v_i1, v_k], lv545[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm219[v_i0, v_i1, v_k] * lv545[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul1(var_reshape435: T.handle, lv547: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape435 = T.match_buffer(var_reshape435, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape435[v_i0, v_i1, v_k], lv547[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + reshape435[v_i0, v_i1, v_k] * lv547[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul10(rms_norm73: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv183: T.Buffer((T.int64(2560), T.int64(2048)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm73[v_i0, v_i1, v_k], lv183[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm73[v_i0, v_i1, v_k] * lv183[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul11(reshape147: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv185: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape147[v_i0, v_i1, v_k], lv185[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + reshape147[v_i0, v_i1, v_k] * lv185[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul12(rms_norm74: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv186: T.Buffer((T.int64(22016), T.int64(2048)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(22016)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm74[v_i0, v_i1, v_k], lv186[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm74[v_i0, v_i1, v_k] * lv186[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul13(mul36: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv187: T.Buffer((T.int64(2048), T.int64(11008)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(mul36[v_i0, v_i1, v_k], lv187[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + mul36[v_i0, v_i1, v_k] * lv187[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul14(rms_norm145: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv363: T.Buffer((T.int64(151936), T.int64(2048)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm145[v_i0, v_i1, v_k], lv363[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + T.Cast("float32", rms_norm145[v_i0, v_i1, v_k]) * T.Cast("float32", lv363[v_i2, v_k]) * T.Cast("float32", v_k) * T.float32(0.00048828125)

    @T.prim_func(private=True)
    def NT_matmul2(var_rms_norm220: T.handle, lv548: T.Buffer((T.int64(22016), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm220 = T.match_buffer(var_rms_norm220, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(22016)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm220[v_i0, v_i1, v_k], lv548[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm220[v_i0, v_i1, v_k] * lv548[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul3(var_mul108: T.handle, lv549: T.Buffer((T.int64(2048), T.int64(11008)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        mul108 = T.match_buffer(var_mul108, (batch_size, T.int64(1), T.int64(11008)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(mul108[v_i0, v_i1, v_k], lv549[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + mul108[v_i0, v_i1, v_k] * lv549[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul4(var_rms_norm291: T.handle, lv725: T.Buffer((T.int64(151936), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm291 = T.match_buffer(var_rms_norm291, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(151936)))
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm291[v_i0, v_i1, v_k], lv725[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + T.Cast("float32", rms_norm291[v_i0, v_i1, v_k]) * T.Cast("float32", lv725[v_i2, v_k])

    @T.prim_func(private=True)
    def NT_matmul5(var_rms_norm146: T.handle, lv364: T.Buffer((T.int64(2560), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm146 = T.match_buffer(var_rms_norm146, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), seq_len, T.int64(2560)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm146[v_i0, v_i1, v_k], lv364[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm146[v_i0, v_i1, v_k] * lv364[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul6(var_reshape291: T.handle, lv366: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape291 = T.match_buffer(var_reshape291, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape291[v_i0, v_i1, v_k], lv366[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + reshape291[v_i0, v_i1, v_k] * lv366[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul7(var_rms_norm147: T.handle, lv367: T.Buffer((T.int64(22016), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm147 = T.match_buffer(var_rms_norm147, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), seq_len, T.int64(22016)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm147[v_i0, v_i1, v_k], lv367[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rms_norm147[v_i0, v_i1, v_k] * lv367[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul8(var_mul72: T.handle, lv368: T.Buffer((T.int64(2048), T.int64(11008)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        mul72 = T.match_buffer(var_mul72, (T.int64(1), seq_len, T.int64(11008)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(mul72[v_i0, v_i1, v_k], lv368[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + mul72[v_i0, v_i1, v_k] * lv368[v_i2, v_k]

    @T.prim_func(private=True)
    def NT_matmul9(var_take1: T.handle, lv544: T.Buffer((T.int64(151936), T.int64(2048)), "float16"), var_NT_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        take1 = T.match_buffer(var_take1, (T.int64(1), batch_size, T.int64(2048)), "float16")
        NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), batch_size, T.int64(151936)))
        # with T.block("root"):
        for i0, i1, i2, k in T.grid(T.int64(1), batch_size, T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(take1[v_i0, v_i1, v_k], lv544[v_i2, v_k])
                T.writes(NT_matmul[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + T.Cast("float32", take1[v_i0, v_i1, v_k]) * T.Cast("float32", lv544[v_i2, v_k])

    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)))
        temp_max = T.alloc_buffer((batch_size, num_chunks))
        temp_sum = T.alloc_buffer((batch_size, num_chunks))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("pad"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                T.writes(A_pad[v0, v1, v2])
                A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("max"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(A_pad[v0, v1, v2])
                T.writes(temp_max[v0, v1])
                with T.init():
                    temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("sum_exp"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1])
                T.writes(temp_sum[v0, v1])
                with T.init():
                    temp_sum[v0, v1] = T.float32(0.0)
                temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
            with T.block("log"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1])
                T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1])
                chunked_max[v0, v1] = temp_max[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding((batch_size * 256 + 1023) // 1024, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 256
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 2
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 2 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding((copy_length * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(2, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(128))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(128)) // T.int64(128))
                    vd = T.axis.spatial(128, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(128)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(2) * T.int64(128))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func(private=True)
    def dequantize(model_embed_tokens_q_weight: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale: T.Buffer((T.int64(151936), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(151936), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_embed_tokens_q_scale[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_embed_tokens_q_scale[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize1(model_layers_0_self_attn_c_attn_q_weight1: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale1: T.Buffer((T.int64(2560), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2560), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize2(model_layers_0_self_attn_o_proj_q_weight1: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale1: T.Buffer((T.int64(2048), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize3(model_layers_0_mlp_gate_up_proj_q_weight1: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale1: T.Buffer((T.int64(22016), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(22016), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_gate_up_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_gate_up_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize4(model_layers_0_mlp_down_proj_q_weight1: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale1: T.Buffer((T.int64(2048), T.int64(344)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(11008)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_down_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_down_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for i in range(batch_size):
            with T.block("block"):
                vi = T.axis.spatial(batch_size, i)
                T.reads()
                T.writes(result[vi, 0])
                result[vi, 0] = value

    @T.prim_func(private=True)
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 2048), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 2048), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 2048), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func(private=True)
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 2048), "float16")
        B = T.match_buffer(pB, (1, seq_len, 2048), "float16")
        O = T.match_buffer(pO, (1, seq_len, 2048), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(2048, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 20, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 2, 128), "float16")
        # with T.block("root"):
        for iters_0, iters_1, iters_2 in T.grid(seq_len, 20, 128):
            with T.block("llama_fused_rope"):
                s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2])
                T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129])
                T.writes(q[s, h, d], k[s, h - 16, d], v[s, h - 18, d])
                if h < 16:
                    freq = T.float32()
                    q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                else:
                    if h < 18:
                        freq = T.float32()
                        k[s, h - 16, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                    else:
                        v[s, h - 18, d] = qkv[s, h, d]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("gather_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[indices[vb], vj], indices[vb])
                T.writes(dst[vb, vj])
                dst[vb, vj] = src[indices[vb], vj]

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0, ax1 in T.grid(out_batch, vocab_size):
            with T.block("T_get_index_from_sorted"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
                T.writes(output_index[v_ax0, 0])
                if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
                    if v_ax1 == T.int64(0):
                        output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
                    else:
                        if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                            output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch, vocab_size):
            with T.block("T_get_renorm_prob"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0])
                T.writes(renorm_prob[v_ax0, 0])
                if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > 1):
                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
                else:
                    if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0]):
                        if v_ax1 + T.int64(1) == vocab_size:
                            renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
                        else:
                            if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0])):
                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]

    @T.prim_func(private=True)
    def index(var_rms_norm72: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm72 = T.match_buffer(var_rms_norm72, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("index"):
                v_i, v__, v_k = T.axis.remap("SSS", [i, _, k])
                T.reads(rms_norm72[v_i, seq_len - T.int64(1), v_k])
                T.writes(index[v_i, v__, v_k])
                index[v_i, v__, v_k] = rms_norm72[v_i, seq_len - T.int64(1), v_k]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for i in range(num_positions + num_samples):
            with T.block("block"):
                vi = T.axis.spatial(num_positions + num_samples, i)
                T.reads(top_prob_offsets[vi], sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], unsorted_probs[T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]):T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + (T.max(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + 1 - T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions])), T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]):T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + (T.max(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]))], sample_indices[vi - num_positions], sampling_results[vi - num_positions])
                T.writes(top_prob_indices[vi], top_prob_probs[vi], sampled_values[vi - num_positions])
                if vi < num_positions:
                    row: T.int32 = top_prob_offsets[vi] // vocab_size
                    col: T.int32 = top_prob_offsets[vi] % vocab_size
                    top_prob_indices[vi] = sorted_indices[row, col]
                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]
                else:
                    vj: T.int32 = vi - num_positions
                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[vb, vj], indices[vb])
                T.writes(dst[indices[vb], vj])
                dst[indices[vb], vj] = src[vb, vj]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for i, j in T.grid(batch_size_1, vocab_size_1):
            with T.block("take_sorted_probs"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(probs[v_i, lv1[v_i, v_j]], lv1[v_i, v_j])
                T.writes(take_sorted_probs[v_i, v_j])
                take_sorted_probs[v_i, v_j] = probs[v_i, lv1[v_i, v_j]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (36, seqlen, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (36, seqlen, 2, 128), "float16")
        # with T.block("root"):
        for p, h, d in T.grid(seqlen, 2, 128):
            with T.block("copy0"):
                vp, vh, vd = T.axis.remap("SSS", [p, h, d])
                T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                position: T.int32 = position_map[vp]
                k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 2, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos, h, f in T.grid(ntoken, 2, 128):
            if position_map[global_pos] != -1:
                with T.block("k_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                with T.block("v_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.argsort(probs, axis=-1, descending=True, dtype="int32")
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight4: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale4: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm219: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv545 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv = R.call_tir(cls.NT_matmul, (rms_norm219, lv545), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add324: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv, model_layers_0_self_attn_c_attn_bias4)
            reshape432: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add324, R.shape([batch_size, 1, 20, 128]))
            reshape433: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape432, R.shape([batch_size, 20, 128]))
            lv546 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape433), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape434: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv546, R.shape([batch_size, 1, 16, 128]))
            reshape435: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape434, R.shape([batch_size, 1, 2048]))
            lv547 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv1 = R.call_tir(cls.NT_matmul1, (reshape435, lv547), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_1 = R.call_tir(cls.fuse_add_norm_decode, (lv1, input_embeds, model_layers_0_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv1_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_1[1]
            rms_norm220: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_1[0]
            lv548 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight4, model_layers_0_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv2 = R.call_tir(cls.NT_matmul2, (rms_norm220, lv548), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split108: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv2, indices_or_sections=2, axis=-1)
            split_0108: R.Tensor((batch_size, 1, 11008), dtype="float16") = split108[0]
            split_1108: R.Tensor((batch_size, 1, 11008), dtype="float16") = split108[1]
            silu108: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0108)
            mul108: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu108, split_1108)
            lv549 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight4, model_layers_0_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv3 = R.call_tir(cls.NT_matmul3, (mul108, lv549), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv2_1 = R.call_tir(cls.fuse_add_norm_decode, (lv3, lv1_1, model_layers_1_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv3_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[1]
            rms_norm221: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[0]
            lv550 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv4 = R.call_tir(cls.NT_matmul, (rms_norm221, lv550), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add327: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv4, model_layers_1_self_attn_c_attn_bias4)
            reshape436: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add327, R.shape([batch_size, 1, 20, 128]))
            reshape437: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape436, R.shape([batch_size, 20, 128]))
            lv551 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape437), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape438: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv551, R.shape([batch_size, 1, 16, 128]))
            reshape439: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape438, R.shape([batch_size, 1, 2048]))
            lv552 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv5 = R.call_tir(cls.NT_matmul1, (reshape439, lv552), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv4_1 = R.call_tir(cls.fuse_add_norm_decode, (lv5, lv3_1, model_layers_1_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv5_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4_1[1]
            rms_norm222: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4_1[0]
            lv553 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight4, model_layers_1_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv6 = R.call_tir(cls.NT_matmul2, (rms_norm222, lv553), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split109: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv6, indices_or_sections=2, axis=-1)
            split_0109: R.Tensor((batch_size, 1, 11008), dtype="float16") = split109[0]
            split_1109: R.Tensor((batch_size, 1, 11008), dtype="float16") = split109[1]
            silu109: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0109)
            mul109: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu109, split_1109)
            lv554 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight4, model_layers_1_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv7 = R.call_tir(cls.NT_matmul3, (mul109, lv554), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6_1 = R.call_tir(cls.fuse_add_norm_decode, (lv7, lv5_1, model_layers_2_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv7_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6_1[1]
            rms_norm223: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6_1[0]
            lv555 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv8 = R.call_tir(cls.NT_matmul, (rms_norm223, lv555), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add330: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv8, model_layers_2_self_attn_c_attn_bias4)
            reshape440: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add330, R.shape([batch_size, 1, 20, 128]))
            reshape441: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape440, R.shape([batch_size, 20, 128]))
            lv556 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape441), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape442: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv556, R.shape([batch_size, 1, 16, 128]))
            reshape443: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape442, R.shape([batch_size, 1, 2048]))
            lv557 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv9 = R.call_tir(cls.NT_matmul1, (reshape443, lv557), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv8_1 = R.call_tir(cls.fuse_add_norm_decode, (lv9, lv7_1, model_layers_2_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv9_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8_1[1]
            rms_norm224: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8_1[0]
            lv558 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight4, model_layers_2_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv10 = R.call_tir(cls.NT_matmul2, (rms_norm224, lv558), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split110: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv10, indices_or_sections=2, axis=-1)
            split_0110: R.Tensor((batch_size, 1, 11008), dtype="float16") = split110[0]
            split_1110: R.Tensor((batch_size, 1, 11008), dtype="float16") = split110[1]
            silu110: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0110)
            mul110: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu110, split_1110)
            lv559 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight4, model_layers_2_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv11 = R.call_tir(cls.NT_matmul3, (mul110, lv559), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv10_1 = R.call_tir(cls.fuse_add_norm_decode, (lv11, lv9_1, model_layers_3_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv11_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10_1[1]
            rms_norm225: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10_1[0]
            lv560 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv12 = R.call_tir(cls.NT_matmul, (rms_norm225, lv560), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add333: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv12, model_layers_3_self_attn_c_attn_bias4)
            reshape444: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add333, R.shape([batch_size, 1, 20, 128]))
            reshape445: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape444, R.shape([batch_size, 20, 128]))
            lv561 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape446: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv561, R.shape([batch_size, 1, 16, 128]))
            reshape447: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape446, R.shape([batch_size, 1, 2048]))
            lv562 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv13 = R.call_tir(cls.NT_matmul1, (reshape447, lv562), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12_1 = R.call_tir(cls.fuse_add_norm_decode, (lv13, lv11_1, model_layers_3_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv13_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12_1[1]
            rms_norm226: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12_1[0]
            lv563 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight4, model_layers_3_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv14 = R.call_tir(cls.NT_matmul2, (rms_norm226, lv563), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split111: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv14, indices_or_sections=2, axis=-1)
            split_0111: R.Tensor((batch_size, 1, 11008), dtype="float16") = split111[0]
            split_1111: R.Tensor((batch_size, 1, 11008), dtype="float16") = split111[1]
            silu111: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0111)
            mul111: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu111, split_1111)
            lv564 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight4, model_layers_3_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv15 = R.call_tir(cls.NT_matmul3, (mul111, lv564), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv14_1 = R.call_tir(cls.fuse_add_norm_decode, (lv15, lv13_1, model_layers_4_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv15_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14_1[1]
            rms_norm227: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14_1[0]
            lv565 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv16 = R.call_tir(cls.NT_matmul, (rms_norm227, lv565), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add336: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv16, model_layers_4_self_attn_c_attn_bias4)
            reshape448: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add336, R.shape([batch_size, 1, 20, 128]))
            reshape449: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape448, R.shape([batch_size, 20, 128]))
            lv566 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape449), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape450: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv566, R.shape([batch_size, 1, 16, 128]))
            reshape451: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape450, R.shape([batch_size, 1, 2048]))
            lv567 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv17 = R.call_tir(cls.NT_matmul1, (reshape451, lv567), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv16_1 = R.call_tir(cls.fuse_add_norm_decode, (lv17, lv15_1, model_layers_4_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv17_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16_1[1]
            rms_norm228: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16_1[0]
            lv568 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight4, model_layers_4_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv18 = R.call_tir(cls.NT_matmul2, (rms_norm228, lv568), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split112: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv18, indices_or_sections=2, axis=-1)
            split_0112: R.Tensor((batch_size, 1, 11008), dtype="float16") = split112[0]
            split_1112: R.Tensor((batch_size, 1, 11008), dtype="float16") = split112[1]
            silu112: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0112)
            mul112: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu112, split_1112)
            lv569 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight4, model_layers_4_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv19 = R.call_tir(cls.NT_matmul3, (mul112, lv569), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18_1 = R.call_tir(cls.fuse_add_norm_decode, (lv19, lv17_1, model_layers_5_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv19_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18_1[1]
            rms_norm229: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18_1[0]
            lv570 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv20 = R.call_tir(cls.NT_matmul, (rms_norm229, lv570), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add339: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv20, model_layers_5_self_attn_c_attn_bias4)
            reshape452: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add339, R.shape([batch_size, 1, 20, 128]))
            reshape453: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape452, R.shape([batch_size, 20, 128]))
            lv571 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape453), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape454: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv571, R.shape([batch_size, 1, 16, 128]))
            reshape455: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape454, R.shape([batch_size, 1, 2048]))
            lv572 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv21 = R.call_tir(cls.NT_matmul1, (reshape455, lv572), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv20_1 = R.call_tir(cls.fuse_add_norm_decode, (lv21, lv19_1, model_layers_5_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv21_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20_1[1]
            rms_norm230: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20_1[0]
            lv573 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight4, model_layers_5_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv22 = R.call_tir(cls.NT_matmul2, (rms_norm230, lv573), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split113: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv22, indices_or_sections=2, axis=-1)
            split_0113: R.Tensor((batch_size, 1, 11008), dtype="float16") = split113[0]
            split_1113: R.Tensor((batch_size, 1, 11008), dtype="float16") = split113[1]
            silu113: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0113)
            mul113: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu113, split_1113)
            lv574 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight4, model_layers_5_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv23 = R.call_tir(cls.NT_matmul3, (mul113, lv574), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv22_1 = R.call_tir(cls.fuse_add_norm_decode, (lv23, lv21_1, model_layers_6_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv23_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22_1[1]
            rms_norm231: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22_1[0]
            lv575 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv24 = R.call_tir(cls.NT_matmul, (rms_norm231, lv575), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add342: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv24, model_layers_6_self_attn_c_attn_bias4)
            reshape456: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add342, R.shape([batch_size, 1, 20, 128]))
            reshape457: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape456, R.shape([batch_size, 20, 128]))
            lv576 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape457), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape458: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv576, R.shape([batch_size, 1, 16, 128]))
            reshape459: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape458, R.shape([batch_size, 1, 2048]))
            lv577 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv25 = R.call_tir(cls.NT_matmul1, (reshape459, lv577), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24_1 = R.call_tir(cls.fuse_add_norm_decode, (lv25, lv23_1, model_layers_6_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv25_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24_1[1]
            rms_norm232: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24_1[0]
            lv578 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight4, model_layers_6_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv26 = R.call_tir(cls.NT_matmul2, (rms_norm232, lv578), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split114: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv26, indices_or_sections=2, axis=-1)
            split_0114: R.Tensor((batch_size, 1, 11008), dtype="float16") = split114[0]
            split_1114: R.Tensor((batch_size, 1, 11008), dtype="float16") = split114[1]
            silu114: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0114)
            mul114: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu114, split_1114)
            lv579 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight4, model_layers_6_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv27 = R.call_tir(cls.NT_matmul3, (mul114, lv579), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv26_1 = R.call_tir(cls.fuse_add_norm_decode, (lv27, lv25_1, model_layers_7_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv27_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26_1[1]
            rms_norm233: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26_1[0]
            lv580 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv28 = R.call_tir(cls.NT_matmul, (rms_norm233, lv580), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add345: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv28, model_layers_7_self_attn_c_attn_bias4)
            reshape460: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add345, R.shape([batch_size, 1, 20, 128]))
            reshape461: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape460, R.shape([batch_size, 20, 128]))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape461), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape462: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv581, R.shape([batch_size, 1, 16, 128]))
            reshape463: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape462, R.shape([batch_size, 1, 2048]))
            lv582 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv29 = R.call_tir(cls.NT_matmul1, (reshape463, lv582), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv28_1 = R.call_tir(cls.fuse_add_norm_decode, (lv29, lv27_1, model_layers_7_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv29_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28_1[1]
            rms_norm234: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28_1[0]
            lv583 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight4, model_layers_7_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv30 = R.call_tir(cls.NT_matmul2, (rms_norm234, lv583), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split115: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv30, indices_or_sections=2, axis=-1)
            split_0115: R.Tensor((batch_size, 1, 11008), dtype="float16") = split115[0]
            split_1115: R.Tensor((batch_size, 1, 11008), dtype="float16") = split115[1]
            silu115: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0115)
            mul115: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu115, split_1115)
            lv584 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight4, model_layers_7_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv31 = R.call_tir(cls.NT_matmul3, (mul115, lv584), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30_1 = R.call_tir(cls.fuse_add_norm_decode, (lv31, lv29_1, model_layers_8_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv31_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30_1[1]
            rms_norm235: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30_1[0]
            lv585 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv32 = R.call_tir(cls.NT_matmul, (rms_norm235, lv585), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add348: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv32, model_layers_8_self_attn_c_attn_bias4)
            reshape464: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add348, R.shape([batch_size, 1, 20, 128]))
            reshape465: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape464, R.shape([batch_size, 20, 128]))
            lv586 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape465), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape466: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv586, R.shape([batch_size, 1, 16, 128]))
            reshape467: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape466, R.shape([batch_size, 1, 2048]))
            lv587 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv33 = R.call_tir(cls.NT_matmul1, (reshape467, lv587), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv32_1 = R.call_tir(cls.fuse_add_norm_decode, (lv33, lv31_1, model_layers_8_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv33_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32_1[1]
            rms_norm236: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32_1[0]
            lv588 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight4, model_layers_8_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv34 = R.call_tir(cls.NT_matmul2, (rms_norm236, lv588), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split116: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv34, indices_or_sections=2, axis=-1)
            split_0116: R.Tensor((batch_size, 1, 11008), dtype="float16") = split116[0]
            split_1116: R.Tensor((batch_size, 1, 11008), dtype="float16") = split116[1]
            silu116: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0116)
            mul116: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu116, split_1116)
            lv589 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight4, model_layers_8_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv35 = R.call_tir(cls.NT_matmul3, (mul116, lv589), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv34_1 = R.call_tir(cls.fuse_add_norm_decode, (lv35, lv33_1, model_layers_9_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv35_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34_1[1]
            rms_norm237: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34_1[0]
            lv590 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv36 = R.call_tir(cls.NT_matmul, (rms_norm237, lv590), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add351: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv36, model_layers_9_self_attn_c_attn_bias4)
            reshape468: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add351, R.shape([batch_size, 1, 20, 128]))
            reshape469: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape468, R.shape([batch_size, 20, 128]))
            lv591 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape469), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape470: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv591, R.shape([batch_size, 1, 16, 128]))
            reshape471: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape470, R.shape([batch_size, 1, 2048]))
            lv592 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv37 = R.call_tir(cls.NT_matmul1, (reshape471, lv592), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36_1 = R.call_tir(cls.fuse_add_norm_decode, (lv37, lv35_1, model_layers_9_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv37_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36_1[1]
            rms_norm238: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36_1[0]
            lv593 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight4, model_layers_9_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv38 = R.call_tir(cls.NT_matmul2, (rms_norm238, lv593), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split117: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv38, indices_or_sections=2, axis=-1)
            split_0117: R.Tensor((batch_size, 1, 11008), dtype="float16") = split117[0]
            split_1117: R.Tensor((batch_size, 1, 11008), dtype="float16") = split117[1]
            silu117: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0117)
            mul117: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu117, split_1117)
            lv594 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight4, model_layers_9_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv39 = R.call_tir(cls.NT_matmul3, (mul117, lv594), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv38_1 = R.call_tir(cls.fuse_add_norm_decode, (lv39, lv37_1, model_layers_10_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv39_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38_1[1]
            rms_norm239: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38_1[0]
            lv595 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv40 = R.call_tir(cls.NT_matmul, (rms_norm239, lv595), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add354: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv40, model_layers_10_self_attn_c_attn_bias4)
            reshape472: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add354, R.shape([batch_size, 1, 20, 128]))
            reshape473: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape472, R.shape([batch_size, 20, 128]))
            lv596 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape473), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape474: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv596, R.shape([batch_size, 1, 16, 128]))
            reshape475: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape474, R.shape([batch_size, 1, 2048]))
            lv597 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv41 = R.call_tir(cls.NT_matmul1, (reshape475, lv597), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv40_1 = R.call_tir(cls.fuse_add_norm_decode, (lv41, lv39_1, model_layers_10_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv41_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40_1[1]
            rms_norm240: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40_1[0]
            lv598 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight4, model_layers_10_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv42 = R.call_tir(cls.NT_matmul2, (rms_norm240, lv598), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split118: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv42, indices_or_sections=2, axis=-1)
            split_0118: R.Tensor((batch_size, 1, 11008), dtype="float16") = split118[0]
            split_1118: R.Tensor((batch_size, 1, 11008), dtype="float16") = split118[1]
            silu118: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0118)
            mul118: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu118, split_1118)
            lv599 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight4, model_layers_10_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv43 = R.call_tir(cls.NT_matmul3, (mul118, lv599), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42_1 = R.call_tir(cls.fuse_add_norm_decode, (lv43, lv41_1, model_layers_11_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv43_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42_1[1]
            rms_norm241: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42_1[0]
            lv600 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv44 = R.call_tir(cls.NT_matmul, (rms_norm241, lv600), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add357: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv44, model_layers_11_self_attn_c_attn_bias4)
            reshape476: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add357, R.shape([batch_size, 1, 20, 128]))
            reshape477: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape476, R.shape([batch_size, 20, 128]))
            lv601 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape477), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape478: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv601, R.shape([batch_size, 1, 16, 128]))
            reshape479: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape478, R.shape([batch_size, 1, 2048]))
            lv602 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv45 = R.call_tir(cls.NT_matmul1, (reshape479, lv602), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv44_1 = R.call_tir(cls.fuse_add_norm_decode, (lv45, lv43_1, model_layers_11_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv45_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44_1[1]
            rms_norm242: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44_1[0]
            lv603 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight4, model_layers_11_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv46 = R.call_tir(cls.NT_matmul2, (rms_norm242, lv603), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split119: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv46, indices_or_sections=2, axis=-1)
            split_0119: R.Tensor((batch_size, 1, 11008), dtype="float16") = split119[0]
            split_1119: R.Tensor((batch_size, 1, 11008), dtype="float16") = split119[1]
            silu119: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0119)
            mul119: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu119, split_1119)
            lv604 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight4, model_layers_11_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv47 = R.call_tir(cls.NT_matmul3, (mul119, lv604), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv46_1 = R.call_tir(cls.fuse_add_norm_decode, (lv47, lv45_1, model_layers_12_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv47_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46_1[1]
            rms_norm243: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46_1[0]
            lv605 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv48 = R.call_tir(cls.NT_matmul, (rms_norm243, lv605), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add360: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv48, model_layers_12_self_attn_c_attn_bias4)
            reshape480: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add360, R.shape([batch_size, 1, 20, 128]))
            reshape481: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape480, R.shape([batch_size, 20, 128]))
            lv606 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape482: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv606, R.shape([batch_size, 1, 16, 128]))
            reshape483: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape482, R.shape([batch_size, 1, 2048]))
            lv607 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv49 = R.call_tir(cls.NT_matmul1, (reshape483, lv607), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48_1 = R.call_tir(cls.fuse_add_norm_decode, (lv49, lv47_1, model_layers_12_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv49_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48_1[1]
            rms_norm244: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48_1[0]
            lv608 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight4, model_layers_12_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv50 = R.call_tir(cls.NT_matmul2, (rms_norm244, lv608), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split120: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv50, indices_or_sections=2, axis=-1)
            split_0120: R.Tensor((batch_size, 1, 11008), dtype="float16") = split120[0]
            split_1120: R.Tensor((batch_size, 1, 11008), dtype="float16") = split120[1]
            silu120: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0120)
            mul120: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu120, split_1120)
            lv609 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight4, model_layers_12_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv51 = R.call_tir(cls.NT_matmul3, (mul120, lv609), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv50_1 = R.call_tir(cls.fuse_add_norm_decode, (lv51, lv49_1, model_layers_13_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv51_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50_1[1]
            rms_norm245: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50_1[0]
            lv610 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv52 = R.call_tir(cls.NT_matmul, (rms_norm245, lv610), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add363: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv52, model_layers_13_self_attn_c_attn_bias4)
            reshape484: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add363, R.shape([batch_size, 1, 20, 128]))
            reshape485: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape484, R.shape([batch_size, 20, 128]))
            lv611 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape485), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape486: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv611, R.shape([batch_size, 1, 16, 128]))
            reshape487: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape486, R.shape([batch_size, 1, 2048]))
            lv612 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv53 = R.call_tir(cls.NT_matmul1, (reshape487, lv612), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv52_1 = R.call_tir(cls.fuse_add_norm_decode, (lv53, lv51_1, model_layers_13_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv53_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52_1[1]
            rms_norm246: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52_1[0]
            lv613 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight4, model_layers_13_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv54 = R.call_tir(cls.NT_matmul2, (rms_norm246, lv613), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split121: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv54, indices_or_sections=2, axis=-1)
            split_0121: R.Tensor((batch_size, 1, 11008), dtype="float16") = split121[0]
            split_1121: R.Tensor((batch_size, 1, 11008), dtype="float16") = split121[1]
            silu121: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0121)
            mul121: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu121, split_1121)
            lv614 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight4, model_layers_13_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv55 = R.call_tir(cls.NT_matmul3, (mul121, lv614), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54_1 = R.call_tir(cls.fuse_add_norm_decode, (lv55, lv53_1, model_layers_14_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv55_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54_1[1]
            rms_norm247: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54_1[0]
            lv615 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv56 = R.call_tir(cls.NT_matmul, (rms_norm247, lv615), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add366: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv56, model_layers_14_self_attn_c_attn_bias4)
            reshape488: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add366, R.shape([batch_size, 1, 20, 128]))
            reshape489: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape488, R.shape([batch_size, 20, 128]))
            lv616 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape489), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape490: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv616, R.shape([batch_size, 1, 16, 128]))
            reshape491: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape490, R.shape([batch_size, 1, 2048]))
            lv617 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv57 = R.call_tir(cls.NT_matmul1, (reshape491, lv617), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv56_1 = R.call_tir(cls.fuse_add_norm_decode, (lv57, lv55_1, model_layers_14_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv57_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56_1[1]
            rms_norm248: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56_1[0]
            lv618 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight4, model_layers_14_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv58 = R.call_tir(cls.NT_matmul2, (rms_norm248, lv618), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split122: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv58, indices_or_sections=2, axis=-1)
            split_0122: R.Tensor((batch_size, 1, 11008), dtype="float16") = split122[0]
            split_1122: R.Tensor((batch_size, 1, 11008), dtype="float16") = split122[1]
            silu122: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0122)
            mul122: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu122, split_1122)
            lv619 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight4, model_layers_14_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv59 = R.call_tir(cls.NT_matmul3, (mul122, lv619), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv58_1 = R.call_tir(cls.fuse_add_norm_decode, (lv59, lv57_1, model_layers_15_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv59_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58_1[1]
            rms_norm249: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58_1[0]
            lv620 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv60 = R.call_tir(cls.NT_matmul, (rms_norm249, lv620), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add369: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv60, model_layers_15_self_attn_c_attn_bias4)
            reshape492: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add369, R.shape([batch_size, 1, 20, 128]))
            reshape493: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape492, R.shape([batch_size, 20, 128]))
            lv621 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape493), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape494: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv621, R.shape([batch_size, 1, 16, 128]))
            reshape495: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape494, R.shape([batch_size, 1, 2048]))
            lv622 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv61 = R.call_tir(cls.NT_matmul1, (reshape495, lv622), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60_1 = R.call_tir(cls.fuse_add_norm_decode, (lv61, lv59_1, model_layers_15_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv61_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60_1[1]
            rms_norm250: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60_1[0]
            lv623 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight4, model_layers_15_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv62 = R.call_tir(cls.NT_matmul2, (rms_norm250, lv623), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split123: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv62, indices_or_sections=2, axis=-1)
            split_0123: R.Tensor((batch_size, 1, 11008), dtype="float16") = split123[0]
            split_1123: R.Tensor((batch_size, 1, 11008), dtype="float16") = split123[1]
            silu123: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0123)
            mul123: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu123, split_1123)
            lv624 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight4, model_layers_15_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv63 = R.call_tir(cls.NT_matmul3, (mul123, lv624), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv62_1 = R.call_tir(cls.fuse_add_norm_decode, (lv63, lv61_1, model_layers_16_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv63_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62_1[1]
            rms_norm251: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62_1[0]
            lv625 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv64 = R.call_tir(cls.NT_matmul, (rms_norm251, lv625), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add372: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv64, model_layers_16_self_attn_c_attn_bias4)
            reshape496: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add372, R.shape([batch_size, 1, 20, 128]))
            reshape497: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape496, R.shape([batch_size, 20, 128]))
            lv626 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape497), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape498: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv626, R.shape([batch_size, 1, 16, 128]))
            reshape499: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape498, R.shape([batch_size, 1, 2048]))
            lv627 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv65 = R.call_tir(cls.NT_matmul1, (reshape499, lv627), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv64_1 = R.call_tir(cls.fuse_add_norm_decode, (lv65, lv63_1, model_layers_16_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv65_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64_1[1]
            rms_norm252: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64_1[0]
            lv628 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight4, model_layers_16_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv66 = R.call_tir(cls.NT_matmul2, (rms_norm252, lv628), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split124: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv66, indices_or_sections=2, axis=-1)
            split_0124: R.Tensor((batch_size, 1, 11008), dtype="float16") = split124[0]
            split_1124: R.Tensor((batch_size, 1, 11008), dtype="float16") = split124[1]
            silu124: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0124)
            mul124: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu124, split_1124)
            lv629 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight4, model_layers_16_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv67 = R.call_tir(cls.NT_matmul3, (mul124, lv629), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66_1 = R.call_tir(cls.fuse_add_norm_decode, (lv67, lv65_1, model_layers_17_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv67_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66_1[1]
            rms_norm253: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66_1[0]
            lv630 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv68 = R.call_tir(cls.NT_matmul, (rms_norm253, lv630), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add375: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv68, model_layers_17_self_attn_c_attn_bias4)
            reshape500: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add375, R.shape([batch_size, 1, 20, 128]))
            reshape501: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape500, R.shape([batch_size, 20, 128]))
            lv631 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape501), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape502: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv631, R.shape([batch_size, 1, 16, 128]))
            reshape503: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape502, R.shape([batch_size, 1, 2048]))
            lv632 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv69 = R.call_tir(cls.NT_matmul1, (reshape503, lv632), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv68_1 = R.call_tir(cls.fuse_add_norm_decode, (lv69, lv67_1, model_layers_17_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv69_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68_1[1]
            rms_norm254: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68_1[0]
            lv633 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight4, model_layers_17_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv70 = R.call_tir(cls.NT_matmul2, (rms_norm254, lv633), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split125: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv70, indices_or_sections=2, axis=-1)
            split_0125: R.Tensor((batch_size, 1, 11008), dtype="float16") = split125[0]
            split_1125: R.Tensor((batch_size, 1, 11008), dtype="float16") = split125[1]
            silu125: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0125)
            mul125: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu125, split_1125)
            lv634 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight4, model_layers_17_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv71 = R.call_tir(cls.NT_matmul3, (mul125, lv634), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv70_1 = R.call_tir(cls.fuse_add_norm_decode, (lv71, lv69_1, model_layers_18_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv71_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70_1[1]
            rms_norm255: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70_1[0]
            lv635 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.NT_matmul, (rms_norm255, lv635), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add378: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv72, model_layers_18_self_attn_c_attn_bias4)
            reshape504: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add378, R.shape([batch_size, 1, 20, 128]))
            reshape505: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape504, R.shape([batch_size, 20, 128]))
            lv636 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape505), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape506: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv636, R.shape([batch_size, 1, 16, 128]))
            reshape507: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape506, R.shape([batch_size, 1, 2048]))
            lv637 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv73 = R.call_tir(cls.NT_matmul1, (reshape507, lv637), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72_1 = R.call_tir(cls.fuse_add_norm_decode, (lv73, lv71_1, model_layers_18_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv73_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72_1[1]
            rms_norm256: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72_1[0]
            lv638 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight4, model_layers_18_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv74 = R.call_tir(cls.NT_matmul2, (rms_norm256, lv638), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split126: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv74, indices_or_sections=2, axis=-1)
            split_0126: R.Tensor((batch_size, 1, 11008), dtype="float16") = split126[0]
            split_1126: R.Tensor((batch_size, 1, 11008), dtype="float16") = split126[1]
            silu126: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0126)
            mul126: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu126, split_1126)
            lv639 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight4, model_layers_18_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv75 = R.call_tir(cls.NT_matmul3, (mul126, lv639), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv74_1 = R.call_tir(cls.fuse_add_norm_decode, (lv75, lv73_1, model_layers_19_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv75_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74_1[1]
            rms_norm257: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74_1[0]
            lv640 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv76 = R.call_tir(cls.NT_matmul, (rms_norm257, lv640), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add381: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv76, model_layers_19_self_attn_c_attn_bias4)
            reshape508: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add381, R.shape([batch_size, 1, 20, 128]))
            reshape509: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape508, R.shape([batch_size, 20, 128]))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape509), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape510: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv641, R.shape([batch_size, 1, 16, 128]))
            reshape511: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape510, R.shape([batch_size, 1, 2048]))
            lv642 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv77 = R.call_tir(cls.NT_matmul1, (reshape511, lv642), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv76_1 = R.call_tir(cls.fuse_add_norm_decode, (lv77, lv75_1, model_layers_19_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv77_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76_1[1]
            rms_norm258: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76_1[0]
            lv643 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight4, model_layers_19_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv78 = R.call_tir(cls.NT_matmul2, (rms_norm258, lv643), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split127: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv78, indices_or_sections=2, axis=-1)
            split_0127: R.Tensor((batch_size, 1, 11008), dtype="float16") = split127[0]
            split_1127: R.Tensor((batch_size, 1, 11008), dtype="float16") = split127[1]
            silu127: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0127)
            mul127: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu127, split_1127)
            lv644 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight4, model_layers_19_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv79 = R.call_tir(cls.NT_matmul3, (mul127, lv644), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78_1 = R.call_tir(cls.fuse_add_norm_decode, (lv79, lv77_1, model_layers_20_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv79_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78_1[1]
            rms_norm259: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78_1[0]
            lv645 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv80 = R.call_tir(cls.NT_matmul, (rms_norm259, lv645), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add384: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv80, model_layers_20_self_attn_c_attn_bias4)
            reshape512: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add384, R.shape([batch_size, 1, 20, 128]))
            reshape513: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape512, R.shape([batch_size, 20, 128]))
            lv646 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape513), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape514: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv646, R.shape([batch_size, 1, 16, 128]))
            reshape515: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape514, R.shape([batch_size, 1, 2048]))
            lv647 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight4, model_layers_20_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv81 = R.call_tir(cls.NT_matmul1, (reshape515, lv647), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv80_1 = R.call_tir(cls.fuse_add_norm_decode, (lv81, lv79_1, model_layers_20_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv81_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80_1[1]
            rms_norm260: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80_1[0]
            lv648 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight4, model_layers_20_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv82 = R.call_tir(cls.NT_matmul2, (rms_norm260, lv648), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split128: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv82, indices_or_sections=2, axis=-1)
            split_0128: R.Tensor((batch_size, 1, 11008), dtype="float16") = split128[0]
            split_1128: R.Tensor((batch_size, 1, 11008), dtype="float16") = split128[1]
            silu128: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0128)
            mul128: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu128, split_1128)
            lv649 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight4, model_layers_20_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv83 = R.call_tir(cls.NT_matmul3, (mul128, lv649), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv82_1 = R.call_tir(cls.fuse_add_norm_decode, (lv83, lv81_1, model_layers_21_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv83_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82_1[1]
            rms_norm261: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82_1[0]
            lv650 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight4, model_layers_21_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv84 = R.call_tir(cls.NT_matmul, (rms_norm261, lv650), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add387: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv84, model_layers_21_self_attn_c_attn_bias4)
            reshape516: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add387, R.shape([batch_size, 1, 20, 128]))
            reshape517: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape516, R.shape([batch_size, 20, 128]))
            lv651 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape518: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv651, R.shape([batch_size, 1, 16, 128]))
            reshape519: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape518, R.shape([batch_size, 1, 2048]))
            lv652 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight4, model_layers_21_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv85 = R.call_tir(cls.NT_matmul1, (reshape519, lv652), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84_1 = R.call_tir(cls.fuse_add_norm_decode, (lv85, lv83_1, model_layers_21_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv85_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84_1[1]
            rms_norm262: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84_1[0]
            lv653 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight4, model_layers_21_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv86 = R.call_tir(cls.NT_matmul2, (rms_norm262, lv653), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split129: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv86, indices_or_sections=2, axis=-1)
            split_0129: R.Tensor((batch_size, 1, 11008), dtype="float16") = split129[0]
            split_1129: R.Tensor((batch_size, 1, 11008), dtype="float16") = split129[1]
            silu129: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0129)
            mul129: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu129, split_1129)
            lv654 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight4, model_layers_21_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv87 = R.call_tir(cls.NT_matmul3, (mul129, lv654), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv86_1 = R.call_tir(cls.fuse_add_norm_decode, (lv87, lv85_1, model_layers_22_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv87_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86_1[1]
            rms_norm263: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86_1[0]
            lv655 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight4, model_layers_22_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv88 = R.call_tir(cls.NT_matmul, (rms_norm263, lv655), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add390: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv88, model_layers_22_self_attn_c_attn_bias4)
            reshape520: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add390, R.shape([batch_size, 1, 20, 128]))
            reshape521: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape520, R.shape([batch_size, 20, 128]))
            lv656 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape521), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape522: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv656, R.shape([batch_size, 1, 16, 128]))
            reshape523: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape522, R.shape([batch_size, 1, 2048]))
            lv657 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight4, model_layers_22_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv89 = R.call_tir(cls.NT_matmul1, (reshape523, lv657), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv88_1 = R.call_tir(cls.fuse_add_norm_decode, (lv89, lv87_1, model_layers_22_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv89_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88_1[1]
            rms_norm264: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88_1[0]
            lv658 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight4, model_layers_22_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv90 = R.call_tir(cls.NT_matmul2, (rms_norm264, lv658), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split130: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv90, indices_or_sections=2, axis=-1)
            split_0130: R.Tensor((batch_size, 1, 11008), dtype="float16") = split130[0]
            split_1130: R.Tensor((batch_size, 1, 11008), dtype="float16") = split130[1]
            silu130: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0130)
            mul130: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu130, split_1130)
            lv659 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight4, model_layers_22_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv91 = R.call_tir(cls.NT_matmul3, (mul130, lv659), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90_1 = R.call_tir(cls.fuse_add_norm_decode, (lv91, lv89_1, model_layers_23_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv91_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90_1[1]
            rms_norm265: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90_1[0]
            lv660 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight4, model_layers_23_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv92 = R.call_tir(cls.NT_matmul, (rms_norm265, lv660), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add393: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv92, model_layers_23_self_attn_c_attn_bias4)
            reshape524: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add393, R.shape([batch_size, 1, 20, 128]))
            reshape525: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape524, R.shape([batch_size, 20, 128]))
            lv661 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape525), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape526: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv661, R.shape([batch_size, 1, 16, 128]))
            reshape527: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape526, R.shape([batch_size, 1, 2048]))
            lv662 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight4, model_layers_23_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv93 = R.call_tir(cls.NT_matmul1, (reshape527, lv662), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv92_1 = R.call_tir(cls.fuse_add_norm_decode, (lv93, lv91_1, model_layers_23_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv93_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92_1[1]
            rms_norm266: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92_1[0]
            lv663 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight4, model_layers_23_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv94 = R.call_tir(cls.NT_matmul2, (rms_norm266, lv663), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split131: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv94, indices_or_sections=2, axis=-1)
            split_0131: R.Tensor((batch_size, 1, 11008), dtype="float16") = split131[0]
            split_1131: R.Tensor((batch_size, 1, 11008), dtype="float16") = split131[1]
            silu131: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0131)
            mul131: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu131, split_1131)
            lv664 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight4, model_layers_23_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv95 = R.call_tir(cls.NT_matmul3, (mul131, lv664), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv94_1 = R.call_tir(cls.fuse_add_norm_decode, (lv95, lv93_1, model_layers_24_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv95_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94_1[1]
            rms_norm267: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94_1[0]
            lv665 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight4, model_layers_24_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv96 = R.call_tir(cls.NT_matmul, (rms_norm267, lv665), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add396: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv96, model_layers_24_self_attn_c_attn_bias4)
            reshape528: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add396, R.shape([batch_size, 1, 20, 128]))
            reshape529: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape528, R.shape([batch_size, 20, 128]))
            lv666 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape529), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape530: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv666, R.shape([batch_size, 1, 16, 128]))
            reshape531: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape530, R.shape([batch_size, 1, 2048]))
            lv667 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight4, model_layers_24_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv97 = R.call_tir(cls.NT_matmul1, (reshape531, lv667), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv96_1 = R.call_tir(cls.fuse_add_norm_decode, (lv97, lv95_1, model_layers_24_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv97_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96_1[1]
            rms_norm268: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96_1[0]
            lv668 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight4, model_layers_24_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv98 = R.call_tir(cls.NT_matmul2, (rms_norm268, lv668), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split132: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv98, indices_or_sections=2, axis=-1)
            split_0132: R.Tensor((batch_size, 1, 11008), dtype="float16") = split132[0]
            split_1132: R.Tensor((batch_size, 1, 11008), dtype="float16") = split132[1]
            silu132: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0132)
            mul132: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu132, split_1132)
            lv669 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight4, model_layers_24_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv99 = R.call_tir(cls.NT_matmul3, (mul132, lv669), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv98_1 = R.call_tir(cls.fuse_add_norm_decode, (lv99, lv97_1, model_layers_25_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv99_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98_1[1]
            rms_norm269: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98_1[0]
            lv670 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight4, model_layers_25_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.NT_matmul, (rms_norm269, lv670), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add399: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv100, model_layers_25_self_attn_c_attn_bias4)
            reshape532: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add399, R.shape([batch_size, 1, 20, 128]))
            reshape533: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape532, R.shape([batch_size, 20, 128]))
            lv671 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape533), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape534: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv671, R.shape([batch_size, 1, 16, 128]))
            reshape535: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape534, R.shape([batch_size, 1, 2048]))
            lv672 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight4, model_layers_25_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv101 = R.call_tir(cls.NT_matmul1, (reshape535, lv672), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv100_1 = R.call_tir(cls.fuse_add_norm_decode, (lv101, lv99_1, model_layers_25_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv101_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100_1[1]
            rms_norm270: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100_1[0]
            lv673 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight4, model_layers_25_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv102 = R.call_tir(cls.NT_matmul2, (rms_norm270, lv673), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split133: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv102, indices_or_sections=2, axis=-1)
            split_0133: R.Tensor((batch_size, 1, 11008), dtype="float16") = split133[0]
            split_1133: R.Tensor((batch_size, 1, 11008), dtype="float16") = split133[1]
            silu133: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0133)
            mul133: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu133, split_1133)
            lv674 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight4, model_layers_25_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv103 = R.call_tir(cls.NT_matmul3, (mul133, lv674), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv102_1 = R.call_tir(cls.fuse_add_norm_decode, (lv103, lv101_1, model_layers_26_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv103_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102_1[1]
            rms_norm271: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102_1[0]
            lv675 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight4, model_layers_26_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv104 = R.call_tir(cls.NT_matmul, (rms_norm271, lv675), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add402: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv104, model_layers_26_self_attn_c_attn_bias4)
            reshape536: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add402, R.shape([batch_size, 1, 20, 128]))
            reshape537: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape536, R.shape([batch_size, 20, 128]))
            lv676 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape537), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape538: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv676, R.shape([batch_size, 1, 16, 128]))
            reshape539: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape538, R.shape([batch_size, 1, 2048]))
            lv677 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight4, model_layers_26_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv105 = R.call_tir(cls.NT_matmul1, (reshape539, lv677), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv104_1 = R.call_tir(cls.fuse_add_norm_decode, (lv105, lv103_1, model_layers_26_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv105_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104_1[1]
            rms_norm272: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104_1[0]
            lv678 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight4, model_layers_26_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv106 = R.call_tir(cls.NT_matmul2, (rms_norm272, lv678), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split134: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv106, indices_or_sections=2, axis=-1)
            split_0134: R.Tensor((batch_size, 1, 11008), dtype="float16") = split134[0]
            split_1134: R.Tensor((batch_size, 1, 11008), dtype="float16") = split134[1]
            silu134: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0134)
            mul134: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu134, split_1134)
            lv679 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight4, model_layers_26_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv107 = R.call_tir(cls.NT_matmul3, (mul134, lv679), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv106_1 = R.call_tir(cls.fuse_add_norm_decode, (lv107, lv105_1, model_layers_27_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv107_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106_1[1]
            rms_norm273: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106_1[0]
            lv680 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight4, model_layers_27_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv108 = R.call_tir(cls.NT_matmul, (rms_norm273, lv680), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add405: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv108, model_layers_27_self_attn_c_attn_bias4)
            reshape540: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add405, R.shape([batch_size, 1, 20, 128]))
            reshape541: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape540, R.shape([batch_size, 20, 128]))
            lv681 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape541), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape542: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv681, R.shape([batch_size, 1, 16, 128]))
            reshape543: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape542, R.shape([batch_size, 1, 2048]))
            lv682 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight4, model_layers_27_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv109 = R.call_tir(cls.NT_matmul1, (reshape543, lv682), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv108_1 = R.call_tir(cls.fuse_add_norm_decode, (lv109, lv107_1, model_layers_27_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv109_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108_1[1]
            rms_norm274: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108_1[0]
            lv683 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight4, model_layers_27_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv110 = R.call_tir(cls.NT_matmul2, (rms_norm274, lv683), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split135: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv110, indices_or_sections=2, axis=-1)
            split_0135: R.Tensor((batch_size, 1, 11008), dtype="float16") = split135[0]
            split_1135: R.Tensor((batch_size, 1, 11008), dtype="float16") = split135[1]
            silu135: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0135)
            mul135: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu135, split_1135)
            lv684 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight4, model_layers_27_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv111 = R.call_tir(cls.NT_matmul3, (mul135, lv684), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv110_1 = R.call_tir(cls.fuse_add_norm_decode, (lv111, lv109_1, model_layers_28_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv111_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110_1[1]
            rms_norm275: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110_1[0]
            lv685 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight4, model_layers_28_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv112 = R.call_tir(cls.NT_matmul, (rms_norm275, lv685), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add408: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv112, model_layers_28_self_attn_c_attn_bias4)
            reshape544: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add408, R.shape([batch_size, 1, 20, 128]))
            reshape545: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape544, R.shape([batch_size, 20, 128]))
            lv686 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape545), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape546: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv686, R.shape([batch_size, 1, 16, 128]))
            reshape547: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape546, R.shape([batch_size, 1, 2048]))
            lv687 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight4, model_layers_28_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv113 = R.call_tir(cls.NT_matmul1, (reshape547, lv687), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv112_1 = R.call_tir(cls.fuse_add_norm_decode, (lv113, lv111_1, model_layers_28_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv113_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112_1[1]
            rms_norm276: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112_1[0]
            lv688 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight4, model_layers_28_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv114 = R.call_tir(cls.NT_matmul2, (rms_norm276, lv688), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split136: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv114, indices_or_sections=2, axis=-1)
            split_0136: R.Tensor((batch_size, 1, 11008), dtype="float16") = split136[0]
            split_1136: R.Tensor((batch_size, 1, 11008), dtype="float16") = split136[1]
            silu136: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0136)
            mul136: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu136, split_1136)
            lv689 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight4, model_layers_28_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv115 = R.call_tir(cls.NT_matmul3, (mul136, lv689), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv114_1 = R.call_tir(cls.fuse_add_norm_decode, (lv115, lv113_1, model_layers_29_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv115_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114_1[1]
            rms_norm277: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114_1[0]
            lv690 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight4, model_layers_29_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv116 = R.call_tir(cls.NT_matmul, (rms_norm277, lv690), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add411: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv116, model_layers_29_self_attn_c_attn_bias4)
            reshape548: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add411, R.shape([batch_size, 1, 20, 128]))
            reshape549: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape548, R.shape([batch_size, 20, 128]))
            lv691 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape549), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape550: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv691, R.shape([batch_size, 1, 16, 128]))
            reshape551: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape550, R.shape([batch_size, 1, 2048]))
            lv692 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight4, model_layers_29_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv117 = R.call_tir(cls.NT_matmul1, (reshape551, lv692), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv116_1 = R.call_tir(cls.fuse_add_norm_decode, (lv117, lv115_1, model_layers_29_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv117_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116_1[1]
            rms_norm278: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116_1[0]
            lv693 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight4, model_layers_29_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.NT_matmul2, (rms_norm278, lv693), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split137: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv118, indices_or_sections=2, axis=-1)
            split_0137: R.Tensor((batch_size, 1, 11008), dtype="float16") = split137[0]
            split_1137: R.Tensor((batch_size, 1, 11008), dtype="float16") = split137[1]
            silu137: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0137)
            mul137: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu137, split_1137)
            lv694 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight4, model_layers_29_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv119 = R.call_tir(cls.NT_matmul3, (mul137, lv694), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv118_1 = R.call_tir(cls.fuse_add_norm_decode, (lv119, lv117_1, model_layers_30_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv119_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118_1[1]
            rms_norm279: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118_1[0]
            lv695 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight4, model_layers_30_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv120 = R.call_tir(cls.NT_matmul, (rms_norm279, lv695), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add414: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv120, model_layers_30_self_attn_c_attn_bias4)
            reshape552: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add414, R.shape([batch_size, 1, 20, 128]))
            reshape553: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape552, R.shape([batch_size, 20, 128]))
            lv696 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape554: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv696, R.shape([batch_size, 1, 16, 128]))
            reshape555: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape554, R.shape([batch_size, 1, 2048]))
            lv697 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight4, model_layers_30_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv121 = R.call_tir(cls.NT_matmul1, (reshape555, lv697), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv120_1 = R.call_tir(cls.fuse_add_norm_decode, (lv121, lv119_1, model_layers_30_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv121_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120_1[1]
            rms_norm280: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120_1[0]
            lv698 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight4, model_layers_30_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv122 = R.call_tir(cls.NT_matmul2, (rms_norm280, lv698), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split138: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv122, indices_or_sections=2, axis=-1)
            split_0138: R.Tensor((batch_size, 1, 11008), dtype="float16") = split138[0]
            split_1138: R.Tensor((batch_size, 1, 11008), dtype="float16") = split138[1]
            silu138: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0138)
            mul138: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu138, split_1138)
            lv699 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight4, model_layers_30_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv123 = R.call_tir(cls.NT_matmul3, (mul138, lv699), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv122_1 = R.call_tir(cls.fuse_add_norm_decode, (lv123, lv121_1, model_layers_31_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv123_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122_1[1]
            rms_norm281: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122_1[0]
            lv700 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight4, model_layers_31_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv124 = R.call_tir(cls.NT_matmul, (rms_norm281, lv700), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add417: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv124, model_layers_31_self_attn_c_attn_bias4)
            reshape556: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add417, R.shape([batch_size, 1, 20, 128]))
            reshape557: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape556, R.shape([batch_size, 20, 128]))
            lv701 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape557), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape558: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv701, R.shape([batch_size, 1, 16, 128]))
            reshape559: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape558, R.shape([batch_size, 1, 2048]))
            lv702 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight4, model_layers_31_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv125 = R.call_tir(cls.NT_matmul1, (reshape559, lv702), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv124_1 = R.call_tir(cls.fuse_add_norm_decode, (lv125, lv123_1, model_layers_31_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv125_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124_1[1]
            rms_norm282: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124_1[0]
            lv703 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight4, model_layers_31_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv126 = R.call_tir(cls.NT_matmul2, (rms_norm282, lv703), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split139: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv126, indices_or_sections=2, axis=-1)
            split_0139: R.Tensor((batch_size, 1, 11008), dtype="float16") = split139[0]
            split_1139: R.Tensor((batch_size, 1, 11008), dtype="float16") = split139[1]
            silu139: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0139)
            mul139: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu139, split_1139)
            lv704 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight4, model_layers_31_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv127 = R.call_tir(cls.NT_matmul3, (mul139, lv704), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv126_1 = R.call_tir(cls.fuse_add_norm_decode, (lv127, lv125_1, model_layers_32_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv127_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126_1[1]
            rms_norm283: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126_1[0]
            lv705 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight4, model_layers_32_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv128 = R.call_tir(cls.NT_matmul, (rms_norm283, lv705), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add420: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv128, model_layers_32_self_attn_c_attn_bias4)
            reshape560: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add420, R.shape([batch_size, 1, 20, 128]))
            reshape561: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape560, R.shape([batch_size, 20, 128]))
            lv706 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape561), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape562: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv706, R.shape([batch_size, 1, 16, 128]))
            reshape563: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape562, R.shape([batch_size, 1, 2048]))
            lv707 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight4, model_layers_32_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv129 = R.call_tir(cls.NT_matmul1, (reshape563, lv707), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv128_1 = R.call_tir(cls.fuse_add_norm_decode, (lv129, lv127_1, model_layers_32_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv129_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128_1[1]
            rms_norm284: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128_1[0]
            lv708 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight4, model_layers_32_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv130 = R.call_tir(cls.NT_matmul2, (rms_norm284, lv708), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split140: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv130, indices_or_sections=2, axis=-1)
            split_0140: R.Tensor((batch_size, 1, 11008), dtype="float16") = split140[0]
            split_1140: R.Tensor((batch_size, 1, 11008), dtype="float16") = split140[1]
            silu140: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0140)
            mul140: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu140, split_1140)
            lv709 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight4, model_layers_32_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv131 = R.call_tir(cls.NT_matmul3, (mul140, lv709), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv130_1 = R.call_tir(cls.fuse_add_norm_decode, (lv131, lv129_1, model_layers_33_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv131_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130_1[1]
            rms_norm285: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130_1[0]
            lv710 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight4, model_layers_33_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv132 = R.call_tir(cls.NT_matmul, (rms_norm285, lv710), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add423: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv132, model_layers_33_self_attn_c_attn_bias4)
            reshape564: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add423, R.shape([batch_size, 1, 20, 128]))
            reshape565: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape564, R.shape([batch_size, 20, 128]))
            lv711 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape565), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape566: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv711, R.shape([batch_size, 1, 16, 128]))
            reshape567: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape566, R.shape([batch_size, 1, 2048]))
            lv712 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight4, model_layers_33_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv133 = R.call_tir(cls.NT_matmul1, (reshape567, lv712), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv132_1 = R.call_tir(cls.fuse_add_norm_decode, (lv133, lv131_1, model_layers_33_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv133_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132_1[1]
            rms_norm286: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132_1[0]
            lv713 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight4, model_layers_33_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv134 = R.call_tir(cls.NT_matmul2, (rms_norm286, lv713), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split141: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv134, indices_or_sections=2, axis=-1)
            split_0141: R.Tensor((batch_size, 1, 11008), dtype="float16") = split141[0]
            split_1141: R.Tensor((batch_size, 1, 11008), dtype="float16") = split141[1]
            silu141: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0141)
            mul141: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu141, split_1141)
            lv714 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight4, model_layers_33_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv135 = R.call_tir(cls.NT_matmul3, (mul141, lv714), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv134_1 = R.call_tir(cls.fuse_add_norm_decode, (lv135, lv133_1, model_layers_34_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv135_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134_1[1]
            rms_norm287: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134_1[0]
            lv715 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight4, model_layers_34_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.NT_matmul, (rms_norm287, lv715), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add426: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv136, model_layers_34_self_attn_c_attn_bias4)
            reshape568: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add426, R.shape([batch_size, 1, 20, 128]))
            reshape569: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape568, R.shape([batch_size, 20, 128]))
            lv716 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape569), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape570: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv716, R.shape([batch_size, 1, 16, 128]))
            reshape571: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape570, R.shape([batch_size, 1, 2048]))
            lv717 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight4, model_layers_34_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv137 = R.call_tir(cls.NT_matmul1, (reshape571, lv717), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv136_1 = R.call_tir(cls.fuse_add_norm_decode, (lv137, lv135_1, model_layers_34_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv137_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136_1[1]
            rms_norm288: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136_1[0]
            lv718 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight4, model_layers_34_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv138 = R.call_tir(cls.NT_matmul2, (rms_norm288, lv718), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split142: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv138, indices_or_sections=2, axis=-1)
            split_0142: R.Tensor((batch_size, 1, 11008), dtype="float16") = split142[0]
            split_1142: R.Tensor((batch_size, 1, 11008), dtype="float16") = split142[1]
            silu142: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0142)
            mul142: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu142, split_1142)
            lv719 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight4, model_layers_34_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv139 = R.call_tir(cls.NT_matmul3, (mul142, lv719), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv138_1 = R.call_tir(cls.fuse_add_norm_decode, (lv139, lv137_1, model_layers_35_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv139_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138_1[1]
            rms_norm289: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138_1[0]
            lv720 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight4, model_layers_35_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv140 = R.call_tir(cls.NT_matmul, (rms_norm289, lv720), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            add429: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(lv140, model_layers_35_self_attn_c_attn_bias4)
            reshape572: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add429, R.shape([batch_size, 1, 20, 128]))
            reshape573: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape572, R.shape([batch_size, 20, 128]))
            lv721 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape573), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape574: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv721, R.shape([batch_size, 1, 16, 128]))
            reshape575: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape574, R.shape([batch_size, 1, 2048]))
            lv722 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight4, model_layers_35_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv141 = R.call_tir(cls.NT_matmul1, (reshape575, lv722), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv140_1 = R.call_tir(cls.fuse_add_norm_decode, (lv141, lv139_1, model_layers_35_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv141_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140_1[1]
            rms_norm290: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140_1[0]
            lv723 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight4, model_layers_35_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv142 = R.call_tir(cls.NT_matmul2, (rms_norm290, lv723), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            split143: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(lv142, indices_or_sections=2, axis=-1)
            split_0143: R.Tensor((batch_size, 1, 11008), dtype="float16") = split143[0]
            split_1143: R.Tensor((batch_size, 1, 11008), dtype="float16") = split143[1]
            silu143: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0143)
            mul143: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu143, split_1143)
            lv724 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight4, model_layers_35_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv143 = R.call_tir(cls.NT_matmul3, (mul143, lv724), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv142_1 = R.call_tir(cls.fuse_add_norm_decode, (lv143, lv141_1, model_norm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            rms_norm291: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv142_1[0]
            lv725 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight4, model_embed_tokens_q_scale4), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            lv144 = R.call_tir(cls.NT_matmul4, (rms_norm291, lv725), out_sinfo=R.Tensor((batch_size, 1, 151936), dtype="float32"))
            gv4: R.Tuple(R.Tensor((batch_size, 1, 151936), dtype="float32"), R.Object) = lv144, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight3: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale3: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm146: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv364 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight3, model_layers_0_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv145 = R.call_tir(cls.NT_matmul5, (rms_norm146, lv364), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add216: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv145, model_layers_0_self_attn_c_attn_bias3)
            reshape288: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add216, R.shape([1, seq_len, 20, 128]))
            reshape289: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape288, R.shape([seq_len, 20, 128]))
            lv365 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape290: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv365, R.shape([1, seq_len, 16, 128]))
            reshape291: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape290, R.shape([1, seq_len, 2048]))
            lv366 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight3, model_layers_0_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv146 = R.call_tir(cls.NT_matmul6, (reshape291, lv366), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv144 = R.call_tir(cls.fuse_add_norm_prefill, (lv146, input_embeds, model_layers_0_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv145_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[1]
            rms_norm147: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[0]
            lv367 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight3, model_layers_0_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv147 = R.call_tir(cls.NT_matmul7, (rms_norm147, lv367), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split72: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv147, indices_or_sections=2, axis=-1)
            split_072: R.Tensor((1, seq_len, 11008), dtype="float16") = split72[0]
            split_172: R.Tensor((1, seq_len, 11008), dtype="float16") = split72[1]
            silu72: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_072)
            mul72: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu72, split_172)
            lv368 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight3, model_layers_0_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv148 = R.call_tir(cls.NT_matmul8, (mul72, lv368), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv146_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv148, lv145_1, model_layers_1_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv147_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146_1[1]
            rms_norm148: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146_1[0]
            lv369 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight3, model_layers_1_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv149 = R.call_tir(cls.NT_matmul5, (rms_norm148, lv369), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add219: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv149, model_layers_1_self_attn_c_attn_bias3)
            reshape292: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add219, R.shape([1, seq_len, 20, 128]))
            reshape293: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape292, R.shape([seq_len, 20, 128]))
            lv370 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape294: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv370, R.shape([1, seq_len, 16, 128]))
            reshape295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape294, R.shape([1, seq_len, 2048]))
            lv371 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight3, model_layers_1_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv150 = R.call_tir(cls.NT_matmul6, (reshape295, lv371), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv148_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv150, lv147_1, model_layers_1_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv149_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148_1[1]
            rms_norm149: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148_1[0]
            lv372 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight3, model_layers_1_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv151 = R.call_tir(cls.NT_matmul7, (rms_norm149, lv372), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split73: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv151, indices_or_sections=2, axis=-1)
            split_073: R.Tensor((1, seq_len, 11008), dtype="float16") = split73[0]
            split_173: R.Tensor((1, seq_len, 11008), dtype="float16") = split73[1]
            silu73: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_073)
            mul73: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu73, split_173)
            lv373 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight3, model_layers_1_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv152 = R.call_tir(cls.NT_matmul8, (mul73, lv373), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv150_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv152, lv149_1, model_layers_2_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv151_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150_1[1]
            rms_norm150: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150_1[0]
            lv374 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight3, model_layers_2_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv153 = R.call_tir(cls.NT_matmul5, (rms_norm150, lv374), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add222: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv153, model_layers_2_self_attn_c_attn_bias3)
            reshape296: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add222, R.shape([1, seq_len, 20, 128]))
            reshape297: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape296, R.shape([seq_len, 20, 128]))
            lv375 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape298: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv375, R.shape([1, seq_len, 16, 128]))
            reshape299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape298, R.shape([1, seq_len, 2048]))
            lv376 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight3, model_layers_2_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv154 = R.call_tir(cls.NT_matmul6, (reshape299, lv376), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv152_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv154, lv151_1, model_layers_2_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv153_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152_1[1]
            rms_norm151: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152_1[0]
            lv377 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight3, model_layers_2_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv155 = R.call_tir(cls.NT_matmul7, (rms_norm151, lv377), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split74: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv155, indices_or_sections=2, axis=-1)
            split_074: R.Tensor((1, seq_len, 11008), dtype="float16") = split74[0]
            split_174: R.Tensor((1, seq_len, 11008), dtype="float16") = split74[1]
            silu74: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_074)
            mul74: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu74, split_174)
            lv378 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight3, model_layers_2_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv156 = R.call_tir(cls.NT_matmul8, (mul74, lv378), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv154_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv156, lv153_1, model_layers_3_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv155_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154_1[1]
            rms_norm152: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154_1[0]
            lv379 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight3, model_layers_3_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv157 = R.call_tir(cls.NT_matmul5, (rms_norm152, lv379), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add225: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv157, model_layers_3_self_attn_c_attn_bias3)
            reshape300: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add225, R.shape([1, seq_len, 20, 128]))
            reshape301: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape300, R.shape([seq_len, 20, 128]))
            lv380 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape302: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv380, R.shape([1, seq_len, 16, 128]))
            reshape303: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape302, R.shape([1, seq_len, 2048]))
            lv381 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight3, model_layers_3_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv158 = R.call_tir(cls.NT_matmul6, (reshape303, lv381), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv156_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv158, lv155_1, model_layers_3_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv157_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156_1[1]
            rms_norm153: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156_1[0]
            lv382 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight3, model_layers_3_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv159 = R.call_tir(cls.NT_matmul7, (rms_norm153, lv382), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split75: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv159, indices_or_sections=2, axis=-1)
            split_075: R.Tensor((1, seq_len, 11008), dtype="float16") = split75[0]
            split_175: R.Tensor((1, seq_len, 11008), dtype="float16") = split75[1]
            silu75: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_075)
            mul75: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu75, split_175)
            lv383 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight3, model_layers_3_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv160 = R.call_tir(cls.NT_matmul8, (mul75, lv383), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv158_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv160, lv157_1, model_layers_4_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv159_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158_1[1]
            rms_norm154: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158_1[0]
            lv384 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight3, model_layers_4_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv161 = R.call_tir(cls.NT_matmul5, (rms_norm154, lv384), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add228: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv161, model_layers_4_self_attn_c_attn_bias3)
            reshape304: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add228, R.shape([1, seq_len, 20, 128]))
            reshape305: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape304, R.shape([seq_len, 20, 128]))
            lv385 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape306: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv385, R.shape([1, seq_len, 16, 128]))
            reshape307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape306, R.shape([1, seq_len, 2048]))
            lv386 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight3, model_layers_4_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv162 = R.call_tir(cls.NT_matmul6, (reshape307, lv386), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv160_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv162, lv159_1, model_layers_4_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv161_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160_1[1]
            rms_norm155: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160_1[0]
            lv387 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight3, model_layers_4_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv163 = R.call_tir(cls.NT_matmul7, (rms_norm155, lv387), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split76: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv163, indices_or_sections=2, axis=-1)
            split_076: R.Tensor((1, seq_len, 11008), dtype="float16") = split76[0]
            split_176: R.Tensor((1, seq_len, 11008), dtype="float16") = split76[1]
            silu76: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_076)
            mul76: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu76, split_176)
            lv388 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight3, model_layers_4_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv164 = R.call_tir(cls.NT_matmul8, (mul76, lv388), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv162_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv164, lv161_1, model_layers_5_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv163_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162_1[1]
            rms_norm156: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162_1[0]
            lv389 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight3, model_layers_5_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv165 = R.call_tir(cls.NT_matmul5, (rms_norm156, lv389), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add231: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv165, model_layers_5_self_attn_c_attn_bias3)
            reshape308: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add231, R.shape([1, seq_len, 20, 128]))
            reshape309: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape308, R.shape([seq_len, 20, 128]))
            lv390 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape310: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv390, R.shape([1, seq_len, 16, 128]))
            reshape311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape310, R.shape([1, seq_len, 2048]))
            lv391 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight3, model_layers_5_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv166 = R.call_tir(cls.NT_matmul6, (reshape311, lv391), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv164_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv166, lv163_1, model_layers_5_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv165_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164_1[1]
            rms_norm157: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164_1[0]
            lv392 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight3, model_layers_5_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv167 = R.call_tir(cls.NT_matmul7, (rms_norm157, lv392), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split77: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv167, indices_or_sections=2, axis=-1)
            split_077: R.Tensor((1, seq_len, 11008), dtype="float16") = split77[0]
            split_177: R.Tensor((1, seq_len, 11008), dtype="float16") = split77[1]
            silu77: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_077)
            mul77: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu77, split_177)
            lv393 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight3, model_layers_5_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv168 = R.call_tir(cls.NT_matmul8, (mul77, lv393), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv166_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv168, lv165_1, model_layers_6_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv167_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166_1[1]
            rms_norm158: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166_1[0]
            lv394 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight3, model_layers_6_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv169 = R.call_tir(cls.NT_matmul5, (rms_norm158, lv394), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add234: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv169, model_layers_6_self_attn_c_attn_bias3)
            reshape312: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add234, R.shape([1, seq_len, 20, 128]))
            reshape313: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape312, R.shape([seq_len, 20, 128]))
            lv395 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape314: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv395, R.shape([1, seq_len, 16, 128]))
            reshape315: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape314, R.shape([1, seq_len, 2048]))
            lv396 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight3, model_layers_6_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv170 = R.call_tir(cls.NT_matmul6, (reshape315, lv396), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv168_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv170, lv167_1, model_layers_6_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv169_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168_1[1]
            rms_norm159: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168_1[0]
            lv397 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight3, model_layers_6_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv171 = R.call_tir(cls.NT_matmul7, (rms_norm159, lv397), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split78: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv171, indices_or_sections=2, axis=-1)
            split_078: R.Tensor((1, seq_len, 11008), dtype="float16") = split78[0]
            split_178: R.Tensor((1, seq_len, 11008), dtype="float16") = split78[1]
            silu78: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_078)
            mul78: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu78, split_178)
            lv398 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight3, model_layers_6_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv172 = R.call_tir(cls.NT_matmul8, (mul78, lv398), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv170_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv172, lv169_1, model_layers_7_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv171_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170_1[1]
            rms_norm160: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170_1[0]
            lv399 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight3, model_layers_7_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv173 = R.call_tir(cls.NT_matmul5, (rms_norm160, lv399), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add237: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv173, model_layers_7_self_attn_c_attn_bias3)
            reshape316: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add237, R.shape([1, seq_len, 20, 128]))
            reshape317: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape316, R.shape([seq_len, 20, 128]))
            lv400 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape318: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv400, R.shape([1, seq_len, 16, 128]))
            reshape319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape318, R.shape([1, seq_len, 2048]))
            lv401 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight3, model_layers_7_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv174 = R.call_tir(cls.NT_matmul6, (reshape319, lv401), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv172_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv174, lv171_1, model_layers_7_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv173_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172_1[1]
            rms_norm161: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172_1[0]
            lv402 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight3, model_layers_7_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv175 = R.call_tir(cls.NT_matmul7, (rms_norm161, lv402), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split79: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv175, indices_or_sections=2, axis=-1)
            split_079: R.Tensor((1, seq_len, 11008), dtype="float16") = split79[0]
            split_179: R.Tensor((1, seq_len, 11008), dtype="float16") = split79[1]
            silu79: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_079)
            mul79: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu79, split_179)
            lv403 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight3, model_layers_7_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv176 = R.call_tir(cls.NT_matmul8, (mul79, lv403), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv174_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv176, lv173_1, model_layers_8_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv175_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174_1[1]
            rms_norm162: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174_1[0]
            lv404 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight3, model_layers_8_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv177 = R.call_tir(cls.NT_matmul5, (rms_norm162, lv404), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add240: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv177, model_layers_8_self_attn_c_attn_bias3)
            reshape320: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add240, R.shape([1, seq_len, 20, 128]))
            reshape321: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape320, R.shape([seq_len, 20, 128]))
            lv405 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape321), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape322: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv405, R.shape([1, seq_len, 16, 128]))
            reshape323: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape322, R.shape([1, seq_len, 2048]))
            lv406 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight3, model_layers_8_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv178 = R.call_tir(cls.NT_matmul6, (reshape323, lv406), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv176_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv178, lv175_1, model_layers_8_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv177_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176_1[1]
            rms_norm163: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176_1[0]
            lv407 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight3, model_layers_8_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv179 = R.call_tir(cls.NT_matmul7, (rms_norm163, lv407), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split80: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv179, indices_or_sections=2, axis=-1)
            split_080: R.Tensor((1, seq_len, 11008), dtype="float16") = split80[0]
            split_180: R.Tensor((1, seq_len, 11008), dtype="float16") = split80[1]
            silu80: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_080)
            mul80: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu80, split_180)
            lv408 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight3, model_layers_8_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv180 = R.call_tir(cls.NT_matmul8, (mul80, lv408), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv178_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv180, lv177_1, model_layers_9_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv179_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178_1[1]
            rms_norm164: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178_1[0]
            lv409 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight3, model_layers_9_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv181 = R.call_tir(cls.NT_matmul5, (rms_norm164, lv409), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add243: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv181, model_layers_9_self_attn_c_attn_bias3)
            reshape324: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add243, R.shape([1, seq_len, 20, 128]))
            reshape325: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape324, R.shape([seq_len, 20, 128]))
            lv410 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape325), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape326: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv410, R.shape([1, seq_len, 16, 128]))
            reshape327: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape326, R.shape([1, seq_len, 2048]))
            lv411 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight3, model_layers_9_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv182 = R.call_tir(cls.NT_matmul6, (reshape327, lv411), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv180_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv182, lv179_1, model_layers_9_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv181_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180_1[1]
            rms_norm165: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180_1[0]
            lv412 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight3, model_layers_9_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv183 = R.call_tir(cls.NT_matmul7, (rms_norm165, lv412), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split81: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv183, indices_or_sections=2, axis=-1)
            split_081: R.Tensor((1, seq_len, 11008), dtype="float16") = split81[0]
            split_181: R.Tensor((1, seq_len, 11008), dtype="float16") = split81[1]
            silu81: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_081)
            mul81: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu81, split_181)
            lv413 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight3, model_layers_9_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv184 = R.call_tir(cls.NT_matmul8, (mul81, lv413), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv182_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv184, lv181_1, model_layers_10_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv183_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182_1[1]
            rms_norm166: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182_1[0]
            lv414 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight3, model_layers_10_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv185 = R.call_tir(cls.NT_matmul5, (rms_norm166, lv414), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add246: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv185, model_layers_10_self_attn_c_attn_bias3)
            reshape328: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add246, R.shape([1, seq_len, 20, 128]))
            reshape329: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape328, R.shape([seq_len, 20, 128]))
            lv415 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape329), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape330: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv415, R.shape([1, seq_len, 16, 128]))
            reshape331: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape330, R.shape([1, seq_len, 2048]))
            lv416 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight3, model_layers_10_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv186 = R.call_tir(cls.NT_matmul6, (reshape331, lv416), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv184_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv186, lv183_1, model_layers_10_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv185_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184_1[1]
            rms_norm167: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184_1[0]
            lv417 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight3, model_layers_10_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv187 = R.call_tir(cls.NT_matmul7, (rms_norm167, lv417), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split82: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv187, indices_or_sections=2, axis=-1)
            split_082: R.Tensor((1, seq_len, 11008), dtype="float16") = split82[0]
            split_182: R.Tensor((1, seq_len, 11008), dtype="float16") = split82[1]
            silu82: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_082)
            mul82: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu82, split_182)
            lv418 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight3, model_layers_10_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv188 = R.call_tir(cls.NT_matmul8, (mul82, lv418), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv186_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv188, lv185_1, model_layers_11_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv187_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186_1[1]
            rms_norm168: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186_1[0]
            lv419 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight3, model_layers_11_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv189 = R.call_tir(cls.NT_matmul5, (rms_norm168, lv419), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add249: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv189, model_layers_11_self_attn_c_attn_bias3)
            reshape332: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add249, R.shape([1, seq_len, 20, 128]))
            reshape333: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape332, R.shape([seq_len, 20, 128]))
            lv420 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape333), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape334: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv420, R.shape([1, seq_len, 16, 128]))
            reshape335: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape334, R.shape([1, seq_len, 2048]))
            lv421 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight3, model_layers_11_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv190 = R.call_tir(cls.NT_matmul6, (reshape335, lv421), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv188_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv190, lv187_1, model_layers_11_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv189_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188_1[1]
            rms_norm169: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188_1[0]
            lv422 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight3, model_layers_11_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv191 = R.call_tir(cls.NT_matmul7, (rms_norm169, lv422), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split83: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv191, indices_or_sections=2, axis=-1)
            split_083: R.Tensor((1, seq_len, 11008), dtype="float16") = split83[0]
            split_183: R.Tensor((1, seq_len, 11008), dtype="float16") = split83[1]
            silu83: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_083)
            mul83: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu83, split_183)
            lv423 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight3, model_layers_11_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv192 = R.call_tir(cls.NT_matmul8, (mul83, lv423), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv190_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv192, lv189_1, model_layers_12_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv191_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv190_1[1]
            rms_norm170: R.Tensor((1, seq_len, 2048), dtype="float16") = lv190_1[0]
            lv424 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight3, model_layers_12_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv193 = R.call_tir(cls.NT_matmul5, (rms_norm170, lv424), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add252: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv193, model_layers_12_self_attn_c_attn_bias3)
            reshape336: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add252, R.shape([1, seq_len, 20, 128]))
            reshape337: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape336, R.shape([seq_len, 20, 128]))
            lv425 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape337), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape338: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv425, R.shape([1, seq_len, 16, 128]))
            reshape339: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape338, R.shape([1, seq_len, 2048]))
            lv426 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight3, model_layers_12_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv194 = R.call_tir(cls.NT_matmul6, (reshape339, lv426), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv192_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv194, lv191_1, model_layers_12_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv193_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192_1[1]
            rms_norm171: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192_1[0]
            lv427 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight3, model_layers_12_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv195 = R.call_tir(cls.NT_matmul7, (rms_norm171, lv427), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split84: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv195, indices_or_sections=2, axis=-1)
            split_084: R.Tensor((1, seq_len, 11008), dtype="float16") = split84[0]
            split_184: R.Tensor((1, seq_len, 11008), dtype="float16") = split84[1]
            silu84: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_084)
            mul84: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu84, split_184)
            lv428 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight3, model_layers_12_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv196 = R.call_tir(cls.NT_matmul8, (mul84, lv428), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv194_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv196, lv193_1, model_layers_13_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv195_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194_1[1]
            rms_norm172: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194_1[0]
            lv429 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight3, model_layers_13_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv197 = R.call_tir(cls.NT_matmul5, (rms_norm172, lv429), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add255: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv197, model_layers_13_self_attn_c_attn_bias3)
            reshape340: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add255, R.shape([1, seq_len, 20, 128]))
            reshape341: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape340, R.shape([seq_len, 20, 128]))
            lv430 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape341), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape342: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv430, R.shape([1, seq_len, 16, 128]))
            reshape343: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape342, R.shape([1, seq_len, 2048]))
            lv431 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight3, model_layers_13_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv198 = R.call_tir(cls.NT_matmul6, (reshape343, lv431), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv196_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv198, lv195_1, model_layers_13_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv197_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196_1[1]
            rms_norm173: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196_1[0]
            lv432 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight3, model_layers_13_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv199 = R.call_tir(cls.NT_matmul7, (rms_norm173, lv432), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split85: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv199, indices_or_sections=2, axis=-1)
            split_085: R.Tensor((1, seq_len, 11008), dtype="float16") = split85[0]
            split_185: R.Tensor((1, seq_len, 11008), dtype="float16") = split85[1]
            silu85: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_085)
            mul85: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu85, split_185)
            lv433 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight3, model_layers_13_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv200 = R.call_tir(cls.NT_matmul8, (mul85, lv433), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv198_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv200, lv197_1, model_layers_14_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv199_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198_1[1]
            rms_norm174: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198_1[0]
            lv434 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight3, model_layers_14_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv201 = R.call_tir(cls.NT_matmul5, (rms_norm174, lv434), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add258: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv201, model_layers_14_self_attn_c_attn_bias3)
            reshape344: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add258, R.shape([1, seq_len, 20, 128]))
            reshape345: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape344, R.shape([seq_len, 20, 128]))
            lv435 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape345), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape346: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv435, R.shape([1, seq_len, 16, 128]))
            reshape347: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape346, R.shape([1, seq_len, 2048]))
            lv436 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight3, model_layers_14_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv202 = R.call_tir(cls.NT_matmul6, (reshape347, lv436), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv200_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv202, lv199_1, model_layers_14_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv201_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200_1[1]
            rms_norm175: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200_1[0]
            lv437 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight3, model_layers_14_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv203 = R.call_tir(cls.NT_matmul7, (rms_norm175, lv437), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split86: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv203, indices_or_sections=2, axis=-1)
            split_086: R.Tensor((1, seq_len, 11008), dtype="float16") = split86[0]
            split_186: R.Tensor((1, seq_len, 11008), dtype="float16") = split86[1]
            silu86: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_086)
            mul86: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu86, split_186)
            lv438 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight3, model_layers_14_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv204 = R.call_tir(cls.NT_matmul8, (mul86, lv438), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv202_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv204, lv201_1, model_layers_15_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv203_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv202_1[1]
            rms_norm176: R.Tensor((1, seq_len, 2048), dtype="float16") = lv202_1[0]
            lv439 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight3, model_layers_15_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv205 = R.call_tir(cls.NT_matmul5, (rms_norm176, lv439), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add261: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv205, model_layers_15_self_attn_c_attn_bias3)
            reshape348: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add261, R.shape([1, seq_len, 20, 128]))
            reshape349: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape348, R.shape([seq_len, 20, 128]))
            lv440 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape349), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape350: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv440, R.shape([1, seq_len, 16, 128]))
            reshape351: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape350, R.shape([1, seq_len, 2048]))
            lv441 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight3, model_layers_15_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv206 = R.call_tir(cls.NT_matmul6, (reshape351, lv441), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv204_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv206, lv203_1, model_layers_15_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv205_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv204_1[1]
            rms_norm177: R.Tensor((1, seq_len, 2048), dtype="float16") = lv204_1[0]
            lv442 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight3, model_layers_15_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv207 = R.call_tir(cls.NT_matmul7, (rms_norm177, lv442), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split87: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv207, indices_or_sections=2, axis=-1)
            split_087: R.Tensor((1, seq_len, 11008), dtype="float16") = split87[0]
            split_187: R.Tensor((1, seq_len, 11008), dtype="float16") = split87[1]
            silu87: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_087)
            mul87: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu87, split_187)
            lv443 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight3, model_layers_15_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv208 = R.call_tir(cls.NT_matmul8, (mul87, lv443), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv206_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv208, lv205_1, model_layers_16_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv207_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv206_1[1]
            rms_norm178: R.Tensor((1, seq_len, 2048), dtype="float16") = lv206_1[0]
            lv444 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight3, model_layers_16_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv209 = R.call_tir(cls.NT_matmul5, (rms_norm178, lv444), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add264: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv209, model_layers_16_self_attn_c_attn_bias3)
            reshape352: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add264, R.shape([1, seq_len, 20, 128]))
            reshape353: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape352, R.shape([seq_len, 20, 128]))
            lv445 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape353), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape354: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv445, R.shape([1, seq_len, 16, 128]))
            reshape355: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape354, R.shape([1, seq_len, 2048]))
            lv446 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight3, model_layers_16_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv210 = R.call_tir(cls.NT_matmul6, (reshape355, lv446), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv208_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv210, lv207_1, model_layers_16_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv209_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv208_1[1]
            rms_norm179: R.Tensor((1, seq_len, 2048), dtype="float16") = lv208_1[0]
            lv447 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight3, model_layers_16_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv211 = R.call_tir(cls.NT_matmul7, (rms_norm179, lv447), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split88: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv211, indices_or_sections=2, axis=-1)
            split_088: R.Tensor((1, seq_len, 11008), dtype="float16") = split88[0]
            split_188: R.Tensor((1, seq_len, 11008), dtype="float16") = split88[1]
            silu88: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_088)
            mul88: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu88, split_188)
            lv448 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight3, model_layers_16_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv212 = R.call_tir(cls.NT_matmul8, (mul88, lv448), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv210_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv212, lv209_1, model_layers_17_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv211_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv210_1[1]
            rms_norm180: R.Tensor((1, seq_len, 2048), dtype="float16") = lv210_1[0]
            lv449 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight3, model_layers_17_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv213 = R.call_tir(cls.NT_matmul5, (rms_norm180, lv449), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add267: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv213, model_layers_17_self_attn_c_attn_bias3)
            reshape356: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add267, R.shape([1, seq_len, 20, 128]))
            reshape357: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape356, R.shape([seq_len, 20, 128]))
            lv450 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape357), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape358: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv450, R.shape([1, seq_len, 16, 128]))
            reshape359: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape358, R.shape([1, seq_len, 2048]))
            lv451 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight3, model_layers_17_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv214 = R.call_tir(cls.NT_matmul6, (reshape359, lv451), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv212_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv214, lv211_1, model_layers_17_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv213_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv212_1[1]
            rms_norm181: R.Tensor((1, seq_len, 2048), dtype="float16") = lv212_1[0]
            lv452 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight3, model_layers_17_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv215 = R.call_tir(cls.NT_matmul7, (rms_norm181, lv452), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split89: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv215, indices_or_sections=2, axis=-1)
            split_089: R.Tensor((1, seq_len, 11008), dtype="float16") = split89[0]
            split_189: R.Tensor((1, seq_len, 11008), dtype="float16") = split89[1]
            silu89: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_089)
            mul89: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu89, split_189)
            lv453 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight3, model_layers_17_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv216 = R.call_tir(cls.NT_matmul8, (mul89, lv453), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv214_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv216, lv213_1, model_layers_18_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv215_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv214_1[1]
            rms_norm182: R.Tensor((1, seq_len, 2048), dtype="float16") = lv214_1[0]
            lv454 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight3, model_layers_18_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv217 = R.call_tir(cls.NT_matmul5, (rms_norm182, lv454), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add270: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv217, model_layers_18_self_attn_c_attn_bias3)
            reshape360: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add270, R.shape([1, seq_len, 20, 128]))
            reshape361: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape360, R.shape([seq_len, 20, 128]))
            lv455 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape361), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape362: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv455, R.shape([1, seq_len, 16, 128]))
            reshape363: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape362, R.shape([1, seq_len, 2048]))
            lv456 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight3, model_layers_18_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv218 = R.call_tir(cls.NT_matmul6, (reshape363, lv456), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv216_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv218, lv215_1, model_layers_18_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv217_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv216_1[1]
            rms_norm183: R.Tensor((1, seq_len, 2048), dtype="float16") = lv216_1[0]
            lv457 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight3, model_layers_18_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv219 = R.call_tir(cls.NT_matmul7, (rms_norm183, lv457), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split90: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv219, indices_or_sections=2, axis=-1)
            split_090: R.Tensor((1, seq_len, 11008), dtype="float16") = split90[0]
            split_190: R.Tensor((1, seq_len, 11008), dtype="float16") = split90[1]
            silu90: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_090)
            mul90: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu90, split_190)
            lv458 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight3, model_layers_18_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv220 = R.call_tir(cls.NT_matmul8, (mul90, lv458), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv218_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv220, lv217_1, model_layers_19_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv219_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv218_1[1]
            rms_norm184: R.Tensor((1, seq_len, 2048), dtype="float16") = lv218_1[0]
            lv459 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight3, model_layers_19_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv221 = R.call_tir(cls.NT_matmul5, (rms_norm184, lv459), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add273: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv221, model_layers_19_self_attn_c_attn_bias3)
            reshape364: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add273, R.shape([1, seq_len, 20, 128]))
            reshape365: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape364, R.shape([seq_len, 20, 128]))
            lv460 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape365), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape366: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv460, R.shape([1, seq_len, 16, 128]))
            reshape367: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape366, R.shape([1, seq_len, 2048]))
            lv461 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight3, model_layers_19_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv222 = R.call_tir(cls.NT_matmul6, (reshape367, lv461), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv220_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv222, lv219_1, model_layers_19_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv221_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv220_1[1]
            rms_norm185: R.Tensor((1, seq_len, 2048), dtype="float16") = lv220_1[0]
            lv462 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight3, model_layers_19_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv223 = R.call_tir(cls.NT_matmul7, (rms_norm185, lv462), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split91: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv223, indices_or_sections=2, axis=-1)
            split_091: R.Tensor((1, seq_len, 11008), dtype="float16") = split91[0]
            split_191: R.Tensor((1, seq_len, 11008), dtype="float16") = split91[1]
            silu91: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_091)
            mul91: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu91, split_191)
            lv463 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight3, model_layers_19_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv224 = R.call_tir(cls.NT_matmul8, (mul91, lv463), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv222_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv224, lv221_1, model_layers_20_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv223_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv222_1[1]
            rms_norm186: R.Tensor((1, seq_len, 2048), dtype="float16") = lv222_1[0]
            lv464 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight3, model_layers_20_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv225 = R.call_tir(cls.NT_matmul5, (rms_norm186, lv464), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add276: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv225, model_layers_20_self_attn_c_attn_bias3)
            reshape368: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add276, R.shape([1, seq_len, 20, 128]))
            reshape369: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape368, R.shape([seq_len, 20, 128]))
            lv465 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape369), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape370: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv465, R.shape([1, seq_len, 16, 128]))
            reshape371: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape370, R.shape([1, seq_len, 2048]))
            lv466 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight3, model_layers_20_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv226 = R.call_tir(cls.NT_matmul6, (reshape371, lv466), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv224_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv226, lv223_1, model_layers_20_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv225_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv224_1[1]
            rms_norm187: R.Tensor((1, seq_len, 2048), dtype="float16") = lv224_1[0]
            lv467 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight3, model_layers_20_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv227 = R.call_tir(cls.NT_matmul7, (rms_norm187, lv467), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split92: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv227, indices_or_sections=2, axis=-1)
            split_092: R.Tensor((1, seq_len, 11008), dtype="float16") = split92[0]
            split_192: R.Tensor((1, seq_len, 11008), dtype="float16") = split92[1]
            silu92: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_092)
            mul92: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu92, split_192)
            lv468 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight3, model_layers_20_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv228 = R.call_tir(cls.NT_matmul8, (mul92, lv468), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv226_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv228, lv225_1, model_layers_21_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv227_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv226_1[1]
            rms_norm188: R.Tensor((1, seq_len, 2048), dtype="float16") = lv226_1[0]
            lv469 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight3, model_layers_21_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv229 = R.call_tir(cls.NT_matmul5, (rms_norm188, lv469), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add279: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv229, model_layers_21_self_attn_c_attn_bias3)
            reshape372: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add279, R.shape([1, seq_len, 20, 128]))
            reshape373: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape372, R.shape([seq_len, 20, 128]))
            lv470 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape373), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape374: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv470, R.shape([1, seq_len, 16, 128]))
            reshape375: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape374, R.shape([1, seq_len, 2048]))
            lv471 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight3, model_layers_21_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv230 = R.call_tir(cls.NT_matmul6, (reshape375, lv471), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv228_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv230, lv227_1, model_layers_21_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv229_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv228_1[1]
            rms_norm189: R.Tensor((1, seq_len, 2048), dtype="float16") = lv228_1[0]
            lv472 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight3, model_layers_21_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv231 = R.call_tir(cls.NT_matmul7, (rms_norm189, lv472), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split93: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv231, indices_or_sections=2, axis=-1)
            split_093: R.Tensor((1, seq_len, 11008), dtype="float16") = split93[0]
            split_193: R.Tensor((1, seq_len, 11008), dtype="float16") = split93[1]
            silu93: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_093)
            mul93: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu93, split_193)
            lv473 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight3, model_layers_21_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv232 = R.call_tir(cls.NT_matmul8, (mul93, lv473), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv230_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv232, lv229_1, model_layers_22_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv231_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv230_1[1]
            rms_norm190: R.Tensor((1, seq_len, 2048), dtype="float16") = lv230_1[0]
            lv474 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight3, model_layers_22_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv233 = R.call_tir(cls.NT_matmul5, (rms_norm190, lv474), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add282: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv233, model_layers_22_self_attn_c_attn_bias3)
            reshape376: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add282, R.shape([1, seq_len, 20, 128]))
            reshape377: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape376, R.shape([seq_len, 20, 128]))
            lv475 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape377), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape378: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv475, R.shape([1, seq_len, 16, 128]))
            reshape379: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape378, R.shape([1, seq_len, 2048]))
            lv476 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight3, model_layers_22_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv234 = R.call_tir(cls.NT_matmul6, (reshape379, lv476), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv232_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv234, lv231_1, model_layers_22_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv233_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv232_1[1]
            rms_norm191: R.Tensor((1, seq_len, 2048), dtype="float16") = lv232_1[0]
            lv477 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight3, model_layers_22_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv235 = R.call_tir(cls.NT_matmul7, (rms_norm191, lv477), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split94: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv235, indices_or_sections=2, axis=-1)
            split_094: R.Tensor((1, seq_len, 11008), dtype="float16") = split94[0]
            split_194: R.Tensor((1, seq_len, 11008), dtype="float16") = split94[1]
            silu94: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_094)
            mul94: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu94, split_194)
            lv478 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight3, model_layers_22_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv236 = R.call_tir(cls.NT_matmul8, (mul94, lv478), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv234_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv236, lv233_1, model_layers_23_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv235_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv234_1[1]
            rms_norm192: R.Tensor((1, seq_len, 2048), dtype="float16") = lv234_1[0]
            lv479 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight3, model_layers_23_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv237 = R.call_tir(cls.NT_matmul5, (rms_norm192, lv479), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add285: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv237, model_layers_23_self_attn_c_attn_bias3)
            reshape380: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add285, R.shape([1, seq_len, 20, 128]))
            reshape381: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape380, R.shape([seq_len, 20, 128]))
            lv480 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape381), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape382: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv480, R.shape([1, seq_len, 16, 128]))
            reshape383: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape382, R.shape([1, seq_len, 2048]))
            lv481 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight3, model_layers_23_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv238 = R.call_tir(cls.NT_matmul6, (reshape383, lv481), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv236_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv238, lv235_1, model_layers_23_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv237_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv236_1[1]
            rms_norm193: R.Tensor((1, seq_len, 2048), dtype="float16") = lv236_1[0]
            lv482 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight3, model_layers_23_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv239 = R.call_tir(cls.NT_matmul7, (rms_norm193, lv482), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split95: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv239, indices_or_sections=2, axis=-1)
            split_095: R.Tensor((1, seq_len, 11008), dtype="float16") = split95[0]
            split_195: R.Tensor((1, seq_len, 11008), dtype="float16") = split95[1]
            silu95: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_095)
            mul95: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu95, split_195)
            lv483 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight3, model_layers_23_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv240 = R.call_tir(cls.NT_matmul8, (mul95, lv483), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv238_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv240, lv237_1, model_layers_24_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv239_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv238_1[1]
            rms_norm194: R.Tensor((1, seq_len, 2048), dtype="float16") = lv238_1[0]
            lv484 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight3, model_layers_24_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv241 = R.call_tir(cls.NT_matmul5, (rms_norm194, lv484), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add288: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv241, model_layers_24_self_attn_c_attn_bias3)
            reshape384: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add288, R.shape([1, seq_len, 20, 128]))
            reshape385: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape384, R.shape([seq_len, 20, 128]))
            lv485 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape385), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape386: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv485, R.shape([1, seq_len, 16, 128]))
            reshape387: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape386, R.shape([1, seq_len, 2048]))
            lv486 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight3, model_layers_24_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv242 = R.call_tir(cls.NT_matmul6, (reshape387, lv486), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv240_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv242, lv239_1, model_layers_24_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv241_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv240_1[1]
            rms_norm195: R.Tensor((1, seq_len, 2048), dtype="float16") = lv240_1[0]
            lv487 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight3, model_layers_24_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv243 = R.call_tir(cls.NT_matmul7, (rms_norm195, lv487), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split96: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv243, indices_or_sections=2, axis=-1)
            split_096: R.Tensor((1, seq_len, 11008), dtype="float16") = split96[0]
            split_196: R.Tensor((1, seq_len, 11008), dtype="float16") = split96[1]
            silu96: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_096)
            mul96: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu96, split_196)
            lv488 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight3, model_layers_24_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv244 = R.call_tir(cls.NT_matmul8, (mul96, lv488), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv242_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv244, lv241_1, model_layers_25_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv243_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv242_1[1]
            rms_norm196: R.Tensor((1, seq_len, 2048), dtype="float16") = lv242_1[0]
            lv489 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight3, model_layers_25_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv245 = R.call_tir(cls.NT_matmul5, (rms_norm196, lv489), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add291: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv245, model_layers_25_self_attn_c_attn_bias3)
            reshape388: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add291, R.shape([1, seq_len, 20, 128]))
            reshape389: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape388, R.shape([seq_len, 20, 128]))
            lv490 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape389), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape390: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv490, R.shape([1, seq_len, 16, 128]))
            reshape391: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape390, R.shape([1, seq_len, 2048]))
            lv491 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight3, model_layers_25_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv246 = R.call_tir(cls.NT_matmul6, (reshape391, lv491), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv244_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv246, lv243_1, model_layers_25_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv245_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv244_1[1]
            rms_norm197: R.Tensor((1, seq_len, 2048), dtype="float16") = lv244_1[0]
            lv492 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight3, model_layers_25_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv247 = R.call_tir(cls.NT_matmul7, (rms_norm197, lv492), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split97: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv247, indices_or_sections=2, axis=-1)
            split_097: R.Tensor((1, seq_len, 11008), dtype="float16") = split97[0]
            split_197: R.Tensor((1, seq_len, 11008), dtype="float16") = split97[1]
            silu97: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_097)
            mul97: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu97, split_197)
            lv493 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight3, model_layers_25_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv248 = R.call_tir(cls.NT_matmul8, (mul97, lv493), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv246_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv248, lv245_1, model_layers_26_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv247_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv246_1[1]
            rms_norm198: R.Tensor((1, seq_len, 2048), dtype="float16") = lv246_1[0]
            lv494 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight3, model_layers_26_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv249 = R.call_tir(cls.NT_matmul5, (rms_norm198, lv494), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add294: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv249, model_layers_26_self_attn_c_attn_bias3)
            reshape392: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add294, R.shape([1, seq_len, 20, 128]))
            reshape393: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape392, R.shape([seq_len, 20, 128]))
            lv495 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape393), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape394: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv495, R.shape([1, seq_len, 16, 128]))
            reshape395: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape394, R.shape([1, seq_len, 2048]))
            lv496 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight3, model_layers_26_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv250 = R.call_tir(cls.NT_matmul6, (reshape395, lv496), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv248_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv250, lv247_1, model_layers_26_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv249_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv248_1[1]
            rms_norm199: R.Tensor((1, seq_len, 2048), dtype="float16") = lv248_1[0]
            lv497 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight3, model_layers_26_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv251 = R.call_tir(cls.NT_matmul7, (rms_norm199, lv497), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split98: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv251, indices_or_sections=2, axis=-1)
            split_098: R.Tensor((1, seq_len, 11008), dtype="float16") = split98[0]
            split_198: R.Tensor((1, seq_len, 11008), dtype="float16") = split98[1]
            silu98: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_098)
            mul98: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu98, split_198)
            lv498 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight3, model_layers_26_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv252 = R.call_tir(cls.NT_matmul8, (mul98, lv498), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv250_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv252, lv249_1, model_layers_27_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv251_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv250_1[1]
            rms_norm200: R.Tensor((1, seq_len, 2048), dtype="float16") = lv250_1[0]
            lv499 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight3, model_layers_27_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv253 = R.call_tir(cls.NT_matmul5, (rms_norm200, lv499), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add297: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv253, model_layers_27_self_attn_c_attn_bias3)
            reshape396: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add297, R.shape([1, seq_len, 20, 128]))
            reshape397: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape396, R.shape([seq_len, 20, 128]))
            lv500 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape397), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape398: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv500, R.shape([1, seq_len, 16, 128]))
            reshape399: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape398, R.shape([1, seq_len, 2048]))
            lv501 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight3, model_layers_27_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv254 = R.call_tir(cls.NT_matmul6, (reshape399, lv501), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv252_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv254, lv251_1, model_layers_27_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv253_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv252_1[1]
            rms_norm201: R.Tensor((1, seq_len, 2048), dtype="float16") = lv252_1[0]
            lv502 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight3, model_layers_27_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv255 = R.call_tir(cls.NT_matmul7, (rms_norm201, lv502), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split99: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv255, indices_or_sections=2, axis=-1)
            split_099: R.Tensor((1, seq_len, 11008), dtype="float16") = split99[0]
            split_199: R.Tensor((1, seq_len, 11008), dtype="float16") = split99[1]
            silu99: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_099)
            mul99: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu99, split_199)
            lv503 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight3, model_layers_27_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv256 = R.call_tir(cls.NT_matmul8, (mul99, lv503), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv254_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv256, lv253_1, model_layers_28_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv255_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv254_1[1]
            rms_norm202: R.Tensor((1, seq_len, 2048), dtype="float16") = lv254_1[0]
            lv504 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight3, model_layers_28_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv257 = R.call_tir(cls.NT_matmul5, (rms_norm202, lv504), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add300: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv257, model_layers_28_self_attn_c_attn_bias3)
            reshape400: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add300, R.shape([1, seq_len, 20, 128]))
            reshape401: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape400, R.shape([seq_len, 20, 128]))
            lv505 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape401), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape402: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv505, R.shape([1, seq_len, 16, 128]))
            reshape403: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape402, R.shape([1, seq_len, 2048]))
            lv506 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight3, model_layers_28_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv258 = R.call_tir(cls.NT_matmul6, (reshape403, lv506), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv256_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv258, lv255_1, model_layers_28_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv257_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv256_1[1]
            rms_norm203: R.Tensor((1, seq_len, 2048), dtype="float16") = lv256_1[0]
            lv507 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight3, model_layers_28_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv259 = R.call_tir(cls.NT_matmul7, (rms_norm203, lv507), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split100: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv259, indices_or_sections=2, axis=-1)
            split_0100: R.Tensor((1, seq_len, 11008), dtype="float16") = split100[0]
            split_1100: R.Tensor((1, seq_len, 11008), dtype="float16") = split100[1]
            silu100: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0100)
            mul100: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu100, split_1100)
            lv508 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight3, model_layers_28_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv260 = R.call_tir(cls.NT_matmul8, (mul100, lv508), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv258_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv260, lv257_1, model_layers_29_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv259_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv258_1[1]
            rms_norm204: R.Tensor((1, seq_len, 2048), dtype="float16") = lv258_1[0]
            lv509 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight3, model_layers_29_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv261 = R.call_tir(cls.NT_matmul5, (rms_norm204, lv509), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add303: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv261, model_layers_29_self_attn_c_attn_bias3)
            reshape404: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add303, R.shape([1, seq_len, 20, 128]))
            reshape405: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape404, R.shape([seq_len, 20, 128]))
            lv510 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape405), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape406: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv510, R.shape([1, seq_len, 16, 128]))
            reshape407: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape406, R.shape([1, seq_len, 2048]))
            lv511 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight3, model_layers_29_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv262 = R.call_tir(cls.NT_matmul6, (reshape407, lv511), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv260_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv262, lv259_1, model_layers_29_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv261_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv260_1[1]
            rms_norm205: R.Tensor((1, seq_len, 2048), dtype="float16") = lv260_1[0]
            lv512 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight3, model_layers_29_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv263 = R.call_tir(cls.NT_matmul7, (rms_norm205, lv512), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split101: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv263, indices_or_sections=2, axis=-1)
            split_0101: R.Tensor((1, seq_len, 11008), dtype="float16") = split101[0]
            split_1101: R.Tensor((1, seq_len, 11008), dtype="float16") = split101[1]
            silu101: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0101)
            mul101: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu101, split_1101)
            lv513 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight3, model_layers_29_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv264 = R.call_tir(cls.NT_matmul8, (mul101, lv513), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv262_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv264, lv261_1, model_layers_30_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv263_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv262_1[1]
            rms_norm206: R.Tensor((1, seq_len, 2048), dtype="float16") = lv262_1[0]
            lv514 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight3, model_layers_30_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv265 = R.call_tir(cls.NT_matmul5, (rms_norm206, lv514), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add306: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv265, model_layers_30_self_attn_c_attn_bias3)
            reshape408: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add306, R.shape([1, seq_len, 20, 128]))
            reshape409: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape408, R.shape([seq_len, 20, 128]))
            lv515 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape409), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape410: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv515, R.shape([1, seq_len, 16, 128]))
            reshape411: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape410, R.shape([1, seq_len, 2048]))
            lv516 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight3, model_layers_30_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv266 = R.call_tir(cls.NT_matmul6, (reshape411, lv516), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv264_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv266, lv263_1, model_layers_30_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv265_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv264_1[1]
            rms_norm207: R.Tensor((1, seq_len, 2048), dtype="float16") = lv264_1[0]
            lv517 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight3, model_layers_30_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv267 = R.call_tir(cls.NT_matmul7, (rms_norm207, lv517), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split102: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv267, indices_or_sections=2, axis=-1)
            split_0102: R.Tensor((1, seq_len, 11008), dtype="float16") = split102[0]
            split_1102: R.Tensor((1, seq_len, 11008), dtype="float16") = split102[1]
            silu102: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0102)
            mul102: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu102, split_1102)
            lv518 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight3, model_layers_30_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv268 = R.call_tir(cls.NT_matmul8, (mul102, lv518), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv266_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv268, lv265_1, model_layers_31_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv267_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv266_1[1]
            rms_norm208: R.Tensor((1, seq_len, 2048), dtype="float16") = lv266_1[0]
            lv519 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight3, model_layers_31_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv269 = R.call_tir(cls.NT_matmul5, (rms_norm208, lv519), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add309: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv269, model_layers_31_self_attn_c_attn_bias3)
            reshape412: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add309, R.shape([1, seq_len, 20, 128]))
            reshape413: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape412, R.shape([seq_len, 20, 128]))
            lv520 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape413), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape414: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv520, R.shape([1, seq_len, 16, 128]))
            reshape415: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape414, R.shape([1, seq_len, 2048]))
            lv521 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight3, model_layers_31_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv270 = R.call_tir(cls.NT_matmul6, (reshape415, lv521), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv268_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv270, lv267_1, model_layers_31_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv269_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv268_1[1]
            rms_norm209: R.Tensor((1, seq_len, 2048), dtype="float16") = lv268_1[0]
            lv522 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight3, model_layers_31_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv271 = R.call_tir(cls.NT_matmul7, (rms_norm209, lv522), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split103: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv271, indices_or_sections=2, axis=-1)
            split_0103: R.Tensor((1, seq_len, 11008), dtype="float16") = split103[0]
            split_1103: R.Tensor((1, seq_len, 11008), dtype="float16") = split103[1]
            silu103: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0103)
            mul103: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu103, split_1103)
            lv523 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight3, model_layers_31_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv272 = R.call_tir(cls.NT_matmul8, (mul103, lv523), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv270_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv272, lv269_1, model_layers_32_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv271_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv270_1[1]
            rms_norm210: R.Tensor((1, seq_len, 2048), dtype="float16") = lv270_1[0]
            lv524 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight3, model_layers_32_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv273 = R.call_tir(cls.NT_matmul5, (rms_norm210, lv524), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add312: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv273, model_layers_32_self_attn_c_attn_bias3)
            reshape416: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add312, R.shape([1, seq_len, 20, 128]))
            reshape417: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape416, R.shape([seq_len, 20, 128]))
            lv525 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape417), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape418: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv525, R.shape([1, seq_len, 16, 128]))
            reshape419: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape418, R.shape([1, seq_len, 2048]))
            lv526 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight3, model_layers_32_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv274 = R.call_tir(cls.NT_matmul6, (reshape419, lv526), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv272_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv274, lv271_1, model_layers_32_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv273_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv272_1[1]
            rms_norm211: R.Tensor((1, seq_len, 2048), dtype="float16") = lv272_1[0]
            lv527 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight3, model_layers_32_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv275 = R.call_tir(cls.NT_matmul7, (rms_norm211, lv527), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split104: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv275, indices_or_sections=2, axis=-1)
            split_0104: R.Tensor((1, seq_len, 11008), dtype="float16") = split104[0]
            split_1104: R.Tensor((1, seq_len, 11008), dtype="float16") = split104[1]
            silu104: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0104)
            mul104: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu104, split_1104)
            lv528 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight3, model_layers_32_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv276 = R.call_tir(cls.NT_matmul8, (mul104, lv528), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv274_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv276, lv273_1, model_layers_33_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv275_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv274_1[1]
            rms_norm212: R.Tensor((1, seq_len, 2048), dtype="float16") = lv274_1[0]
            lv529 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight3, model_layers_33_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv277 = R.call_tir(cls.NT_matmul5, (rms_norm212, lv529), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add315: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv277, model_layers_33_self_attn_c_attn_bias3)
            reshape420: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add315, R.shape([1, seq_len, 20, 128]))
            reshape421: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape420, R.shape([seq_len, 20, 128]))
            lv530 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape421), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape422: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv530, R.shape([1, seq_len, 16, 128]))
            reshape423: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape422, R.shape([1, seq_len, 2048]))
            lv531 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight3, model_layers_33_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv278 = R.call_tir(cls.NT_matmul6, (reshape423, lv531), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv276_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv278, lv275_1, model_layers_33_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv277_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv276_1[1]
            rms_norm213: R.Tensor((1, seq_len, 2048), dtype="float16") = lv276_1[0]
            lv532 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight3, model_layers_33_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv279 = R.call_tir(cls.NT_matmul7, (rms_norm213, lv532), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split105: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv279, indices_or_sections=2, axis=-1)
            split_0105: R.Tensor((1, seq_len, 11008), dtype="float16") = split105[0]
            split_1105: R.Tensor((1, seq_len, 11008), dtype="float16") = split105[1]
            silu105: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0105)
            mul105: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu105, split_1105)
            lv533 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight3, model_layers_33_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv280 = R.call_tir(cls.NT_matmul8, (mul105, lv533), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv278_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv280, lv277_1, model_layers_34_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv279_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv278_1[1]
            rms_norm214: R.Tensor((1, seq_len, 2048), dtype="float16") = lv278_1[0]
            lv534 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight3, model_layers_34_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv281 = R.call_tir(cls.NT_matmul5, (rms_norm214, lv534), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add318: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv281, model_layers_34_self_attn_c_attn_bias3)
            reshape424: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add318, R.shape([1, seq_len, 20, 128]))
            reshape425: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape424, R.shape([seq_len, 20, 128]))
            lv535 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape425), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape426: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv535, R.shape([1, seq_len, 16, 128]))
            reshape427: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape426, R.shape([1, seq_len, 2048]))
            lv536 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight3, model_layers_34_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv282 = R.call_tir(cls.NT_matmul6, (reshape427, lv536), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv280_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv282, lv279_1, model_layers_34_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv281_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv280_1[1]
            rms_norm215: R.Tensor((1, seq_len, 2048), dtype="float16") = lv280_1[0]
            lv537 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight3, model_layers_34_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv283 = R.call_tir(cls.NT_matmul7, (rms_norm215, lv537), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split106: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv283, indices_or_sections=2, axis=-1)
            split_0106: R.Tensor((1, seq_len, 11008), dtype="float16") = split106[0]
            split_1106: R.Tensor((1, seq_len, 11008), dtype="float16") = split106[1]
            silu106: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0106)
            mul106: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu106, split_1106)
            lv538 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight3, model_layers_34_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv284 = R.call_tir(cls.NT_matmul8, (mul106, lv538), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv282_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv284, lv281_1, model_layers_35_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv283_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv282_1[1]
            rms_norm216: R.Tensor((1, seq_len, 2048), dtype="float16") = lv282_1[0]
            lv539 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight3, model_layers_35_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv285 = R.call_tir(cls.NT_matmul5, (rms_norm216, lv539), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add321: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv285, model_layers_35_self_attn_c_attn_bias3)
            reshape428: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add321, R.shape([1, seq_len, 20, 128]))
            reshape429: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape428, R.shape([seq_len, 20, 128]))
            lv540 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape429), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape430: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv540, R.shape([1, seq_len, 16, 128]))
            reshape431: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape430, R.shape([1, seq_len, 2048]))
            lv541 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight3, model_layers_35_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv286 = R.call_tir(cls.NT_matmul6, (reshape431, lv541), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv284_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv286, lv283_1, model_layers_35_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv285_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv284_1[1]
            rms_norm217: R.Tensor((1, seq_len, 2048), dtype="float16") = lv284_1[0]
            lv542 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight3, model_layers_35_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv287 = R.call_tir(cls.NT_matmul7, (rms_norm217, lv542), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split107: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv287, indices_or_sections=2, axis=-1)
            split_0107: R.Tensor((1, seq_len, 11008), dtype="float16") = split107[0]
            split_1107: R.Tensor((1, seq_len, 11008), dtype="float16") = split107[1]
            silu107: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0107)
            mul107: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu107, split_1107)
            lv543 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight3, model_layers_35_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv288 = R.call_tir(cls.NT_matmul8, (mul107, lv543), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv286_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv288, lv285_1, model_norm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            rms_norm218: R.Tensor((1, seq_len, 2048), dtype="float16") = lv286_1[0]
            take1: R.Tensor((1, batch_size, 2048), dtype="float16") = R.take(rms_norm218, logit_positions, axis=1)
            lv544 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight3, model_embed_tokens_q_scale3), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            lv289 = R.call_tir(cls.NT_matmul9, (take1, lv544), out_sinfo=R.Tensor((1, batch_size, 151936), dtype="float32"))
            gv3: R.Tuple(R.Tensor((1, batch_size, 151936), dtype="float32"), R.Object) = lv289, paged_kv_cache
            R.output(gv3)
        return gv3

    @R.function
    def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", 151936), dtype="float32"), R.Object):
        seq_len = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight5: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale5: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm292: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv726 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight5, model_layers_0_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv290 = R.call_tir(cls.NT_matmul5, (rms_norm292, lv726), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add432: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv290, model_layers_0_self_attn_c_attn_bias5)
            reshape576: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add432, R.shape([1, seq_len, 20, 128]))
            reshape577: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape576, R.shape([seq_len, 20, 128]))
            lv727 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape577), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape578: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv727, R.shape([1, seq_len, 16, 128]))
            reshape579: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape578, R.shape([1, seq_len, 2048]))
            lv728 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight5, model_layers_0_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv291 = R.call_tir(cls.NT_matmul6, (reshape579, lv728), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv288 = R.call_tir(cls.fuse_add_norm_prefill, (lv291, input_embeds, model_layers_0_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv289: R.Tensor((1, seq_len, 2048), dtype="float16") = lv288[1]
            rms_norm293: R.Tensor((1, seq_len, 2048), dtype="float16") = lv288[0]
            lv729 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight5, model_layers_0_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv292 = R.call_tir(cls.NT_matmul7, (rms_norm293, lv729), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split144: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv292, indices_or_sections=2, axis=-1)
            split_0144: R.Tensor((1, seq_len, 11008), dtype="float16") = split144[0]
            split_1144: R.Tensor((1, seq_len, 11008), dtype="float16") = split144[1]
            silu144: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0144)
            mul144: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu144, split_1144)
            lv730 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight5, model_layers_0_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv293 = R.call_tir(cls.NT_matmul8, (mul144, lv730), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv290_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv293, lv289, model_layers_1_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv291_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv290_1[1]
            rms_norm294: R.Tensor((1, seq_len, 2048), dtype="float16") = lv290_1[0]
            lv731 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight5, model_layers_1_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv294 = R.call_tir(cls.NT_matmul5, (rms_norm294, lv731), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add435: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv294, model_layers_1_self_attn_c_attn_bias5)
            reshape580: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add435, R.shape([1, seq_len, 20, 128]))
            reshape581: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape580, R.shape([seq_len, 20, 128]))
            lv732 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape581), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape582: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv732, R.shape([1, seq_len, 16, 128]))
            reshape583: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape582, R.shape([1, seq_len, 2048]))
            lv733 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight5, model_layers_1_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv295 = R.call_tir(cls.NT_matmul6, (reshape583, lv733), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv292_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv295, lv291_1, model_layers_1_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv293_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv292_1[1]
            rms_norm295: R.Tensor((1, seq_len, 2048), dtype="float16") = lv292_1[0]
            lv734 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight5, model_layers_1_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv296 = R.call_tir(cls.NT_matmul7, (rms_norm295, lv734), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split145: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv296, indices_or_sections=2, axis=-1)
            split_0145: R.Tensor((1, seq_len, 11008), dtype="float16") = split145[0]
            split_1145: R.Tensor((1, seq_len, 11008), dtype="float16") = split145[1]
            silu145: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0145)
            mul145: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu145, split_1145)
            lv735 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight5, model_layers_1_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv297 = R.call_tir(cls.NT_matmul8, (mul145, lv735), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv294_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv297, lv293_1, model_layers_2_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv295_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv294_1[1]
            rms_norm296: R.Tensor((1, seq_len, 2048), dtype="float16") = lv294_1[0]
            lv736 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight5, model_layers_2_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv298 = R.call_tir(cls.NT_matmul5, (rms_norm296, lv736), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add438: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv298, model_layers_2_self_attn_c_attn_bias5)
            reshape584: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add438, R.shape([1, seq_len, 20, 128]))
            reshape585: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape584, R.shape([seq_len, 20, 128]))
            lv737 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape585), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape586: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv737, R.shape([1, seq_len, 16, 128]))
            reshape587: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape586, R.shape([1, seq_len, 2048]))
            lv738 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight5, model_layers_2_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv299 = R.call_tir(cls.NT_matmul6, (reshape587, lv738), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv296_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv299, lv295_1, model_layers_2_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv297_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv296_1[1]
            rms_norm297: R.Tensor((1, seq_len, 2048), dtype="float16") = lv296_1[0]
            lv739 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight5, model_layers_2_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv300 = R.call_tir(cls.NT_matmul7, (rms_norm297, lv739), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split146: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv300, indices_or_sections=2, axis=-1)
            split_0146: R.Tensor((1, seq_len, 11008), dtype="float16") = split146[0]
            split_1146: R.Tensor((1, seq_len, 11008), dtype="float16") = split146[1]
            silu146: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0146)
            mul146: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu146, split_1146)
            lv740 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight5, model_layers_2_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv301 = R.call_tir(cls.NT_matmul8, (mul146, lv740), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv298_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv301, lv297_1, model_layers_3_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv299_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv298_1[1]
            rms_norm298: R.Tensor((1, seq_len, 2048), dtype="float16") = lv298_1[0]
            lv741 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight5, model_layers_3_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv302 = R.call_tir(cls.NT_matmul5, (rms_norm298, lv741), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add441: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv302, model_layers_3_self_attn_c_attn_bias5)
            reshape588: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add441, R.shape([1, seq_len, 20, 128]))
            reshape589: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape588, R.shape([seq_len, 20, 128]))
            lv742 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape589), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape590: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv742, R.shape([1, seq_len, 16, 128]))
            reshape591: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape590, R.shape([1, seq_len, 2048]))
            lv743 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight5, model_layers_3_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv303 = R.call_tir(cls.NT_matmul6, (reshape591, lv743), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv300_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv303, lv299_1, model_layers_3_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv301_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv300_1[1]
            rms_norm299: R.Tensor((1, seq_len, 2048), dtype="float16") = lv300_1[0]
            lv744 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight5, model_layers_3_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv304 = R.call_tir(cls.NT_matmul7, (rms_norm299, lv744), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split147: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv304, indices_or_sections=2, axis=-1)
            split_0147: R.Tensor((1, seq_len, 11008), dtype="float16") = split147[0]
            split_1147: R.Tensor((1, seq_len, 11008), dtype="float16") = split147[1]
            silu147: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0147)
            mul147: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu147, split_1147)
            lv745 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight5, model_layers_3_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv305 = R.call_tir(cls.NT_matmul8, (mul147, lv745), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv302_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv305, lv301_1, model_layers_4_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv303_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv302_1[1]
            rms_norm300: R.Tensor((1, seq_len, 2048), dtype="float16") = lv302_1[0]
            lv746 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight5, model_layers_4_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv306 = R.call_tir(cls.NT_matmul5, (rms_norm300, lv746), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add444: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv306, model_layers_4_self_attn_c_attn_bias5)
            reshape592: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add444, R.shape([1, seq_len, 20, 128]))
            reshape593: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape592, R.shape([seq_len, 20, 128]))
            lv747 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape593), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape594: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv747, R.shape([1, seq_len, 16, 128]))
            reshape595: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape594, R.shape([1, seq_len, 2048]))
            lv748 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight5, model_layers_4_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv307 = R.call_tir(cls.NT_matmul6, (reshape595, lv748), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv304_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv307, lv303_1, model_layers_4_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv305_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv304_1[1]
            rms_norm301: R.Tensor((1, seq_len, 2048), dtype="float16") = lv304_1[0]
            lv749 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight5, model_layers_4_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv308 = R.call_tir(cls.NT_matmul7, (rms_norm301, lv749), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split148: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv308, indices_or_sections=2, axis=-1)
            split_0148: R.Tensor((1, seq_len, 11008), dtype="float16") = split148[0]
            split_1148: R.Tensor((1, seq_len, 11008), dtype="float16") = split148[1]
            silu148: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0148)
            mul148: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu148, split_1148)
            lv750 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight5, model_layers_4_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv309 = R.call_tir(cls.NT_matmul8, (mul148, lv750), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv306_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv309, lv305_1, model_layers_5_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv307_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv306_1[1]
            rms_norm302: R.Tensor((1, seq_len, 2048), dtype="float16") = lv306_1[0]
            lv751 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight5, model_layers_5_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv310 = R.call_tir(cls.NT_matmul5, (rms_norm302, lv751), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add447: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv310, model_layers_5_self_attn_c_attn_bias5)
            reshape596: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add447, R.shape([1, seq_len, 20, 128]))
            reshape597: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape596, R.shape([seq_len, 20, 128]))
            lv752 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape597), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape598: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv752, R.shape([1, seq_len, 16, 128]))
            reshape599: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape598, R.shape([1, seq_len, 2048]))
            lv753 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight5, model_layers_5_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv311 = R.call_tir(cls.NT_matmul6, (reshape599, lv753), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv308_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv311, lv307_1, model_layers_5_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv309_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv308_1[1]
            rms_norm303: R.Tensor((1, seq_len, 2048), dtype="float16") = lv308_1[0]
            lv754 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight5, model_layers_5_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv312 = R.call_tir(cls.NT_matmul7, (rms_norm303, lv754), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split149: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv312, indices_or_sections=2, axis=-1)
            split_0149: R.Tensor((1, seq_len, 11008), dtype="float16") = split149[0]
            split_1149: R.Tensor((1, seq_len, 11008), dtype="float16") = split149[1]
            silu149: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0149)
            mul149: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu149, split_1149)
            lv755 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight5, model_layers_5_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv313 = R.call_tir(cls.NT_matmul8, (mul149, lv755), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv310_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv313, lv309_1, model_layers_6_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv311_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv310_1[1]
            rms_norm304: R.Tensor((1, seq_len, 2048), dtype="float16") = lv310_1[0]
            lv756 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight5, model_layers_6_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv314 = R.call_tir(cls.NT_matmul5, (rms_norm304, lv756), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add450: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv314, model_layers_6_self_attn_c_attn_bias5)
            reshape600: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add450, R.shape([1, seq_len, 20, 128]))
            reshape601: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape600, R.shape([seq_len, 20, 128]))
            lv757 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape601), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape602: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv757, R.shape([1, seq_len, 16, 128]))
            reshape603: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape602, R.shape([1, seq_len, 2048]))
            lv758 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight5, model_layers_6_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv315 = R.call_tir(cls.NT_matmul6, (reshape603, lv758), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv312_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv315, lv311_1, model_layers_6_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv313_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv312_1[1]
            rms_norm305: R.Tensor((1, seq_len, 2048), dtype="float16") = lv312_1[0]
            lv759 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight5, model_layers_6_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv316 = R.call_tir(cls.NT_matmul7, (rms_norm305, lv759), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split150: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv316, indices_or_sections=2, axis=-1)
            split_0150: R.Tensor((1, seq_len, 11008), dtype="float16") = split150[0]
            split_1150: R.Tensor((1, seq_len, 11008), dtype="float16") = split150[1]
            silu150: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0150)
            mul150: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu150, split_1150)
            lv760 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight5, model_layers_6_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv317 = R.call_tir(cls.NT_matmul8, (mul150, lv760), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv314_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv317, lv313_1, model_layers_7_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv315_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv314_1[1]
            rms_norm306: R.Tensor((1, seq_len, 2048), dtype="float16") = lv314_1[0]
            lv761 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight5, model_layers_7_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv318 = R.call_tir(cls.NT_matmul5, (rms_norm306, lv761), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add453: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv318, model_layers_7_self_attn_c_attn_bias5)
            reshape604: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add453, R.shape([1, seq_len, 20, 128]))
            reshape605: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape604, R.shape([seq_len, 20, 128]))
            lv762 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape605), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape606: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv762, R.shape([1, seq_len, 16, 128]))
            reshape607: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape606, R.shape([1, seq_len, 2048]))
            lv763 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight5, model_layers_7_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv319 = R.call_tir(cls.NT_matmul6, (reshape607, lv763), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv316_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv319, lv315_1, model_layers_7_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv317_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv316_1[1]
            rms_norm307: R.Tensor((1, seq_len, 2048), dtype="float16") = lv316_1[0]
            lv764 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight5, model_layers_7_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv320 = R.call_tir(cls.NT_matmul7, (rms_norm307, lv764), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split151: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv320, indices_or_sections=2, axis=-1)
            split_0151: R.Tensor((1, seq_len, 11008), dtype="float16") = split151[0]
            split_1151: R.Tensor((1, seq_len, 11008), dtype="float16") = split151[1]
            silu151: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0151)
            mul151: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu151, split_1151)
            lv765 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight5, model_layers_7_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv321 = R.call_tir(cls.NT_matmul8, (mul151, lv765), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv318_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv321, lv317_1, model_layers_8_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv319_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv318_1[1]
            rms_norm308: R.Tensor((1, seq_len, 2048), dtype="float16") = lv318_1[0]
            lv766 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight5, model_layers_8_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv322 = R.call_tir(cls.NT_matmul5, (rms_norm308, lv766), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add456: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv322, model_layers_8_self_attn_c_attn_bias5)
            reshape608: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add456, R.shape([1, seq_len, 20, 128]))
            reshape609: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape608, R.shape([seq_len, 20, 128]))
            lv767 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape609), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape610: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv767, R.shape([1, seq_len, 16, 128]))
            reshape611: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape610, R.shape([1, seq_len, 2048]))
            lv768 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight5, model_layers_8_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv323 = R.call_tir(cls.NT_matmul6, (reshape611, lv768), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv320_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv323, lv319_1, model_layers_8_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv321_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv320_1[1]
            rms_norm309: R.Tensor((1, seq_len, 2048), dtype="float16") = lv320_1[0]
            lv769 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight5, model_layers_8_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv324 = R.call_tir(cls.NT_matmul7, (rms_norm309, lv769), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split152: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv324, indices_or_sections=2, axis=-1)
            split_0152: R.Tensor((1, seq_len, 11008), dtype="float16") = split152[0]
            split_1152: R.Tensor((1, seq_len, 11008), dtype="float16") = split152[1]
            silu152: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0152)
            mul152: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu152, split_1152)
            lv770 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight5, model_layers_8_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv325 = R.call_tir(cls.NT_matmul8, (mul152, lv770), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv322_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv325, lv321_1, model_layers_9_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv323_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv322_1[1]
            rms_norm310: R.Tensor((1, seq_len, 2048), dtype="float16") = lv322_1[0]
            lv771 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight5, model_layers_9_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv326 = R.call_tir(cls.NT_matmul5, (rms_norm310, lv771), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add459: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv326, model_layers_9_self_attn_c_attn_bias5)
            reshape612: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add459, R.shape([1, seq_len, 20, 128]))
            reshape613: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape612, R.shape([seq_len, 20, 128]))
            lv772 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape613), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape614: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv772, R.shape([1, seq_len, 16, 128]))
            reshape615: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape614, R.shape([1, seq_len, 2048]))
            lv773 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight5, model_layers_9_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv327 = R.call_tir(cls.NT_matmul6, (reshape615, lv773), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv324_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv327, lv323_1, model_layers_9_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv325_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv324_1[1]
            rms_norm311: R.Tensor((1, seq_len, 2048), dtype="float16") = lv324_1[0]
            lv774 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight5, model_layers_9_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv328 = R.call_tir(cls.NT_matmul7, (rms_norm311, lv774), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split153: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv328, indices_or_sections=2, axis=-1)
            split_0153: R.Tensor((1, seq_len, 11008), dtype="float16") = split153[0]
            split_1153: R.Tensor((1, seq_len, 11008), dtype="float16") = split153[1]
            silu153: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0153)
            mul153: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu153, split_1153)
            lv775 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight5, model_layers_9_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv329 = R.call_tir(cls.NT_matmul8, (mul153, lv775), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv326_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv329, lv325_1, model_layers_10_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv327_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv326_1[1]
            rms_norm312: R.Tensor((1, seq_len, 2048), dtype="float16") = lv326_1[0]
            lv776 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight5, model_layers_10_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv330 = R.call_tir(cls.NT_matmul5, (rms_norm312, lv776), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add462: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv330, model_layers_10_self_attn_c_attn_bias5)
            reshape616: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add462, R.shape([1, seq_len, 20, 128]))
            reshape617: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape616, R.shape([seq_len, 20, 128]))
            lv777 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape617), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape618: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv777, R.shape([1, seq_len, 16, 128]))
            reshape619: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape618, R.shape([1, seq_len, 2048]))
            lv778 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight5, model_layers_10_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv331 = R.call_tir(cls.NT_matmul6, (reshape619, lv778), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv328_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv331, lv327_1, model_layers_10_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv329_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv328_1[1]
            rms_norm313: R.Tensor((1, seq_len, 2048), dtype="float16") = lv328_1[0]
            lv779 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight5, model_layers_10_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv332 = R.call_tir(cls.NT_matmul7, (rms_norm313, lv779), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split154: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv332, indices_or_sections=2, axis=-1)
            split_0154: R.Tensor((1, seq_len, 11008), dtype="float16") = split154[0]
            split_1154: R.Tensor((1, seq_len, 11008), dtype="float16") = split154[1]
            silu154: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0154)
            mul154: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu154, split_1154)
            lv780 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight5, model_layers_10_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv333 = R.call_tir(cls.NT_matmul8, (mul154, lv780), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv330_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv333, lv329_1, model_layers_11_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv331_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv330_1[1]
            rms_norm314: R.Tensor((1, seq_len, 2048), dtype="float16") = lv330_1[0]
            lv781 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight5, model_layers_11_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv334 = R.call_tir(cls.NT_matmul5, (rms_norm314, lv781), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add465: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv334, model_layers_11_self_attn_c_attn_bias5)
            reshape620: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add465, R.shape([1, seq_len, 20, 128]))
            reshape621: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape620, R.shape([seq_len, 20, 128]))
            lv782 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape621), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape622: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv782, R.shape([1, seq_len, 16, 128]))
            reshape623: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape622, R.shape([1, seq_len, 2048]))
            lv783 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight5, model_layers_11_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv335 = R.call_tir(cls.NT_matmul6, (reshape623, lv783), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv332_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv335, lv331_1, model_layers_11_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv333_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv332_1[1]
            rms_norm315: R.Tensor((1, seq_len, 2048), dtype="float16") = lv332_1[0]
            lv784 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight5, model_layers_11_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv336 = R.call_tir(cls.NT_matmul7, (rms_norm315, lv784), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split155: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv336, indices_or_sections=2, axis=-1)
            split_0155: R.Tensor((1, seq_len, 11008), dtype="float16") = split155[0]
            split_1155: R.Tensor((1, seq_len, 11008), dtype="float16") = split155[1]
            silu155: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0155)
            mul155: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu155, split_1155)
            lv785 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight5, model_layers_11_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv337 = R.call_tir(cls.NT_matmul8, (mul155, lv785), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv334_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv337, lv333_1, model_layers_12_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv335_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv334_1[1]
            rms_norm316: R.Tensor((1, seq_len, 2048), dtype="float16") = lv334_1[0]
            lv786 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight5, model_layers_12_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv338 = R.call_tir(cls.NT_matmul5, (rms_norm316, lv786), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add468: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv338, model_layers_12_self_attn_c_attn_bias5)
            reshape624: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add468, R.shape([1, seq_len, 20, 128]))
            reshape625: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape624, R.shape([seq_len, 20, 128]))
            lv787 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape625), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape626: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv787, R.shape([1, seq_len, 16, 128]))
            reshape627: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape626, R.shape([1, seq_len, 2048]))
            lv788 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight5, model_layers_12_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv339 = R.call_tir(cls.NT_matmul6, (reshape627, lv788), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv336_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv339, lv335_1, model_layers_12_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv337_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv336_1[1]
            rms_norm317: R.Tensor((1, seq_len, 2048), dtype="float16") = lv336_1[0]
            lv789 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight5, model_layers_12_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv340 = R.call_tir(cls.NT_matmul7, (rms_norm317, lv789), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split156: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv340, indices_or_sections=2, axis=-1)
            split_0156: R.Tensor((1, seq_len, 11008), dtype="float16") = split156[0]
            split_1156: R.Tensor((1, seq_len, 11008), dtype="float16") = split156[1]
            silu156: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0156)
            mul156: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu156, split_1156)
            lv790 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight5, model_layers_12_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv341 = R.call_tir(cls.NT_matmul8, (mul156, lv790), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv338_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv341, lv337_1, model_layers_13_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv339_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv338_1[1]
            rms_norm318: R.Tensor((1, seq_len, 2048), dtype="float16") = lv338_1[0]
            lv791 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight5, model_layers_13_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv342 = R.call_tir(cls.NT_matmul5, (rms_norm318, lv791), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add471: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv342, model_layers_13_self_attn_c_attn_bias5)
            reshape628: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add471, R.shape([1, seq_len, 20, 128]))
            reshape629: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape628, R.shape([seq_len, 20, 128]))
            lv792 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape629), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape630: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv792, R.shape([1, seq_len, 16, 128]))
            reshape631: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape630, R.shape([1, seq_len, 2048]))
            lv793 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight5, model_layers_13_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv343 = R.call_tir(cls.NT_matmul6, (reshape631, lv793), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv340_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv343, lv339_1, model_layers_13_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv341_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv340_1[1]
            rms_norm319: R.Tensor((1, seq_len, 2048), dtype="float16") = lv340_1[0]
            lv794 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight5, model_layers_13_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv344 = R.call_tir(cls.NT_matmul7, (rms_norm319, lv794), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split157: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv344, indices_or_sections=2, axis=-1)
            split_0157: R.Tensor((1, seq_len, 11008), dtype="float16") = split157[0]
            split_1157: R.Tensor((1, seq_len, 11008), dtype="float16") = split157[1]
            silu157: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0157)
            mul157: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu157, split_1157)
            lv795 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight5, model_layers_13_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv345 = R.call_tir(cls.NT_matmul8, (mul157, lv795), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv342_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv345, lv341_1, model_layers_14_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv343_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv342_1[1]
            rms_norm320: R.Tensor((1, seq_len, 2048), dtype="float16") = lv342_1[0]
            lv796 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight5, model_layers_14_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv346 = R.call_tir(cls.NT_matmul5, (rms_norm320, lv796), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add474: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv346, model_layers_14_self_attn_c_attn_bias5)
            reshape632: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add474, R.shape([1, seq_len, 20, 128]))
            reshape633: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape632, R.shape([seq_len, 20, 128]))
            lv797 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape633), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape634: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv797, R.shape([1, seq_len, 16, 128]))
            reshape635: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape634, R.shape([1, seq_len, 2048]))
            lv798 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight5, model_layers_14_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv347 = R.call_tir(cls.NT_matmul6, (reshape635, lv798), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv344_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv347, lv343_1, model_layers_14_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv345_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv344_1[1]
            rms_norm321: R.Tensor((1, seq_len, 2048), dtype="float16") = lv344_1[0]
            lv799 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight5, model_layers_14_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv348 = R.call_tir(cls.NT_matmul7, (rms_norm321, lv799), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split158: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv348, indices_or_sections=2, axis=-1)
            split_0158: R.Tensor((1, seq_len, 11008), dtype="float16") = split158[0]
            split_1158: R.Tensor((1, seq_len, 11008), dtype="float16") = split158[1]
            silu158: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0158)
            mul158: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu158, split_1158)
            lv800 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight5, model_layers_14_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv349 = R.call_tir(cls.NT_matmul8, (mul158, lv800), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv346_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv349, lv345_1, model_layers_15_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv347_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv346_1[1]
            rms_norm322: R.Tensor((1, seq_len, 2048), dtype="float16") = lv346_1[0]
            lv801 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight5, model_layers_15_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv350 = R.call_tir(cls.NT_matmul5, (rms_norm322, lv801), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add477: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv350, model_layers_15_self_attn_c_attn_bias5)
            reshape636: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add477, R.shape([1, seq_len, 20, 128]))
            reshape637: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape636, R.shape([seq_len, 20, 128]))
            lv802 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape637), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape638: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv802, R.shape([1, seq_len, 16, 128]))
            reshape639: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape638, R.shape([1, seq_len, 2048]))
            lv803 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight5, model_layers_15_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv351 = R.call_tir(cls.NT_matmul6, (reshape639, lv803), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv348_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv351, lv347_1, model_layers_15_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv349_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv348_1[1]
            rms_norm323: R.Tensor((1, seq_len, 2048), dtype="float16") = lv348_1[0]
            lv804 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight5, model_layers_15_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv352 = R.call_tir(cls.NT_matmul7, (rms_norm323, lv804), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split159: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv352, indices_or_sections=2, axis=-1)
            split_0159: R.Tensor((1, seq_len, 11008), dtype="float16") = split159[0]
            split_1159: R.Tensor((1, seq_len, 11008), dtype="float16") = split159[1]
            silu159: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0159)
            mul159: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu159, split_1159)
            lv805 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight5, model_layers_15_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv353 = R.call_tir(cls.NT_matmul8, (mul159, lv805), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv350_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv353, lv349_1, model_layers_16_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv351_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv350_1[1]
            rms_norm324: R.Tensor((1, seq_len, 2048), dtype="float16") = lv350_1[0]
            lv806 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight5, model_layers_16_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv354 = R.call_tir(cls.NT_matmul5, (rms_norm324, lv806), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add480: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv354, model_layers_16_self_attn_c_attn_bias5)
            reshape640: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add480, R.shape([1, seq_len, 20, 128]))
            reshape641: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape640, R.shape([seq_len, 20, 128]))
            lv807 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape641), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape642: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv807, R.shape([1, seq_len, 16, 128]))
            reshape643: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape642, R.shape([1, seq_len, 2048]))
            lv808 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight5, model_layers_16_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv355 = R.call_tir(cls.NT_matmul6, (reshape643, lv808), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv352_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv355, lv351_1, model_layers_16_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv353_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv352_1[1]
            rms_norm325: R.Tensor((1, seq_len, 2048), dtype="float16") = lv352_1[0]
            lv809 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight5, model_layers_16_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv356 = R.call_tir(cls.NT_matmul7, (rms_norm325, lv809), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split160: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv356, indices_or_sections=2, axis=-1)
            split_0160: R.Tensor((1, seq_len, 11008), dtype="float16") = split160[0]
            split_1160: R.Tensor((1, seq_len, 11008), dtype="float16") = split160[1]
            silu160: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0160)
            mul160: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu160, split_1160)
            lv810 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight5, model_layers_16_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv357 = R.call_tir(cls.NT_matmul8, (mul160, lv810), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv354_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv357, lv353_1, model_layers_17_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv355_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv354_1[1]
            rms_norm326: R.Tensor((1, seq_len, 2048), dtype="float16") = lv354_1[0]
            lv811 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight5, model_layers_17_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv358 = R.call_tir(cls.NT_matmul5, (rms_norm326, lv811), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add483: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv358, model_layers_17_self_attn_c_attn_bias5)
            reshape644: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add483, R.shape([1, seq_len, 20, 128]))
            reshape645: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape644, R.shape([seq_len, 20, 128]))
            lv812 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape645), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape646: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv812, R.shape([1, seq_len, 16, 128]))
            reshape647: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape646, R.shape([1, seq_len, 2048]))
            lv813 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight5, model_layers_17_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv359 = R.call_tir(cls.NT_matmul6, (reshape647, lv813), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv356_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv359, lv355_1, model_layers_17_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv357_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv356_1[1]
            rms_norm327: R.Tensor((1, seq_len, 2048), dtype="float16") = lv356_1[0]
            lv814 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight5, model_layers_17_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv360 = R.call_tir(cls.NT_matmul7, (rms_norm327, lv814), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split161: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv360, indices_or_sections=2, axis=-1)
            split_0161: R.Tensor((1, seq_len, 11008), dtype="float16") = split161[0]
            split_1161: R.Tensor((1, seq_len, 11008), dtype="float16") = split161[1]
            silu161: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0161)
            mul161: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu161, split_1161)
            lv815 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight5, model_layers_17_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv361 = R.call_tir(cls.NT_matmul8, (mul161, lv815), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv358_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv361, lv357_1, model_layers_18_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv359_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv358_1[1]
            rms_norm328: R.Tensor((1, seq_len, 2048), dtype="float16") = lv358_1[0]
            lv816 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight5, model_layers_18_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv362 = R.call_tir(cls.NT_matmul5, (rms_norm328, lv816), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add486: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv362, model_layers_18_self_attn_c_attn_bias5)
            reshape648: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add486, R.shape([1, seq_len, 20, 128]))
            reshape649: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape648, R.shape([seq_len, 20, 128]))
            lv817 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape649), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape650: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv817, R.shape([1, seq_len, 16, 128]))
            reshape651: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape650, R.shape([1, seq_len, 2048]))
            lv818 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight5, model_layers_18_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv363 = R.call_tir(cls.NT_matmul6, (reshape651, lv818), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv360_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv363, lv359_1, model_layers_18_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv361_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv360_1[1]
            rms_norm329: R.Tensor((1, seq_len, 2048), dtype="float16") = lv360_1[0]
            lv819 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight5, model_layers_18_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv364 = R.call_tir(cls.NT_matmul7, (rms_norm329, lv819), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split162: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv364, indices_or_sections=2, axis=-1)
            split_0162: R.Tensor((1, seq_len, 11008), dtype="float16") = split162[0]
            split_1162: R.Tensor((1, seq_len, 11008), dtype="float16") = split162[1]
            silu162: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0162)
            mul162: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu162, split_1162)
            lv820 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight5, model_layers_18_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv365 = R.call_tir(cls.NT_matmul8, (mul162, lv820), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv362_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv365, lv361_1, model_layers_19_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv363_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv362_1[1]
            rms_norm330: R.Tensor((1, seq_len, 2048), dtype="float16") = lv362_1[0]
            lv821 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight5, model_layers_19_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv366 = R.call_tir(cls.NT_matmul5, (rms_norm330, lv821), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add489: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv366, model_layers_19_self_attn_c_attn_bias5)
            reshape652: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add489, R.shape([1, seq_len, 20, 128]))
            reshape653: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape652, R.shape([seq_len, 20, 128]))
            lv822 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape653), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape654: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv822, R.shape([1, seq_len, 16, 128]))
            reshape655: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape654, R.shape([1, seq_len, 2048]))
            lv823 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight5, model_layers_19_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv367 = R.call_tir(cls.NT_matmul6, (reshape655, lv823), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv364_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv367, lv363_1, model_layers_19_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv365_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv364_1[1]
            rms_norm331: R.Tensor((1, seq_len, 2048), dtype="float16") = lv364_1[0]
            lv824 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight5, model_layers_19_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv368 = R.call_tir(cls.NT_matmul7, (rms_norm331, lv824), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split163: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv368, indices_or_sections=2, axis=-1)
            split_0163: R.Tensor((1, seq_len, 11008), dtype="float16") = split163[0]
            split_1163: R.Tensor((1, seq_len, 11008), dtype="float16") = split163[1]
            silu163: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0163)
            mul163: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu163, split_1163)
            lv825 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight5, model_layers_19_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv369 = R.call_tir(cls.NT_matmul8, (mul163, lv825), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv366_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv369, lv365_1, model_layers_20_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv367_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv366_1[1]
            rms_norm332: R.Tensor((1, seq_len, 2048), dtype="float16") = lv366_1[0]
            lv826 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight5, model_layers_20_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv370 = R.call_tir(cls.NT_matmul5, (rms_norm332, lv826), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add492: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv370, model_layers_20_self_attn_c_attn_bias5)
            reshape656: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add492, R.shape([1, seq_len, 20, 128]))
            reshape657: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape656, R.shape([seq_len, 20, 128]))
            lv827 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape657), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape658: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv827, R.shape([1, seq_len, 16, 128]))
            reshape659: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape658, R.shape([1, seq_len, 2048]))
            lv828 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight5, model_layers_20_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv371 = R.call_tir(cls.NT_matmul6, (reshape659, lv828), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv368_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv371, lv367_1, model_layers_20_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv369_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv368_1[1]
            rms_norm333: R.Tensor((1, seq_len, 2048), dtype="float16") = lv368_1[0]
            lv829 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight5, model_layers_20_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv372 = R.call_tir(cls.NT_matmul7, (rms_norm333, lv829), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split164: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv372, indices_or_sections=2, axis=-1)
            split_0164: R.Tensor((1, seq_len, 11008), dtype="float16") = split164[0]
            split_1164: R.Tensor((1, seq_len, 11008), dtype="float16") = split164[1]
            silu164: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0164)
            mul164: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu164, split_1164)
            lv830 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight5, model_layers_20_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv373 = R.call_tir(cls.NT_matmul8, (mul164, lv830), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv370_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv373, lv369_1, model_layers_21_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv371_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv370_1[1]
            rms_norm334: R.Tensor((1, seq_len, 2048), dtype="float16") = lv370_1[0]
            lv831 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight5, model_layers_21_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv374 = R.call_tir(cls.NT_matmul5, (rms_norm334, lv831), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add495: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv374, model_layers_21_self_attn_c_attn_bias5)
            reshape660: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add495, R.shape([1, seq_len, 20, 128]))
            reshape661: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape660, R.shape([seq_len, 20, 128]))
            lv832 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape661), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape662: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv832, R.shape([1, seq_len, 16, 128]))
            reshape663: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape662, R.shape([1, seq_len, 2048]))
            lv833 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight5, model_layers_21_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv375 = R.call_tir(cls.NT_matmul6, (reshape663, lv833), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv372_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv375, lv371_1, model_layers_21_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv373_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv372_1[1]
            rms_norm335: R.Tensor((1, seq_len, 2048), dtype="float16") = lv372_1[0]
            lv834 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight5, model_layers_21_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv376 = R.call_tir(cls.NT_matmul7, (rms_norm335, lv834), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split165: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv376, indices_or_sections=2, axis=-1)
            split_0165: R.Tensor((1, seq_len, 11008), dtype="float16") = split165[0]
            split_1165: R.Tensor((1, seq_len, 11008), dtype="float16") = split165[1]
            silu165: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0165)
            mul165: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu165, split_1165)
            lv835 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight5, model_layers_21_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv377 = R.call_tir(cls.NT_matmul8, (mul165, lv835), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv374_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv377, lv373_1, model_layers_22_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv375_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv374_1[1]
            rms_norm336: R.Tensor((1, seq_len, 2048), dtype="float16") = lv374_1[0]
            lv836 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight5, model_layers_22_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv378 = R.call_tir(cls.NT_matmul5, (rms_norm336, lv836), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add498: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv378, model_layers_22_self_attn_c_attn_bias5)
            reshape664: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add498, R.shape([1, seq_len, 20, 128]))
            reshape665: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape664, R.shape([seq_len, 20, 128]))
            lv837 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape665), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape666: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv837, R.shape([1, seq_len, 16, 128]))
            reshape667: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape666, R.shape([1, seq_len, 2048]))
            lv838 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight5, model_layers_22_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv379 = R.call_tir(cls.NT_matmul6, (reshape667, lv838), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv376_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv379, lv375_1, model_layers_22_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv377_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv376_1[1]
            rms_norm337: R.Tensor((1, seq_len, 2048), dtype="float16") = lv376_1[0]
            lv839 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight5, model_layers_22_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv380 = R.call_tir(cls.NT_matmul7, (rms_norm337, lv839), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split166: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv380, indices_or_sections=2, axis=-1)
            split_0166: R.Tensor((1, seq_len, 11008), dtype="float16") = split166[0]
            split_1166: R.Tensor((1, seq_len, 11008), dtype="float16") = split166[1]
            silu166: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0166)
            mul166: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu166, split_1166)
            lv840 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight5, model_layers_22_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv381 = R.call_tir(cls.NT_matmul8, (mul166, lv840), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv378_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv381, lv377_1, model_layers_23_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv379_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv378_1[1]
            rms_norm338: R.Tensor((1, seq_len, 2048), dtype="float16") = lv378_1[0]
            lv841 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight5, model_layers_23_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv382 = R.call_tir(cls.NT_matmul5, (rms_norm338, lv841), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add501: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv382, model_layers_23_self_attn_c_attn_bias5)
            reshape668: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add501, R.shape([1, seq_len, 20, 128]))
            reshape669: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape668, R.shape([seq_len, 20, 128]))
            lv842 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape669), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape670: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv842, R.shape([1, seq_len, 16, 128]))
            reshape671: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape670, R.shape([1, seq_len, 2048]))
            lv843 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight5, model_layers_23_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv383 = R.call_tir(cls.NT_matmul6, (reshape671, lv843), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv380_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv383, lv379_1, model_layers_23_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv381_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv380_1[1]
            rms_norm339: R.Tensor((1, seq_len, 2048), dtype="float16") = lv380_1[0]
            lv844 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight5, model_layers_23_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv384 = R.call_tir(cls.NT_matmul7, (rms_norm339, lv844), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split167: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv384, indices_or_sections=2, axis=-1)
            split_0167: R.Tensor((1, seq_len, 11008), dtype="float16") = split167[0]
            split_1167: R.Tensor((1, seq_len, 11008), dtype="float16") = split167[1]
            silu167: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0167)
            mul167: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu167, split_1167)
            lv845 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight5, model_layers_23_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv385 = R.call_tir(cls.NT_matmul8, (mul167, lv845), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv382_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv385, lv381_1, model_layers_24_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv383_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv382_1[1]
            rms_norm340: R.Tensor((1, seq_len, 2048), dtype="float16") = lv382_1[0]
            lv846 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight5, model_layers_24_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv386 = R.call_tir(cls.NT_matmul5, (rms_norm340, lv846), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add504: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv386, model_layers_24_self_attn_c_attn_bias5)
            reshape672: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add504, R.shape([1, seq_len, 20, 128]))
            reshape673: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape672, R.shape([seq_len, 20, 128]))
            lv847 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape673), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape674: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv847, R.shape([1, seq_len, 16, 128]))
            reshape675: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape674, R.shape([1, seq_len, 2048]))
            lv848 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight5, model_layers_24_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv387 = R.call_tir(cls.NT_matmul6, (reshape675, lv848), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv384_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv387, lv383_1, model_layers_24_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv385_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv384_1[1]
            rms_norm341: R.Tensor((1, seq_len, 2048), dtype="float16") = lv384_1[0]
            lv849 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight5, model_layers_24_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv388 = R.call_tir(cls.NT_matmul7, (rms_norm341, lv849), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split168: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv388, indices_or_sections=2, axis=-1)
            split_0168: R.Tensor((1, seq_len, 11008), dtype="float16") = split168[0]
            split_1168: R.Tensor((1, seq_len, 11008), dtype="float16") = split168[1]
            silu168: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0168)
            mul168: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu168, split_1168)
            lv850 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight5, model_layers_24_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv389 = R.call_tir(cls.NT_matmul8, (mul168, lv850), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv386_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv389, lv385_1, model_layers_25_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv387_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv386_1[1]
            rms_norm342: R.Tensor((1, seq_len, 2048), dtype="float16") = lv386_1[0]
            lv851 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight5, model_layers_25_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv390 = R.call_tir(cls.NT_matmul5, (rms_norm342, lv851), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add507: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv390, model_layers_25_self_attn_c_attn_bias5)
            reshape676: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add507, R.shape([1, seq_len, 20, 128]))
            reshape677: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape676, R.shape([seq_len, 20, 128]))
            lv852 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape677), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape678: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv852, R.shape([1, seq_len, 16, 128]))
            reshape679: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape678, R.shape([1, seq_len, 2048]))
            lv853 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight5, model_layers_25_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv391 = R.call_tir(cls.NT_matmul6, (reshape679, lv853), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv388_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv391, lv387_1, model_layers_25_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv389_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv388_1[1]
            rms_norm343: R.Tensor((1, seq_len, 2048), dtype="float16") = lv388_1[0]
            lv854 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight5, model_layers_25_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv392 = R.call_tir(cls.NT_matmul7, (rms_norm343, lv854), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split169: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv392, indices_or_sections=2, axis=-1)
            split_0169: R.Tensor((1, seq_len, 11008), dtype="float16") = split169[0]
            split_1169: R.Tensor((1, seq_len, 11008), dtype="float16") = split169[1]
            silu169: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0169)
            mul169: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu169, split_1169)
            lv855 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight5, model_layers_25_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv393 = R.call_tir(cls.NT_matmul8, (mul169, lv855), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv390_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv393, lv389_1, model_layers_26_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv391_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv390_1[1]
            rms_norm344: R.Tensor((1, seq_len, 2048), dtype="float16") = lv390_1[0]
            lv856 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight5, model_layers_26_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv394 = R.call_tir(cls.NT_matmul5, (rms_norm344, lv856), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add510: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv394, model_layers_26_self_attn_c_attn_bias5)
            reshape680: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add510, R.shape([1, seq_len, 20, 128]))
            reshape681: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape680, R.shape([seq_len, 20, 128]))
            lv857 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape681), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape682: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv857, R.shape([1, seq_len, 16, 128]))
            reshape683: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape682, R.shape([1, seq_len, 2048]))
            lv858 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight5, model_layers_26_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv395 = R.call_tir(cls.NT_matmul6, (reshape683, lv858), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv392_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv395, lv391_1, model_layers_26_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv393_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv392_1[1]
            rms_norm345: R.Tensor((1, seq_len, 2048), dtype="float16") = lv392_1[0]
            lv859 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight5, model_layers_26_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv396 = R.call_tir(cls.NT_matmul7, (rms_norm345, lv859), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split170: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv396, indices_or_sections=2, axis=-1)
            split_0170: R.Tensor((1, seq_len, 11008), dtype="float16") = split170[0]
            split_1170: R.Tensor((1, seq_len, 11008), dtype="float16") = split170[1]
            silu170: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0170)
            mul170: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu170, split_1170)
            lv860 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight5, model_layers_26_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv397 = R.call_tir(cls.NT_matmul8, (mul170, lv860), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv394_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv397, lv393_1, model_layers_27_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv395_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv394_1[1]
            rms_norm346: R.Tensor((1, seq_len, 2048), dtype="float16") = lv394_1[0]
            lv861 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight5, model_layers_27_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv398 = R.call_tir(cls.NT_matmul5, (rms_norm346, lv861), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add513: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv398, model_layers_27_self_attn_c_attn_bias5)
            reshape684: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add513, R.shape([1, seq_len, 20, 128]))
            reshape685: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape684, R.shape([seq_len, 20, 128]))
            lv862 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape685), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape686: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv862, R.shape([1, seq_len, 16, 128]))
            reshape687: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape686, R.shape([1, seq_len, 2048]))
            lv863 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight5, model_layers_27_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv399 = R.call_tir(cls.NT_matmul6, (reshape687, lv863), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv396_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv399, lv395_1, model_layers_27_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv397_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv396_1[1]
            rms_norm347: R.Tensor((1, seq_len, 2048), dtype="float16") = lv396_1[0]
            lv864 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight5, model_layers_27_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv400 = R.call_tir(cls.NT_matmul7, (rms_norm347, lv864), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split171: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv400, indices_or_sections=2, axis=-1)
            split_0171: R.Tensor((1, seq_len, 11008), dtype="float16") = split171[0]
            split_1171: R.Tensor((1, seq_len, 11008), dtype="float16") = split171[1]
            silu171: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0171)
            mul171: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu171, split_1171)
            lv865 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight5, model_layers_27_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv401 = R.call_tir(cls.NT_matmul8, (mul171, lv865), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv398_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv401, lv397_1, model_layers_28_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv399_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv398_1[1]
            rms_norm348: R.Tensor((1, seq_len, 2048), dtype="float16") = lv398_1[0]
            lv866 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight5, model_layers_28_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv402 = R.call_tir(cls.NT_matmul5, (rms_norm348, lv866), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add516: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv402, model_layers_28_self_attn_c_attn_bias5)
            reshape688: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add516, R.shape([1, seq_len, 20, 128]))
            reshape689: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape688, R.shape([seq_len, 20, 128]))
            lv867 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape689), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape690: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv867, R.shape([1, seq_len, 16, 128]))
            reshape691: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape690, R.shape([1, seq_len, 2048]))
            lv868 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight5, model_layers_28_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv403 = R.call_tir(cls.NT_matmul6, (reshape691, lv868), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv400_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv403, lv399_1, model_layers_28_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv401_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv400_1[1]
            rms_norm349: R.Tensor((1, seq_len, 2048), dtype="float16") = lv400_1[0]
            lv869 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight5, model_layers_28_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv404 = R.call_tir(cls.NT_matmul7, (rms_norm349, lv869), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split172: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv404, indices_or_sections=2, axis=-1)
            split_0172: R.Tensor((1, seq_len, 11008), dtype="float16") = split172[0]
            split_1172: R.Tensor((1, seq_len, 11008), dtype="float16") = split172[1]
            silu172: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0172)
            mul172: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu172, split_1172)
            lv870 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight5, model_layers_28_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv405 = R.call_tir(cls.NT_matmul8, (mul172, lv870), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv402_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv405, lv401_1, model_layers_29_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv403_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv402_1[1]
            rms_norm350: R.Tensor((1, seq_len, 2048), dtype="float16") = lv402_1[0]
            lv871 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight5, model_layers_29_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv406 = R.call_tir(cls.NT_matmul5, (rms_norm350, lv871), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add519: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv406, model_layers_29_self_attn_c_attn_bias5)
            reshape692: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add519, R.shape([1, seq_len, 20, 128]))
            reshape693: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape692, R.shape([seq_len, 20, 128]))
            lv872 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape693), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape694: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv872, R.shape([1, seq_len, 16, 128]))
            reshape695: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape694, R.shape([1, seq_len, 2048]))
            lv873 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight5, model_layers_29_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv407 = R.call_tir(cls.NT_matmul6, (reshape695, lv873), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv404_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv407, lv403_1, model_layers_29_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv405_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv404_1[1]
            rms_norm351: R.Tensor((1, seq_len, 2048), dtype="float16") = lv404_1[0]
            lv874 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight5, model_layers_29_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv408 = R.call_tir(cls.NT_matmul7, (rms_norm351, lv874), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split173: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv408, indices_or_sections=2, axis=-1)
            split_0173: R.Tensor((1, seq_len, 11008), dtype="float16") = split173[0]
            split_1173: R.Tensor((1, seq_len, 11008), dtype="float16") = split173[1]
            silu173: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0173)
            mul173: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu173, split_1173)
            lv875 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight5, model_layers_29_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv409 = R.call_tir(cls.NT_matmul8, (mul173, lv875), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv406_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv409, lv405_1, model_layers_30_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv407_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv406_1[1]
            rms_norm352: R.Tensor((1, seq_len, 2048), dtype="float16") = lv406_1[0]
            lv876 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight5, model_layers_30_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv410 = R.call_tir(cls.NT_matmul5, (rms_norm352, lv876), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add522: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv410, model_layers_30_self_attn_c_attn_bias5)
            reshape696: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add522, R.shape([1, seq_len, 20, 128]))
            reshape697: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape696, R.shape([seq_len, 20, 128]))
            lv877 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape697), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape698: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv877, R.shape([1, seq_len, 16, 128]))
            reshape699: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape698, R.shape([1, seq_len, 2048]))
            lv878 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight5, model_layers_30_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv411 = R.call_tir(cls.NT_matmul6, (reshape699, lv878), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv408_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv411, lv407_1, model_layers_30_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv409_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv408_1[1]
            rms_norm353: R.Tensor((1, seq_len, 2048), dtype="float16") = lv408_1[0]
            lv879 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight5, model_layers_30_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv412 = R.call_tir(cls.NT_matmul7, (rms_norm353, lv879), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split174: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv412, indices_or_sections=2, axis=-1)
            split_0174: R.Tensor((1, seq_len, 11008), dtype="float16") = split174[0]
            split_1174: R.Tensor((1, seq_len, 11008), dtype="float16") = split174[1]
            silu174: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0174)
            mul174: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu174, split_1174)
            lv880 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight5, model_layers_30_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv413 = R.call_tir(cls.NT_matmul8, (mul174, lv880), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv410_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv413, lv409_1, model_layers_31_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv411_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv410_1[1]
            rms_norm354: R.Tensor((1, seq_len, 2048), dtype="float16") = lv410_1[0]
            lv881 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight5, model_layers_31_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv414 = R.call_tir(cls.NT_matmul5, (rms_norm354, lv881), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add525: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv414, model_layers_31_self_attn_c_attn_bias5)
            reshape700: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add525, R.shape([1, seq_len, 20, 128]))
            reshape701: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape700, R.shape([seq_len, 20, 128]))
            lv882 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape701), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape702: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv882, R.shape([1, seq_len, 16, 128]))
            reshape703: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape702, R.shape([1, seq_len, 2048]))
            lv883 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight5, model_layers_31_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv415 = R.call_tir(cls.NT_matmul6, (reshape703, lv883), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv412_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv415, lv411_1, model_layers_31_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv413_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv412_1[1]
            rms_norm355: R.Tensor((1, seq_len, 2048), dtype="float16") = lv412_1[0]
            lv884 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight5, model_layers_31_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv416 = R.call_tir(cls.NT_matmul7, (rms_norm355, lv884), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split175: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv416, indices_or_sections=2, axis=-1)
            split_0175: R.Tensor((1, seq_len, 11008), dtype="float16") = split175[0]
            split_1175: R.Tensor((1, seq_len, 11008), dtype="float16") = split175[1]
            silu175: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0175)
            mul175: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu175, split_1175)
            lv885 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight5, model_layers_31_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv417 = R.call_tir(cls.NT_matmul8, (mul175, lv885), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv414_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv417, lv413_1, model_layers_32_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv415_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv414_1[1]
            rms_norm356: R.Tensor((1, seq_len, 2048), dtype="float16") = lv414_1[0]
            lv886 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight5, model_layers_32_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv418 = R.call_tir(cls.NT_matmul5, (rms_norm356, lv886), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add528: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv418, model_layers_32_self_attn_c_attn_bias5)
            reshape704: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add528, R.shape([1, seq_len, 20, 128]))
            reshape705: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape704, R.shape([seq_len, 20, 128]))
            lv887 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape705), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape706: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv887, R.shape([1, seq_len, 16, 128]))
            reshape707: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape706, R.shape([1, seq_len, 2048]))
            lv888 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight5, model_layers_32_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv419 = R.call_tir(cls.NT_matmul6, (reshape707, lv888), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv416_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv419, lv415_1, model_layers_32_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv417_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv416_1[1]
            rms_norm357: R.Tensor((1, seq_len, 2048), dtype="float16") = lv416_1[0]
            lv889 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight5, model_layers_32_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv420 = R.call_tir(cls.NT_matmul7, (rms_norm357, lv889), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split176: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv420, indices_or_sections=2, axis=-1)
            split_0176: R.Tensor((1, seq_len, 11008), dtype="float16") = split176[0]
            split_1176: R.Tensor((1, seq_len, 11008), dtype="float16") = split176[1]
            silu176: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0176)
            mul176: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu176, split_1176)
            lv890 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight5, model_layers_32_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv421 = R.call_tir(cls.NT_matmul8, (mul176, lv890), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv418_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv421, lv417_1, model_layers_33_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv419_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv418_1[1]
            rms_norm358: R.Tensor((1, seq_len, 2048), dtype="float16") = lv418_1[0]
            lv891 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight5, model_layers_33_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv422 = R.call_tir(cls.NT_matmul5, (rms_norm358, lv891), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add531: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv422, model_layers_33_self_attn_c_attn_bias5)
            reshape708: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add531, R.shape([1, seq_len, 20, 128]))
            reshape709: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape708, R.shape([seq_len, 20, 128]))
            lv892 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape709), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape710: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv892, R.shape([1, seq_len, 16, 128]))
            reshape711: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape710, R.shape([1, seq_len, 2048]))
            lv893 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight5, model_layers_33_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv423 = R.call_tir(cls.NT_matmul6, (reshape711, lv893), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv420_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv423, lv419_1, model_layers_33_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv421_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv420_1[1]
            rms_norm359: R.Tensor((1, seq_len, 2048), dtype="float16") = lv420_1[0]
            lv894 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight5, model_layers_33_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv424 = R.call_tir(cls.NT_matmul7, (rms_norm359, lv894), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split177: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv424, indices_or_sections=2, axis=-1)
            split_0177: R.Tensor((1, seq_len, 11008), dtype="float16") = split177[0]
            split_1177: R.Tensor((1, seq_len, 11008), dtype="float16") = split177[1]
            silu177: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0177)
            mul177: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu177, split_1177)
            lv895 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight5, model_layers_33_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv425 = R.call_tir(cls.NT_matmul8, (mul177, lv895), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv422_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv425, lv421_1, model_layers_34_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv423_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv422_1[1]
            rms_norm360: R.Tensor((1, seq_len, 2048), dtype="float16") = lv422_1[0]
            lv896 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight5, model_layers_34_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv426 = R.call_tir(cls.NT_matmul5, (rms_norm360, lv896), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add534: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv426, model_layers_34_self_attn_c_attn_bias5)
            reshape712: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add534, R.shape([1, seq_len, 20, 128]))
            reshape713: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape712, R.shape([seq_len, 20, 128]))
            lv897 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape713), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape714: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv897, R.shape([1, seq_len, 16, 128]))
            reshape715: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape714, R.shape([1, seq_len, 2048]))
            lv898 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight5, model_layers_34_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv427 = R.call_tir(cls.NT_matmul6, (reshape715, lv898), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv424_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv427, lv423_1, model_layers_34_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv425_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv424_1[1]
            rms_norm361: R.Tensor((1, seq_len, 2048), dtype="float16") = lv424_1[0]
            lv899 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight5, model_layers_34_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv428 = R.call_tir(cls.NT_matmul7, (rms_norm361, lv899), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split178: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv428, indices_or_sections=2, axis=-1)
            split_0178: R.Tensor((1, seq_len, 11008), dtype="float16") = split178[0]
            split_1178: R.Tensor((1, seq_len, 11008), dtype="float16") = split178[1]
            silu178: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0178)
            mul178: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu178, split_1178)
            lv900 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight5, model_layers_34_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv429 = R.call_tir(cls.NT_matmul8, (mul178, lv900), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv426_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv429, lv425_1, model_layers_35_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv427_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv426_1[1]
            rms_norm362: R.Tensor((1, seq_len, 2048), dtype="float16") = lv426_1[0]
            lv901 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight5, model_layers_35_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv430 = R.call_tir(cls.NT_matmul5, (rms_norm362, lv901), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            add537: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(lv430, model_layers_35_self_attn_c_attn_bias5)
            reshape716: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add537, R.shape([1, seq_len, 20, 128]))
            reshape717: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape716, R.shape([seq_len, 20, 128]))
            lv902 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape717), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape718: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv902, R.shape([1, seq_len, 16, 128]))
            reshape719: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape718, R.shape([1, seq_len, 2048]))
            lv903 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight5, model_layers_35_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv431 = R.call_tir(cls.NT_matmul6, (reshape719, lv903), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv428_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv431, lv427_1, model_layers_35_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv429_1: R.Tensor((1, seq_len, 2048), dtype="float16") = lv428_1[1]
            rms_norm363: R.Tensor((1, seq_len, 2048), dtype="float16") = lv428_1[0]
            lv904 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight5, model_layers_35_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv432 = R.call_tir(cls.NT_matmul7, (rms_norm363, lv904), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            split179: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(lv432, indices_or_sections=2, axis=-1)
            split_0179: R.Tensor((1, seq_len, 11008), dtype="float16") = split179[0]
            split_1179: R.Tensor((1, seq_len, 11008), dtype="float16") = split179[1]
            silu179: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0179)
            mul179: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu179, split_1179)
            lv905 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight5, model_layers_35_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv433 = R.call_tir(cls.NT_matmul8, (mul179, lv905), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv430_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv433, lv429_1, model_norm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            rms_norm364: R.Tensor((1, seq_len, 2048), dtype="float16") = lv430_1[0]
            lv906 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight5, model_embed_tokens_q_scale5), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            lv434 = R.call_tir(cls.NT_matmul9, (rms_norm364, lv906), out_sinfo=R.Tensor((1, seq_len, 151936), dtype="float32"))
            gv5: R.Tuple(R.Tensor((1, seq_len, 151936), dtype="float32"), R.Object) = lv434, paged_kv_cache
            R.output(gv5)
        return gv5

    @R.function
    def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object:
        max_batch_size = T.int64()
        max_total_seq_len = T.int64()
        prefill_chunk_size = T.int64()
        page_size = T.int64()
        support_sliding_window = T.int64()
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        gv: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16")
        paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 36]), R.prim_value(16), R.prim_value(2), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(T.float32(1000000.0)), gv, cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,))
        return paged_kv_cache

    @R.function
    def decode(input_embed: R.Tensor((1, 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 151936), dtype="float32"), R.Object):
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight2: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale2: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm73: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(input_embed, model_layers_0_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv183 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight2, model_layers_0_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv435 = R.call_tir(cls.NT_matmul10, (rms_norm73, lv183), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add108: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv435, model_layers_0_self_attn_c_attn_bias2)
            reshape144: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add108, R.shape([1, 1, 20, 128]))
            reshape145: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape144, R.shape([1, 20, 128]))
            lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape145), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape146: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv184, R.shape([1, 1, 16, 128]))
            reshape147: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape146, R.shape([1, 1, 2048]))
            lv185 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight2, model_layers_0_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv436 = R.call_tir(cls.NT_matmul11, (reshape147, lv185), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv432 = R.call_tir(cls.fuse_add_norm_prefill, (lv436, input_embed, model_layers_0_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv433: R.Tensor((1, 1, 2048), dtype="float16") = lv432[1]
            rms_norm74: R.Tensor((1, 1, 2048), dtype="float16") = lv432[0]
            lv186 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight2, model_layers_0_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv437 = R.call_tir(cls.NT_matmul12, (rms_norm74, lv186), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split36: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv437, indices_or_sections=2, axis=-1)
            split_036: R.Tensor((1, 1, 11008), dtype="float16") = split36[0]
            split_136: R.Tensor((1, 1, 11008), dtype="float16") = split36[1]
            silu36: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_036)
            mul36: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu36, split_136)
            lv187 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight2, model_layers_0_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv438 = R.call_tir(cls.NT_matmul13, (mul36, lv187), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv434 = R.call_tir(cls.fuse_add_norm_prefill, (lv438, lv433, model_layers_1_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv435_1: R.Tensor((1, 1, 2048), dtype="float16") = lv434[1]
            rms_norm75: R.Tensor((1, 1, 2048), dtype="float16") = lv434[0]
            lv188 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight2, model_layers_1_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv439 = R.call_tir(cls.NT_matmul10, (rms_norm75, lv188), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add111: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv439, model_layers_1_self_attn_c_attn_bias2)
            reshape148: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add111, R.shape([1, 1, 20, 128]))
            reshape149: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape148, R.shape([1, 20, 128]))
            lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape149), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape150: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv189, R.shape([1, 1, 16, 128]))
            reshape151: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape150, R.shape([1, 1, 2048]))
            lv190 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight2, model_layers_1_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv440 = R.call_tir(cls.NT_matmul11, (reshape151, lv190), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv436_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv440, lv435_1, model_layers_1_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv437_1: R.Tensor((1, 1, 2048), dtype="float16") = lv436_1[1]
            rms_norm76: R.Tensor((1, 1, 2048), dtype="float16") = lv436_1[0]
            lv191 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight2, model_layers_1_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv441 = R.call_tir(cls.NT_matmul12, (rms_norm76, lv191), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split37: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv441, indices_or_sections=2, axis=-1)
            split_037: R.Tensor((1, 1, 11008), dtype="float16") = split37[0]
            split_137: R.Tensor((1, 1, 11008), dtype="float16") = split37[1]
            silu37: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_037)
            mul37: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu37, split_137)
            lv192 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight2, model_layers_1_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv442 = R.call_tir(cls.NT_matmul13, (mul37, lv192), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv438_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv442, lv437_1, model_layers_2_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv439_1: R.Tensor((1, 1, 2048), dtype="float16") = lv438_1[1]
            rms_norm77: R.Tensor((1, 1, 2048), dtype="float16") = lv438_1[0]
            lv193 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight2, model_layers_2_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv443 = R.call_tir(cls.NT_matmul10, (rms_norm77, lv193), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add114: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv443, model_layers_2_self_attn_c_attn_bias2)
            reshape152: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add114, R.shape([1, 1, 20, 128]))
            reshape153: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape152, R.shape([1, 20, 128]))
            lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape153), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape154: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv194, R.shape([1, 1, 16, 128]))
            reshape155: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape154, R.shape([1, 1, 2048]))
            lv195 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight2, model_layers_2_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv444 = R.call_tir(cls.NT_matmul11, (reshape155, lv195), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv440_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv444, lv439_1, model_layers_2_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv441_1: R.Tensor((1, 1, 2048), dtype="float16") = lv440_1[1]
            rms_norm78: R.Tensor((1, 1, 2048), dtype="float16") = lv440_1[0]
            lv196 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight2, model_layers_2_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv445 = R.call_tir(cls.NT_matmul12, (rms_norm78, lv196), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split38: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv445, indices_or_sections=2, axis=-1)
            split_038: R.Tensor((1, 1, 11008), dtype="float16") = split38[0]
            split_138: R.Tensor((1, 1, 11008), dtype="float16") = split38[1]
            silu38: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_038)
            mul38: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu38, split_138)
            lv197 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight2, model_layers_2_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv446 = R.call_tir(cls.NT_matmul13, (mul38, lv197), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv442_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv446, lv441_1, model_layers_3_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv443_1: R.Tensor((1, 1, 2048), dtype="float16") = lv442_1[1]
            rms_norm79: R.Tensor((1, 1, 2048), dtype="float16") = lv442_1[0]
            lv198 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight2, model_layers_3_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv447 = R.call_tir(cls.NT_matmul10, (rms_norm79, lv198), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add117: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv447, model_layers_3_self_attn_c_attn_bias2)
            reshape156: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add117, R.shape([1, 1, 20, 128]))
            reshape157: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape156, R.shape([1, 20, 128]))
            lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape157), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape158: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv199, R.shape([1, 1, 16, 128]))
            reshape159: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape158, R.shape([1, 1, 2048]))
            lv200 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight2, model_layers_3_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv448 = R.call_tir(cls.NT_matmul11, (reshape159, lv200), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv444_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv448, lv443_1, model_layers_3_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv445_1: R.Tensor((1, 1, 2048), dtype="float16") = lv444_1[1]
            rms_norm80: R.Tensor((1, 1, 2048), dtype="float16") = lv444_1[0]
            lv201 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight2, model_layers_3_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv449 = R.call_tir(cls.NT_matmul12, (rms_norm80, lv201), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split39: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv449, indices_or_sections=2, axis=-1)
            split_039: R.Tensor((1, 1, 11008), dtype="float16") = split39[0]
            split_139: R.Tensor((1, 1, 11008), dtype="float16") = split39[1]
            silu39: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_039)
            mul39: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu39, split_139)
            lv202 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight2, model_layers_3_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv450 = R.call_tir(cls.NT_matmul13, (mul39, lv202), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv446_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv450, lv445_1, model_layers_4_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv447_1: R.Tensor((1, 1, 2048), dtype="float16") = lv446_1[1]
            rms_norm81: R.Tensor((1, 1, 2048), dtype="float16") = lv446_1[0]
            lv203 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight2, model_layers_4_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv451 = R.call_tir(cls.NT_matmul10, (rms_norm81, lv203), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add120: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv451, model_layers_4_self_attn_c_attn_bias2)
            reshape160: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add120, R.shape([1, 1, 20, 128]))
            reshape161: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape160, R.shape([1, 20, 128]))
            lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape161), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape162: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv204, R.shape([1, 1, 16, 128]))
            reshape163: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape162, R.shape([1, 1, 2048]))
            lv205 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight2, model_layers_4_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv452 = R.call_tir(cls.NT_matmul11, (reshape163, lv205), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv448_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv452, lv447_1, model_layers_4_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv449_1: R.Tensor((1, 1, 2048), dtype="float16") = lv448_1[1]
            rms_norm82: R.Tensor((1, 1, 2048), dtype="float16") = lv448_1[0]
            lv206 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight2, model_layers_4_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv453 = R.call_tir(cls.NT_matmul12, (rms_norm82, lv206), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split40: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv453, indices_or_sections=2, axis=-1)
            split_040: R.Tensor((1, 1, 11008), dtype="float16") = split40[0]
            split_140: R.Tensor((1, 1, 11008), dtype="float16") = split40[1]
            silu40: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_040)
            mul40: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu40, split_140)
            lv207 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight2, model_layers_4_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv454 = R.call_tir(cls.NT_matmul13, (mul40, lv207), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv450_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv454, lv449_1, model_layers_5_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv451_1: R.Tensor((1, 1, 2048), dtype="float16") = lv450_1[1]
            rms_norm83: R.Tensor((1, 1, 2048), dtype="float16") = lv450_1[0]
            lv208 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight2, model_layers_5_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv455 = R.call_tir(cls.NT_matmul10, (rms_norm83, lv208), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add123: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv455, model_layers_5_self_attn_c_attn_bias2)
            reshape164: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add123, R.shape([1, 1, 20, 128]))
            reshape165: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape164, R.shape([1, 20, 128]))
            lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape165), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape166: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv209, R.shape([1, 1, 16, 128]))
            reshape167: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape166, R.shape([1, 1, 2048]))
            lv210 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight2, model_layers_5_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv456 = R.call_tir(cls.NT_matmul11, (reshape167, lv210), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv452_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv456, lv451_1, model_layers_5_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv453_1: R.Tensor((1, 1, 2048), dtype="float16") = lv452_1[1]
            rms_norm84: R.Tensor((1, 1, 2048), dtype="float16") = lv452_1[0]
            lv211 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight2, model_layers_5_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv457 = R.call_tir(cls.NT_matmul12, (rms_norm84, lv211), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split41: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv457, indices_or_sections=2, axis=-1)
            split_041: R.Tensor((1, 1, 11008), dtype="float16") = split41[0]
            split_141: R.Tensor((1, 1, 11008), dtype="float16") = split41[1]
            silu41: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_041)
            mul41: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu41, split_141)
            lv212 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight2, model_layers_5_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv458 = R.call_tir(cls.NT_matmul13, (mul41, lv212), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv454_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv458, lv453_1, model_layers_6_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv455_1: R.Tensor((1, 1, 2048), dtype="float16") = lv454_1[1]
            rms_norm85: R.Tensor((1, 1, 2048), dtype="float16") = lv454_1[0]
            lv213 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight2, model_layers_6_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv459 = R.call_tir(cls.NT_matmul10, (rms_norm85, lv213), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add126: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv459, model_layers_6_self_attn_c_attn_bias2)
            reshape168: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add126, R.shape([1, 1, 20, 128]))
            reshape169: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape168, R.shape([1, 20, 128]))
            lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape169), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape170: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv214, R.shape([1, 1, 16, 128]))
            reshape171: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape170, R.shape([1, 1, 2048]))
            lv215 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight2, model_layers_6_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv460 = R.call_tir(cls.NT_matmul11, (reshape171, lv215), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv456_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv460, lv455_1, model_layers_6_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv457_1: R.Tensor((1, 1, 2048), dtype="float16") = lv456_1[1]
            rms_norm86: R.Tensor((1, 1, 2048), dtype="float16") = lv456_1[0]
            lv216 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight2, model_layers_6_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv461 = R.call_tir(cls.NT_matmul12, (rms_norm86, lv216), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split42: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv461, indices_or_sections=2, axis=-1)
            split_042: R.Tensor((1, 1, 11008), dtype="float16") = split42[0]
            split_142: R.Tensor((1, 1, 11008), dtype="float16") = split42[1]
            silu42: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_042)
            mul42: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu42, split_142)
            lv217 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight2, model_layers_6_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv462 = R.call_tir(cls.NT_matmul13, (mul42, lv217), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv458_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv462, lv457_1, model_layers_7_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv459_1: R.Tensor((1, 1, 2048), dtype="float16") = lv458_1[1]
            rms_norm87: R.Tensor((1, 1, 2048), dtype="float16") = lv458_1[0]
            lv218 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight2, model_layers_7_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv463 = R.call_tir(cls.NT_matmul10, (rms_norm87, lv218), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add129: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv463, model_layers_7_self_attn_c_attn_bias2)
            reshape172: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add129, R.shape([1, 1, 20, 128]))
            reshape173: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape172, R.shape([1, 20, 128]))
            lv219 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape173), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape174: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv219, R.shape([1, 1, 16, 128]))
            reshape175: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape174, R.shape([1, 1, 2048]))
            lv220 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight2, model_layers_7_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv464 = R.call_tir(cls.NT_matmul11, (reshape175, lv220), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv460_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv464, lv459_1, model_layers_7_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv461_1: R.Tensor((1, 1, 2048), dtype="float16") = lv460_1[1]
            rms_norm88: R.Tensor((1, 1, 2048), dtype="float16") = lv460_1[0]
            lv221 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight2, model_layers_7_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv465 = R.call_tir(cls.NT_matmul12, (rms_norm88, lv221), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split43: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv465, indices_or_sections=2, axis=-1)
            split_043: R.Tensor((1, 1, 11008), dtype="float16") = split43[0]
            split_143: R.Tensor((1, 1, 11008), dtype="float16") = split43[1]
            silu43: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_043)
            mul43: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu43, split_143)
            lv222 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight2, model_layers_7_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv466 = R.call_tir(cls.NT_matmul13, (mul43, lv222), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv462_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv466, lv461_1, model_layers_8_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv463_1: R.Tensor((1, 1, 2048), dtype="float16") = lv462_1[1]
            rms_norm89: R.Tensor((1, 1, 2048), dtype="float16") = lv462_1[0]
            lv223 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight2, model_layers_8_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv467 = R.call_tir(cls.NT_matmul10, (rms_norm89, lv223), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add132: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv467, model_layers_8_self_attn_c_attn_bias2)
            reshape176: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add132, R.shape([1, 1, 20, 128]))
            reshape177: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape176, R.shape([1, 20, 128]))
            lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape177), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape178: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv224, R.shape([1, 1, 16, 128]))
            reshape179: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape178, R.shape([1, 1, 2048]))
            lv225 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight2, model_layers_8_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv468 = R.call_tir(cls.NT_matmul11, (reshape179, lv225), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv464_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv468, lv463_1, model_layers_8_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv465_1: R.Tensor((1, 1, 2048), dtype="float16") = lv464_1[1]
            rms_norm90: R.Tensor((1, 1, 2048), dtype="float16") = lv464_1[0]
            lv226 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight2, model_layers_8_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv469 = R.call_tir(cls.NT_matmul12, (rms_norm90, lv226), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split44: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv469, indices_or_sections=2, axis=-1)
            split_044: R.Tensor((1, 1, 11008), dtype="float16") = split44[0]
            split_144: R.Tensor((1, 1, 11008), dtype="float16") = split44[1]
            silu44: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_044)
            mul44: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu44, split_144)
            lv227 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight2, model_layers_8_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv470 = R.call_tir(cls.NT_matmul13, (mul44, lv227), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv466_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv470, lv465_1, model_layers_9_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv467_1: R.Tensor((1, 1, 2048), dtype="float16") = lv466_1[1]
            rms_norm91: R.Tensor((1, 1, 2048), dtype="float16") = lv466_1[0]
            lv228 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight2, model_layers_9_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv471 = R.call_tir(cls.NT_matmul10, (rms_norm91, lv228), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add135: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv471, model_layers_9_self_attn_c_attn_bias2)
            reshape180: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add135, R.shape([1, 1, 20, 128]))
            reshape181: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape180, R.shape([1, 20, 128]))
            lv229 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape181), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape182: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv229, R.shape([1, 1, 16, 128]))
            reshape183: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape182, R.shape([1, 1, 2048]))
            lv230 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight2, model_layers_9_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv472 = R.call_tir(cls.NT_matmul11, (reshape183, lv230), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv468_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv472, lv467_1, model_layers_9_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv469_1: R.Tensor((1, 1, 2048), dtype="float16") = lv468_1[1]
            rms_norm92: R.Tensor((1, 1, 2048), dtype="float16") = lv468_1[0]
            lv231 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight2, model_layers_9_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv473 = R.call_tir(cls.NT_matmul12, (rms_norm92, lv231), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split45: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv473, indices_or_sections=2, axis=-1)
            split_045: R.Tensor((1, 1, 11008), dtype="float16") = split45[0]
            split_145: R.Tensor((1, 1, 11008), dtype="float16") = split45[1]
            silu45: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_045)
            mul45: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu45, split_145)
            lv232 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight2, model_layers_9_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv474 = R.call_tir(cls.NT_matmul13, (mul45, lv232), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv470_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv474, lv469_1, model_layers_10_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv471_1: R.Tensor((1, 1, 2048), dtype="float16") = lv470_1[1]
            rms_norm93: R.Tensor((1, 1, 2048), dtype="float16") = lv470_1[0]
            lv233 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight2, model_layers_10_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv475 = R.call_tir(cls.NT_matmul10, (rms_norm93, lv233), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add138: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv475, model_layers_10_self_attn_c_attn_bias2)
            reshape184: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add138, R.shape([1, 1, 20, 128]))
            reshape185: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape184, R.shape([1, 20, 128]))
            lv234 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape185), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape186: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv234, R.shape([1, 1, 16, 128]))
            reshape187: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape186, R.shape([1, 1, 2048]))
            lv235 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight2, model_layers_10_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv476 = R.call_tir(cls.NT_matmul11, (reshape187, lv235), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv472_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv476, lv471_1, model_layers_10_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv473_1: R.Tensor((1, 1, 2048), dtype="float16") = lv472_1[1]
            rms_norm94: R.Tensor((1, 1, 2048), dtype="float16") = lv472_1[0]
            lv236 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight2, model_layers_10_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv477 = R.call_tir(cls.NT_matmul12, (rms_norm94, lv236), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split46: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv477, indices_or_sections=2, axis=-1)
            split_046: R.Tensor((1, 1, 11008), dtype="float16") = split46[0]
            split_146: R.Tensor((1, 1, 11008), dtype="float16") = split46[1]
            silu46: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_046)
            mul46: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu46, split_146)
            lv237 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight2, model_layers_10_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv478 = R.call_tir(cls.NT_matmul13, (mul46, lv237), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv474_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv478, lv473_1, model_layers_11_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv475_1: R.Tensor((1, 1, 2048), dtype="float16") = lv474_1[1]
            rms_norm95: R.Tensor((1, 1, 2048), dtype="float16") = lv474_1[0]
            lv238 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight2, model_layers_11_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv479 = R.call_tir(cls.NT_matmul10, (rms_norm95, lv238), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add141: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv479, model_layers_11_self_attn_c_attn_bias2)
            reshape188: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add141, R.shape([1, 1, 20, 128]))
            reshape189: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape188, R.shape([1, 20, 128]))
            lv239 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape189), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape190: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv239, R.shape([1, 1, 16, 128]))
            reshape191: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape190, R.shape([1, 1, 2048]))
            lv240 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight2, model_layers_11_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv480 = R.call_tir(cls.NT_matmul11, (reshape191, lv240), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv476_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv480, lv475_1, model_layers_11_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv477_1: R.Tensor((1, 1, 2048), dtype="float16") = lv476_1[1]
            rms_norm96: R.Tensor((1, 1, 2048), dtype="float16") = lv476_1[0]
            lv241 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight2, model_layers_11_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv481 = R.call_tir(cls.NT_matmul12, (rms_norm96, lv241), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split47: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv481, indices_or_sections=2, axis=-1)
            split_047: R.Tensor((1, 1, 11008), dtype="float16") = split47[0]
            split_147: R.Tensor((1, 1, 11008), dtype="float16") = split47[1]
            silu47: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_047)
            mul47: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu47, split_147)
            lv242 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight2, model_layers_11_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv482 = R.call_tir(cls.NT_matmul13, (mul47, lv242), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv478_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv482, lv477_1, model_layers_12_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv479_1: R.Tensor((1, 1, 2048), dtype="float16") = lv478_1[1]
            rms_norm97: R.Tensor((1, 1, 2048), dtype="float16") = lv478_1[0]
            lv243 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight2, model_layers_12_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv483 = R.call_tir(cls.NT_matmul10, (rms_norm97, lv243), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add144: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv483, model_layers_12_self_attn_c_attn_bias2)
            reshape192: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add144, R.shape([1, 1, 20, 128]))
            reshape193: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape192, R.shape([1, 20, 128]))
            lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape193), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape194: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv244, R.shape([1, 1, 16, 128]))
            reshape195: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape194, R.shape([1, 1, 2048]))
            lv245 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight2, model_layers_12_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv484 = R.call_tir(cls.NT_matmul11, (reshape195, lv245), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv480_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv484, lv479_1, model_layers_12_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv481_1: R.Tensor((1, 1, 2048), dtype="float16") = lv480_1[1]
            rms_norm98: R.Tensor((1, 1, 2048), dtype="float16") = lv480_1[0]
            lv246 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight2, model_layers_12_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv485 = R.call_tir(cls.NT_matmul12, (rms_norm98, lv246), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split48: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv485, indices_or_sections=2, axis=-1)
            split_048: R.Tensor((1, 1, 11008), dtype="float16") = split48[0]
            split_148: R.Tensor((1, 1, 11008), dtype="float16") = split48[1]
            silu48: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_048)
            mul48: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu48, split_148)
            lv247 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight2, model_layers_12_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv486 = R.call_tir(cls.NT_matmul13, (mul48, lv247), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv482_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv486, lv481_1, model_layers_13_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv483_1: R.Tensor((1, 1, 2048), dtype="float16") = lv482_1[1]
            rms_norm99: R.Tensor((1, 1, 2048), dtype="float16") = lv482_1[0]
            lv248 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight2, model_layers_13_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv487 = R.call_tir(cls.NT_matmul10, (rms_norm99, lv248), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add147: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv487, model_layers_13_self_attn_c_attn_bias2)
            reshape196: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add147, R.shape([1, 1, 20, 128]))
            reshape197: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape196, R.shape([1, 20, 128]))
            lv249 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape197), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape198: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv249, R.shape([1, 1, 16, 128]))
            reshape199: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape198, R.shape([1, 1, 2048]))
            lv250 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight2, model_layers_13_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv488 = R.call_tir(cls.NT_matmul11, (reshape199, lv250), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv484_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv488, lv483_1, model_layers_13_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv485_1: R.Tensor((1, 1, 2048), dtype="float16") = lv484_1[1]
            rms_norm100: R.Tensor((1, 1, 2048), dtype="float16") = lv484_1[0]
            lv251 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight2, model_layers_13_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv489 = R.call_tir(cls.NT_matmul12, (rms_norm100, lv251), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split49: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv489, indices_or_sections=2, axis=-1)
            split_049: R.Tensor((1, 1, 11008), dtype="float16") = split49[0]
            split_149: R.Tensor((1, 1, 11008), dtype="float16") = split49[1]
            silu49: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_049)
            mul49: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu49, split_149)
            lv252 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight2, model_layers_13_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv490 = R.call_tir(cls.NT_matmul13, (mul49, lv252), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv486_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv490, lv485_1, model_layers_14_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv487_1: R.Tensor((1, 1, 2048), dtype="float16") = lv486_1[1]
            rms_norm101: R.Tensor((1, 1, 2048), dtype="float16") = lv486_1[0]
            lv253 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight2, model_layers_14_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv491 = R.call_tir(cls.NT_matmul10, (rms_norm101, lv253), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add150: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv491, model_layers_14_self_attn_c_attn_bias2)
            reshape200: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add150, R.shape([1, 1, 20, 128]))
            reshape201: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape200, R.shape([1, 20, 128]))
            lv254 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape201), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape202: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv254, R.shape([1, 1, 16, 128]))
            reshape203: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape202, R.shape([1, 1, 2048]))
            lv255 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight2, model_layers_14_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv492 = R.call_tir(cls.NT_matmul11, (reshape203, lv255), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv488_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv492, lv487_1, model_layers_14_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv489_1: R.Tensor((1, 1, 2048), dtype="float16") = lv488_1[1]
            rms_norm102: R.Tensor((1, 1, 2048), dtype="float16") = lv488_1[0]
            lv256 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight2, model_layers_14_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv493 = R.call_tir(cls.NT_matmul12, (rms_norm102, lv256), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split50: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv493, indices_or_sections=2, axis=-1)
            split_050: R.Tensor((1, 1, 11008), dtype="float16") = split50[0]
            split_150: R.Tensor((1, 1, 11008), dtype="float16") = split50[1]
            silu50: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_050)
            mul50: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu50, split_150)
            lv257 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight2, model_layers_14_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv494 = R.call_tir(cls.NT_matmul13, (mul50, lv257), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv490_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv494, lv489_1, model_layers_15_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv491_1: R.Tensor((1, 1, 2048), dtype="float16") = lv490_1[1]
            rms_norm103: R.Tensor((1, 1, 2048), dtype="float16") = lv490_1[0]
            lv258 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight2, model_layers_15_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv495 = R.call_tir(cls.NT_matmul10, (rms_norm103, lv258), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add153: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv495, model_layers_15_self_attn_c_attn_bias2)
            reshape204: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add153, R.shape([1, 1, 20, 128]))
            reshape205: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape204, R.shape([1, 20, 128]))
            lv259 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape205), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape206: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv259, R.shape([1, 1, 16, 128]))
            reshape207: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape206, R.shape([1, 1, 2048]))
            lv260 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight2, model_layers_15_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv496 = R.call_tir(cls.NT_matmul11, (reshape207, lv260), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv492_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv496, lv491_1, model_layers_15_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv493_1: R.Tensor((1, 1, 2048), dtype="float16") = lv492_1[1]
            rms_norm104: R.Tensor((1, 1, 2048), dtype="float16") = lv492_1[0]
            lv261 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight2, model_layers_15_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv497 = R.call_tir(cls.NT_matmul12, (rms_norm104, lv261), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split51: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv497, indices_or_sections=2, axis=-1)
            split_051: R.Tensor((1, 1, 11008), dtype="float16") = split51[0]
            split_151: R.Tensor((1, 1, 11008), dtype="float16") = split51[1]
            silu51: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_051)
            mul51: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu51, split_151)
            lv262 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight2, model_layers_15_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv498 = R.call_tir(cls.NT_matmul13, (mul51, lv262), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv494_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv498, lv493_1, model_layers_16_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv495_1: R.Tensor((1, 1, 2048), dtype="float16") = lv494_1[1]
            rms_norm105: R.Tensor((1, 1, 2048), dtype="float16") = lv494_1[0]
            lv263 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight2, model_layers_16_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv499 = R.call_tir(cls.NT_matmul10, (rms_norm105, lv263), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add156: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv499, model_layers_16_self_attn_c_attn_bias2)
            reshape208: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add156, R.shape([1, 1, 20, 128]))
            reshape209: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape208, R.shape([1, 20, 128]))
            lv264 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape209), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape210: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv264, R.shape([1, 1, 16, 128]))
            reshape211: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape210, R.shape([1, 1, 2048]))
            lv265 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight2, model_layers_16_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv500 = R.call_tir(cls.NT_matmul11, (reshape211, lv265), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv496_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv500, lv495_1, model_layers_16_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv497_1: R.Tensor((1, 1, 2048), dtype="float16") = lv496_1[1]
            rms_norm106: R.Tensor((1, 1, 2048), dtype="float16") = lv496_1[0]
            lv266 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight2, model_layers_16_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv501 = R.call_tir(cls.NT_matmul12, (rms_norm106, lv266), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split52: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv501, indices_or_sections=2, axis=-1)
            split_052: R.Tensor((1, 1, 11008), dtype="float16") = split52[0]
            split_152: R.Tensor((1, 1, 11008), dtype="float16") = split52[1]
            silu52: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_052)
            mul52: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu52, split_152)
            lv267 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight2, model_layers_16_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv502 = R.call_tir(cls.NT_matmul13, (mul52, lv267), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv498_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv502, lv497_1, model_layers_17_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv499_1: R.Tensor((1, 1, 2048), dtype="float16") = lv498_1[1]
            rms_norm107: R.Tensor((1, 1, 2048), dtype="float16") = lv498_1[0]
            lv268 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight2, model_layers_17_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv503 = R.call_tir(cls.NT_matmul10, (rms_norm107, lv268), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add159: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv503, model_layers_17_self_attn_c_attn_bias2)
            reshape212: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add159, R.shape([1, 1, 20, 128]))
            reshape213: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape212, R.shape([1, 20, 128]))
            lv269 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape213), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape214: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv269, R.shape([1, 1, 16, 128]))
            reshape215: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape214, R.shape([1, 1, 2048]))
            lv270 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight2, model_layers_17_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv504 = R.call_tir(cls.NT_matmul11, (reshape215, lv270), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv500_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv504, lv499_1, model_layers_17_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv501_1: R.Tensor((1, 1, 2048), dtype="float16") = lv500_1[1]
            rms_norm108: R.Tensor((1, 1, 2048), dtype="float16") = lv500_1[0]
            lv271 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight2, model_layers_17_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv505 = R.call_tir(cls.NT_matmul12, (rms_norm108, lv271), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split53: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv505, indices_or_sections=2, axis=-1)
            split_053: R.Tensor((1, 1, 11008), dtype="float16") = split53[0]
            split_153: R.Tensor((1, 1, 11008), dtype="float16") = split53[1]
            silu53: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_053)
            mul53: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu53, split_153)
            lv272 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight2, model_layers_17_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv506 = R.call_tir(cls.NT_matmul13, (mul53, lv272), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv502_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv506, lv501_1, model_layers_18_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv503_1: R.Tensor((1, 1, 2048), dtype="float16") = lv502_1[1]
            rms_norm109: R.Tensor((1, 1, 2048), dtype="float16") = lv502_1[0]
            lv273 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight2, model_layers_18_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv507 = R.call_tir(cls.NT_matmul10, (rms_norm109, lv273), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add162: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv507, model_layers_18_self_attn_c_attn_bias2)
            reshape216: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add162, R.shape([1, 1, 20, 128]))
            reshape217: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape216, R.shape([1, 20, 128]))
            lv274 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape217), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape218: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv274, R.shape([1, 1, 16, 128]))
            reshape219: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape218, R.shape([1, 1, 2048]))
            lv275 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight2, model_layers_18_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv508 = R.call_tir(cls.NT_matmul11, (reshape219, lv275), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv504_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv508, lv503_1, model_layers_18_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv505_1: R.Tensor((1, 1, 2048), dtype="float16") = lv504_1[1]
            rms_norm110: R.Tensor((1, 1, 2048), dtype="float16") = lv504_1[0]
            lv276 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight2, model_layers_18_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv509 = R.call_tir(cls.NT_matmul12, (rms_norm110, lv276), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split54: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv509, indices_or_sections=2, axis=-1)
            split_054: R.Tensor((1, 1, 11008), dtype="float16") = split54[0]
            split_154: R.Tensor((1, 1, 11008), dtype="float16") = split54[1]
            silu54: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_054)
            mul54: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu54, split_154)
            lv277 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight2, model_layers_18_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv510 = R.call_tir(cls.NT_matmul13, (mul54, lv277), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv506_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv510, lv505_1, model_layers_19_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv507_1: R.Tensor((1, 1, 2048), dtype="float16") = lv506_1[1]
            rms_norm111: R.Tensor((1, 1, 2048), dtype="float16") = lv506_1[0]
            lv278 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight2, model_layers_19_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv511 = R.call_tir(cls.NT_matmul10, (rms_norm111, lv278), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add165: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv511, model_layers_19_self_attn_c_attn_bias2)
            reshape220: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add165, R.shape([1, 1, 20, 128]))
            reshape221: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape220, R.shape([1, 20, 128]))
            lv279 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape221), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape222: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv279, R.shape([1, 1, 16, 128]))
            reshape223: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape222, R.shape([1, 1, 2048]))
            lv280 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight2, model_layers_19_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv512 = R.call_tir(cls.NT_matmul11, (reshape223, lv280), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv508_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv512, lv507_1, model_layers_19_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv509_1: R.Tensor((1, 1, 2048), dtype="float16") = lv508_1[1]
            rms_norm112: R.Tensor((1, 1, 2048), dtype="float16") = lv508_1[0]
            lv281 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight2, model_layers_19_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv513 = R.call_tir(cls.NT_matmul12, (rms_norm112, lv281), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split55: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv513, indices_or_sections=2, axis=-1)
            split_055: R.Tensor((1, 1, 11008), dtype="float16") = split55[0]
            split_155: R.Tensor((1, 1, 11008), dtype="float16") = split55[1]
            silu55: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_055)
            mul55: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu55, split_155)
            lv282 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight2, model_layers_19_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv514 = R.call_tir(cls.NT_matmul13, (mul55, lv282), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv510_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv514, lv509_1, model_layers_20_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv511_1: R.Tensor((1, 1, 2048), dtype="float16") = lv510_1[1]
            rms_norm113: R.Tensor((1, 1, 2048), dtype="float16") = lv510_1[0]
            lv283 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight2, model_layers_20_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv515 = R.call_tir(cls.NT_matmul10, (rms_norm113, lv283), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add168: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv515, model_layers_20_self_attn_c_attn_bias2)
            reshape224: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add168, R.shape([1, 1, 20, 128]))
            reshape225: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape224, R.shape([1, 20, 128]))
            lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape225), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape226: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv284, R.shape([1, 1, 16, 128]))
            reshape227: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape226, R.shape([1, 1, 2048]))
            lv285 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight2, model_layers_20_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv516 = R.call_tir(cls.NT_matmul11, (reshape227, lv285), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv512_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv516, lv511_1, model_layers_20_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv513_1: R.Tensor((1, 1, 2048), dtype="float16") = lv512_1[1]
            rms_norm114: R.Tensor((1, 1, 2048), dtype="float16") = lv512_1[0]
            lv286 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight2, model_layers_20_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv517 = R.call_tir(cls.NT_matmul12, (rms_norm114, lv286), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split56: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv517, indices_or_sections=2, axis=-1)
            split_056: R.Tensor((1, 1, 11008), dtype="float16") = split56[0]
            split_156: R.Tensor((1, 1, 11008), dtype="float16") = split56[1]
            silu56: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_056)
            mul56: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu56, split_156)
            lv287 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight2, model_layers_20_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv518 = R.call_tir(cls.NT_matmul13, (mul56, lv287), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv514_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv518, lv513_1, model_layers_21_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv515_1: R.Tensor((1, 1, 2048), dtype="float16") = lv514_1[1]
            rms_norm115: R.Tensor((1, 1, 2048), dtype="float16") = lv514_1[0]
            lv288 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight2, model_layers_21_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv519 = R.call_tir(cls.NT_matmul10, (rms_norm115, lv288), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add171: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv519, model_layers_21_self_attn_c_attn_bias2)
            reshape228: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add171, R.shape([1, 1, 20, 128]))
            reshape229: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape228, R.shape([1, 20, 128]))
            lv289 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape229), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape230: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv289, R.shape([1, 1, 16, 128]))
            reshape231: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape230, R.shape([1, 1, 2048]))
            lv290 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight2, model_layers_21_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv520 = R.call_tir(cls.NT_matmul11, (reshape231, lv290), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv516_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv520, lv515_1, model_layers_21_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv517_1: R.Tensor((1, 1, 2048), dtype="float16") = lv516_1[1]
            rms_norm116: R.Tensor((1, 1, 2048), dtype="float16") = lv516_1[0]
            lv291 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight2, model_layers_21_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv521 = R.call_tir(cls.NT_matmul12, (rms_norm116, lv291), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split57: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv521, indices_or_sections=2, axis=-1)
            split_057: R.Tensor((1, 1, 11008), dtype="float16") = split57[0]
            split_157: R.Tensor((1, 1, 11008), dtype="float16") = split57[1]
            silu57: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_057)
            mul57: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu57, split_157)
            lv292 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight2, model_layers_21_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv522 = R.call_tir(cls.NT_matmul13, (mul57, lv292), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv518_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv522, lv517_1, model_layers_22_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv519_1: R.Tensor((1, 1, 2048), dtype="float16") = lv518_1[1]
            rms_norm117: R.Tensor((1, 1, 2048), dtype="float16") = lv518_1[0]
            lv293 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight2, model_layers_22_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv523 = R.call_tir(cls.NT_matmul10, (rms_norm117, lv293), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add174: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv523, model_layers_22_self_attn_c_attn_bias2)
            reshape232: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add174, R.shape([1, 1, 20, 128]))
            reshape233: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape232, R.shape([1, 20, 128]))
            lv294 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape233), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape234: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv294, R.shape([1, 1, 16, 128]))
            reshape235: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape234, R.shape([1, 1, 2048]))
            lv295 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight2, model_layers_22_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv524 = R.call_tir(cls.NT_matmul11, (reshape235, lv295), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv520_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv524, lv519_1, model_layers_22_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv521_1: R.Tensor((1, 1, 2048), dtype="float16") = lv520_1[1]
            rms_norm118: R.Tensor((1, 1, 2048), dtype="float16") = lv520_1[0]
            lv296 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight2, model_layers_22_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv525 = R.call_tir(cls.NT_matmul12, (rms_norm118, lv296), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split58: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv525, indices_or_sections=2, axis=-1)
            split_058: R.Tensor((1, 1, 11008), dtype="float16") = split58[0]
            split_158: R.Tensor((1, 1, 11008), dtype="float16") = split58[1]
            silu58: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_058)
            mul58: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu58, split_158)
            lv297 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight2, model_layers_22_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv526 = R.call_tir(cls.NT_matmul13, (mul58, lv297), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv522_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv526, lv521_1, model_layers_23_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv523_1: R.Tensor((1, 1, 2048), dtype="float16") = lv522_1[1]
            rms_norm119: R.Tensor((1, 1, 2048), dtype="float16") = lv522_1[0]
            lv298 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight2, model_layers_23_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv527 = R.call_tir(cls.NT_matmul10, (rms_norm119, lv298), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add177: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv527, model_layers_23_self_attn_c_attn_bias2)
            reshape236: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add177, R.shape([1, 1, 20, 128]))
            reshape237: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape236, R.shape([1, 20, 128]))
            lv299 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape237), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape238: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv299, R.shape([1, 1, 16, 128]))
            reshape239: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape238, R.shape([1, 1, 2048]))
            lv300 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight2, model_layers_23_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv528 = R.call_tir(cls.NT_matmul11, (reshape239, lv300), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv524_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv528, lv523_1, model_layers_23_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv525_1: R.Tensor((1, 1, 2048), dtype="float16") = lv524_1[1]
            rms_norm120: R.Tensor((1, 1, 2048), dtype="float16") = lv524_1[0]
            lv301 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight2, model_layers_23_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv529 = R.call_tir(cls.NT_matmul12, (rms_norm120, lv301), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split59: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv529, indices_or_sections=2, axis=-1)
            split_059: R.Tensor((1, 1, 11008), dtype="float16") = split59[0]
            split_159: R.Tensor((1, 1, 11008), dtype="float16") = split59[1]
            silu59: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_059)
            mul59: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu59, split_159)
            lv302 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight2, model_layers_23_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv530 = R.call_tir(cls.NT_matmul13, (mul59, lv302), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv526_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv530, lv525_1, model_layers_24_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv527_1: R.Tensor((1, 1, 2048), dtype="float16") = lv526_1[1]
            rms_norm121: R.Tensor((1, 1, 2048), dtype="float16") = lv526_1[0]
            lv303 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight2, model_layers_24_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv531 = R.call_tir(cls.NT_matmul10, (rms_norm121, lv303), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add180: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv531, model_layers_24_self_attn_c_attn_bias2)
            reshape240: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add180, R.shape([1, 1, 20, 128]))
            reshape241: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape240, R.shape([1, 20, 128]))
            lv304 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape241), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape242: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv304, R.shape([1, 1, 16, 128]))
            reshape243: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape242, R.shape([1, 1, 2048]))
            lv305 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight2, model_layers_24_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv532 = R.call_tir(cls.NT_matmul11, (reshape243, lv305), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv528_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv532, lv527_1, model_layers_24_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv529_1: R.Tensor((1, 1, 2048), dtype="float16") = lv528_1[1]
            rms_norm122: R.Tensor((1, 1, 2048), dtype="float16") = lv528_1[0]
            lv306 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight2, model_layers_24_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv533 = R.call_tir(cls.NT_matmul12, (rms_norm122, lv306), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split60: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv533, indices_or_sections=2, axis=-1)
            split_060: R.Tensor((1, 1, 11008), dtype="float16") = split60[0]
            split_160: R.Tensor((1, 1, 11008), dtype="float16") = split60[1]
            silu60: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_060)
            mul60: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu60, split_160)
            lv307 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight2, model_layers_24_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv534 = R.call_tir(cls.NT_matmul13, (mul60, lv307), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv530_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv534, lv529_1, model_layers_25_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv531_1: R.Tensor((1, 1, 2048), dtype="float16") = lv530_1[1]
            rms_norm123: R.Tensor((1, 1, 2048), dtype="float16") = lv530_1[0]
            lv308 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight2, model_layers_25_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv535 = R.call_tir(cls.NT_matmul10, (rms_norm123, lv308), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add183: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv535, model_layers_25_self_attn_c_attn_bias2)
            reshape244: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add183, R.shape([1, 1, 20, 128]))
            reshape245: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape244, R.shape([1, 20, 128]))
            lv309 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape245), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape246: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv309, R.shape([1, 1, 16, 128]))
            reshape247: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape246, R.shape([1, 1, 2048]))
            lv310 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight2, model_layers_25_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv536 = R.call_tir(cls.NT_matmul11, (reshape247, lv310), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv532_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv536, lv531_1, model_layers_25_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv533_1: R.Tensor((1, 1, 2048), dtype="float16") = lv532_1[1]
            rms_norm124: R.Tensor((1, 1, 2048), dtype="float16") = lv532_1[0]
            lv311 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight2, model_layers_25_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv537 = R.call_tir(cls.NT_matmul12, (rms_norm124, lv311), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split61: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv537, indices_or_sections=2, axis=-1)
            split_061: R.Tensor((1, 1, 11008), dtype="float16") = split61[0]
            split_161: R.Tensor((1, 1, 11008), dtype="float16") = split61[1]
            silu61: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_061)
            mul61: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu61, split_161)
            lv312 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight2, model_layers_25_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv538 = R.call_tir(cls.NT_matmul13, (mul61, lv312), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv534_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv538, lv533_1, model_layers_26_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv535_1: R.Tensor((1, 1, 2048), dtype="float16") = lv534_1[1]
            rms_norm125: R.Tensor((1, 1, 2048), dtype="float16") = lv534_1[0]
            lv313 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight2, model_layers_26_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv539 = R.call_tir(cls.NT_matmul10, (rms_norm125, lv313), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add186: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv539, model_layers_26_self_attn_c_attn_bias2)
            reshape248: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add186, R.shape([1, 1, 20, 128]))
            reshape249: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape248, R.shape([1, 20, 128]))
            lv314 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape249), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape250: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv314, R.shape([1, 1, 16, 128]))
            reshape251: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape250, R.shape([1, 1, 2048]))
            lv315 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight2, model_layers_26_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv540 = R.call_tir(cls.NT_matmul11, (reshape251, lv315), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv536_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv540, lv535_1, model_layers_26_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv537_1: R.Tensor((1, 1, 2048), dtype="float16") = lv536_1[1]
            rms_norm126: R.Tensor((1, 1, 2048), dtype="float16") = lv536_1[0]
            lv316 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight2, model_layers_26_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv541 = R.call_tir(cls.NT_matmul12, (rms_norm126, lv316), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split62: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv541, indices_or_sections=2, axis=-1)
            split_062: R.Tensor((1, 1, 11008), dtype="float16") = split62[0]
            split_162: R.Tensor((1, 1, 11008), dtype="float16") = split62[1]
            silu62: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_062)
            mul62: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu62, split_162)
            lv317 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight2, model_layers_26_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv542 = R.call_tir(cls.NT_matmul13, (mul62, lv317), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv538_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv542, lv537_1, model_layers_27_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv539_1: R.Tensor((1, 1, 2048), dtype="float16") = lv538_1[1]
            rms_norm127: R.Tensor((1, 1, 2048), dtype="float16") = lv538_1[0]
            lv318 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight2, model_layers_27_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv543 = R.call_tir(cls.NT_matmul10, (rms_norm127, lv318), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add189: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv543, model_layers_27_self_attn_c_attn_bias2)
            reshape252: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add189, R.shape([1, 1, 20, 128]))
            reshape253: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape252, R.shape([1, 20, 128]))
            lv319 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape253), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape254: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv319, R.shape([1, 1, 16, 128]))
            reshape255: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape254, R.shape([1, 1, 2048]))
            lv320 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight2, model_layers_27_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv544 = R.call_tir(cls.NT_matmul11, (reshape255, lv320), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv540_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv544, lv539_1, model_layers_27_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv541_1: R.Tensor((1, 1, 2048), dtype="float16") = lv540_1[1]
            rms_norm128: R.Tensor((1, 1, 2048), dtype="float16") = lv540_1[0]
            lv321 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight2, model_layers_27_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv545 = R.call_tir(cls.NT_matmul12, (rms_norm128, lv321), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split63: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv545, indices_or_sections=2, axis=-1)
            split_063: R.Tensor((1, 1, 11008), dtype="float16") = split63[0]
            split_163: R.Tensor((1, 1, 11008), dtype="float16") = split63[1]
            silu63: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_063)
            mul63: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu63, split_163)
            lv322 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight2, model_layers_27_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv546 = R.call_tir(cls.NT_matmul13, (mul63, lv322), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv542_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv546, lv541_1, model_layers_28_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv543_1: R.Tensor((1, 1, 2048), dtype="float16") = lv542_1[1]
            rms_norm129: R.Tensor((1, 1, 2048), dtype="float16") = lv542_1[0]
            lv323 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight2, model_layers_28_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv547 = R.call_tir(cls.NT_matmul10, (rms_norm129, lv323), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add192: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv547, model_layers_28_self_attn_c_attn_bias2)
            reshape256: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add192, R.shape([1, 1, 20, 128]))
            reshape257: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape256, R.shape([1, 20, 128]))
            lv324 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape257), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape258: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv324, R.shape([1, 1, 16, 128]))
            reshape259: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape258, R.shape([1, 1, 2048]))
            lv325 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight2, model_layers_28_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv548 = R.call_tir(cls.NT_matmul11, (reshape259, lv325), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv544_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv548, lv543_1, model_layers_28_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv545_1: R.Tensor((1, 1, 2048), dtype="float16") = lv544_1[1]
            rms_norm130: R.Tensor((1, 1, 2048), dtype="float16") = lv544_1[0]
            lv326 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight2, model_layers_28_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv549 = R.call_tir(cls.NT_matmul12, (rms_norm130, lv326), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split64: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv549, indices_or_sections=2, axis=-1)
            split_064: R.Tensor((1, 1, 11008), dtype="float16") = split64[0]
            split_164: R.Tensor((1, 1, 11008), dtype="float16") = split64[1]
            silu64: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_064)
            mul64: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu64, split_164)
            lv327 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight2, model_layers_28_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv550 = R.call_tir(cls.NT_matmul13, (mul64, lv327), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv546_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv550, lv545_1, model_layers_29_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv547_1: R.Tensor((1, 1, 2048), dtype="float16") = lv546_1[1]
            rms_norm131: R.Tensor((1, 1, 2048), dtype="float16") = lv546_1[0]
            lv328 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight2, model_layers_29_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv551 = R.call_tir(cls.NT_matmul10, (rms_norm131, lv328), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add195: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv551, model_layers_29_self_attn_c_attn_bias2)
            reshape260: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add195, R.shape([1, 1, 20, 128]))
            reshape261: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape260, R.shape([1, 20, 128]))
            lv329 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape261), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape262: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv329, R.shape([1, 1, 16, 128]))
            reshape263: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape262, R.shape([1, 1, 2048]))
            lv330 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight2, model_layers_29_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv552 = R.call_tir(cls.NT_matmul11, (reshape263, lv330), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv548_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv552, lv547_1, model_layers_29_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv549_1: R.Tensor((1, 1, 2048), dtype="float16") = lv548_1[1]
            rms_norm132: R.Tensor((1, 1, 2048), dtype="float16") = lv548_1[0]
            lv331 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight2, model_layers_29_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv553 = R.call_tir(cls.NT_matmul12, (rms_norm132, lv331), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split65: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv553, indices_or_sections=2, axis=-1)
            split_065: R.Tensor((1, 1, 11008), dtype="float16") = split65[0]
            split_165: R.Tensor((1, 1, 11008), dtype="float16") = split65[1]
            silu65: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_065)
            mul65: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu65, split_165)
            lv332 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight2, model_layers_29_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv554 = R.call_tir(cls.NT_matmul13, (mul65, lv332), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv550_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv554, lv549_1, model_layers_30_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv551_1: R.Tensor((1, 1, 2048), dtype="float16") = lv550_1[1]
            rms_norm133: R.Tensor((1, 1, 2048), dtype="float16") = lv550_1[0]
            lv333 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight2, model_layers_30_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv555 = R.call_tir(cls.NT_matmul10, (rms_norm133, lv333), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add198: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv555, model_layers_30_self_attn_c_attn_bias2)
            reshape264: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add198, R.shape([1, 1, 20, 128]))
            reshape265: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape264, R.shape([1, 20, 128]))
            lv334 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape265), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape266: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv334, R.shape([1, 1, 16, 128]))
            reshape267: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape266, R.shape([1, 1, 2048]))
            lv335 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight2, model_layers_30_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv556 = R.call_tir(cls.NT_matmul11, (reshape267, lv335), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv552_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv556, lv551_1, model_layers_30_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv553_1: R.Tensor((1, 1, 2048), dtype="float16") = lv552_1[1]
            rms_norm134: R.Tensor((1, 1, 2048), dtype="float16") = lv552_1[0]
            lv336 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight2, model_layers_30_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv557 = R.call_tir(cls.NT_matmul12, (rms_norm134, lv336), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split66: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv557, indices_or_sections=2, axis=-1)
            split_066: R.Tensor((1, 1, 11008), dtype="float16") = split66[0]
            split_166: R.Tensor((1, 1, 11008), dtype="float16") = split66[1]
            silu66: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_066)
            mul66: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu66, split_166)
            lv337 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight2, model_layers_30_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv558 = R.call_tir(cls.NT_matmul13, (mul66, lv337), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv554_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv558, lv553_1, model_layers_31_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv555_1: R.Tensor((1, 1, 2048), dtype="float16") = lv554_1[1]
            rms_norm135: R.Tensor((1, 1, 2048), dtype="float16") = lv554_1[0]
            lv338 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight2, model_layers_31_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv559 = R.call_tir(cls.NT_matmul10, (rms_norm135, lv338), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add201: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv559, model_layers_31_self_attn_c_attn_bias2)
            reshape268: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add201, R.shape([1, 1, 20, 128]))
            reshape269: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape268, R.shape([1, 20, 128]))
            lv339 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape269), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape270: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv339, R.shape([1, 1, 16, 128]))
            reshape271: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape270, R.shape([1, 1, 2048]))
            lv340 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight2, model_layers_31_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv560 = R.call_tir(cls.NT_matmul11, (reshape271, lv340), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv556_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv560, lv555_1, model_layers_31_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv557_1: R.Tensor((1, 1, 2048), dtype="float16") = lv556_1[1]
            rms_norm136: R.Tensor((1, 1, 2048), dtype="float16") = lv556_1[0]
            lv341 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight2, model_layers_31_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv561 = R.call_tir(cls.NT_matmul12, (rms_norm136, lv341), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split67: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv561, indices_or_sections=2, axis=-1)
            split_067: R.Tensor((1, 1, 11008), dtype="float16") = split67[0]
            split_167: R.Tensor((1, 1, 11008), dtype="float16") = split67[1]
            silu67: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_067)
            mul67: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu67, split_167)
            lv342 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight2, model_layers_31_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv562 = R.call_tir(cls.NT_matmul13, (mul67, lv342), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv558_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv562, lv557_1, model_layers_32_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv559_1: R.Tensor((1, 1, 2048), dtype="float16") = lv558_1[1]
            rms_norm137: R.Tensor((1, 1, 2048), dtype="float16") = lv558_1[0]
            lv343 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight2, model_layers_32_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv563 = R.call_tir(cls.NT_matmul10, (rms_norm137, lv343), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add204: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv563, model_layers_32_self_attn_c_attn_bias2)
            reshape272: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add204, R.shape([1, 1, 20, 128]))
            reshape273: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape272, R.shape([1, 20, 128]))
            lv344 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape273), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape274: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv344, R.shape([1, 1, 16, 128]))
            reshape275: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape274, R.shape([1, 1, 2048]))
            lv345 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight2, model_layers_32_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv564 = R.call_tir(cls.NT_matmul11, (reshape275, lv345), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv560_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv564, lv559_1, model_layers_32_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv561_1: R.Tensor((1, 1, 2048), dtype="float16") = lv560_1[1]
            rms_norm138: R.Tensor((1, 1, 2048), dtype="float16") = lv560_1[0]
            lv346 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight2, model_layers_32_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv565 = R.call_tir(cls.NT_matmul12, (rms_norm138, lv346), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split68: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv565, indices_or_sections=2, axis=-1)
            split_068: R.Tensor((1, 1, 11008), dtype="float16") = split68[0]
            split_168: R.Tensor((1, 1, 11008), dtype="float16") = split68[1]
            silu68: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_068)
            mul68: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu68, split_168)
            lv347 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight2, model_layers_32_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv566 = R.call_tir(cls.NT_matmul13, (mul68, lv347), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv562_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv566, lv561_1, model_layers_33_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv563_1: R.Tensor((1, 1, 2048), dtype="float16") = lv562_1[1]
            rms_norm139: R.Tensor((1, 1, 2048), dtype="float16") = lv562_1[0]
            lv348 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight2, model_layers_33_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv567 = R.call_tir(cls.NT_matmul10, (rms_norm139, lv348), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add207: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv567, model_layers_33_self_attn_c_attn_bias2)
            reshape276: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add207, R.shape([1, 1, 20, 128]))
            reshape277: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape276, R.shape([1, 20, 128]))
            lv349 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape277), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape278: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv349, R.shape([1, 1, 16, 128]))
            reshape279: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape278, R.shape([1, 1, 2048]))
            lv350 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight2, model_layers_33_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv568 = R.call_tir(cls.NT_matmul11, (reshape279, lv350), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv564_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv568, lv563_1, model_layers_33_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv565_1: R.Tensor((1, 1, 2048), dtype="float16") = lv564_1[1]
            rms_norm140: R.Tensor((1, 1, 2048), dtype="float16") = lv564_1[0]
            lv351 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight2, model_layers_33_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv569 = R.call_tir(cls.NT_matmul12, (rms_norm140, lv351), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split69: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv569, indices_or_sections=2, axis=-1)
            split_069: R.Tensor((1, 1, 11008), dtype="float16") = split69[0]
            split_169: R.Tensor((1, 1, 11008), dtype="float16") = split69[1]
            silu69: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_069)
            mul69: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu69, split_169)
            lv352 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight2, model_layers_33_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv570 = R.call_tir(cls.NT_matmul13, (mul69, lv352), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv566_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv570, lv565_1, model_layers_34_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv567_1: R.Tensor((1, 1, 2048), dtype="float16") = lv566_1[1]
            rms_norm141: R.Tensor((1, 1, 2048), dtype="float16") = lv566_1[0]
            lv353 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight2, model_layers_34_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv571 = R.call_tir(cls.NT_matmul10, (rms_norm141, lv353), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add210: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv571, model_layers_34_self_attn_c_attn_bias2)
            reshape280: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add210, R.shape([1, 1, 20, 128]))
            reshape281: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape280, R.shape([1, 20, 128]))
            lv354 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape281), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape282: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv354, R.shape([1, 1, 16, 128]))
            reshape283: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape282, R.shape([1, 1, 2048]))
            lv355 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight2, model_layers_34_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv572 = R.call_tir(cls.NT_matmul11, (reshape283, lv355), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv568_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv572, lv567_1, model_layers_34_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv569_1: R.Tensor((1, 1, 2048), dtype="float16") = lv568_1[1]
            rms_norm142: R.Tensor((1, 1, 2048), dtype="float16") = lv568_1[0]
            lv356 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight2, model_layers_34_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv573 = R.call_tir(cls.NT_matmul12, (rms_norm142, lv356), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split70: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv573, indices_or_sections=2, axis=-1)
            split_070: R.Tensor((1, 1, 11008), dtype="float16") = split70[0]
            split_170: R.Tensor((1, 1, 11008), dtype="float16") = split70[1]
            silu70: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_070)
            mul70: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu70, split_170)
            lv357 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight2, model_layers_34_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv574 = R.call_tir(cls.NT_matmul13, (mul70, lv357), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv570_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv574, lv569_1, model_layers_35_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv571_1: R.Tensor((1, 1, 2048), dtype="float16") = lv570_1[1]
            rms_norm143: R.Tensor((1, 1, 2048), dtype="float16") = lv570_1[0]
            lv358 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight2, model_layers_35_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            lv575 = R.call_tir(cls.NT_matmul10, (rms_norm143, lv358), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            add213: R.Tensor((1, 1, 2560), dtype="float16") = R.add(lv575, model_layers_35_self_attn_c_attn_bias2)
            reshape284: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add213, R.shape([1, 1, 20, 128]))
            reshape285: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape284, R.shape([1, 20, 128]))
            lv359 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape285), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape286: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv359, R.shape([1, 1, 16, 128]))
            reshape287: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape286, R.shape([1, 1, 2048]))
            lv360 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight2, model_layers_35_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            lv576 = R.call_tir(cls.NT_matmul11, (reshape287, lv360), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv572_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv576, lv571_1, model_layers_35_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv573_1: R.Tensor((1, 1, 2048), dtype="float16") = lv572_1[1]
            rms_norm144: R.Tensor((1, 1, 2048), dtype="float16") = lv572_1[0]
            lv361 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight2, model_layers_35_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            lv577 = R.call_tir(cls.NT_matmul12, (rms_norm144, lv361), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            split71: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(lv577, indices_or_sections=2, axis=-1)
            split_071: R.Tensor((1, 1, 11008), dtype="float16") = split71[0]
            split_171: R.Tensor((1, 1, 11008), dtype="float16") = split71[1]
            silu71: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_071)
            mul71: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu71, split_171)
            lv362 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight2, model_layers_35_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            lv578 = R.call_tir(cls.NT_matmul13, (mul71, lv362), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv574_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv578, lv573_1, model_norm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            rms_norm145: R.Tensor((1, 1, 2048), dtype="float16") = lv574_1[0]
            lv363 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight2, model_embed_tokens_q_scale2), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            lv579 = R.call_tir(cls.NT_matmul14, (rms_norm145, lv363), out_sinfo=R.Tensor((1, 1, 151936), dtype="float32"))
            gv2: R.Tuple(R.Tensor((1, 1, 151936), dtype="float32"), R.Object) = lv579, paged_kv_cache
            R.output(gv2)
        return gv2

    @R.function
    def embed(input_ids: R.Tensor(("seq_len",), dtype="int32"), packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dt