diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py index ba3d8ef7c..d57ffbe42 100644 --- a/tests/brevitas/core/test_scaling_quant.py +++ b/tests/brevitas/core/test_scaling_quant.py @@ -112,13 +112,13 @@ def hook_scale(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp) assert bit_width == SCALE_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) def hook_zp(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.zp_int_quant(inp) assert bit_width == ZP_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat) for module in linear.modules():