Source code for brevitas.core.function_wrapper.clamp
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
ScriptModule wrappers for various variants of clamping.
"""
import torch
from torch import Tensor
import brevitas
from brevitas.function import tensor_clamp
[docs]class TensorClamp(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops.tensor_clamp`.
Examples:
>>> tensor_clamp = TensorClamp()
>>> min_val = torch.tensor(-2.0)
>>> max_val = torch.tensor(2.0)
>>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val)
tensor([-2., 2.])
"""
def __init__(self) -> None:
super(TensorClamp, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor, min_val: Tensor, max_val: Tensor):
return tensor_clamp(x, min_val=min_val, max_val=max_val)
[docs]class ScalarClamp(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~torch.clamp`.
Examples:
>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0)
>>> scalar_clamp(torch.tensor([-3.0, 3.0]))
tensor([-2., 2.])
"""
__constants__ = ['min_val', 'max_val']
def __init__(self, min_val, max_val) -> None:
super(ScalarClamp, self).__init__()
self.min_val = min_val
self.max_val = max_val
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor):
return torch.clamp(x, min=self.min_val, max=self.max_val)
[docs]class ClampMin(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~torch.clamp_min`.
Examples:
>>> clamp_min = ClampMin(min_val=-2.0)
>>> clamp_min(torch.tensor(-3.0))
tensor(-2.)
"""
__constants__ = ['min_val']
def __init__(self, min_val: float) -> None:
super(ClampMin, self).__init__()
self.min_val = min_val
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor):
return x.clamp_min(self.min_val)