"""Debug compiled models with TVM instrument"""

import argparse
import json
import random
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union

import numpy as np
import tvm
from tvm import relax, runtime
from tvm.contrib import tvmjs
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
from tvm.runtime import Device, Module, Object, ShapeTuple
from tvm.runtime.relax_vm import VirtualMachine

from tapml.conversation_template import ConvTemplateRegistry
from tapml.protocol.tapml_chat_config import TapmlChatConfig
from tapml.serve import data, engine_utils
from tapml.support.style import green, red
from tapml.tokenizers import Tokenizer


def _extract_metadata(mod: Module):
    return json.loads(VirtualMachine(mod, tvm.runtime.device("cpu"))["_metadata"]())


def _load_params(
    model_weight_path: str, device: Device, model_metadata: Dict[str, Any]
) -> List[tvm.nd.NDArray]:
    params, meta = tvmjs.load_ndarray_cache(model_weight_path, device)
    param_names = [param["name"] for param in model_metadata["params"]]
    assert len(param_names) == meta["ParamSize"]

    plist = []
    for param_name in param_names:
        plist.append(params[param_name])
    return plist


def _get_tvm_module(
    model_weight_path: str,
    lib_path: str,
    device: Device,
    instrument: Callable,
):
    ex = tvm.runtime.load_module(lib_path)
    vm = relax.VirtualMachine(ex, device)
    vm.set_instrument(instrument)
    metadata = _extract_metadata(ex)
    params = _load_params(model_weight_path, device, metadata)
    return vm.module, params, metadata


