Source code for brevitas.core.function_wrapper.clamp

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""
ScriptModule wrappers for various variants of clamping.
"""

import torch
from torch import Tensor

import brevitas
from brevitas.function import tensor_clamp


[docs]class TensorClamp(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops.tensor_clamp`. Examples: >>> tensor_clamp = TensorClamp() >>> min_val = torch.tensor(-2.0) >>> max_val = torch.tensor(2.0) >>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val) tensor([-2., 2.]) """ def __init__(self) -> None: super(TensorClamp, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: Tensor, min_val: Tensor, max_val: Tensor): return tensor_clamp(x, min_val=min_val, max_val=max_val)
[docs]class ScalarClamp(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~torch.clamp`. Examples: >>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0) >>> scalar_clamp(torch.tensor([-3.0, 3.0])) tensor([-2., 2.]) """ __constants__ = ['min_val', 'max_val'] def __init__(self, min_val, max_val) -> None: super(ScalarClamp, self).__init__() self.min_val = min_val self.max_val = max_val
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): return torch.clamp(x, min=self.min_val, max=self.max_val)
[docs]class ClampMin(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~torch.clamp_min`. Examples: >>> clamp_min = ClampMin(min_val=-2.0) >>> clamp_min(torch.tensor(-3.0)) tensor(-2.) """ __constants__ = ['min_val'] def __init__(self, min_val: float) -> None: super(ClampMin, self).__init__() self.min_val = min_val
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): return x.clamp_min(self.min_val)