brevitas.core.function_wrapper package#
Submodules#
brevitas.core.function_wrapper.clamp module#
ScriptModule wrappers for various variants of clamping.
- class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#
Bases:
Module
ScriptModule wrapper for
clamp_min()
.Examples
>>> clamp_min = ClampMin(min_val=-2.0) >>> clamp_min(torch.tensor(-3.0)) tensor(-2.)
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#
Bases:
Module
ScriptModule wrapper for
clamp()
.Examples
>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0) >>> scalar_clamp(torch.tensor([-3.0, 3.0])) tensor([-2., 2.])
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.clamp.TensorClamp[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp()
.Examples
>>> tensor_clamp = TensorClamp() >>> min_val = torch.tensor(-2.0) >>> max_val = torch.tensor(2.0) >>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val) tensor([-2., 2.])
- forward(x, min_val, max_val)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.function_wrapper.misc module#
A collection of miscellaneous ScriptModule used in various quantizers.
- class brevitas.core.function_wrapper.misc.Identity[source]#
Bases:
Module
Identity ScriptModule.
Examples
>>> identity = Identity() >>> x = torch.randn(size=[10,]) >>> y = identity(x) >>> y is x True
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
- class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#
Bases:
Module
Module wrapper for
log2_()
.Examples
>>> inplace_log_two = InplaceLogTwo() >>> x = torch.tensor(8.0) >>> inplace_log_two(x) >>> x tensor(3.)
Notes
Inplace operations in TorchScript can be problematic, compilation is disabled.
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
- class brevitas.core.function_wrapper.misc.LogTwo[source]#
Bases:
Module
ScriptModule wrapper for
log2()
.Examples
>>> log_two = LogTwo() >>> x = torch.tensor(8.0) >>> log_two(x) tensor(3.)
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
- class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#
Bases:
Module
ScriptModule implementation of 2.0 ** x.
Examples
>>> power_of_two = PowerOfTwo() >>> x = torch.tensor(5.0) >>> power_of_two(x) tensor(32.)
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
brevitas.core.function_wrapper.ops_ste module#
ScriptModule wrappers of various functions defined in ops_ste
.
- class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#
Bases:
Module
ScriptModule wrapper for
ceil_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#
Bases:
Module
ScriptModule wrapper for
dpu_round_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#
Bases:
Module
ScriptModule wrapper for
floor_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp_ste_()
.- forward(x, min_val, max_val)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#
Bases:
Module
ScriptModule wrapper for
round_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#
Bases:
Module
ScriptModule wrapper for
round_to_zero_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#
Bases:
Module
ScriptModule wrapper for
scalar_clamp_min_ste()
.- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp_ste()
.- forward(x, min_val, max_val)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.function_wrapper.shape module#
ScriptModule classes to compute the view of a tensor according to various different criteria.
- class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView[source]#
Bases:
Module
ScriptModule to compute the
over_batch_over_output_channels()
view of an input tensor.Examples
>>> view_module = OverBatchOverOutputChannelView() >>> y = view_module(torch.empty(size=[8, 10, 5, 5])) >>> y.shape torch.Size([8, 10, 25])
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverBatchOverTensorView[source]#
Bases:
Module
ScriptMoodule to compute the
over_batch_over_tensor()
view of an input tensor.Examples
>>> view_module = OverBatchOverTensorView() >>> y = view_module(torch.empty(size=[8, 10, 5, 5])) >>> y.shape torch.Size([8, 250])
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims)[source]#
Bases:
Module
ScriptMoodule to compute the
over_output_channels()
view of an input tensor.Examples
>>> view_module = OverOutputChannelView(permute_dims=None) >>> y = view_module(torch.empty(size=[16, 8, 5, 5])) >>> y.shape torch.Size([16, 200])
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverTensorView[source]#
Bases:
Module
ScriptMoodule to compute the
over_tensor()
view of an input tensor.Examples
>>> view_module = OverTensorView() >>> y = view_module(torch.empty(size=[16, 6, 5, 5])) >>> y.shape torch.Size([2400])
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.PermuteDims(permute_dims)[source]#
Bases:
Module
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.StatsInputViewShapeImpl[source]#
Bases:
object
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. All adhere to the same interface.
- OVER_BATCH_OVER_OUTPUT_CHANNELS#
alias of
OverBatchOverOutputChannelView
- OVER_BATCH_OVER_TENSOR#
alias of
OverBatchOverTensorView
- OVER_OUTPUT_CHANNELS#
alias of
OverOutputChannelView
- OVER_TENSOR#
alias of
OverTensorView