Source code for brevitas.function.shape
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
Implementation of various functions to compute shapes that induce flattening along certain
dimensions of a tensor.
"""
from typing import Tuple
from torch import Tensor
import brevitas
__all__ = [
'over_tensor',
'over_output_channels',
'over_batch_over_tensor',
'over_output_features',
'over_batch_over_output_channels']
[docs]@brevitas.jit.script
def over_tensor(x: Tensor) -> int:
"""
Computes the shape s such that x.view(s) is a flat tensor.
Args:
x (Tensor): Input tensor.
Returns:
The number -1 corresponding to a flat shape.
Examples:
>>> over_tensor(torch.randn([2, 3, 4, 3]))
-1
"""
return -1
[docs]@brevitas.jit.script
def over_output_channels(x: Tensor) -> Tuple[int, int]:
"""
Computes the shape s such that x.view(s) is a 2-dim tensor with output channels
at dimension 0 and any other feature at dimension 1.
Args:
x (Tensor): Input tensor with output channels at dimension 0.
Returns:
A tuple containing the 2-dim shape.
Examples:
>>> over_output_channels(torch.randn([2, 3, 4, 3]))
(2, -1)
"""
return x.shape[0], -1
[docs]@brevitas.jit.script
def over_batch_over_tensor(x: Tensor) -> Tuple[int, int]:
"""
Computes the shape s such that x.view(s) is a 2-dim tensor with batches
at dimension 0 and any other feature at dimension 1.
Args:
x (Tensor): Input tensor with batches at dimension 0.
Returns:
A tuple containing the 2-dim shape.
Examples:
>>> over_batch_over_tensor(torch.randn([2, 3, 4, 3]))
(2, -1)
"""
return x.shape[0], -1
[docs]@brevitas.jit.script
def over_batch_over_output_channels(x: Tensor):
"""
Returns a shape s such that x.view(s) is a 3-dim tensor with batches
at dimension 0, output channels at dimension 1, and any other feature at dimension 2.
Args:
x (Tensor): Input tensor with batches at dimension 0 and output channels at dimension 1.
Returns:
A tuple containing the 3-dim shape.
Examples:
>>> over_batch_over_output_channels(torch.randn([2, 3, 4, 3]))
(2, 3, -1)
"""
return x.shape[0], x.shape[1], -1
[docs]@brevitas.jit.script
def over_output_features(x: Tensor):
"""
Returns a shape s such that x.view(s) is a 2-dim tensor with all features except the last
one at dimension 0.
Args:
x (Tensor): Input tensor with batches at dimension 0 and output channels at dimension 1.
Returns:
A tuple containing the 2-dim shape.
Examples:
>>> over_output_features(torch.randn([2, 3, 4, 3]))
(24, 3)
"""
return -1, x.shape[-1]