From 966de1288eac946557ce880e3f8562b2422eaa79 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Nov 2023 13:29:52 +0100 Subject: [PATCH] Fix (notebook): increase atol for asserts (#759) --- notebooks/quantized_recurrent.ipynb | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/notebooks/quantized_recurrent.ipynb b/notebooks/quantized_recurrent.ipynb index f032c442c..766e82745 100644 --- a/notebooks/quantized_recurrent.ipynb +++ b/notebooks/quantized_recurrent.ipynb @@ -636,6 +636,7 @@ "from torch.nn import RNN\n", "from brevitas.nn import QuantRNN\n", "from brevitas import config\n", + "ATOL = 1e-6\n", "\n", "config.IGNORE_MISSING_KEYS = True\n", "torch.manual_seed(123456)\n", @@ -648,12 +649,10 @@ "\n", "# Generate random input\n", "inp = torch.randn(5, 2, 10)\n", - "\n", "# Check outputs are the same\n", - "assert torch.isclose(quant_rnn(inp)[0], float_rnn(inp)[0]).all().item()\n", - "\n", + "assert torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL)\n", "# Check hidden states are the same\n", - "assert torch.isclose(quant_rnn(inp)[1], float_rnn(inp)[1]).all().item()" + "assert torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL)" ] }, {