Source code for brevitas.core.quant.binary

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Tuple

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import binary_sign_ste


[docs]class BinaryQuant(brevitas.jit.ScriptModule): """ ScriptModule that implements scaled uniform binary quantization of an input tensor. Quantization is performed with :func:`~brevitas.function.ops_ste.binary_sign_ste`. Args: scaling_impl (Module): Module that returns a scale factor. quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. Examples: >>> from brevitas.core.scaling import ConstScaling >>> binary_quant = BinaryQuant(ConstScaling(0.1)) >>> inp = torch.Tensor([0.04, -0.6, 3.3]) >>> out, scale, zero_point, bit_width = binary_quant(inp) >>> out tensor([ 0.1000, -0.1000, 0.1000]) >>> scale tensor(0.1000) >>> zero_point tensor(0.) >>> bit_width tensor(1.) Note: Maps to quant_type == QuantType.BINARY == 'BINARY' == 'binary' when applied to weights in higher-level APIs. Note: Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0): super(BinaryQuant, self).__init__() assert signed, "Unsigned binary quant not supported" self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) self.delay_wrapper = DelayWrapper(quant_delay_steps)
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = binary_sign_ste(x) * scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width()
[docs]class ClampedBinaryQuant(brevitas.jit.ScriptModule): """ ScriptModule that implements scaled uniform binary quantization of an input tensor. Before going through quantization, the input tensor is clamped between (- scale, scale), which on the backward pass zeroes gradients corresponding to inputs outside that range. Quantization is performed with :func:`~brevitas.function.ops_ste.binary_sign_ste`. Args: scaling_impl (Module): Module that returns a scale factor. tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp() quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. Examples: >>> from brevitas.core.scaling import ConstScaling >>> binary_quant = ClampedBinaryQuant(ConstScaling(0.1)) >>> inp = torch.Tensor([0.04, -0.6, 3.3]).requires_grad_(True) >>> out, scale, zero_point, bit_width = binary_quant(inp) >>> out tensor([ 0.1000, -0.1000, 0.1000], grad_fn=<MulBackward0>) >>> out.backward(torch.Tensor([1.0, 1.0, 1.0])) >>> inp.grad tensor([0.1000, 0.0000, 0.0000]) >>> scale tensor(0.1000) >>> zero_point tensor(0.) >>> bit_width tensor(1.) Note: Maps to quant_type == QuantType.BINARY == 'BINARY' == 'binary' when applied to activations in higher-level APIs. Note: Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ def __init__( self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(ClampedBinaryQuant, self).__init__() self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) self.delay_wrapper = DelayWrapper(quant_delay_steps) self.tensor_clamp_impl = tensor_clamp_impl
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) y = binary_sign_ste(y) * scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width()