From 7b1f46fa36ae1dc02abdfa5ceb9aee423c1f71b2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Nov 2023 13:25:29 +0000 Subject: [PATCH] Fix (tests): adapt tests to new logic --- tests/brevitas/core/test_stats.py | 6 ++++-- tests/brevitas/fx/test_tracer.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/core/test_stats.py b/tests/brevitas/core/test_stats.py index 24586131e..224f323fc 100644 --- a/tests/brevitas/core/test_stats.py +++ b/tests/brevitas/core/test_stats.py @@ -63,10 +63,12 @@ def test_zero_percentile(self): def test_interval_percentile(self): values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - values = torch.tensor(values) + values = torch.tensor(values, dtype=torch.float32) interval_percentile = PercentileInterval(low_percentile_q=0.01, high_percentile_q=99.9) out = interval_percentile(values) range = self.compute_percentile(values, low_q=0.01, high_q=99.9) - expected_out = torch.abs(range[1] - range[0]) + # Clamp is to make sure the lower bound is not positive to align with zero-point statistics + low_result = torch.clamp(range[0], max=torch.tensor(0.0)) + expected_out = torch.abs(range[1] - low_result) assert torch.allclose(out, expected_out) diff --git a/tests/brevitas/fx/test_tracer.py b/tests/brevitas/fx/test_tracer.py index 23a8efc95..be5698d2c 100644 --- a/tests/brevitas/fx/test_tracer.py +++ b/tests/brevitas/fx/test_tracer.py @@ -238,4 +238,4 @@ def test_quant_module(module): out = mod(x) graph_model = value_trace(mod, value_args={'x': x_trace}) graph_out = graph_model(x) - assert graph_out.value.isclose(out.value).all().item() + assert graph_out.isclose(out).all().item()