diff --git a/notebooks/quantized_recurrent.ipynb b/notebooks/quantized_recurrent.ipynb index f032c442c..766e82745 100644 --- a/notebooks/quantized_recurrent.ipynb +++ b/notebooks/quantized_recurrent.ipynb @@ -636,6 +636,7 @@ "from torch.nn import RNN\n", "from brevitas.nn import QuantRNN\n", "from brevitas import config\n", + "ATOL = 1e-6\n", "\n", "config.IGNORE_MISSING_KEYS = True\n", "torch.manual_seed(123456)\n", @@ -648,12 +649,10 @@ "\n", "# Generate random input\n", "inp = torch.randn(5, 2, 10)\n", - "\n", "# Check outputs are the same\n", - "assert torch.isclose(quant_rnn(inp)[0], float_rnn(inp)[0]).all().item()\n", - "\n", + "assert torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL)\n", "# Check hidden states are the same\n", - "assert torch.isclose(quant_rnn(inp)[1], float_rnn(inp)[1]).all().item()" + "assert torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL)" ] }, {