Source code for brevitas.core.scaling.runtime

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
from brevitas.core.utils import ParameterWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import abs_binary_sign_grad


[docs]class StatsFromParameterScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, affine_rescaling, affine_shift_scale, dtype, device)
[docs] @brevitas.jit.script_method def forward( self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats() if threshold is None: threshold = torch.ones(1).type_as(stats) return self.stats_scaling_impl(stats, threshold)
class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, affine_shift_scale: bool, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> None: super(_StatsScaling, self).__init__() if affine_shift_scale and not affine_rescaling: raise RuntimeError( "Disabling shifting of the scale requires to enable affine rescaling first.") if affine_rescaling: self.affine_rescaling = _AffineRescaling( scaling_shape, affine_shift_scale, dtype, device) else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() self.restrict_scaling_impl = restrict_scaling_impl @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) threshold = self.restrict_scaling_pre(threshold) stats = self.restrict_scaling_pre(stats) stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats
[docs]class RuntimeStatsScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_momentum, dtype, device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, affine_rescaling, affine_shift_scale, dtype, device)
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) return self.stats_scaling_impl(stats, threshold)
class _AffineRescaling(brevitas.jit.ScriptModule): def __init__( self, scaling_shape, shift_scale, dtype: Optional[torch.dtype], device: Optional[torch.device]): super(_AffineRescaling, self).__init__() self.affine_weight = Parameter(torch.ones(scaling_shape, dtype=dtype, device=device)) if shift_scale: self.affine_bias = ParameterWrapper( torch.zeros(scaling_shape, dtype=dtype, device=device)) else: self.affine_bias = StatelessBuffer(torch.tensor(0., dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, x): out = x * self.affine_weight + self.affine_bias() out = abs_binary_sign_grad(out) return out def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(_AffineRescaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) affine_weight_key = prefix + 'affine_weight' affine_bias_key = prefix + 'affine_bias' if config.IGNORE_MISSING_KEYS and affine_weight_key in missing_keys: missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key)
[docs]class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( self, group_size: int, group_dim: int, input_view_impl: Module, scaling_stats_impl: Module, scaling_min_val: Optional[float], restrict_scaling_impl: Module = FloatRestrictValue()) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
[docs] @brevitas.jit.script_method def forward( self, stats_input: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val out = self.restrict_clamp_scaling(out) return out