brevitas.core package

Submodules

brevitas.core.bit_width module

class brevitas.core.bit_width.BitWidthConst(bit_width_init, restrict_bit_width_type)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.BitWidthImplType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

CONST = 'CONST'
PARAMETER = 'PARAMETER'
class brevitas.core.bit_width.BitWidthParameter(bit_width_init, min_overall_bit_width, max_overall_bit_width, restrict_bit_width_type, override_pretrained)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.IdentityBitWidth(optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.LsbTruncParameterBitWidth(ls_bit_width_to_trunc, trunc_at_least_init_val, min_overall_bit_width, max_overall_bit_width, bit_width_impl_type, override_pretrained)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.MsbClampParameterBitWidth(ms_bit_width_to_clamp, clamp_at_least_init_val, min_overall_bit_width, max_overall_bit_width, bit_width_impl_type, override_pretrained)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.RemoveBitwidthParameter(bit_width_to_remove, remove_at_least_init_val, restrict_bit_width_impl, override_pretrained)

Bases: torch.jit.ScriptModule

class brevitas.core.bit_width.ZeroLsbTruncBitWidth(optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None)

Bases: torch.jit.ScriptModule

forward(input_bit_width, zero_hw_sentinel)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.function_wrapper module

class brevitas.core.function_wrapper.CeilSte

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.ClampMin(min_val)

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.ConstScalarClamp(min_val, max_val)

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.FloorSte

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.Identity

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.LogTwo

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.OverBatchOverOutputChannelView

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.OverBatchOverTensorView

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.OverOutputChannelView

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.OverTensorView

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.PowerOfTwo

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.RoundSte

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.TensorClamp

Bases: torch.jit.ScriptModule

class brevitas.core.function_wrapper.TensorClampSte

Bases: torch.jit.ScriptModule

brevitas.core.quant module

class brevitas.core.quant.IdentityQuant

Bases: torch.jit.ScriptModule

forward(x, zero_hw_sentinel)
class brevitas.core.quant.ClampedBinaryQuant(scaling_impl)

Bases: torch.jit.ScriptModule

class brevitas.core.quant.IntQuant(narrow_range, signed, float_to_int_impl, tensor_clamp_impl)

Bases: torch.jit.ScriptModule

class brevitas.core.quant.PrescaledRestrictIntQuantWithInputBitWidth(narrow_range, signed, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)

Bases: torch.jit.ScriptModule

class brevitas.core.quant.IdentityPrescaledIntQuant

Bases: torch.jit.ScriptModule

class brevitas.core.quant.QuantType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

BINARY = 'BINARY'
FP = 'FP'
INT = 'INT'
TERNARY = 'TERNARY'
class brevitas.core.quant.BinaryQuant(scaling_impl)

Bases: torch.jit.ScriptModule

class brevitas.core.quant.TernaryQuant(scaling_impl, threshold)

Bases: torch.jit.ScriptModule

class brevitas.core.quant.RescalingIntQuant(narrow_range, runtime, signed, scaling_impl, int_scaling_impl, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)

Bases: torch.jit.ScriptModule

static scaling_init_from_min_max(min_val_init, max_val_init)
Return type

Tensor

class brevitas.core.quant.PrescaledRestrictIntQuant(narrow_range, signed, tensor_clamp_impl, msb_clamp_bit_width_impl, float_to_int_impl)

Bases: torch.jit.ScriptModule

brevitas.core.restrict_val module

class brevitas.core.restrict_val.FloatToIntImplType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

CEIL = 'CEIL'
FLOOR = 'FLOOR'
ROUND = 'ROUND'
class brevitas.core.restrict_val.RestrictValue(restrict_value_type, float_to_int_impl_type, min_val)

Bases: torch.jit.ScriptModule

static restrict_value_op(restrict_value_type, restrict_value_op_impl_type)
class brevitas.core.restrict_val.RestrictValueOpImplType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

MATH = 'MATH'
TORCH_FN = 'TORCH_FN'
TORCH_MODULE = 'TORCH_MODULE'
class brevitas.core.restrict_val.RestrictValueType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

FP = 'FP'
INT = 'INT'
LOG_FP = 'LOG_FP'
POWER_OF_TWO = 'POWER_OF_TWO'

brevitas.core.scaling module

class brevitas.core.scaling.AffineRescaling(affine_shape)

Bases: torch.jit.ScriptModule

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.IntScaling(narrow_range, signed, restrict_scaling_type)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.ParameterStatsScaling(stats_op, restrict_scaling_type, stats_input_view_shape_impl, stats_output_shape, stats_input_concat_dim, sigma, scaling_min_val, stats_reduce_dim, tracked_parameter_list, affine)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.PowerOfTwoIntScale(signed)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.RuntimeStatsScaling(stats_op, restrict_scaling_type, stats_input_view_shape_impl, stats_output_shape, sigma, scaling_min_val, stats_reduce_dim, stats_permute_dims, stats_buffer_momentum, stats_buffer_init, affine)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.ScalingImplType

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

AFFINE_STATS = 'AFFINE_STATS'
CONST = 'CONST'
HE = 'HE'
OVERRIDE = 'OVERRIDE'
PARAMETER = 'PARAMETER'
PARAMETER_FROM_STATS = 'PARAMETER_FROM_STATS'
STATS = 'STATS'
class brevitas.core.scaling.SignedFpIntScale(narrow_range)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.StandaloneScaling(scaling_init, is_parameter, parameter_shape, scaling_min_val, restrict_scaling_type)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.StatsScaling(stats_op, restrict_scaling_type, stats_output_shape, scaling_min_val, affine)

Bases: torch.jit.ScriptModule

class brevitas.core.scaling.UnsignedFpIntScale

Bases: torch.jit.ScriptModule

brevitas.core.stats module

class brevitas.core.stats.StatsInputViewShapeImpl

Bases: object

OVER_BATCH_OVER_OUTPUT_CHANNELS

alias of brevitas.core.function_wrapper.OverBatchOverOutputChannelView

OVER_BATCH_OVER_TENSOR

alias of brevitas.core.function_wrapper.OverBatchOverTensorView

OVER_OUTPUT_CHANNELS

alias of brevitas.core.function_wrapper.OverOutputChannelView

OVER_TENSOR

alias of brevitas.core.function_wrapper.OverTensorView

class brevitas.core.stats.StatsOp

Bases: brevitas.utils.python_utils.AutoName

An enumeration.

AVE = 'AVE'
MAX = 'MAX'
MAX_AVE = 'MAX_AVE'
MEAN_LEARN_SIGMA_STD = 'MEAN_LEARN_SIGMA_STD'
MEAN_SIGMA_STD = 'MEAN_SIGMA_STD'
class brevitas.core.stats.ParameterListStats(stats_op, stats_input_view_shape_impl, stats_reduce_dim, stats_input_concat_dim, stats_output_shape, tracked_parameter_list, sigma)

Bases: torch.jit.ScriptModule

Module contents