brevitas.proxy package¶
Submodules¶
brevitas.proxy.parameter_quant module¶
-
class
brevitas.proxy.parameter_quant.
WeightQuantProxy
(bit_width, quant_type, narrow_range, scaling_override, restrict_scaling_type, scaling_const, scaling_stats_op, scaling_impl_type, scaling_stats_reduce_dim, scaling_shape, scaling_min_val, bit_width_impl_type, restrict_bit_width_type, min_overall_bit_width, max_overall_bit_width, tracked_parameter_list_init, bit_width_impl_override, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, ternary_threshold, scaling_stats_sigma, override_pretrained_bit_width)¶ Bases:
brevitas.proxy.parameter_quant.ParameterQuantProxy
- Parameters
bit_width (
Optional
[int
]) – The bit-width at which weights are quantized to. If bit_width_impl_type is set toPARAMETER
, this value is used for initialization. If quant_type is set toFP
, this value is ignored.quant_type (
QuantType
) – Type of quantization. If set toFP
, no quantization is performed.narrow_range (
bool
) – Restrict range of quantized values to a symmetrical interval around 0. For example, given bit_width set to 8 and quant_type set toINT
, if narrow_range is set toTrue
, the range of quantized values is in[-127, 127]
; If set toFalse
, it’s in[-128,127]
.restrict_scaling_type (
RestrictValueType
) – Type of restriction imposed on the values of the scaling factor of the quantized weights.scaling_const (
Optional
[float
]) – If scaling_impl_type is set toCONST
, this value is used as the scaling factor across all relevant dimensions. Ignored otherwise.scaling_stats_op (
StatsOp
) – Type of statistical operation performed for scaling, if required. If scaling_impl_type is set toSTATS
orAFFINE_STATS
, the operation is part of the compute graph and back-propagated through. If scaling_impl_type is set toPARAMETER_FROM_STATS
, the operation is used only for computing the initialization of the parameter, possibly across some dimensions. Ignored otherwise.scaling_impl_type (
ScalingImplType
) – Type of strategy adopted for scaling the quantized weights.scaling_stats_reduce_dim (
Optional
[int
]) – Dimension within the shape determined by scaling_stats_input_view_shape_impl along which scaling_stats_op is applied. If set toNone
, scaling is assumed to be over the whole tensor. Ignored whenever scaling_stats_op is ignored.scaling_shape (
Tuple
[int
, …]) – Shape of the scaling factor tensor. This is required to be broadcastable w.r.t. the weight tensor to scale.scaling_min_val (
Optional
[float
]) – Minimum value that the scaling factors can reach. This has precedence over anything else, including scaling_const when scaling_impl_type is set toCONST
. Useful in case of numerical instabilities. If set to None, no minimum is imposed.bit_width_impl_type (
Optional
[BitWidthImplType
]) – Type of strategy adopted for precision at which the weights are quantized to when quant_type is set toINT
. Ignored otherwise.restrict_bit_width_type (
Optional
[RestrictValueType
]) – If bit_width_impl_type is set toPARAMETER
and quant_type is set toINT
, this value constraints or relax the bit-width value that can be learned. Ignored otherwise.min_overall_bit_width (
Optional
[int
]) – If bit_width_impl_type is set toPARAMETER
and quant_type is set toINT
, this value imposes a lower bound on the learned value. Ignored otherwise.max_overall_bit_width (
Optional
[int
]) – If bit_width_impl_type is set toPARAMETER
and quant_type is set toINT
, this value imposes an upper bound on the learned value. Ignored otherwise.tracked_parameter_list_init (
Parameter
) – Pytorch Parameter of which statistics are computed when scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
. This value initializes the list of parameters that are concatenated together when computing statistics.bit_width_impl_override (
Union
[BitWidthConst
,BitWidthParameter
,None
]) – Override the bit-width implementation with an implementation defined elsewhere. Accepts BitWidthConst or BitWidthParameter type of Modules. Useful for sharing the same learned bit-width between different layers.scaling_stats_input_view_shape_impl (
Module
) – When scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
, this Module reshapes each tracked parameter before concatenating them together and computing their statistics.scaling_stats_input_concat_dim (
int
) – When scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
, this value defines the dimension along which the tracked parameters are concated after scaling_stats_input_view_shape_impl is called, but before statistics are taken.ternary_threshold (
Optional
[float
]) – Value to be used as a threshold when quant_type is set toTERNARY
. Ignored otherwise.scaling_stats_sigma (
Optional
[float
]) – Value to be used as sigma if scaling_impl_type is set toSTATS
,AFFINE_STATS
orPARAMETER_FROM_STATS
and scaling_stats_op is set toAVE_SIGMA_STD
orAVE_LEARN_SIGMA_STD
. Ignored otherwise. When scaling_impl_type is set toSTATS
orAFFINE_STATS
, and scaling_stats_op is set toAVE_LEARN_SIGMA_STD
, the value is used for initialization.override_pretrained_bit_width (
bool
) – If set toTrue
, when loading a pre-trained model that includes a learned bit-width, the pre-trained value is ignored and replaced by the value specified bybit-width
.
-
add_tracked_parameter
(x)¶ - Return type
None
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type
Tuple
[Tensor
,Tensor
,Tensor
]
-
int_weight
(x)¶
-
re_init_tensor_quant
()¶
-
class
brevitas.proxy.parameter_quant.
BiasQuantProxy
(quant_type, bit_width, narrow_range)¶ Bases:
brevitas.proxy.parameter_quant.ParameterQuantProxy
-
forward
(x, input_scale, input_bit_width)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type
Tuple
[Tensor
,Optional
[Tensor
],Optional
[Tensor
]]
-
brevitas.proxy.quant_proxy module¶
-
class
brevitas.proxy.quant_proxy.
QuantProxy
¶ Bases:
torch.nn.modules.module.Module
-
state_dict
(destination=None, prefix='', keep_vars=False)¶ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
-
brevitas.proxy.runtime_quant module¶
-
class
brevitas.proxy.runtime_quant.
ActivationQuantProxy
(activation_impl, bit_width, signed, narrow_range, min_val, max_val, quant_type, float_to_int_impl_type, scaling_override, scaling_impl_type, scaling_per_channel, scaling_min_val, scaling_stats_sigma, scaling_stats_op, scaling_stats_buffer_momentum, scaling_stats_permute_dims, per_channel_broadcastable_shape, min_overall_bit_width, max_overall_bit_width, bit_width_impl_override, bit_width_impl_type, restrict_bit_width_type, restrict_scaling_type, override_pretrained_bit_width)¶ Bases:
brevitas.proxy.quant_proxy.QuantProxy
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
brevitas.proxy.runtime_quant.
ClampQuantProxy
(signed, narrow_range, quant_type, ms_bit_width_to_clamp, clamp_at_least_init_val, min_overall_bit_width, max_overall_bit_width, msb_clamp_bit_width_impl_type, override_pretrained_bit_width)¶ Bases:
brevitas.proxy.quant_proxy.QuantProxy
-
forward
(x, input_scale, input_bit_width)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
brevitas.proxy.runtime_quant.
FusedActivationQuantProxy
(activation_impl, tensor_quant)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.proxy.runtime_quant.
TruncQuantProxy
(signed, quant_type, ls_bit_width_to_trunc, trunc_at_least_init_val, min_overall_bit_width, max_overall_bit_width, lsb_trunc_bit_width_impl_type, explicit_rescaling, override_pretrained_bit_width)¶ Bases:
brevitas.proxy.quant_proxy.QuantProxy
-
forward
(x, input_scale, input_bit_width)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-