class DebugChat:  # pylint: disable=too-many-instance-attributes, too-few-public-methods
    def __init__(  # pylint: disable=too-many-arguments
        self,
        model: str,
        golden_lib: str,
        golden_device: str,
        buggy_lib: str,
        buggy_device: str,
        skip_visited: bool = True,
    ):
        self.golden_device = tvm.device(golden_device)
        self.buggy_device = tvm.device(buggy_device)
        self.buggy_runtime_lib = tvm.runtime.load_module(buggy_lib)
        self.buggy_mod = relax.VirtualMachine(
            self.buggy_runtime_lib, self.buggy_device
        ).module
        self.instrument = LibCompare(
            self.buggy_runtime_lib,
            self.buggy_device,
            skip_visited=skip_visited,
        )
        self.golden_mod, self.params, self.metadata = _get_tvm_module(
            model, golden_lib, self.golden_device, self.instrument
        )
        self.model_path = Path(model)
        self.config_file_path = self.model_path / "mlc-chat-config.json"
        with open(self.config_file_path, mode="rt", encoding="utf-8") as file:
            self.chat_config = TapmlChatConfig.model_validate_json(file.read())

        conv_template = self.chat_config.conv_template
        self.conversation = (
            ConvTemplateRegistry.get_conv_template(conv_template)
            if isinstance(conv_template, str)
            else conv_template
        )
        self.tokenizer = Tokenizer(str(self.model_path))

        self.add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
        self.begin_forward_func = tvm.get_global_func(
            "vm.builtin.kv_state_begin_forward"
        )
        self.end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward")
        self.nd_view_func = tvm.get_global_func("vm.builtin.reshape")
        self.sample_topp_from_prob_func = tvm.get_global_func(
            "vm.builtin.sample_top_p_from_prob"
        )

        self.embed_func = self.golden_mod["embed"]
        self.prefill_func = self.golden_mod["prefill"]
        self.decode_func = self.golden_mod["decode"]

        def _get_kv_cache_func(mod: tvm.runtime.Module) -> tvm.runtime.PackedFunc:
            if mod.implements_function("create_flashinfer_paged_kv_cache"):
                return mod["create_flashinfer_paged_kv_cache"]
            elif mod.implements_function("create_tir_paged_kv_cache"):
                return mod["create_tir_paged_kv_cache"]
            else:
                raise RuntimeError("Unsupported KV cache type")

        self.golden_create_kv_cache_func = _get_kv_cache_func(self.golden_mod)
        self.buggy_create_kv_cache_func = _get_kv_cache_func(self.buggy_mod)

        self.appeared_token_freq: Dict[int, int] = {}
        self._init_kv_cache()

    def _init_kv_cache(self):
        max_num_sequence = 1
        page_size = 16
        sliding_window_size = (
            self.chat_config.sliding_window_size
            if self.chat_config.sliding_window_size
            else self.metadata["sliding_window_size"]
        )
        prefill_chunk_size = 128
        max_total_sequence_length = 512
        support_sliding_window = int(sliding_window_size != -1)
        self.golden_kv_caches = self.golden_create_kv_cache_func(
            ShapeTuple([max_num_sequence]),
            ShapeTuple([max_total_sequence_length]),
            ShapeTuple([prefill_chunk_size]),
            ShapeTuple([page_size]),
            ShapeTuple([support_sliding_window]),
        )
        self.buggy_kv_caches = self.buggy_create_kv_cache_func(
            ShapeTuple([max_num_sequence]),
            ShapeTuple([max_total_sequence_length]),
            ShapeTuple([prefill_chunk_size]),
            ShapeTuple([page_size]),
            ShapeTuple([support_sliding_window]),
        )
        self.add_sequence_func(self.golden_kv_caches, 0)
        self.add_sequence_func(self.buggy_kv_caches, 0)

    def _preprocess_prompts(self, prompt: str) -> List[List[int]]:
        print(
            "======================= Starts Tokenization & Embedding ======================="
        )
        # Step 0. Generate prompt string using conversation template
        self.conversation.messages.append(("user", prompt))
        self.conversation.messages.append(("assistant", None))

        with open(self.config_file_path, "r", encoding="utf-8") as file:
            config = json.load(file)
        parsed_prompt = self.conversation.as_prompt(config)
        print(
            "Parsed prompt using conversation template "
            f"{self.conversation.name}: {parsed_prompt}"
        )
        tokens = engine_utils.process_prompts(parsed_prompt, self.tokenizer.encode)  # type: ignore

        if self.conversation.system_prefix_token_ids is not None:
            tokens[0] = self.conversation.system_prefix_token_ids + tokens[0]

        return tokens

    def _embed(
        self, data_inputs: List[Union[List[int], data.ImageData]]
    ) -> Tuple[tvm.nd.NDArray, int]:
        embeddings = []
        for data_input in data_inputs:
            # Process token data
            data_input = tvm.nd.array(
                np.array(data_input).astype("int32"), device=self.golden_device
            )
            embeddings.append(self.embed_func(data_input, self.params).asnumpy())

        # Concatenate
        concat_embeddings = tvm.nd.array(
            np.concatenate(embeddings, axis=0), device=self.golden_device
        )
        concat_embeddings = self.nd_view_func(
            concat_embeddings,
            ShapeTuple([1, concat_embeddings.shape[0], concat_embeddings.shape[1]]),
        )
        input_len = concat_embeddings.shape[1]

        return concat_embeddings, input_len

    def _prefill(self, embedding: tvm.nd.NDArray, input_len: int):
        print("======================= Starts Prefill =======================")
        seq_len_shape = ShapeTuple([input_len])

        self.begin_forward_func(self.golden_kv_caches, ShapeTuple([0]), seq_len_shape)
        self.begin_forward_func(self.buggy_kv_caches, ShapeTuple([0]), seq_len_shape)

        logits, _ = self.prefill_func(embedding, self.golden_kv_caches, self.params)

        self.end_forward_func(self.golden_kv_caches)
        self.end_forward_func(self.buggy_kv_caches)
        return logits

    def _decode(self, token: int):
        embedding, _ = self._embed([[token]])
        self.begin_forward_func(self.golden_kv_caches, ShapeTuple([0]), ShapeTuple([1]))
        self.begin_forward_func(self.buggy_kv_caches, ShapeTuple([0]), ShapeTuple([1]))

        logits, _ = self.decode_func(embedding, self.golden_kv_caches, self.params)

        self.end_forward_func(self.golden_kv_caches)
        self.end_forward_func(self.buggy_kv_caches)
        return logits

    def _softmax_with_temperature(self, logits: np.ndarray, temperature: float):
        # Adjust logits based on the temperature
        logits = np.array(logits) / temperature
        logits -= np.max(logits, axis=-1, keepdims=True)

        exp_logits = np.exp(logits, logits)
        exp_logits /= np.sum(exp_logits, axis=-1, keepdims=True)
        return exp_logits

    def _apply_presence_and_freq_penalty(
        self, logits: np.ndarray, presence_penalty: float, freq_penalty: float
    ):
        for token_id, freq in self.appeared_token_freq.items():
            logits[:, :, token_id] -= freq * freq_penalty + presence_penalty

    def _sample_token_from_logits(
        self,
        logits: tvm.nd.NDArray,
        *,
        temperature=1.0,
        top_p=1.0,
        presence_penalty=0.0,
        frequency_penalty=0.0,
    ):
        logits_np = logits.numpy()

        if presence_penalty != 0.0 or frequency_penalty != 0.0:
            self._apply_presence_and_freq_penalty(
                logits_np, presence_penalty, frequency_penalty
            )

        logits_np = self._softmax_with_temperature(logits_np, temperature)

        logits = logits.copyfrom(logits_np)
        next_token = self.sample_topp_from_prob_func(logits, top_p, random.random())
        return next_token

    def generate(
        self,
        prompt: str,
        generate_length: int,
    ):
        """Generates the response from the model given a user prompt. User will need to
        specify the generation length for debugging purpose. For example, a generation
        length of 3 will include 1 prefill step and 2 decode steps.

        Parameters
        ----------
        prompt : str
            The user input prompt.

        generate_length : int
            How many tokens to generate.
        """
        out_tokens = []

        self.instrument.reset(self.golden_kv_caches, self.buggy_kv_caches)
        data_inputs = self._preprocess_prompts(prompt)
        print(f"{green('Data inputs: ')}: {data_inputs}")
        embedding, input_len = self._embed(data_inputs)
        logits = self._prefill(embedding, input_len)
        next_token = self._sample_token_from_logits(logits)
        out_tokens.append(next_token)

        print("======================= Starts Decode =======================")
        for _ in range(generate_length - 1):
            self.instrument.reset(self.golden_kv_caches, self.buggy_kv_caches)
            logits = self._decode(next_token)
            next_token = self._sample_token_from_logits(logits)
            out_tokens.append(next_token)

            if next_token in self.conversation.stop_token_ids:
                break
        print(f"{green('Output tokens: ')}: {out_tokens}")


