Source code for brevitas.core.bit_width.parameter

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.restrict_val import IntRestrictValue
from brevitas.function import abs_binary_sign_grad

MIN_INT_BIT_WIDTH = 2
NON_ZERO_EPSILON = 1e-6
REMOVE_ZERO_BIT_WIDTH = 0.1


[docs]class BitWidthParameter(brevitas.jit.ScriptModule): """ ScriptModule that returns a learnable bit-width wrapped in a float torch.Tensor. Args: bit_width (int): value to initialize the output learned bit-width. min_bit_width (int): lower bound for the output learned bit-width. Default: 2. restrict_bit_width_impl: restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()). override_pretrained_bit_width (bool): ignore pretrained bit-width loaded from a state dict. Default: False. Returns: Tensor: bit-width wrapped in a float torch.tensor and backend by a learnable torch.nn.Parameter. Raises: RuntimeError: if bit_width < min_bit_width. Examples: >>> bit_width_parameter = BitWidthParameter(8) >>> bit_width_parameter() tensor(8., grad_fn=<RoundSteFnBackward>) Note: Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict. Note: Maps to bit_width_impl_type == BitWidthImplType.PARAMETER == 'PARAMETER' == 'parameter' in higher-level APIs. """ __constants__ = ['bit_width_base', 'override_pretrained'] def __init__( self, bit_width: int, min_bit_width: int = MIN_INT_BIT_WIDTH, restrict_bit_width_impl: Module = IntRestrictValue(RoundSte()), override_pretrained_bit_width: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(BitWidthParameter, self).__init__() if bit_width < MIN_INT_BIT_WIDTH: raise RuntimeError( "Int bit width has to be at least {}, instead is {}.".format( MIN_INT_BIT_WIDTH, bit_width)) if min_bit_width < MIN_INT_BIT_WIDTH: raise RuntimeError( "Min int bit width has to be at least {}, instead is {}.".format( MIN_INT_BIT_WIDTH, min_bit_width)) if bit_width < min_bit_width: raise RuntimeError( "Int bit width has to be at least {}, instead is {}.".format( min_bit_width, bit_width)) bit_width = float(int(bit_width)) min_bit_width = float(int(min_bit_width)) bit_width_base = restrict_bit_width_impl.restrict_init_float(min_bit_width) bit_width = restrict_bit_width_impl.restrict_init_float(bit_width) bit_width_offset_init = bit_width - bit_width_base self.bit_width_offset = Parameter( torch.tensor(bit_width_offset_init, dtype=dtype, device=device)) self.bit_width_base = bit_width_base self.restrict_bit_width_impl = restrict_bit_width_impl self.override_pretrained = override_pretrained_bit_width
[docs] @brevitas.jit.script_method def forward(self) -> Tensor: bit_width = abs_binary_sign_grad(self.bit_width_offset) + self.bit_width_base bit_width = self.restrict_bit_width_impl(bit_width) return bit_width
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): bit_width_const_key = prefix + 'bit_width' bit_width_offset_key = prefix + 'bit_width_offset' if bit_width_const_key in state_dict: assert bit_width_offset_key not in state_dict, "Both should not be true." bit_width = state_dict[bit_width_const_key] state_dict[bit_width_offset_key] = bit_width - self.bit_width_base del state_dict[bit_width_const_key] if self.override_pretrained and bit_width_offset_key in state_dict: del state_dict[bit_width_offset_key] super(BitWidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) if (config.IGNORE_MISSING_KEYS or self.override_pretrained) and bit_width_offset_key in missing_keys: missing_keys.remove(bit_width_offset_key)
[docs]class RemoveBitwidthParameter(brevitas.jit.ScriptModule): __constants__ = ['non_zero_epsilon', 'override_pretrained'] def __init__( self, bit_width_to_remove: int, override_pretrained_bit_width: bool = False, non_zero_epsilon: float = NON_ZERO_EPSILON, remove_zero_bit_width=REMOVE_ZERO_BIT_WIDTH, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None): super(RemoveBitwidthParameter, self).__init__() if bit_width_to_remove < 0: raise RuntimeError("Bit width to clamp has to be >= 0.".format(bit_width_to_remove)) elif bit_width_to_remove == 0: bit_width_coeff_init = 1 / remove_zero_bit_width else: bit_width_coeff_init = 1 / bit_width_to_remove self.bit_width_coeff = Parameter( torch.tensor(bit_width_coeff_init, dtype=dtype, device=device)) self.non_zero_epsilon = non_zero_epsilon self.override_pretrained = override_pretrained_bit_width
[docs] @brevitas.jit.script_method def forward(self) -> Tensor: bit_width_to_remove = 1.0 / (self.non_zero_epsilon + torch.abs(self.bit_width_coeff)) return bit_width_to_remove
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): bit_width_coeff_key = prefix + 'bit_width_coeff' if self.override_pretrained and bit_width_coeff_key in state_dict: del state_dict[bit_width_coeff_key] super(RemoveBitwidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) if (config.IGNORE_MISSING_KEYS or self.override_pretrained) and bit_width_coeff_key in missing_keys: missing_keys.remove(bit_width_coeff_key)