brevitas.proxy package

Submodules

brevitas.proxy.parameter_quant module

class brevitas.proxy.parameter_quant.WeightQuantProxy(bit_width, quant_type, narrow_range, scaling_override, restrict_scaling_type, scaling_const, scaling_stats_op, scaling_impl_type, scaling_stats_reduce_dim, scaling_shape, scaling_min_val, bit_width_impl_type, restrict_bit_width_type, min_overall_bit_width, max_overall_bit_width, tracked_parameter_list_init, bit_width_impl_override, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, ternary_threshold, scaling_stats_sigma, override_pretrained_bit_width)

Bases: brevitas.proxy.parameter_quant.ParameterQuantProxy

Parameters
  • bit_width (Optional[int]) – The bit-width at which weights are quantized to. If bit_width_impl_type is set to PARAMETER, this value is used for initialization. If quant_type is set to FP, this value is ignored.

  • quant_type (QuantType) – Type of quantization. If set to FP, no quantization is performed.

  • narrow_range (bool) – Restrict range of quantized values to a symmetrical interval around 0. For example, given bit_width set to 8 and quant_type set to INT, if narrow_range is set to True, the range of quantized values is in [-127, 127]; If set to False, it’s in [-128,127].

  • restrict_scaling_type (RestrictValueType) – Type of restriction imposed on the values of the scaling factor of the quantized weights.

  • scaling_const (Optional[float]) – If scaling_impl_type is set to CONST, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.

  • scaling_stats_op (StatsOp) – Type of statistical operation performed for scaling, if required. If scaling_impl_type is set to STATS or AFFINE_STATS, the operation is part of the compute graph and back-propagated through. If scaling_impl_type is set to PARAMETER_FROM_STATS, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.

  • scaling_impl_type (ScalingImplType) – Type of strategy adopted for scaling the quantized weights.

  • scaling_stats_reduce_dim (Optional[int]) – Dimension within the shape determined by scaling_stats_input_view_shape_impl along which scaling_stats_op is applied. If set to None, scaling is assumed to be over the whole tensor. Ignored whenever scaling_stats_op is ignored.

  • scaling_shape (Tuple[int, …]) – Shape of the scaling factor tensor. This is required to be broadcastable w.r.t. the weight tensor to scale.

  • scaling_min_val (Optional[float]) – Minimum value that the scaling factors can reach. This has precedence over anything else, including scaling_const when scaling_impl_type is set to CONST. Useful in case of numerical instabilities. If set to None, no minimum is imposed.

  • bit_width_impl_type (Optional[BitWidthImplType]) – Type of strategy adopted for precision at which the weights are quantized to when quant_type is set to INT. Ignored otherwise.

  • restrict_bit_width_type (Optional[RestrictValueType]) – If bit_width_impl_type is set to PARAMETER and quant_type is set to INT, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.

  • min_overall_bit_width (Optional[int]) – If bit_width_impl_type is set to PARAMETER and quant_type is set to INT, this value imposes a lower bound on the learned value. Ignored otherwise.

  • max_overall_bit_width (Optional[int]) – If bit_width_impl_type is set to PARAMETER and quant_type is set to INT, this value imposes an upper bound on the learned value. Ignored otherwise.

  • tracked_parameter_list_init (Parameter) – Pytorch Parameter of which statistics are computed when scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS. This value initializes the list of parameters that are concatenated together when computing statistics.

  • bit_width_impl_override (Union[BitWidthConst, BitWidthParameter, None]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.

  • scaling_stats_input_view_shape_impl (Module) – When scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS, this Module reshapes each tracked parameter before concatenating them together and computing their statistics.

  • scaling_stats_input_concat_dim (int) – When scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS, this value defines the dimension along which the tracked parameters are concated after scaling_stats_input_view_shape_impl is called, but before statistics are taken.

  • ternary_threshold (Optional[float]) – Value to be used as a threshold when quant_type is set to TERNARY. Ignored otherwise.

  • scaling_stats_sigma (Optional[float]) – Value to be used as sigma if scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS and scaling_stats_op is set to AVE_SIGMA_STD or AVE_LEARN_SIGMA_STD. Ignored otherwise. When scaling_impl_type is set to STATS or AFFINE_STATS, and scaling_stats_op is set to AVE_LEARN_SIGMA_STD, the value is used for initialization.

  • override_pretrained_bit_width (bool) – If set to True, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified by bit-width.

add_tracked_parameter(x)
Return type

None

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tuple[Tensor, Tensor, Tensor]

int_weight(x)
re_init_tensor_quant()
class brevitas.proxy.parameter_quant.BiasQuantProxy(quant_type, bit_width, narrow_range)

Bases: brevitas.proxy.parameter_quant.ParameterQuantProxy

forward(x, input_scale, input_bit_width)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tuple[Tensor, Optional[Tensor], Optional[Tensor]]

brevitas.proxy.quant_proxy module

class brevitas.proxy.quant_proxy.QuantProxy

Bases: torch.nn.modules.module.Module

state_dict(destination=None, prefix='', keep_vars=False)

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']

brevitas.proxy.runtime_quant module

class brevitas.proxy.runtime_quant.ActivationQuantProxy(activation_impl, bit_width, signed, narrow_range, min_val, max_val, quant_type, float_to_int_impl_type, scaling_override, scaling_impl_type, scaling_per_channel, scaling_min_val, scaling_stats_sigma, scaling_stats_op, scaling_stats_buffer_momentum, scaling_stats_permute_dims, per_channel_broadcastable_shape, min_overall_bit_width, max_overall_bit_width, bit_width_impl_override, bit_width_impl_type, restrict_bit_width_type, restrict_scaling_type, override_pretrained_bit_width)

Bases: brevitas.proxy.quant_proxy.QuantProxy

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.proxy.runtime_quant.ClampQuantProxy(signed, narrow_range, quant_type, ms_bit_width_to_clamp, clamp_at_least_init_val, min_overall_bit_width, max_overall_bit_width, msb_clamp_bit_width_impl_type, override_pretrained_bit_width)

Bases: brevitas.proxy.quant_proxy.QuantProxy

forward(x, input_scale, input_bit_width)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.proxy.runtime_quant.FusedActivationQuantProxy(activation_impl, tensor_quant)

Bases: torch.jit.ScriptModule

class brevitas.proxy.runtime_quant.TruncQuantProxy(signed, quant_type, ls_bit_width_to_trunc, trunc_at_least_init_val, min_overall_bit_width, max_overall_bit_width, lsb_trunc_bit_width_impl_type, explicit_rescaling, override_pretrained_bit_width)

Bases: brevitas.proxy.quant_proxy.QuantProxy

forward(x, input_scale, input_bit_width)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents