# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import brevitas
from brevitas import config
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import DEFAULT_MOMENTUM
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.function import abs_binary_sign_grad
from .utils import inplace_momentum_update
from .utils import inplace_tensor_add
from .utils import StatelessBuffer
__all__ = [
'ZeroZeroPoint',
'StatsFromParameterZeroPoint',
'ParameterFromRuntimeZeroPoint',
'ParameterZeroPoint']
[docs]class ZeroZeroPoint(brevitas.jit.ScriptModule):
def __init__(self) -> None:
super(ZeroZeroPoint, self).__init__()
self.zero_point = StatelessBuffer(torch.tensor(0.0))
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
return self.zero_point()
class _ScaleShiftZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['quantize_zero_point']
def __init__(self, int_quant: Module, quantize_zero_point: bool) -> None:
super(_ScaleShiftZeroPoint, self).__init__()
self.int_quant = int_quant
self.quantize_zero_point = quantize_zero_point
@brevitas.jit.script_method
def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
min_int = self.int_quant.min_int(bit_width)
if self.quantize_zero_point:
out = self.int_quant.to_int(scale, min_int, bit_width, zero_point)
else:
out = zero_point / scale + min_int
return out
[docs]class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule):
def __init__(
self,
int_quant: Module,
quantize_zero_point: bool,
zero_point_stats_input_view_shape_impl: Module,
zero_point_stats_input_concat_dim: int,
zero_point_stats_impl: Module,
zero_point_shape: Tuple[int, ...],
tracked_parameter_list: List[torch.nn.Parameter]) -> None:
super(StatsFromParameterZeroPoint, self).__init__()
self.parameter_list_stats = _ParameterListStats(
zero_point_stats_impl,
zero_point_shape,
zero_point_stats_input_view_shape_impl,
zero_point_stats_input_concat_dim,
tracked_parameter_list)
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
stats = self.parameter_list_stats()
return self.scale_shift_zero_point(-stats, scale, bit_width)
[docs]class ParameterFromRuntimeZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['stats_permute_dims', 'collect_stats_steps', 'zero_point_shape', 'momentum']
def __init__(
self,
collect_stats_steps: int,
int_quant: Module,
quantize_zero_point: bool,
zero_point_stats_impl: Optional[int],
zero_point_shape: Tuple[int, ...],
zero_point_stats_input_view_shape_impl: Module,
zero_point_stats_momentum: Optional[float] = DEFAULT_MOMENTUM) -> None:
super(ParameterFromRuntimeZeroPoint, self).__init__()
assert collect_stats_steps > 0, 'Steps should be more than 0'
self.collect_stats_steps = collect_stats_steps
self.counter: int = brevitas.jit.Attribute(0, int)
self.zero_point_shape = zero_point_shape
self.stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl
self.momentum = zero_point_stats_momentum
self.value = Parameter(torch.full(zero_point_shape, 0.0))
self.register_buffer('buffer', torch.full(zero_point_shape, 0.0))
self.zero_point_stats_impl = zero_point_stats_impl
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
[docs] @brevitas.jit.script_method
def training_forward(self, x) -> Tensor:
if self.counter < self.collect_stats_steps:
stats_input = self.stats_input_view_shape_impl(x)
stats = self.zero_point_stats_impl(stats_input)
stats = stats.view(self.zero_point_shape)
new_counter = self.counter + 1
if self.counter == 0:
inplace_tensor_add(self.buffer, stats.detach())
else:
inplace_momentum_update(
self.buffer, stats.detach(), self.momentum, self.counter, new_counter)
self.counter = new_counter
# work around find_unusued_parameters=True in DDP
out = stats + 0. * self.value
elif self.counter == self.collect_stats_steps:
inplace_tensor_add(self.value.detach(), self.buffer)
self.counter = self.counter + 1
out = self.value
else:
out = self.value
return out
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
if self.training:
out = self.training_forward(x)
else:
if self.counter <= self.collect_stats_steps:
out = self.buffer
else:
out = self.value
out = abs_binary_sign_grad(out)
out = self.scale_shift_zero_point(out, scale, bit_width)
return out
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromRuntimeZeroPoint,
self).state_dict(destination, prefix, keep_vars)
# Avoid saving the buffer
del output_dict[prefix + 'buffer']
# Avoid saving the init value
if self.counter == 0:
del output_dict[prefix + 'value']
# Save buffer into value for any non-zero number of collection steps
elif self.counter <= self.collect_stats_steps:
output_dict[prefix + 'value'] = self.buffer
return output_dict
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
super(ParameterFromRuntimeZeroPoint, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
value_key = prefix + 'value'
buffer_key = prefix + 'buffer'
# Buffer is supposed to be always missing
missing_keys.remove(buffer_key)
# Pytorch stores training flag as a buffer with JIT enabled
training_key = prefix + 'training'
if training_key in missing_keys:
missing_keys.remove(training_key)
# disable stats collection when a pretrained value is loaded
if value_key not in missing_keys:
self.counter = self.collect_stats_steps + 1
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)
[docs]class ParameterZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['stats_permute_dims', 'collect_stats_steps', 'momentum']
def __init__(
self,
zero_point_init: Union[float, torch.Tensor],
int_quant: Module,
quantize_zero_point: bool,
zero_point_shape: Tuple[int, ...] = None) -> None:
super(ParameterZeroPoint, self).__init__()
if (isinstance(zero_point_init, Tensor) and zero_point_shape is not None and
zero_point_init.shape != SCALAR_SHAPE and
zero_point_init.shape != zero_point_shape):
raise RuntimeError("zero_point_init.shape is non-scalar and != from zero_point_shape.")
if isinstance(zero_point_init, Tensor):
zero_point_init = zero_point_init.detach()
else:
zero_point_init = torch.tensor(zero_point_init)
if zero_point_init.shape == SCALAR_SHAPE and zero_point_shape is not None:
zero_point_init = torch.full(zero_point_shape, zero_point_init)
self.value = Parameter(zero_point_init)
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
out = abs_binary_sign_grad(self.value)
out = self.scale_shift_zero_point(out, scale, bit_width)
return out
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
super(ParameterZeroPoint, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
value_key = prefix + 'value'
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)