Source code for brevitas.core.restrict_val

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import math
from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper import Identity
from brevitas.core.function_wrapper import InplaceLogTwo
from brevitas.core.function_wrapper import LogTwo
from brevitas.core.function_wrapper import PowerOfTwo
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import ScalarClampMinSte
from brevitas.inject.enum import FloatToIntImplType  # retrocompatibility
from brevitas.inject.enum import RestrictValueType

assert RestrictValueType  # prevent removal of unused import
assert FloatToIntImplType


class _RestrictClampValue(brevitas.jit.ScriptModule):

    def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]):
        super(_RestrictClampValue, self).__init__()
        if scaling_min_val is not None and scaling_min_val != 0:
            self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
        else:
            self.clamp_min_ste = Identity()
        if restrict_value_impl is not None:
            self.restrict_value_impl = restrict_value_impl
        else:
            self.restrict_value_impl = Identity()

    @brevitas.jit.script_method
    def forward(self, x: torch.Tensor):
        x = self.restrict_value_impl(x)
        x = self.clamp_min_ste(x)
        return x


class _RestrictValue(brevitas.jit.ScriptModule):

    def __init__(self, restrict_value_impl: Optional[Module]):
        super(_RestrictValue, self).__init__()
        if restrict_value_impl is not None:
            self.restrict_value_impl = restrict_value_impl
        else:
            self.restrict_value_impl = Identity()

    @brevitas.jit.script_method
    def forward(self, x: torch.Tensor):
        x = self.restrict_value_impl(x)
        return x


class _ClampValue(brevitas.jit.ScriptModule):

    def __init__(self, scaling_min_val: Optional[float]):
        super(_ClampValue, self).__init__()
        if scaling_min_val is not None and scaling_min_val != 0:
            self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
        else:
            self.clamp_min_ste = Identity()
        self.min_val = scaling_min_val

    @brevitas.jit.script_method
    def forward(self, x: torch.Tensor):
        x = self.clamp_min_ste(x)
        return x


[docs]class FloatRestrictValue(brevitas.jit.ScriptModule): def __init__(self) -> None: super(FloatRestrictValue, self).__init__()
[docs] def restrict_init_float(self, x: float) -> float: return x
[docs] def restrict_init_tensor(self, x: Tensor) -> Tensor: return x
[docs] def restrict_init_module(self): return Identity()
[docs] def restrict_init_inplace_module(self): return Identity()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> Tensor: return x
[docs]class LogFloatRestrictValue(brevitas.jit.ScriptModule): def __init__(self): super(LogFloatRestrictValue, self).__init__() self.power_of_two: Module = PowerOfTwo()
[docs] def restrict_init_float(self, x: float): return math.log2(x)
[docs] def restrict_init_tensor(self, x: torch.Tensor): return torch.log2(x)
[docs] def restrict_init_module(self): return LogTwo()
[docs] def restrict_init_inplace_module(self): return InplaceLogTwo()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.power_of_two(x) return x
[docs]class IntRestrictValue(brevitas.jit.ScriptModule): def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): super(IntRestrictValue, self).__init__() self.float_to_int_impl = restrict_value_float_to_int_impl
[docs] def restrict_init_float(self, x: float): return x
[docs] def restrict_init_tensor(self, x: torch.Tensor): return x
[docs] def restrict_init_module(self): return Identity()
[docs] def restrict_init_inplace_module(self): return Identity()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) return x
[docs]class PowerOfTwoRestrictValue(brevitas.jit.ScriptModule): def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): super(PowerOfTwoRestrictValue, self).__init__() self.float_to_int_impl = restrict_value_float_to_int_impl self.power_of_two: Module = PowerOfTwo()
[docs] def restrict_init_float(self, x: float): return math.log2(x)
[docs] def restrict_init_tensor(self, x: torch.Tensor): return torch.log2(x)
[docs] def restrict_init_module(self): return LogTwo()
[docs] def restrict_init_inplace_module(self): return InplaceLogTwo()
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x