From d5269b41c5489c1446d94dbb4cd504f2f4eb068e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 11:14:16 +0000 Subject: [PATCH] Draft implementation optimal scale --- src/brevitas/core/stats/stats_op.py | 49 +++++++++++++++++++++ tests/brevitas/core/test_opt_scale.py | 62 +++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 tests/brevitas/core/test_opt_scale.py diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index ac520a707..c43c9e72e 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -201,6 +201,55 @@ def forward(self, x: Tensor): return torch.abs(max_val - min_val) +class OptimalIntSymmetricScale(brevitas.jit.ScriptModule): + + def __init__(self, N: int) -> None: + super(OptimalIntSymmetricScale, self).__init__() + # Possible quantized values are {-N, ..., 0, ..., N} + self.N = N + + @brevitas.jit.script_method + def forward(self, x: Tensor): + # Number of elements in the vector + P = len(x) + # Sort absolute values in ascending order + abs_x_sorted, _ = torch.sort(torch.abs(x)) + + # Scales in which at least one element changes its optimal quantized value + transition_scales = 2 * abs_x_sorted.unsqueeze(0) / ( + 2 * torch.arange(start=0, end=self.N).unsqueeze(1) + 1) + + # This operation can be optimised, considering that each row in transition_scales is sorted. due to the monotonicity + # # of the operation, so the computational cost could be reduced from (NP)log(NP) to (NP)log(N) + _, scales_sorting_indices = torch.sort(transition_scales.view(-1)) + + # Book-keeping values for determining the optimal scale + sum_w_q = 0 + sum_q_squared = 0 + + optimal_scale = None + optimal_neg_error = float('-inf') + + # Update the running scale every time a quantized assignment changes, keeping the value with the lowest loss + for j in reversed(range(P * self.N)): + # Retrieved the corresponding value in the transition table + k, i = scales_sorting_indices[j] // P, scales_sorting_indices[j] % P + # The running sums need to be updated to account for the change in the quantized assignment + sum_w_q -= abs_x_sorted[i] * k + sum_q_squared -= torch.square(k) + sum_w_q += abs_x_sorted[i] * (k + 1) + sum_q_squared += torch.square(k + 1) + + neg_error = sum_w_q / torch.sqrt(sum_q_squared) + + # Check if the current value maximized the negative error. If so, update the optimal scale + if neg_error > optimal_neg_error: + optimal_neg_error = neg_error + optimal_scale = sum_w_q / sum_q_squared + + return optimal_scale + + class AbsMaxAve(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] diff --git a/tests/brevitas/core/test_opt_scale.py b/tests/brevitas/core/test_opt_scale.py new file mode 100644 index 000000000..55f83b28f --- /dev/null +++ b/tests/brevitas/core/test_opt_scale.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import pytest_cases +import torch + +from brevitas.core.stats.stats_op import OptimalIntSymmetricScale +from tests.conftest import SEED + +# Number of weights to generate +P = 100 +GRID_SEARCH_ITERS = 1000 +ATOL = 1. / GRID_SEARCH_ITERS + + +class TestScale: + + def test_optimal_scale_ternary(self): + # Quantized values are {-1, 0, 1} + N = 1 + # Generate a vector of random weights + x = torch.rand((P,), dtype=torch.float32) + + # Optimal scale in the ternary case admits a closed-form solution + # See https://arxiv.org/pdf/1707.04319 + abs_sorted_x, _ = torch.sort(torch.abs(x), descending=True) + j_optimal = torch.argmax( + torch.cumsum(abs_sorted_x, dim=-1) / torch.sqrt(torch.arange(start=1, end=P + 1))) + gt_optimal_scale = torch.sum(abs_sorted_x[:j_optimal + 1]) / (j_optimal + 1) + + optimal_int_symmetric_scale = OptimalIntSymmetricScale(N=N) + optimal_scale = optimal_int_symmetric_scale(x) + + # Compare scales + assert torch.allclose(gt_optimal_scale, optimal_scale) + + @pytest_cases.parametrize("N", [2, 3, 5]) + # Quantized values are {-N, ..., 0, ..., 1} + def test_optimal_scale_grid_search(self, N): + # Generate a vector of random weights + x = torch.rand((P,), dtype=torch.float32) + + # Compute optimal scale + optimal_int_symmetric_scale = OptimalIntSymmetricScale(N=N) + optimal_scale = optimal_int_symmetric_scale(x) + + # Compare with that obtained via grid-search + def error_closure(scale): + return torch.sum(torch.square(x - scale * torch.clamp(torch.round(x / scale), -N, N))) + + gt_optimal_scale = None + gt_optimal_error = float('inf') + + for i in range(GRID_SEARCH_ITERS): + curr_scale = torch.tensor(i / GRID_SEARCH_ITERS, dtype=torch.float32) + curr_error = error_closure(curr_scale) + if curr_error < gt_optimal_error: + gt_optimal_error = curr_error + gt_optimal_scale = curr_scale + + torch.allclose(optimal_scale, gt_optimal_scale, atol=ATOL, rtol=1e-1)