# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional, Tuple
import torch
import brevitas
VALUE_ATTR_NAME = 'value'
[docs]@torch.jit.ignore
def inplace_tensor_add(tensor: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
tensor.add_(value)
return tensor
[docs]@torch.jit.ignore
def inplace_tensor_mul(tensor: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
tensor.mul_(value)
return tensor
[docs]@torch.jit.ignore
def inplace_momentum_update(
tensor: torch.Tensor,
update: torch.Tensor,
momentum: Optional[float],
counter: int,
new_counter: int) -> torch.Tensor:
if momentum is None:
tensor.mul_(counter / new_counter)
tensor.add_(update / new_counter)
else:
tensor.mul_(1 - momentum)
tensor.add_(momentum * update)
return tensor
[docs]class StatelessBuffer(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(StatelessBuffer, self).__init__()
self.register_buffer(VALUE_ATTR_NAME, value)
[docs] @brevitas.jit.script_method
def forward(self):
return self.value.detach()
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
super(StatelessBuffer, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
value_key = prefix + VALUE_ATTR_NAME
if value_key in missing_keys:
missing_keys.remove(value_key)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(StatelessBuffer, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
del output_dict[prefix + VALUE_ATTR_NAME]
return output_dict
[docs]class SingleArgStatelessBuffer(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(SingleArgStatelessBuffer, self).__init__()
self.const = StatelessBuffer(torch.tensor(value))
[docs] @brevitas.jit.script_method
def forward(self, placeholder):
return self.const()
[docs]class ParameterWrapper(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(ParameterWrapper, self).__init__()
self.register_parameter(VALUE_ATTR_NAME, value)
[docs] @brevitas.jit.script_method
def forward(self):
return self.value
[docs]class SliceTensor(brevitas.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.subtensor_slice_list = brevitas.jit.Attribute([None], List[Optional[Tuple[int, int]]])
[docs] @torch.jit.ignore
def eager_forward(self, x: torch.Tensor) -> torch.Tensor:
slices = tuple(slice(*s) if s is not None else slice(s) for s in self.subtensor_slice_list)
x = x[slices]
return x
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
for i, s in enumerate(self.subtensor_slice_list):
if s is not None:
x = x.slice(i, s[0], s[1])
else:
x = self.eager_forward(x)
return x