"""
This file specifies how TAPML's GPTBigCode parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

from tapml.loader import ExternMapping
from tapml.quantization import Quantization

from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM


def huggingface(model_config: GPTBigCodeConfig, quantization: Quantization) -> ExternMapping:
    """Returns a parameter mapping that maps from the names of TAPML parameters to
    the names of HuggingFace PyTorch parameters.

    Parameters
    ----------
    model_config : GPTBigCodeConfig
        The configuration of the GPTBigCode model.

    quantization : Quantization
        The quantization configuration.

    Returns
    -------
    param_map : ExternMapping
        The parameter mapping from TAPML to HuggingFace PyTorch.
    """
    model = GPTBigCodeForCausalLM(model_config)
    if quantization is not None:
        model.to(quantization.model_dtype)
    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]
        spec=model.get_default_spec(),
        allow_extern=True,
    )
    named_parameters = dict(_named_params)

    mapping = ExternMapping()

    for tapml_name, tapml_param in named_parameters.items():
        if tapml_name not in mapping.param_map:
            mapping.add_mapping(
                tapml_name,
                [tapml_name],
                functools.partial(lambda x, dtype: x.astype(dtype), dtype=tapml_param.dtype),
            )

    return mapping