class LibCompare(LibCompareVMInstrument):
    """The default debug instrument to use if users don't specify
    a customized one.

    This debug instrument will dump the arguments and output of each
    VM Call instruction into a .npz file. It will also alert the user
    if any function outputs are NaN or INF.

    Parameters
    ----------
    mod: runtime.Module
        The module of interest to be validated.

    device: runtime.Device
        The device to run the target module on.

    time_eval: bool
        Whether to time evaluate the functions.
    """

    def __init__(  # pylint: disable=too-many-arguments, unused-argument
        self,
        mod: runtime.Module,
        target_device: runtime.Device,
        skip_visited: bool = True,
    ):
        super().__init__(mod, target_device, True)
        self.visited: Set[str] = set([])
        self.skip_visited = skip_visited
        self.golden_kv_cache = None
        self.buggy_kv_cache = None
        self.attention_builtin = [
            "vm.builtin.paged_attention_kv_cache_create_reduced",
            "vm.builtin.attention_kv_cache_attention_with_fused_qkv",
        ]

    def reset(
        self, golden_kv_cache: Object, buggy_kv_cache: Object
    ):  # pylint: disable=unused-argument
        """Reset the state of the Instrument class"""
        self.visited = set([])
        self.counter = 0
        self.golden_kv_cache = golden_kv_cache
        self.buggy_kv_cache = buggy_kv_cache

    def skip_instrument(self, func, name, before_run, ret_val, *args):
        if name.startswith("shape_func"):
            return True
        if self.skip_visited and name in self.visited:
            return True
        self.visited.add(name)
        return False

    def compare(
        self,
        name: str,
        ref_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
        new_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
        ret_indices: Iterable[int],
    ):
        """Comparison function, can be overloaded.

        Parameters
        ----------
        name: str
            Name of the function.

        ref_args:
            The reference arguments.

        new_args:
            The args to be passed to the comparison function.

        ret_indices:
            List of indices to validate return values.
        """
        if "matmul" in name or "gemv" in name or "gemm" in name:
            atol, rtol = 5e-1, 1e-2
        else:
            atol, rtol = 1e-2, 1e-3
        my_func = (
            tvm.get_global_func(name)
            if name.startswith("vm.builtin")
            else self.mod.get_function(name, query_imports=True)
        )
        if self.verbose:
            print(f"[{self.counter}] Validating {name} ...", end="")
        my_func(*new_args)
        error_info = []
        for rindex in ret_indices:
            if isinstance(new_args[rindex], tvm.nd.NDArray):
                buggy_val = new_args[rindex].numpy()
                golden_val = ref_args[rindex].numpy()
                if not np.allclose(
                    buggy_val,
                    golden_val,
                    atol=atol,
                    rtol=rtol,
                ):
                    abs_error = np.max(np.abs(buggy_val - golden_val))
                    if np.count_nonzero(np.abs(buggy_val - golden_val) > atol) > 0:
                        relative_error = np.max(
                            np.abs(buggy_val - golden_val)[
                                np.abs(buggy_val - golden_val) > atol
                            ]
                            / np.abs(golden_val)[np.abs(buggy_val - golden_val) > atol]
                        )
                    else:
                        relative_error = 0
                    error_info.append((abs_error, relative_error))

        if self.verbose:
            if len(error_info) == 0:
                print(f"{green('Passed. ✅')}")
            else:
                print(f"{red('Failed. ❌')}")
                for max_abs_error, max_abs_error_ratio in error_info:
                    print(
                        f"  - Max absolute error: {max_abs_error:.4f}, Max absolute error ratio: {max_abs_error_ratio * 100:.2f}%"
                    )
        self.counter += 1

    def __call__(self, func, name, before_run, ret_val, *args):
        if before_run:
            return
        is_attention_builtin = name in self.attention_builtin
        if not is_attention_builtin and name.startswith("vm.builtin."):
            # if name not in self.visited:
            # print(f"Skipping {name} because of builtin")
            # self.visited.add(name)
            return

        if self.skip_instrument(func, name, before_run, ret_val, *args):
            return

        new_args = []
        # not always true, true for most ops.
        ret_indices = (len(args) - 1,)
        temp_args = []
        for i, arg in enumerate(args):
            if isinstance(arg, tvm.nd.NDArray):
                arr = tvm.nd.empty(arg.shape, arg.dtype, device=self.device)
                # copy from cpu since we look at different device
                if i not in ret_indices:
                    temp_cpu = arg.copyto(tvm.cpu())
                    temp_args.append(temp_cpu)
                    arr.copyfrom(temp_cpu)
                new_args.append(arr)
            elif self.golden_kv_cache is not None and self.golden_kv_cache == arg:
                new_args.append(self.buggy_kv_cache)
            else:
                new_args.append(arg)
        # wait until all copy complete before we release temp_cpu
        self.device.sync()
        self.compare(name, args, new_args, ret_indices)


