Source code for brevitas.core.function_wrapper.ops_ste

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""
ScriptModule wrappers of various functions defined in :obj:`~brevitas.function.ops_ste`.
"""

import torch

import brevitas
from brevitas.function.ops_ste import *


[docs]class RoundSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.round_ste`. """ def __init__(self) -> None: super(RoundSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return round_ste(x)
[docs]class FloorSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.floor_ste`. """ def __init__(self) -> None: super(FloorSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return floor_ste(x)
[docs]class RoundToZeroSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.round_to_zero_ste`. """ def __init__(self) -> None: super(RoundToZeroSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return round_to_zero_ste(x)
[docs]class DPURoundSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.dpu_round_ste`. """ def __init__(self) -> None: super(DPURoundSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return dpu_round_ste(x)
[docs]class CeilSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.ceil_ste`. """ def __init__(self) -> None: super(CeilSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return ceil_ste(x)
[docs]class ScalarClampMinSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.scalar_clamp_min_ste`. """ __constants__ = ['min_val'] def __init__(self, min_val: float) -> None: super(ScalarClampMinSte, self).__init__() self.min_val = min_val
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): return scalar_clamp_min_ste(x, self.min_val)
[docs]class TensorClampSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.tensor_clamp_ste`. """ def __init__(self) -> None: super(TensorClampSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor): return tensor_clamp_ste(x, min_val, max_val)
[docs]class InplaceTensorClampSte(brevitas.jit.ScriptModule): """ ScriptModule wrapper for :func:`~brevitas.function.ops_ste.tensor_clamp_ste_`. """ def __init__(self) -> None: super(InplaceTensorClampSte, self).__init__()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor): return tensor_clamp_ste_(x, min_val, max_val)