From 848cc7d443fe1244700feb54a10a2f7d382323f6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 19:12:55 +0000 Subject: [PATCH] missing return --- src/brevitas/proxy/float_parameter_quant.py | 16 ++++++++-------- .../proxy/groupwise_float_parameter_quant.py | 4 ++-- .../proxy/groupwise_int_parameter_quant.py | 4 ++-- src/brevitas/proxy/parameter_quant.py | 6 +++--- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index de455db31..027625f11 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -22,28 +22,28 @@ def bit_width(self): return bit_width def scale(self): - self.retrieve_attribute('scale') + return self.retrieve_attribute('scale') def zero_point(self): - self.retrieve_attribute('zero_point') + return self.retrieve_attribute('zero_point') def exponent_bit_width(self): - self.retrieve_attribute('exponent_bit_width') + return self.retrieve_attribute('exponent_bit_width') def mantissa_bit_width(self): - self.retrieve_attribute('mantissa_bit_width') + return self.retrieve_attribute('mantissa_bit_width') def exponent_bias(self): - self.retrieve_attribute('exponent_bias') + return self.retrieve_attribute('exponent_bias') def is_saturating(self): - self.retrieve_attribute('is_saturating') + return self.retrieve_attribute('saturating') def inf_values(self): - self.retrieve_attribute('inf_values') + return self.retrieve_attribute('inf_values') def nan_values(self): - self.retrieve_attribute('nan_values') + return self.retrieve_attribute('nan_values') @property def is_ocp(self): diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 7c55c1958..12aacd23b 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -15,10 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseFloat def scale_(self): - self.retrieve_attribute('scale_') + return self.retrieve_attribute('scale_') def zero_point_(self): - self.retrieve_attribute('zero_point_') + return self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index a9fd169f7..905e50c52 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -15,10 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseInt def scale_(self): - self.retrieve_attribute('scale_') + return self.retrieve_attribute('scale_') def zero_point_(self): - self.retrieve_attribute('zero_point_') + return self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 85b74296d..2ca0afe92 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -202,13 +202,13 @@ def requires_quant_input(self): return False def scale(self): - self.retrieve_attribute('scale') + return self.retrieve_attribute('scale') def zero_point(self): - self.retrieve_attribute('zero_point') + return self.retrieve_attribute('zero_point') def bit_width(self): - self.retrieve_attribute('bit_width') + return self.retrieve_attribute('bit_width') def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training)