# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
Implementation of various functions with a straight-through gradient estimators, dispatched to
either a native just-in-time compiled backend (when env ``BREVITAS_JIT=1``) or to an autograd
Function implemented in :obj:`~brevitas.ops.autograd_ste_ops` (when env ``BREVITAS_JIT=0``).
The native backend is enabled when ``BREVITAS_JIT`` is enabled to allow for end-to-end compilation
of the built-in quantizers, since as of Pytorch 1.8.1 a torch.autograd.Function is not supported by
the compiler.
"""
import torch
from torch import Tensor
import brevitas
from brevitas.function.ops import binary_sign
from brevitas.function.ops import dpu_round
from brevitas.function.ops import round_to_zero
from brevitas.function.ops import tensor_clamp
from brevitas.function.ops import tensor_clamp_
__all__ = [
'round_ste',
'ceil_ste',
'floor_ste',
'tensor_clamp_ste',
'tensor_clamp_ste_',
'scalar_clamp_ste',
'scalar_clamp_min_ste',
'binary_sign_ste',
'ternary_sign_ste',
'round_to_zero_ste',
'dpu_round_ste',
'abs_binary_sign_grad']
if brevitas.NATIVE_STE_BACKEND_LOADED:
fn_prefix = torch
script_flag = brevitas.jit.script
else:
fn_prefix = brevitas
script_flag = torch.jit.ignore
[docs]@script_flag
def round_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`torch.round` with a straight-through gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.round_ste_impl` (with env
``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = round_ste(x)
>>> y
tensor([ 2., -2.], grad_fn=<RoundSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.round(x)
return fn_prefix.ops.autograd_ste_ops.round_ste_impl(x)
[docs]@script_flag
def ceil_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`torch.ceil` with a straight-through gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.ceil_ste_impl` (with env
``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = ceil_ste(x)
>>> y
tensor([ 2., -1.], grad_fn=<CeilSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.ceil(x)
return fn_prefix.ops.autograd_ste_ops.ceil_ste_impl(x)
[docs]@script_flag
def floor_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`torch.floor` with a straight-through gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.floor_ste_impl` (with env
``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = floor_ste(x)
>>> y
tensor([ 1., -2.], grad_fn=<FloorSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.floor(x)
return fn_prefix.ops.autograd_ste_ops.floor_ste_impl(x)
[docs]@script_flag
def tensor_clamp_ste(x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor:
"""
Function that implements :func:`~brevitas.function.ops.tensor_clamp` with a straight-through
gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val
and max_val is always None.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl` (with
env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = tensor_clamp_ste(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5]))
>>> y
tensor([ 1.0000, 0.4000, -0.5000], grad_fn=<TensorClampSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return tensor_clamp(x, min_val, max_val)
output = fn_prefix.ops.autograd_ste_ops.tensor_clamp_ste_impl(x, min_val, max_val)
return output
[docs]@script_flag
def tensor_clamp_ste_(x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor:
"""
Function that implements :func:`~brevitas.function.ops.tensor_clamp_` with a straight-through
gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val
and max_val is always None.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl_` (with
env ``BREVITAS_JIT=0``) or its C++ just-in-time compiled variant (with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = tensor_clamp_ste_(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5]))
>>> y
tensor([ 1.0000, 0.4000, -0.5000], grad_fn=<InplaceTensorClampSteFnBackward>)
>>> (y == x).all().item()
True
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return tensor_clamp_(x, min_val, max_val)
output = fn_prefix.ops.autograd_ste_ops.tensor_clamp_ste_impl_(x, min_val, max_val)
return output
[docs]@script_flag
def scalar_clamp_ste(x: Tensor, min_val: float, max_val: float) -> Tensor:
"""
Function that implements :func:`torch.clamp` with a straight-through gradient estimator
for the gradient of the output w.r.t. to ``x``, while the gradient of ``y`` w.r.t. to ``min_val``
and ``max_val`` is always ``None``.
Args:
x: input tensor to clamp.
min_val: scalar value to use as lower bound for the input tensor.
max_val: scalar value to use as upper bound for the input tensor.
Returns:
Tensor: clamped output tensor.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl`
(with env ``BREVITAS_JIT=0``) or its C++ just-in-time compiled variant
(with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = scalar_clamp_ste(x, -1.0, 1.0)
>>> y
tensor([ 1.0000, 0.4000, -1.0000], grad_fn=<ScalarClampSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.clamp(x, min_val, max_val)
return fn_prefix.ops.autograd_ste_ops.scalar_clamp_ste_impl(x, min_val, max_val)
[docs]@script_flag
def scalar_clamp_min_ste(x: Tensor, min_val: float) -> Tensor:
"""
Function that implements :func:`torch.clamp_min` with a straight-through gradient estimator
for the gradient of output y w.r.t. to ``x``, while the gradient of y w.r.t. to ``min_val`` is
always ``None``.
Args:
x: input tensor to clamp.
min_val: scalar value to use as lower bound for the input tensor.
Returns:
Tensor: clamped output tensor.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl`
(with env ``BREVITAS_JIT=0``) or its C++ just-in-time compiled variant
(with ``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = scalar_clamp_min_ste(x, -1.0)
>>> y
tensor([ 1.5000, 0.4000, -1.0000], grad_fn=<ScalarClampMinSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.clamp_min(x, min_val)
return fn_prefix.ops.autograd_ste_ops.scalar_clamp_min_ste_impl(x, min_val)
[docs]@script_flag
def binary_sign_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`~brevitas.function.ops.binary_sign` with a straight-through
gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.binary_sign_ste_impl` (with
env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True)
>>> y = binary_sign_ste(x)
>>> y
tensor([ 1., 1., -1.], grad_fn=<BinarySignSteFnBackward>)
>>> grad = torch.tensor([0.1, 0.2, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return binary_sign(x)
return fn_prefix.ops.autograd_ste_ops.binary_sign_ste_impl(x)
[docs]@script_flag
def ternary_sign_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`torch.sign` with a straight-through gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl` (with
env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True)
>>> y = ternary_sign_ste(x)
>>> y
tensor([ 1., 0., -1.], grad_fn=<TernarySignSteFnBackward>)
>>> grad = torch.tensor([0.1, 0.2, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.sign(x)
return fn_prefix.ops.autograd_ste_ops.ternary_sign_ste_impl(x)
[docs]@script_flag
def round_to_zero_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`~brevitas.function.ops.round_to_zero` with a straight-through
gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl` (with
env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = round_to_zero_ste(x)
>>> y
tensor([ 1., -1.], grad_fn=<RoundToZeroSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return round_to_zero(x)
return fn_prefix.ops.autograd_ste_ops.round_to_zero_ste_impl(x)
[docs]@script_flag
def dpu_round_ste(x: Tensor) -> Tensor:
"""
Function that implements :func:`~brevitas.function.ops.dpu_round` with a straight-through
gradient estimator.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.dpu_round_ste_impl` (with
env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = dpu_round_ste(x)
>>> y
tensor([ 2., -2.], grad_fn=<DPURoundSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return dpu_round(x)
return fn_prefix.ops.autograd_ste_ops.dpu_round_ste_impl(x)
[docs]@script_flag
def abs_binary_sign_grad(x: Tensor) -> Tensor:
"""
Function that implements :func:`torch.abs` with a binary-sign backward, in order to
have subgradient 1 in 0. Compare with :func:`torch.abs`' subgradient of 0 in 0.
Notes:
Wrapper for either :func:`~brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl`
(with env ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with
``BREVITAS_JIT=1``).
Examples:
>>> x = torch.tensor([0.0], requires_grad=True)
>>> y = abs_binary_sign_grad(x)
>>> y
tensor([0.], grad_fn=<AbsBinarySignGradFnBackward>)
>>> grad = torch.tensor([0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
"""
if torch._C._get_tracing_state():
return torch.abs(x)
return fn_prefix.ops.autograd_ste_ops.abs_binary_sign_grad_impl(x)