"""
This file specifies how TAPML's BERT parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from tapml.loader import ExternMapping
from tapml.quantization import Quantization

from .bert_model import BertConfig, BertModel


def huggingface(model_config: BertConfig, quantization: Quantization) -> ExternMapping:
    """Returns a parameter mapping that maps from the names of TAPML parameters to
    the names of HuggingFace PyTorch parameters.

    Parameters
    ----------
    model_config : BertConfig
        The configuration of the BERT model.

    quantization : Quantization
        The quantization configuration.

    Returns
    -------
    param_map : ExternMapping
        The parameter mapping from TAPML to HuggingFace PyTorch.
    """
    model = BertModel(model_config)
    if quantization is not None:
        model.to(quantization.model_dtype)
    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]
        spec=model.get_default_spec(),
        allow_extern=True,
    )
    named_parameters = dict(_named_params)

    mapping = ExternMapping()

    for i in range(model_config.num_hidden_layers):
        attn = f"encoder.layer.{i}.attention.self"
        tapml_name = f"{attn}.qkv.weight"
        tapml_param = named_parameters[tapml_name]
        mapping.add_mapping(
            tapml_name,
            [
                f"{attn}.query.weight",
                f"{attn}.key.weight",
                f"{attn}.value.weight",
            ],
            functools.partial(
                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
                dtype=tapml_param.dtype,
            ),
        )

        tapml_name = f"{attn}.qkv.bias"
        tapml_param = named_parameters[tapml_name]
        mapping.add_mapping(
            tapml_name,
            [
                f"{attn}.query.bias",
                f"{attn}.key.bias",
                f"{attn}.value.bias",
            ],
            functools.partial(
                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
                dtype=tapml_param.dtype,
            ),
        )

    for tapml_name, tapml_param in named_parameters.items():
        if tapml_name not in mapping.param_map:
            mapping.add_mapping(
                tapml_name,
                [tapml_name],
                functools.partial(
                    lambda x, dtype: x.astype(dtype),
                    dtype=tapml_param.dtype,
                ),
            )

    return mapping
