From ee0beba06d619ff1c08bf1b7b25665d353abca20 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 6 Jan 2025 13:00:47 +0000 Subject: [PATCH] Fix more tests --- tests/brevitas/core/test_scaling_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py index ba3d8ef7c..d57ffbe42 100644 --- a/tests/brevitas/core/test_scaling_quant.py +++ b/tests/brevitas/core/test_scaling_quant.py @@ -112,13 +112,13 @@ def hook_scale(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp) assert bit_width == SCALE_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) def hook_zp(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.zp_int_quant(inp) assert bit_width == ZP_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat) for module in linear.modules():