"""This file specifies how TAPML's Baichuan parameters are quantized using group quantization
or other formats."""

from typing import Tuple

from tvm.relax.frontend import nn

from tapml.loader import QuantizeMapping
from tapml.quantization import FTQuantize, GroupQuantize, NoQuantize

from .baichuan_model import BaichuanConfig, BaichuanForCausalLM


def group_quant(
    model_config: BaichuanConfig,
    quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
    """Quantize a BaichuanLM-architecture model using group quantization."""
    model: nn.Module = BaichuanForCausalLM(model_config)
    model.to(quantization.model_dtype)
    quant_map = QuantizeMapping({}, {})
    quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
    model = quantization.quantize_model(
        model,
        quant_map,
        "",
    )
    return model, quant_map


def ft_quant(
    model_config: BaichuanConfig,
    quantization: FTQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
    """Quantize a BaichuanLM-architecture model using FasterTransformer quantization."""
    model: nn.Module = BaichuanForCausalLM(model_config)
    model.to(quantization.model_dtype)
    quant_map = QuantizeMapping({}, {})
    model = quantization.quantize_model(
        model,
        quant_map,
        "",
    )
    return model, quant_map


def no_quant(
    model_config: BaichuanConfig,
    quantization: NoQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
    """Quantize a BaichuanLM model without quantization."""
    model: nn.Module = BaichuanForCausalLM(model_config)
    model.to(quantization.model_dtype)
    quant_map = QuantizeMapping({}, {})
    return model, quant_map
