brevitas.core.function_wrapper package#

Submodules#

brevitas.core.function_wrapper.clamp module#

ScriptModule wrappers for various variants of clamping.

class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#

Bases: Module

ScriptModule wrapper for clamp_min().

Examples

>>> clamp_min = ClampMin(min_val=-2.0)
>>> clamp_min(torch.tensor(-3.0))
tensor(-2.)
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#

Bases: Module

ScriptModule wrapper for clamp().

Examples

>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0)
>>> scalar_clamp(torch.tensor([-3.0, 3.0]))
tensor([-2.,  2.])
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.clamp.TensorClamp[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp().

Examples

>>> tensor_clamp = TensorClamp()
>>> min_val = torch.tensor(-2.0)
>>> max_val = torch.tensor(2.0)
>>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val)
tensor([-2.,  2.])
forward(x, min_val, max_val)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

brevitas.core.function_wrapper.misc module#

A collection of miscellaneous ScriptModule used in various quantizers.

class brevitas.core.function_wrapper.misc.Identity[source]#

Bases: Module

Identity ScriptModule.

Examples

>>> identity = Identity()
>>> x = torch.randn(size=[10,])
>>> y = identity(x)
>>> y is x
True
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

training: bool#
class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#

Bases: Module

Module wrapper for log2_().

Examples

>>> inplace_log_two = InplaceLogTwo()
>>> x = torch.tensor(8.0)
>>> inplace_log_two(x)
>>> x
tensor(3.)

Notes

Inplace operations in TorchScript can be problematic, compilation is disabled.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

training: bool#
class brevitas.core.function_wrapper.misc.LogTwo[source]#

Bases: Module

ScriptModule wrapper for log2().

Examples

>>> log_two = LogTwo()
>>> x = torch.tensor(8.0)
>>> log_two(x)
tensor(3.)
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

training: bool#
class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#

Bases: Module

ScriptModule implementation of 2.0 ** x.

Examples

>>> power_of_two = PowerOfTwo()
>>> x = torch.tensor(5.0)
>>> power_of_two(x)
tensor(32.)
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

training: bool#

brevitas.core.function_wrapper.ops_ste module#

ScriptModule wrappers of various functions defined in ops_ste.

class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#

Bases: Module

ScriptModule wrapper for ceil_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#

Bases: Module

ScriptModule wrapper for dpu_round_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#

Bases: Module

ScriptModule wrapper for floor_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp_ste_().

forward(x, min_val, max_val)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#

Bases: Module

ScriptModule wrapper for round_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#

Bases: Module

ScriptModule wrapper for round_to_zero_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#

Bases: Module

ScriptModule wrapper for scalar_clamp_min_ste().

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp_ste().

forward(x, min_val, max_val)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

brevitas.core.function_wrapper.shape module#

ScriptModule classes to compute the view of a tensor according to various different criteria.

class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView[source]#

Bases: Module

ScriptModule to compute the over_batch_over_output_channels() view of an input tensor.

Examples

>>> view_module = OverBatchOverOutputChannelView()
>>> y = view_module(torch.empty(size=[8, 10, 5, 5]))
>>> y.shape
torch.Size([8, 10, 25])
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.shape.OverBatchOverTensorView[source]#

Bases: Module

ScriptMoodule to compute the over_batch_over_tensor() view of an input tensor.

Examples

>>> view_module = OverBatchOverTensorView()
>>> y = view_module(torch.empty(size=[8, 10, 5, 5]))
>>> y.shape
torch.Size([8, 250])
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims)[source]#

Bases: Module

ScriptMoodule to compute the over_output_channels() view of an input tensor.

Examples

>>> view_module = OverOutputChannelView(permute_dims=None)
>>> y = view_module(torch.empty(size=[16, 8, 5, 5]))
>>> y.shape
torch.Size([16, 200])
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.shape.OverTensorView[source]#

Bases: Module

ScriptMoodule to compute the over_tensor() view of an input tensor.

Examples

>>> view_module = OverTensorView()
>>> y = view_module(torch.empty(size=[16, 6, 5, 5]))
>>> y.shape
torch.Size([2400])
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.shape.PermuteDims(permute_dims)[source]#

Bases: Module

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class brevitas.core.function_wrapper.shape.StatsInputViewShapeImpl[source]#

Bases: object

Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. All adhere to the same interface.

OVER_BATCH_OVER_OUTPUT_CHANNELS#

alias of OverBatchOverOutputChannelView

OVER_BATCH_OVER_TENSOR#

alias of OverBatchOverTensorView

OVER_OUTPUT_CHANNELS#

alias of OverOutputChannelView

OVER_TENSOR#

alias of OverTensorView

Module contents#