Source code for brevitas.core.scaling.int_scaling

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from torch import Tensor

import brevitas
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int


[docs]class IntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] def __init__(self, signed: bool, narrow_range: bool): super(IntScaling, self).__init__() self.signed = signed self.narrow_range = narrow_range
[docs] @brevitas.jit.script_method def forward(self, bit_width: Tensor) -> Tensor: if self.signed: return -min_int(self.signed, self.narrow_range, bit_width) else: return max_int(self.signed, self.narrow_range, bit_width)
[docs]class PowerOfTwoIntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed'] def __init__(self, signed: bool): super(PowerOfTwoIntScaling, self).__init__() self.signed = signed
[docs] @brevitas.jit.script_method def forward(self, bit_width: Tensor) -> Tensor: return max_int(self.signed, False, bit_width) + 1