brevitas.nn package

Submodules

brevitas.nn.hadamard_classifier module

class brevitas.nn.hadamard_classifier.HadamardClassifier(in_channels, out_channels, fixed_scale=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_output_bit_width(input_bit_width)
state_dict(destination=None, prefix='', keep_vars=False)

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']

brevitas.nn.quant_accumulator module

class brevitas.nn.quant_accumulator.ClampQuantAccumulator(ms_bit_width_to_clamp=0, signed=True, narrow_range=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.INT: 'INT'>, msb_clamp_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, per_elem_ops=None, clamp_at_least_init_val=False, override_pretrained_bit_width=False)

Bases: brevitas.nn.quant_accumulator.QuantAccumulator

class brevitas.nn.quant_accumulator.QuantAccumulator

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.module.Module

property acc_quant_proxy
forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.nn.quant_accumulator.TruncQuantAccumulator(ls_bit_width_to_trunc=0, signed=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.INT: 'INT'>, lsb_trunc_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, trunc_at_least_init_val=False, explicit_rescaling=False, override_pretrained_bit_width=False)

Bases: brevitas.nn.quant_accumulator.QuantAccumulator

brevitas.nn.quant_activation module

class brevitas.nn.quant_activation.QuantActivation(return_quant_tensor)

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.module.Module

property act_quant_proxy
forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

quant_act_scale()
class brevitas.nn.quant_activation.QuantHardTanh(bit_width, min_val=-1.0, max_val=1.0, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, scaling_impl_type=<ScalingImplType.PARAMETER: 'PARAMETER'>, scaling_override=None, scaling_per_channel=False, scaling_stats_sigma=3.0, scaling_stats_op=<StatsOp.MEAN_LEARN_SIGMA_STD: 'MEAN_LEARN_SIGMA_STD'>, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape=None, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_activation.QuantActivation

class brevitas.nn.quant_activation.QuantReLU(bit_width, max_val, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, scaling_impl_type=<ScalingImplType.PARAMETER: 'PARAMETER'>, scaling_override=None, scaling_per_channel=False, scaling_min_val=1.52587890625e-05, scaling_stats_sigma=2.0, scaling_stats_op=<StatsOp.MEAN_LEARN_SIGMA_STD: 'MEAN_LEARN_SIGMA_STD'>, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape=None, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, override_pretrained_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_activation.QuantActivation

class brevitas.nn.quant_activation.QuantSigmoid(bit_width, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_activation.QuantActivation

class brevitas.nn.quant_activation.QuantTanh(bit_width, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_activation.QuantActivation

brevitas.nn.quant_avg_pool module

class brevitas.nn.quant_avg_pool.QuantAvgPool2d(kernel_size, stride=None, signed=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.FP: 'FP'>, lsb_trunc_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>)

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.pooling.AvgPool2d

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_output_bit_width(input_bit_width)

brevitas.nn.quant_bn module

class brevitas.nn.quant_bn.BatchNorm2dToQuantScaleBias(num_features, eps=1e-05, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_quant_override=None, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width=32, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_scale_bias.QuantScaleBias

brevitas.nn.quant_bn.mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias, affine_only)

brevitas.nn.quant_conv module

class brevitas.nn.quant_conv.QuantConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, padding_type=<PaddingType.STANDARD: 'STANDARD'>, dilation=1, groups=1, bias=True, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_override=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width_impl_override=None, weight_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, weight_restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, weight_bit_width=32, weight_min_overall_bit_width=2, weight_max_overall_bit_width=None, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_ternary_threshold=0.5, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, weight_override_pretrained_bit_width=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.conv.Conv2d

Parameters
  • weight_bit_width (int) – The bit-width at which weights are quantized to. If weight_bit_width_impl_type is set to PARAMETER, this value is used for initialization. If weight_quant_type is set to FP, this value is ignored.

  • weight_quant_type (QuantType) – Type of quantization. If set to FP, no quantization is performed.

  • weight_narrow_range (bool) – Restrict range of quantized values to a symmetrical interval around 0. For example, given weight_bit_width set to 8 and quant_type set to INT, if weight_narrow_range is set to True, the range of quantized values is in [-127, 127]; If set to False, it’s in [-128,127].

  • weight_restrict_scaling_type (RestrictValueType) – Type of restriction imposed on the values of the scaling factor of the quantized weights.

  • weight_scaling_const (Optional[float]) – If weight_scaling_impl_type is set to CONST, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.

  • weight_scaling_stats_op (StatsOp) – Type of statistical operation performed for scaling, if required. If weight_scaling_impl_type is set to STATS or AFFINE_STATS, the operation is part of the compute graph and back-propagated through. If weight_scaling_impl_type is set to PARAMETER_FROM_STATS, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.

  • weight_scaling_impl_type (ScalingImplType) – Type of strategy adopted for scaling the quantized weights.

  • weight_scaling_min_val (float) – Minimum value that the scaling factors can reach. This has precedence over anything else, including weight_scaling_const when weight_scaling_impl_type is set to CONST. Useful in case of numerical instabilities. If set to None, no minimum is imposed.

  • weight_bit_width_impl_type (BitWidthImplType) – Type of strategy adopted for precision at which the weights are quantized to when weight_quant_type is set to INT. Ignored otherwise.

  • weight_restrict_bit_width_type (RestrictValueType) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.

  • weight_min_overall_bit_width (Optional[int]) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value imposes a lower bound on the learned value. Ignored otherwise.

  • weight_max_overall_bit_width (Optional[int]) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value imposes an upper bound on the learned value. Ignored otherwise.

  • weight_bit_width_impl_override (Union[BitWidthConst, BitWidthParameter, None]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.

  • weight_ternary_threshold (float) – Value to be used as a threshold when weight_quant_type is set to TERNARY. Ignored otherwise.

  • weight_scaling_stats_sigma (float) – Value to be used as sigma if weight_scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS and weight_scaling_stats_op is set to AVE_SIGMA_STD or AVE_LEARN_SIGMA_STD. Ignored otherwise. When weight_scaling_impl_type is set to STATS or AFFINE_STATS, and weight_scaling_stats_op is set to AVE_LEARN_SIGMA_STD, the value is used for initialization.

  • weight_override_pretrained_bit_width (bool) – If set to True, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified by bit-width.

conv2d(x, weight, bias)
conv2d_same_padding(x, weight, bias)
forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property int_weight
max_output_bit_width(input_bit_width, weight_bit_width)
merge_bn_in(bn, affine_only)
property per_output_channel_broadcastable_shape
property quant_weight_scale

Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1, 1, 1) shape depending on whether scaling is per layer or per-channel. ——-

brevitas.nn.quant_layer module

class brevitas.nn.quant_layer.QuantLayer(compute_output_scale, compute_output_bit_width, return_quant_tensor)

Bases: object

pack_output(output, output_scale, output_bit_width)
unpack_input(input)

brevitas.nn.quant_linear module

class brevitas.nn.quant_linear.QuantLinear(in_features, out_features, bias, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_override=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_narrow_range=False, weight_bit_width_impl_override=None, weight_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, weight_restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, weight_bit_width=32, weight_min_overall_bit_width=2, weight_max_overall_bit_width=None, weight_scaling_override=None, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_scaling_min_val=1.52587890625e-05, weight_ternary_threshold=0.5, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_override_pretrained_bit_width=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_layer.QuantLayer, torch.nn.modules.linear.Linear

Parameters
  • weight_bit_width (int) – The bit-width at which weights are quantized to. If weight_bit_width_impl_type is set to PARAMETER, this value is used for initialization. If weight_quant_type is set to FP, this value is ignored.

  • weight_quant_type (QuantType) – Type of quantization. If set to FP, no quantization is performed.

  • weight_narrow_range (bool) – Restrict range of quantized values to a symmetrical interval around 0. For example, given weight_bit_width set to 8 and quant_type set to INT, if weight_narrow_range is set to True, the range of quantized values is in [-127, 127]; If set to False, it’s in [-128,127].

  • weight_restrict_scaling_type (RestrictValueType) – Type of restriction imposed on the values of the scaling factor of the quantized weights.

  • weight_scaling_const (Optional[float]) – If weight_scaling_impl_type is set to CONST, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.

  • weight_scaling_stats_op (StatsOp) – Type of statistical operation performed for scaling, if required. If weight_scaling_impl_type is set to STATS or AFFINE_STATS, the operation is part of the compute graph and back-propagated through. If weight_scaling_impl_type is set to PARAMETER_FROM_STATS, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.

  • weight_scaling_impl_type (ScalingImplType) – Type of strategy adopted for scaling the quantized weights.

  • weight_scaling_min_val (float) – Minimum value that the scaling factors can reach. This has precedence over anything else, including weight_scaling_const when weight_scaling_impl_type is set to CONST. Useful in case of numerical instabilities. If set to None, no minimum is imposed.

  • weight_bit_width_impl_type (BitWidthImplType) – Type of strategy adopted for precision at which the weights are quantized to when weight_quant_type is set to INT. Ignored otherwise.

  • weight_restrict_bit_width_type (RestrictValueType) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.

  • weight_min_overall_bit_width (Optional[int]) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value imposes a lower bound on the learned value. Ignored otherwise.

  • weight_max_overall_bit_width (Optional[int]) – If weight_bit_width_impl_type is set to PARAMETER and weight_quant_type is set to INT, this value imposes an upper bound on the learned value. Ignored otherwise.

  • weight_bit_width_impl_override (Union[BitWidthConst, BitWidthParameter, None]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.

  • weight_ternary_threshold (float) – Value to be used as a threshold when weight_quant_type is set to TERNARY. Ignored otherwise.

  • weight_scaling_stats_sigma (float) – Value to be used as sigma if weight_scaling_impl_type is set to STATS, AFFINE_STATS or PARAMETER_FROM_STATS and weight_scaling_stats_op is set to AVE_SIGMA_STD or AVE_LEARN_SIGMA_STD. Ignored otherwise. When weight_scaling_impl_type is set to STATS or AFFINE_STATS, and weight_scaling_stats_op is set to AVE_LEARN_SIGMA_STD, the value is used for initialization.

  • weight_override_pretrained_bit_width (bool) – If set to True, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified by bit-width.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property int_weight
max_output_bit_width(input_bit_width, weight_bit_width)
property quant_weight_scale

Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1) shape depending on whether scaling is per layer or per-channel. ——-

brevitas.nn.quant_scale_bias module

class brevitas.nn.quant_scale_bias.ScaleBias(num_features)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.nn.quant_scale_bias.QuantScaleBias(num_features, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_quant_override=None, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width=32, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)

Bases: brevitas.nn.quant_layer.QuantLayer, brevitas.nn.quant_scale_bias.ScaleBias

forward(quant_tensor)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents