Source code for brevitas.core.quant.ternary

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Tuple

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import ternary_sign_ste


[docs]class TernaryQuant(brevitas.jit.ScriptModule): """ ScriptModule that implements scaled uniform ternary quantization of an input tensor. Quantization is performed with :func:`~brevitas.function.ops_ste.ternary_sign_ste`. Args: scaling_impl (Module): Module that returns a scale factor. threshold (float): Ternarization threshold w.r.t. to the scale factor. quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. Examples: >>> from brevitas.core.scaling import ConstScaling >>> ternary_quant = TernaryQuant(ConstScaling(1.0), 0.5) >>> inp = torch.Tensor([0.04, -0.6, 3.3]) >>> out, scale, zero_point, bit_width = ternary_quant(inp) >>> out tensor([ 0., -1., 1.]) >>> scale tensor(1.) >>> zero_point tensor(0.) >>> bit_width tensor(2.) Note: Maps to quant_type == QuantType.TERNARY == 'TERNARY' == 'ternary' in higher-level APIs. Note: Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ __constants__ = ['threshold'] def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: int = None): super(TernaryQuant, self).__init__() self.scaling_impl = scaling_impl self.threshold = threshold self.bit_width = BitWidthConst(2) self.zero_point = StatelessBuffer(torch.tensor(0.0)) self.delay_wrapper = DelayWrapper(quant_delay_steps)
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) mask = x.abs().gt(self.threshold * scale) y = mask.float() * ternary_sign_ste(x) y = y * scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width()