brevitas.core package¶
Submodules¶
brevitas.core.bit_width module¶
-
class
brevitas.core.bit_width.
BitWidthConst
(bit_width_init, restrict_bit_width_type)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
BitWidthImplType
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
CONST
= 'CONST'¶
-
PARAMETER
= 'PARAMETER'¶
-
-
class
brevitas.core.bit_width.
BitWidthParameter
(bit_width_init, min_overall_bit_width, max_overall_bit_width, restrict_bit_width_type, override_pretrained)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
IdentityBitWidth
(optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
LsbTruncParameterBitWidth
(ls_bit_width_to_trunc, trunc_at_least_init_val, min_overall_bit_width, max_overall_bit_width, bit_width_impl_type, override_pretrained)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
MsbClampParameterBitWidth
(ms_bit_width_to_clamp, clamp_at_least_init_val, min_overall_bit_width, max_overall_bit_width, bit_width_impl_type, override_pretrained)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
RemoveBitwidthParameter
(bit_width_to_remove, remove_at_least_init_val, restrict_bit_width_impl, override_pretrained)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.bit_width.
ZeroLsbTruncBitWidth
(optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None)¶ Bases:
torch.jit.ScriptModule
-
forward
(input_bit_width, zero_hw_sentinel)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
brevitas.core.function_wrapper module¶
-
class
brevitas.core.function_wrapper.
CeilSte
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
ClampMin
(min_val)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
ConstScalarClamp
(min_val, max_val)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
FloorSte
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
Identity
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
LogTwo
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
OverBatchOverOutputChannelView
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
OverBatchOverTensorView
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
OverOutputChannelView
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
OverTensorView
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
PowerOfTwo
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
RoundSte
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
TensorClamp
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.function_wrapper.
TensorClampSte
¶ Bases:
torch.jit.ScriptModule
brevitas.core.quant module¶
-
class
brevitas.core.quant.
IdentityQuant
¶ Bases:
torch.jit.ScriptModule
-
forward
(x, zero_hw_sentinel)¶
-
-
class
brevitas.core.quant.
ClampedBinaryQuant
(scaling_impl)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
IntQuant
(narrow_range, signed, float_to_int_impl, tensor_clamp_impl)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
PrescaledRestrictIntQuantWithInputBitWidth
(narrow_range, signed, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
IdentityPrescaledIntQuant
¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
QuantType
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
BINARY
= 'BINARY'¶
-
FP
= 'FP'¶
-
INT
= 'INT'¶
-
TERNARY
= 'TERNARY'¶
-
-
class
brevitas.core.quant.
BinaryQuant
(scaling_impl)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
TernaryQuant
(scaling_impl, threshold)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.quant.
RescalingIntQuant
(narrow_range, runtime, signed, scaling_impl, int_scaling_impl, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)¶ Bases:
torch.jit.ScriptModule
-
static
scaling_init_from_min_max
(min_val_init, max_val_init)¶ - Return type
Tensor
-
static
-
class
brevitas.core.quant.
PrescaledRestrictIntQuant
(narrow_range, signed, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)¶ Bases:
torch.jit.ScriptModule
brevitas.core.restrict_val module¶
-
class
brevitas.core.restrict_val.
FloatToIntImplType
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
CEIL
= 'CEIL'¶
-
FLOOR
= 'FLOOR'¶
-
ROUND
= 'ROUND'¶
-
-
class
brevitas.core.restrict_val.
RestrictValue
(restrict_value_type, float_to_int_impl_type, min_val)¶ Bases:
torch.jit.ScriptModule
-
static
restrict_value_op
(restrict_value_type, restrict_value_op_impl_type)¶
-
static
-
class
brevitas.core.restrict_val.
RestrictValueOpImplType
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
MATH
= 'MATH'¶
-
TORCH_FN
= 'TORCH_FN'¶
-
TORCH_MODULE
= 'TORCH_MODULE'¶
-
brevitas.core.scaling module¶
-
class
brevitas.core.scaling.
AffineRescaling
(affine_shape)¶ Bases:
torch.jit.ScriptModule
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
brevitas.core.scaling.
IntScaling
(narrow_range, signed, restrict_scaling_type)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
ParameterStatsScaling
(stats_op, restrict_scaling_type, stats_input_view_shape_impl, stats_output_shape, stats_input_concat_dim, sigma, scaling_min_val, stats_reduce_dim, tracked_parameter_list, affine)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
PowerOfTwoIntScale
(signed)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
RuntimeStatsScaling
(stats_op, restrict_scaling_type, stats_input_view_shape_impl, stats_output_shape, sigma, scaling_min_val, stats_reduce_dim, stats_permute_dims, stats_buffer_momentum, stats_buffer_init, affine)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
ScalingImplType
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
AFFINE_STATS
= 'AFFINE_STATS'¶
-
CONST
= 'CONST'¶
-
HE
= 'HE'¶
-
OVERRIDE
= 'OVERRIDE'¶
-
PARAMETER
= 'PARAMETER'¶
-
PARAMETER_FROM_STATS
= 'PARAMETER_FROM_STATS'¶
-
STATS
= 'STATS'¶
-
-
class
brevitas.core.scaling.
SignedFpIntScale
(narrow_range)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
StandaloneScaling
(scaling_init, is_parameter, parameter_shape, scaling_min_val, restrict_scaling_type)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
StatsScaling
(stats_op, restrict_scaling_type, stats_output_shape, scaling_min_val, affine)¶ Bases:
torch.jit.ScriptModule
-
class
brevitas.core.scaling.
UnsignedFpIntScale
¶ Bases:
torch.jit.ScriptModule
brevitas.core.stats module¶
-
class
brevitas.core.stats.
StatsInputViewShapeImpl
¶ Bases:
object
-
OVER_BATCH_OVER_OUTPUT_CHANNELS
¶ alias of
brevitas.core.function_wrapper.OverBatchOverOutputChannelView
-
OVER_BATCH_OVER_TENSOR
¶ alias of
brevitas.core.function_wrapper.OverBatchOverTensorView
-
OVER_OUTPUT_CHANNELS
¶ alias of
brevitas.core.function_wrapper.OverOutputChannelView
-
OVER_TENSOR
¶
-
-
class
brevitas.core.stats.
StatsOp
¶ Bases:
brevitas.utils.python_utils.AutoName
An enumeration.
-
AVE
= 'AVE'¶
-
MAX
= 'MAX'¶
-
MAX_AVE
= 'MAX_AVE'¶
-
MEAN_LEARN_SIGMA_STD
= 'MEAN_LEARN_SIGMA_STD'¶
-
MEAN_SIGMA_STD
= 'MEAN_SIGMA_STD'¶
-
-
class
brevitas.core.stats.
ParameterListStats
(stats_op, stats_input_view_shape_impl, stats_reduce_dim, stats_input_concat_dim, stats_output_shape, tracked_parameter_list, sigma)¶ Bases:
torch.jit.ScriptModule