From 457b4661944012b142313b3b85039b1932c12f3f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 21 Jan 2025 16:18:51 +0100 Subject: [PATCH] Fix (runtime_act): fix negative group_dim handling (#1157) --- src/brevitas/export/inference/handler.py | 4 ++-- src/brevitas/proxy/groupwise_float_parameter_quant.py | 2 +- src/brevitas/proxy/groupwise_float_runtime_quant.py | 2 +- src/brevitas/proxy/groupwise_int_parameter_quant.py | 2 +- src/brevitas/proxy/groupwise_int_runtime_quant.py | 2 +- src/brevitas/quant/solver/common.py | 2 +- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 2 +- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 2 +- src/brevitas/utils/quant_utils.py | 5 ++--- src/brevitas_examples/common/generative/quantize.py | 4 ++-- 10 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 59944c2b0..d6646c346 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -131,7 +131,7 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) return output_args @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) return output_args diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 206e983b5..f50341acd 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -30,7 +30,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 5d76e4635..537f5a9ff 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -23,7 +23,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor( diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 51ff97c28..d515e023f 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -30,7 +30,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 96d047808..de19c6b0e 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -23,7 +23,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor( diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 69b4c9438..f5d30e65e 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -181,7 +181,7 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - reduce_dim = group_dim + 1 if group_dim != -1 else -1 + reduce_dim = group_dim + 1 if group_dim >= 0 else group_dim return reduce_dim @value diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 60c5ba84f..70a1fa865 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -97,7 +97,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): - group_dim = group_dim if group_dim != -1 else -2 + group_dim = group_dim if group_dim >= 0 else group_dim - 1 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index fa7e8438e..6c6b8f8bb 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -83,7 +83,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): - group_dim = group_dim if group_dim != -1 else -2 + group_dim = group_dim if group_dim >= 0 else group_dim - 1 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index d0d245089..551d1c9ea 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -220,9 +220,8 @@ def float_to_int_impl_to_enum(module): def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape): - final_shape = dequant_shape curr_shape = value_.shape - start_dim = group_dim if group_dim != -1 else -2 + start_dim = group_dim if group_dim >= 0 else group_dim - 1 new_value = value_.flatten(start_dim, start_dim + 1) if scale_.shape != (): new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) @@ -237,7 +236,7 @@ def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_sha # First, we compute how much we padded along the group_dim shape # Then, we unbind the tensor along the group_dim shape, and drop the padded columns # Finally, we stack the remaining tensors - unpadding_shape = final_shape[group_dim] + unpadding_shape = dequant_shape[group_dim] residual = new_value.shape[group_dim] - unpadding_shape if residual > 0: diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index d845f58a6..949a93065 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -395,8 +395,8 @@ def generate_quantizers( 'group_dim': -1, 'group_size': input_group_size}) k_transposed_quant = sym_input_quant.let( **{ - 'group_dim': -1, 'group_size': input_group_size}) - v_quant = q_scaled_quant + 'group_dim': -2, 'group_size': input_group_size}) + v_quant = k_transposed_quant attn_output_weights_quant = q_scaled_quant else: q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = sym_input_quant