Source code for brevitas.core.utils
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Optional
import torch
import brevitas
VALUE_ATTR_NAME = 'value'
[docs]@torch.jit.ignore
def inplace_tensor_add(tensor: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
tensor.add_(value)
return tensor
[docs]@torch.jit.ignore
def inplace_tensor_mul(tensor: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
tensor.mul_(value)
return tensor
[docs]@torch.jit.ignore
def inplace_momentum_update(
tensor: torch.Tensor,
update: torch.Tensor,
momentum: Optional[float],
counter: int,
new_counter: int) -> torch.Tensor:
if momentum is None:
tensor.mul_(counter / new_counter)
tensor.add_(update / new_counter)
else:
tensor.mul_(1 - momentum)
tensor.add_(momentum * update)
return tensor
[docs]class StatelessBuffer(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(StatelessBuffer, self).__init__()
self.register_buffer(VALUE_ATTR_NAME, value)
[docs] @brevitas.jit.script_method
def forward(self):
return self.value.detach()
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
super(StatelessBuffer, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
value_key = prefix + VALUE_ATTR_NAME
if value_key in missing_keys:
missing_keys.remove(value_key)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(StatelessBuffer, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
del output_dict[prefix + VALUE_ATTR_NAME]
return output_dict
[docs]class SingleArgStatelessBuffer(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(SingleArgStatelessBuffer, self).__init__()
self.const = StatelessBuffer(torch.tensor(value))
[docs] @brevitas.jit.script_method
def forward(self, placeholder):
return self.const()
[docs]class ParameterWrapper(brevitas.jit.ScriptModule):
def __init__(self, value: torch.Tensor):
super(ParameterWrapper, self).__init__()
self.register_parameter(VALUE_ATTR_NAME, value)
[docs] @brevitas.jit.script_method
def forward(self):
return self.value