brevitas.nn package¶
Submodules¶
brevitas.nn.hadamard_classifier module¶
-
class
brevitas.nn.hadamard_classifier.
HadamardClassifier
(in_channels, out_channels, fixed_scale=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.module.Module
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
max_output_bit_width
(input_bit_width)¶
-
state_dict
(destination=None, prefix='', keep_vars=False)¶ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
-
brevitas.nn.quant_accumulator module¶
-
class
brevitas.nn.quant_accumulator.
ClampQuantAccumulator
(ms_bit_width_to_clamp=0, signed=True, narrow_range=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.INT: 'INT'>, msb_clamp_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, per_elem_ops=None, clamp_at_least_init_val=False, override_pretrained_bit_width=False)¶
-
class
brevitas.nn.quant_accumulator.
QuantAccumulator
¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.module.Module
-
property
acc_quant_proxy
¶
-
forward
(input)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
property
-
class
brevitas.nn.quant_accumulator.
TruncQuantAccumulator
(ls_bit_width_to_trunc=0, signed=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.INT: 'INT'>, lsb_trunc_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, trunc_at_least_init_val=False, explicit_rescaling=False, override_pretrained_bit_width=False)¶
brevitas.nn.quant_activation module¶
-
class
brevitas.nn.quant_activation.
QuantActivation
(return_quant_tensor)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.module.Module
-
property
act_quant_proxy
¶
-
forward
(input)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
quant_act_scale
()¶
-
property
-
class
brevitas.nn.quant_activation.
QuantHardTanh
(bit_width, min_val=-1.0, max_val=1.0, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, scaling_impl_type=<ScalingImplType.PARAMETER: 'PARAMETER'>, scaling_override=None, scaling_per_channel=False, scaling_stats_sigma=3.0, scaling_stats_op=<StatsOp.MEAN_LEARN_SIGMA_STD: 'MEAN_LEARN_SIGMA_STD'>, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape=None, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)¶
-
class
brevitas.nn.quant_activation.
QuantReLU
(bit_width, max_val, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, scaling_impl_type=<ScalingImplType.PARAMETER: 'PARAMETER'>, scaling_override=None, scaling_per_channel=False, scaling_min_val=1.52587890625e-05, scaling_stats_sigma=2.0, scaling_stats_op=<StatsOp.MEAN_LEARN_SIGMA_STD: 'MEAN_LEARN_SIGMA_STD'>, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape=None, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, override_pretrained_bit_width=False, return_quant_tensor=False)¶
-
class
brevitas.nn.quant_activation.
QuantSigmoid
(bit_width, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)¶
-
class
brevitas.nn.quant_activation.
QuantTanh
(bit_width, narrow_range=False, quant_type=<QuantType.FP: 'FP'>, float_to_int_impl_type=<FloatToIntImplType.ROUND: 'ROUND'>, min_overall_bit_width=2, max_overall_bit_width=None, bit_width_impl_override=None, bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, scaling_min_val=1.52587890625e-05, override_pretrained_bit_width=False, return_quant_tensor=False)¶
brevitas.nn.quant_avg_pool module¶
-
class
brevitas.nn.quant_avg_pool.
QuantAvgPool2d
(kernel_size, stride=None, signed=True, min_overall_bit_width=2, max_overall_bit_width=32, quant_type=<QuantType.FP: 'FP'>, lsb_trunc_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.pooling.AvgPool2d
-
forward
(input)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
max_output_bit_width
(input_bit_width)¶
-
brevitas.nn.quant_bn module¶
-
class
brevitas.nn.quant_bn.
BatchNorm2dToQuantScaleBias
(num_features, eps=1e-05, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_quant_override=None, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width=32, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)¶
-
brevitas.nn.quant_bn.
mul_add_from_bn
(bn_mean, bn_var, bn_eps, bn_weight, bn_bias, affine_only)¶
brevitas.nn.quant_conv module¶
-
class
brevitas.nn.quant_conv.
QuantConv2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, padding_type=<PaddingType.STANDARD: 'STANDARD'>, dilation=1, groups=1, bias=True, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_override=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width_impl_override=None, weight_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, weight_restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, weight_bit_width=32, weight_min_overall_bit_width=2, weight_max_overall_bit_width=None, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_ternary_threshold=0.5, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, weight_override_pretrained_bit_width=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.conv.Conv2d
- Parameters
weight_bit_width (
int
) – The bit-width at which weights are quantized to. If weight_bit_width_impl_type is set toPARAMETER
, this value is used for initialization. If weight_quant_type is set toFP
, this value is ignored.weight_quant_type (
QuantType
) – Type of quantization. If set toFP
, no quantization is performed.weight_narrow_range (
bool
) – Restrict range of quantized values to a symmetrical interval around 0. For example, given weight_bit_width set to 8 and quant_type set toINT
, if weight_narrow_range is set toTrue
, the range of quantized values is in[-127, 127]
; If set toFalse
, it’s in[-128,127]
.weight_restrict_scaling_type (
RestrictValueType
) – Type of restriction imposed on the values of the scaling factor of the quantized weights.weight_scaling_const (
Optional
[float
]) – If weight_scaling_impl_type is set toCONST
, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.weight_scaling_stats_op (
StatsOp
) – Type of statistical operation performed for scaling, if required. If weight_scaling_impl_type is set toSTATS
orAFFINE_STATS
, the operation is part of the compute graph and back-propagated through. If weight_scaling_impl_type is set toPARAMETER_FROM_STATS
, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.weight_scaling_impl_type (
ScalingImplType
) – Type of strategy adopted for scaling the quantized weights.weight_scaling_min_val (
float
) – Minimum value that the scaling factors can reach. This has precedence over anything else, including weight_scaling_const when weight_scaling_impl_type is set toCONST
. Useful in case of numerical instabilities. If set to None, no minimum is imposed.weight_bit_width_impl_type (
BitWidthImplType
) – Type of strategy adopted for precision at which the weights are quantized to when weight_quant_type is set toINT
. Ignored otherwise.weight_restrict_bit_width_type (
RestrictValueType
) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.weight_min_overall_bit_width (
Optional
[int
]) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value imposes a lower bound on the learned value. Ignored otherwise.weight_max_overall_bit_width (
Optional
[int
]) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value imposes an upper bound on the learned value. Ignored otherwise.weight_bit_width_impl_override (
Union
[BitWidthConst
,BitWidthParameter
,None
]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.weight_ternary_threshold (
float
) – Value to be used as a threshold when weight_quant_type is set toTERNARY
. Ignored otherwise.weight_scaling_stats_sigma (
float
) – Value to be used as sigma if weight_scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
and weight_scaling_stats_op is set toAVE_SIGMA_STD
orAVE_LEARN_SIGMA_STD
. Ignored otherwise. When weight_scaling_impl_type is set toSTATS
orAFFINE_STATS
, and weight_scaling_stats_op is set toAVE_LEARN_SIGMA_STD
, the value is used for initialization.weight_override_pretrained_bit_width (
bool
) – If set toTrue
, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified bybit-width
.
-
conv2d
(x, weight, bias)¶
-
conv2d_same_padding
(x, weight, bias)¶
-
forward
(input)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
property
int_weight
¶
-
max_output_bit_width
(input_bit_width, weight_bit_width)¶
-
merge_bn_in
(bn, affine_only)¶
-
property
per_output_channel_broadcastable_shape
¶
-
property
quant_weight_scale
¶ Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1, 1, 1) shape depending on whether scaling is per layer or per-channel. ——-
brevitas.nn.quant_layer module¶
brevitas.nn.quant_linear module¶
-
class
brevitas.nn.quant_linear.
QuantLinear
(in_features, out_features, bias, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_override=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_narrow_range=False, weight_bit_width_impl_override=None, weight_bit_width_impl_type=<BitWidthImplType.CONST: 'CONST'>, weight_restrict_bit_width_type=<RestrictValueType.INT: 'INT'>, weight_bit_width=32, weight_min_overall_bit_width=2, weight_max_overall_bit_width=None, weight_scaling_override=None, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_scaling_min_val=1.52587890625e-05, weight_ternary_threshold=0.5, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_override_pretrained_bit_width=False, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,torch.nn.modules.linear.Linear
- Parameters
weight_bit_width (
int
) – The bit-width at which weights are quantized to. If weight_bit_width_impl_type is set toPARAMETER
, this value is used for initialization. If weight_quant_type is set toFP
, this value is ignored.weight_quant_type (
QuantType
) – Type of quantization. If set toFP
, no quantization is performed.weight_narrow_range (
bool
) – Restrict range of quantized values to a symmetrical interval around 0. For example, given weight_bit_width set to 8 and quant_type set toINT
, if weight_narrow_range is set toTrue
, the range of quantized values is in[-127, 127]
; If set toFalse
, it’s in[-128,127]
.weight_restrict_scaling_type (
RestrictValueType
) – Type of restriction imposed on the values of the scaling factor of the quantized weights.weight_scaling_const (
Optional
[float
]) – If weight_scaling_impl_type is set toCONST
, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.weight_scaling_stats_op (
StatsOp
) – Type of statistical operation performed for scaling, if required. If weight_scaling_impl_type is set toSTATS
orAFFINE_STATS
, the operation is part of the compute graph and back-propagated through. If weight_scaling_impl_type is set toPARAMETER_FROM_STATS
, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.weight_scaling_impl_type (
ScalingImplType
) – Type of strategy adopted for scaling the quantized weights.weight_scaling_min_val (
float
) – Minimum value that the scaling factors can reach. This has precedence over anything else, including weight_scaling_const when weight_scaling_impl_type is set toCONST
. Useful in case of numerical instabilities. If set to None, no minimum is imposed.weight_bit_width_impl_type (
BitWidthImplType
) – Type of strategy adopted for precision at which the weights are quantized to when weight_quant_type is set toINT
. Ignored otherwise.weight_restrict_bit_width_type (
RestrictValueType
) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.weight_min_overall_bit_width (
Optional
[int
]) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value imposes a lower bound on the learned value. Ignored otherwise.weight_max_overall_bit_width (
Optional
[int
]) – If weight_bit_width_impl_type is set toPARAMETER
and weight_quant_type is set toINT
, this value imposes an upper bound on the learned value. Ignored otherwise.weight_bit_width_impl_override (
Union
[BitWidthConst
,BitWidthParameter
,None
]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.weight_ternary_threshold (
float
) – Value to be used as a threshold when weight_quant_type is set toTERNARY
. Ignored otherwise.weight_scaling_stats_sigma (
float
) – Value to be used as sigma if weight_scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
and weight_scaling_stats_op is set toAVE_SIGMA_STD
orAVE_LEARN_SIGMA_STD
. Ignored otherwise. When weight_scaling_impl_type is set toSTATS
orAFFINE_STATS
, and weight_scaling_stats_op is set toAVE_LEARN_SIGMA_STD
, the value is used for initialization.weight_override_pretrained_bit_width (
bool
) – If set toTrue
, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified bybit-width
.
-
forward
(input)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
property
int_weight
¶
-
max_output_bit_width
(input_bit_width, weight_bit_width)¶
-
property
quant_weight_scale
¶ Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1) shape depending on whether scaling is per layer or per-channel. ——-
brevitas.nn.quant_scale_bias module¶
-
class
brevitas.nn.quant_scale_bias.
ScaleBias
(num_features)¶ Bases:
torch.nn.modules.module.Module
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
brevitas.nn.quant_scale_bias.
QuantScaleBias
(num_features, bias_quant_type=<QuantType.FP: 'FP'>, bias_narrow_range=False, bias_bit_width=None, weight_quant_type=<QuantType.FP: 'FP'>, weight_quant_override=None, weight_narrow_range=False, weight_scaling_override=None, weight_bit_width=32, weight_scaling_impl_type=<ScalingImplType.STATS: 'STATS'>, weight_scaling_const=None, weight_scaling_stats_op=<StatsOp.MAX: 'MAX'>, weight_scaling_per_output_channel=False, weight_restrict_scaling_type=<RestrictValueType.LOG_FP: 'LOG_FP'>, weight_scaling_stats_sigma=3.0, weight_scaling_min_val=1.52587890625e-05, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False)¶ Bases:
brevitas.nn.quant_layer.QuantLayer
,brevitas.nn.quant_scale_bias.ScaleBias
-
forward
(quant_tensor)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-