Source code for brevitas.core.scaling.runtime

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
from brevitas.core.utils import ParameterWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import abs_binary_sign_grad


[docs]class StatsFromParameterScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None) -> None: super(StatsFromParameterScaling, self).__init__() self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, affine_rescaling, affine_shift_scale)
[docs] @brevitas.jit.script_method def forward(self, ignored: torch.Tensor) -> torch.Tensor: stats = self.parameter_list_stats() return self.stats_scaling_impl(stats)
class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, affine_shift_scale: bool) -> None: super(_StatsScaling, self).__init__() if affine_shift_scale and not affine_rescaling: raise RuntimeError( "Disabling shifting of the scale requires to enable affine rescaling first.") if affine_rescaling: self.affine_rescaling = _AffineRescaling(scaling_shape, affine_shift_scale) else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, stats: torch.Tensor) -> torch.Tensor: stats = self.restrict_scaling_pre(stats) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats
[docs]class RuntimeStatsScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None) -> None: super(RuntimeStatsScaling, self).__init__() self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_momentum) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, affine_rescaling, affine_shift_scale)
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): stats = self.runtime_stats(x) return self.stats_scaling_impl(stats)
class _AffineRescaling(brevitas.jit.ScriptModule): def __init__(self, scaling_shape, shift_scale): super(_AffineRescaling, self).__init__() self.affine_weight = Parameter(torch.ones(scaling_shape)) if shift_scale: self.affine_bias = ParameterWrapper(torch.zeros(scaling_shape)) else: self.affine_bias = StatelessBuffer(torch.tensor(0.)) @brevitas.jit.script_method def forward(self, x): out = x * self.affine_weight + self.affine_bias() out = abs_binary_sign_grad(out) return out def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(_AffineRescaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) affine_weight_key = prefix + 'affine_weight' affine_bias_key = prefix + 'affine_bias' if config.IGNORE_MISSING_KEYS and affine_weight_key in missing_keys: missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key)