Source code for brevitas.core.stats.stats_op

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import math
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter

import brevitas
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int

from .stats_wrapper import SCALAR_SHAPE

DEFAULT_STD_DEV_EPSILON = 1e-8


[docs]class NegativeMinOrZero(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(NegativeMinOrZero, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.zero = StatelessBuffer(torch.tensor(0.0))
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: if self.stats_reduce_dim is None: min_val = torch.min(x) else: min_val = torch.min(x, dim=self.stats_reduce_dim)[0] min_val = torch.where( min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype)) return min_val
[docs]class AbsPercentile(brevitas.jit.ScriptModule): __constants__ = ['q', 'stats_reduce_dim'] def __init__( self, high_percentile_q: float, stats_reduce_dim: Optional[int], percentile_q=None): super(AbsPercentile, self).__init__() if percentile_q is not None: raise RuntimeError("percentile_q is deprecated, please pass high_percentile_q.") assert high_percentile_q <= 100, "q has to be a percentage" self.q = high_percentile_q self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * x.numel() + 0.5)) result = x.abs().view(-1).kthvalue(k).values else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." other_dim = abs(self.stats_reduce_dim - 1) dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5)) result = x.abs().kthvalue(k, dim=self.stats_reduce_dim).values return result
[docs]class NegativePercentileOrZero(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim', 'q'] def __init__(self, low_percentile_q, stats_reduce_dim: Optional[int] = None) -> None: super(NegativePercentileOrZero, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.q = low_percentile_q self.zero = StatelessBuffer(torch.tensor(0.0))
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * x.numel())) result = x.view(-1).kthvalue(k).values else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." other_dim = abs(self.stats_reduce_dim - 1) dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * dim_slice.numel())) result = x.kthvalue(k, dim=self.stats_reduce_dim).values result = torch.where( result <= self.zero().to(result.dtype), result, self.zero().to(result.dtype)) return result
[docs]class PercentileInterval(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim', 'low_q', 'high_q'] def __init__( self, low_percentile_q, high_percentile_q, stats_reduce_dim: Optional[int] = None) -> None: super(PercentileInterval, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.low_q = low_percentile_q self.high_q = high_percentile_q
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: if self.stats_reduce_dim is None: low_k = int(math.ceil(.01 * self.low_q * x.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5)) low_result = x.view(-1).kthvalue(low_k).values high_result = x.view(-1).kthvalue(high_k).values else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." other_dim = abs(self.stats_reduce_dim - 1) dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5)) low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim).values high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim).values interval = high_result - low_result abs_interval = torch.abs(interval) return abs_interval
[docs]class AbsMax(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(AbsMax, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: return torch.max(torch.abs(x)) else: return torch.max(torch.abs(x), dim=self.stats_reduce_dim)[0]
[docs]class AbsMinMax(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(AbsMinMax, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: return torch.abs(torch.max(x) - torch.min(x)) else: max_val = torch.max(x, dim=self.stats_reduce_dim)[0] min_val = torch.min(x, dim=self.stats_reduce_dim)[0] return torch.abs(max_val - min_val)
[docs]class AbsMaxAve(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: int) -> None: super(AbsMaxAve, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): return torch.mean(torch.max(torch.abs(x), dim=self.stats_reduce_dim)[0])
[docs]class AbsMaxL2(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: int) -> None: super(AbsMaxL2, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor): per_channel_max = torch.max(torch.abs(x), dim=self.stats_reduce_dim)[0] out = torch.norm(per_channel_max, p=2) out = out / math.sqrt(per_channel_max.view(-1).shape[0]) return out
[docs]class AbsAve(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(AbsAve, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: return torch.mean(torch.abs(x)) else: return torch.mean(torch.abs(x), dim=self.stats_reduce_dim)
[docs]class MeanSigmaStd(brevitas.jit.ScriptModule): def __init__( self, sigma: float, stats_reduce_dim: Optional[int] = None, std_dev_epsilon: float = DEFAULT_STD_DEV_EPSILON) -> None: super(MeanSigmaStd, self).__init__() self.impl = _MeanSigmaStdImpl(stats_reduce_dim, std_dev_epsilon) self.sigma = StatelessBuffer(torch.tensor(sigma))
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): sigma = self.sigma() out = self.impl(x, sigma) return out
class _MeanSigmaStdImpl(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim', 'output_shape', 'epsilon'] def __init__( self, stats_reduce_dim: Optional[int] = None, std_dev_epsilon: float = DEFAULT_STD_DEV_EPSILON) -> None: super(_MeanSigmaStdImpl, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.epsilon = std_dev_epsilon @brevitas.jit.script_method def forward(self, x: Tensor, sigma: Tensor): abs_val = torch.abs(x) if self.stats_reduce_dim is None: mean_val = torch.mean(abs_val) std_val = torch.sqrt(torch.var(abs_val) + self.epsilon) else: mean_val = torch.mean(torch.abs(x), dim=self.stats_reduce_dim) std_val = torch.sqrt(torch.var(abs_val, dim=self.stats_reduce_dim) + self.epsilon) mean_val = mean_val.view(-1) std_val = std_val.view(-1) return mean_val + sigma * std_val
[docs]class MeanLearnedSigmaStd(brevitas.jit.ScriptModule): def __init__( self, sigma: float, stats_output_shape: Tuple[int, ...], stats_reduce_dim: Optional[int] = None, std_dev_epsilon: float = DEFAULT_STD_DEV_EPSILON) -> None: super(MeanLearnedSigmaStd, self).__init__() self.impl = _MeanSigmaStdImpl(stats_reduce_dim, std_dev_epsilon) if stats_output_shape == SCALAR_SHAPE: self.value = Parameter(torch.tensor(sigma)) else: self.value = Parameter(torch.full(stats_output_shape, sigma))
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): sigma = self.sigma.view(self.sigma.shape) # trick to get a tensor type out = self.impl(x, sigma) return out
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): value_key = prefix + 'sigma' retrocomp_value_key = prefix + 'learned_sigma' if retrocomp_value_key in state_dict: # retrocompatibility state_dict[value_key] = state_dict.pop(retrocomp_value_key) super(MeanLearnedSigmaStd, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) sigma_key = prefix + 'sigma' if config.IGNORE_MISSING_KEYS and sigma_key in missing_keys: missing_keys.remove(sigma_key)
[docs]class KLMinimizerThreshold(torch.nn.Module): """ Based on: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py """ def __init__(self, signed, bit_width_impl, num_bins=1000 + 1, smoothing_eps=0.0001): super(KLMinimizerThreshold, self).__init__() self.num_bins = num_bins self.smoothing_eps = smoothing_eps self.signed = signed self.bit_width_impl = bit_width_impl self.absmax_impl = AbsMax()
[docs] def smooth_normalize_distribution(self, p, eps): is_zeros = (p == 0).float() n_zeros = is_zeros.sum() n_nonzeros = torch.numel(p) - n_zeros if not n_nonzeros: return None eps1 = eps * n_zeros / n_nonzeros hist = p.float() hist += eps * is_zeros + (-eps1) * n_nonzeros dist = torch.distributions.categorical.Categorical(logits=hist) return dist
[docs] def forward(self, x: Tensor): absmax = self.absmax_impl(x) bit_width = self.bit_width_impl() num_quantized_bins = max_int(self.signed, False, bit_width).int() thresholds = torch.zeros(self.num_bins // 2 + 1 - num_quantized_bins // 2, device=x.device) divergence = torch.zeros_like(thresholds) quantized_bins = torch.zeros(num_quantized_bins, device=x.device) hist = torch.histc(x, bins=self.num_bins, min=-absmax, max=absmax).int() hist_edges = torch.linspace(-absmax, absmax, self.num_bins + 1) for i in range(num_quantized_bins // 2, self.num_bins // 2 + 1): p_bin_idx_start = self.num_bins // 2 - i p_bin_idx_stop = self.num_bins // 2 + i + 1 thresholds[i - num_quantized_bins // 2] = hist_edges[p_bin_idx_stop] sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] p = sliced_nd_hist.clone() left_outlier_count = torch.sum(hist[0:p_bin_idx_start]) p[0] += left_outlier_count right_outlier_count = torch.sum(hist[p_bin_idx_stop:]) p[-1] += right_outlier_count is_nonzeros = (sliced_nd_hist != 0).float() num_merged_bins = torch.numel(p) // num_quantized_bins for j in range(num_quantized_bins): start = j * num_merged_bins stop = start + num_merged_bins quantized_bins[j] = sliced_nd_hist[start:stop].sum() quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() q = torch.zeros_like(p, dtype=torch.float32, device=x.device) for j in range(num_quantized_bins): start = j * num_merged_bins if j == num_quantized_bins - 1: stop = -1 else: stop = start + num_merged_bins norm = is_nonzeros[start:stop].sum() if norm != 0: q[start:stop] = quantized_bins[j] / norm q[sliced_nd_hist == 0] = 0. p = self.smooth_normalize_distribution(p, self.smoothing_eps) q = self.smooth_normalize_distribution(q, self.smoothing_eps) if q is None: divergence[i - num_quantized_bins // 2] = float('inf') else: divergence[i - num_quantized_bins // 2] = torch.distributions.kl.kl_divergence(p, q) min_divergence_idx = torch.argmin(divergence) opt_threshold = thresholds[min_divergence_idx] return opt_threshold
[docs]class L1Norm(brevitas.jit.ScriptModule): """ScriptModule implementation to collect per-channel L1 normalization stats for weight normalization-based quantization.""" __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(L1Norm, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: # Need to be able to return the max per-channel L1 norm as a scalar raise NotImplementedError("L1 normalization is not supported per-tensor yet.") else: return x.norm(p=1, dim=self.stats_reduce_dim, keepdim=True)
[docs]class L2Norm(brevitas.jit.ScriptModule): """ScriptModule implementation to collect per-channel L2 normalization stats for weight normalization-based quantization.""" __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: super(L2Norm, self).__init__() self.stats_reduce_dim = stats_reduce_dim
[docs] @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: # Need to be able to return the max per-channel L2 norm as a scalar raise NotImplementedError("L2 normalization is not supported per-tensor yet.") else: return x.norm(p=2, dim=self.stats_reduce_dim, keepdim=True)