Source code for brevitas.core.function_wrapper.misc
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
A collection of miscellaneous ScriptModule used in various quantizers.
"""
import torch
import brevitas
[docs]class Identity(brevitas.jit.ScriptModule):
"""
Identity ScriptModule.
Examples:
>>> identity = Identity()
>>> x = torch.randn(size=[10,])
>>> y = identity(x)
>>> y is x
True
"""
def __init__(self) -> None:
super(Identity, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
[docs]class PowerOfTwo(brevitas.jit.ScriptModule):
"""
ScriptModule implementation of 2.0 ** x.
Examples:
>>> power_of_two = PowerOfTwo()
>>> x = torch.tensor(5.0)
>>> power_of_two(x)
tensor(32.)
"""
def __init__(self) -> None:
super(PowerOfTwo, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2.0 ** x
[docs]class LogTwo(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~torch.log2`.
Examples:
>>> log_two = LogTwo()
>>> x = torch.tensor(8.0)
>>> log_two(x)
tensor(3.)
"""
def __init__(self) -> None:
super(LogTwo, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.log2(x)
[docs]class InplaceLogTwo(torch.nn.Module):
"""
Module wrapper for :func:`~torch.log2_`.
Examples:
>>> inplace_log_two = InplaceLogTwo()
>>> x = torch.tensor(8.0)
>>> inplace_log_two(x)
>>> x
tensor(3.)
Notes:
Inplace operations in TorchScript can be problematic, compilation is disabled.
"""
def __init__(self) -> None:
super(InplaceLogTwo, self).__init__()
[docs] @torch.jit.ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.log2_()
return x