# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
Implementation of various torch.autograd.Function with straight-through estimators.
"""
from typing import Tuple
import torch
from torch import Tensor
from torch.autograd import Function
from brevitas.function.ops import binary_sign
from brevitas.function.ops import dpu_round
from brevitas.function.ops import round_to_zero
from brevitas.function.ops import tensor_clamp
from brevitas.function.ops import tensor_clamp_
__all__ = [
'ScalarClampSteFn',
'ScalarClampMinSteFn',
'TensorClampSteFn',
'InplaceTensorClampSteFn',
'RoundToZeroSteFn',
'CeilSteFn',
'FloorSteFn',
'BinarySignSteFn',
'TernarySignSteFn',
'RoundSteFn',
'AbsBinarySignGradFn',
'DPURoundSteFn',
'round_ste_impl',
'binary_sign_ste_impl',
'ternary_sign_ste_impl',
'floor_ste_impl',
'ceil_ste_impl',
'round_to_zero_ste_impl',
'scalar_clamp_min_ste_impl',
'scalar_clamp_ste_impl',
'tensor_clamp_ste_impl',
'abs_binary_sign_grad_impl',
'dpu_round_ste_impl']
[docs]class ScalarClampSteFn(Function):
"""
Autograd function that implements ``torch.clamp`` with a straight-through gradient estimator
for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to ``min_val`` and ``min_val``
are always ``None``.
``ScalarClampSteFn.apply(*args)`` is first aliased to :func:`scalar_clamp_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.scalar_clamp_ste` and invoked when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.scalar_clamp_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor, min_val: float, max_val: float) -> Tensor:
y = torch.clamp(x, min_val, max_val)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tuple[Tensor, None, None]:
return grad_y, None, None
@staticmethod
def symbolic(g, x: Tensor, min_val: float, max_val: float):
y = g.op('Clip', x, torch.tensor(min_val), torch.tensor(max_val))
return y
[docs]class ScalarClampMinSteFn(Function):
"""
Autograd function that implements ``torch.clamp_min`` with a straight-through gradient estimator
for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to ``min_val`` is always
``None``.
``ScalarClampMinSteFn.apply(*args)`` is first aliased to :func:`scalar_clamp_min_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.scalar_clamp_min_ste` and invoked when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.scalar_clamp_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor, min_val: float) -> Tensor:
y = torch.clamp_min(x, min_val)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tuple[Tensor, None]:
return grad_y, None
@staticmethod
def symbolic(g, x: Tensor, min_val: float):
y = g.op('Clip', x, torch.tensor(min_val))
return y
[docs]class TensorClampSteFn(Function):
"""
Autograd function that implements :func:`~brevitas.function.ops.tensor_clamp` with a
straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y
w.r.t. to min_val and max_val is always None.
``TensorClampSteFn.apply(*args)`` is first aliased to :func:`tensor_clamp_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.tensor_clamp` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.tensor_clamp` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor:
y = tensor_clamp(x, min_val, max_val)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tuple[Tensor, None, None]:
return grad_y, None, None
@staticmethod
def symbolic(g, x: Tensor, min_val: Tensor, max_val: Tensor):
upper_cond = g.op('Greater', x, max_val)
y = g.op('Where', upper_cond, max_val, x)
lower_cond = g.op('Less', y, min_val)
y = g.op('Where', lower_cond, min_val, y)
return y
[docs]class InplaceTensorClampSteFn(Function):
"""
Autograd function that implements :func:`~brevitas.function.ops.tensor_clamp_` with a
straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y
w.r.t. to min_val and max_val is always None.
``InplaceTensorClampSteFn.apply(*args)`` is first aliased to
:func:`tensor_clamp_ste_impl_(*args)
<brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl_>` and then wrapped by
:func:`~brevitas.function.ops_ste.tensor_clamp_` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.tensor_clamp_` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor:
y = tensor_clamp_(x, min_val, max_val)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tuple[Tensor, None, None]:
return grad_y, None, None
@staticmethod
def symbolic(g, x: Tensor, min_val: Tensor, max_val: Tensor):
upper_cond = g.op('Greater', x, max_val)
y = g.op('Where', upper_cond, max_val, x)
lower_cond = g.op('Less', y, min_val)
y = g.op('Where', lower_cond, min_val, y)
return y
[docs]class RoundToZeroSteFn(Function):
"""
Autograd function that implements :func:`~brevitas.function.ops.round_to_zero` with a
straight-through gradient estimator.
``RoundToZeroSteFn.apply(*args)`` is first aliased to :func:`round_to_zero_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.round_to_zero_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.round_to_zero_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = round_to_zero(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
abs = g.op('Abs', x)
sign = g.op('Sign', x)
floor = g.op('Floor', abs)
y = g.op('Mul', sign, floor)
return y
[docs]class DPURoundSteFn(Function):
"""
Autograd function that implements :func:`~brevitas.function.ops.dpu_round` with a
straight-through gradient estimator.
``DPURoundSteFn.apply(*args)`` is first aliased to :func:`dpu_round_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.dpu_round_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.dpu_round_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.dpu_round_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = dpu_round(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
raise NotImplementedError
[docs]class CeilSteFn(Function):
"""
Autograd function that implements :func:`torch.ceil` with a straight-through gradient estimator.
``CeilSteFn.apply(*args)`` is first aliased to :func:`ceil_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.ceil_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.ceil_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.ceil_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = torch.ceil(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
y = g.op('Ceil', x)
return y
[docs]class FloorSteFn(Function):
"""
Autograd function that implements :func:`torch.floor` with a straight-through gradient estimator.
``FloorSteFn.apply(*args)`` is first aliased to :func:`floor_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.floor_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.floor_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.floor_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = torch.floor(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
y = g.op('Floor', x)
return y
[docs]class BinarySignSteFn(Function):
"""
Autograd function that implements :func:`~brevitas.function.ops.binary_sign` with a
straight-through gradient estimator.
``BinarySignSteFn.apply(*args)`` is first aliased to
:func:`binary_sign_ste_impl(*args)<brevitas.ops.autograd_ste_ops.binary_sign_ste_impl>`
and then wrapped by :func:`~brevitas.function.ops_ste.binary_sign_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.binary_sign_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = binary_sign(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
# requires ONNX opset >= 12
positive_mask = g.op('GreaterOrEqual', x, torch.tensor(0.))
negative_mask = g.op('Less', x, torch.tensor(0.))
positive_mask = g.op('Cast', positive_mask, to_i=torch.onnx.TensorProtoDataType.FLOAT)
negative_mask = g.op('Cast', negative_mask, to_i=torch.onnx.TensorProtoDataType.FLOAT)
y = g.op('Sub', positive_mask, negative_mask)
return y
[docs]class TernarySignSteFn(Function):
"""
Autograd function that implements :func:`torch.sign` with a straight-through gradient estimator.
``TernarySignSteFn.apply(*args)`` is first aliased to :func:`ternary_sign_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.ternary_sign_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.ternary_sign_ste` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = torch.sign(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
y = g.op('Sign', x)
return y
[docs]class RoundSteFn(Function):
"""
Autograd function that implements :func:`torch.round` with a straight-through gradient
estimator.
``RoundSteFn.apply(*args)`` is first aliased to :func:`round_ste_impl(*args)
<brevitas.ops.autograd_ste_ops.round_ste_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.round_ste` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.round_ste` for details on the interface and examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
y = torch.round(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
return grad_y
@staticmethod
def symbolic(g, x: Tensor):
# Requires ONNX opset >= 11
y = g.op('Round', x)
return y
[docs]class AbsBinarySignGradFn(Function):
"""
Autograd function that implements :func:`torch.abs` with a binary-sign backward, in order to
have subgradient 1 in 0. Compare with :func:`torch.abs`' subgradient of 0 in 0.
``AbsBinarySignGradFn.apply(*args)`` is first aliased to :func:`abs_binary_sign_grad(*args)
<brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl>` and then wrapped by
:func:`~brevitas.function.ops_ste.abs_binary_sign_grad` when env ``BREVITAS_JIT=0``.
See :func:`~brevitas.function.ops_ste.abs_binary_sign_grad` for details on the interface and
examples.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(binary_sign(x).type(torch.int8)) # save some memory
y = torch.abs(x)
return y
@staticmethod
def backward(ctx, grad_y: Tensor) -> Tensor:
binary_sign, = ctx.saved_tensors
return binary_sign.float() * grad_y
@staticmethod
def symbolic(g, x: Tensor):
y = g.op('Abs', x)
return y
#: Alias for :class:`RoundSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.RoundSteFn>`
round_ste_impl = RoundSteFn.apply
#: Alias for :class:`BinarySignSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.BinarySignSteFn>`
binary_sign_ste_impl = BinarySignSteFn.apply
#: Alias for :class:`TernarySignSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.TernarySignSteFn>`
ternary_sign_ste_impl = TernarySignSteFn.apply
#: Alias for :class:`FloorSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.FloorSteFn>`
floor_ste_impl = FloorSteFn.apply
#: Alias for :class:`CeilSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.CeilSteFn>`
ceil_ste_impl = CeilSteFn.apply
#: Alias for :class:`RoundToZeroSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.RoundToZeroSteFn>`
round_to_zero_ste_impl = RoundToZeroSteFn.apply
#: Alias for :class:`DPURoundSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.DPURoundSteFn>`
dpu_round_ste_impl = DPURoundSteFn.apply
#: Alias for :class:`ScalarClampMinSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.ScalarClampMinSteFn>`
scalar_clamp_min_ste_impl = ScalarClampMinSteFn.apply
#: Alias for :class:`ScalarClampSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.ScalarClampSteFn>`
scalar_clamp_ste_impl = ScalarClampSteFn.apply
#: Alias for :class:`TensorClampSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.TensorClampSteFn>`
tensor_clamp_ste_impl = TensorClampSteFn.apply
#: Alias for :class:`InplaceTensorClampSteFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.InplaceTensorClampSteFn>`
tensor_clamp_ste_impl_ = InplaceTensorClampSteFn.apply
#: Alias for :class:`AbsBinarySignGradFn.apply(*args)
#: <brevitas.ops.autograd_ste_ops.AbsBinarySignGradFn>`
abs_binary_sign_grad_impl = AbsBinarySignGradFn.apply