def main():
    parser = argparse.ArgumentParser("TapML Chat Debug Tool")
    parser.add_argument(
        "model",
        type=str,
        help="An TapML model directory that contains `tapml-chat-config.json`",
    )
    parser.add_argument(
        "--golden-lib",
        type=str,
        help="The full path to the golden model library file to use (e.g. a ``.so`` file).",
        required=True,
    )
    parser.add_argument(
        "--golden-device",
        type=str,
        help="The device to run the golden model on.",
        required=True,
    )
    parser.add_argument(
        "--buggy-lib",
        type=str,
        help="The full path to the buggy model library file to use (e.g. a ``.so`` file).",
        required=True,
    )
    parser.add_argument(
        "--buggy-device",
        type=str,
        help="The device to run the buggy model on.",
        required=True,
    )
    parser.add_argument(
        "--prompt",
        type=str,
        help="The user input prompt.",
        default="Hello, world!",
    )
    parser.add_argument(
        "--skip-visited",
        action="store_true",
        help="Whether to skip visited functions.",
    )
    parser.add_argument(
        "--generate-len",
        type=int,
        help="The length of the generated text.",
        default=2,  # 1 for prefill and the rest for decode
    )
    parsed = parser.parse_args()
    debug_chat = DebugChat(
        model=parsed.model,
        golden_lib=parsed.golden_lib,
        golden_device=parsed.golden_device,
        buggy_lib=parsed.buggy_lib,
        buggy_device=parsed.buggy_device,
        skip_visited=parsed.skip_visited,
    )
    debug_chat.generate(parsed.prompt, parsed.generate_len)


if __name__ == "__main__":
    main()
