Skip to content

Commit

Permalink
Fix (notebook): increase atol for asserts (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Nov 15, 2023
1 parent e207311 commit 966de12
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions notebooks/quantized_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@
"from torch.nn import RNN\n",
"from brevitas.nn import QuantRNN\n",
"from brevitas import config\n",
"ATOL = 1e-6\n",
"\n",
"config.IGNORE_MISSING_KEYS = True\n",
"torch.manual_seed(123456)\n",
Expand All @@ -648,12 +649,10 @@
"\n",
"# Generate random input\n",
"inp = torch.randn(5, 2, 10)\n",
"\n",
"# Check outputs are the same\n",
"assert torch.isclose(quant_rnn(inp)[0], float_rnn(inp)[0]).all().item()\n",
"\n",
"assert torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL)\n",
"# Check hidden states are the same\n",
"assert torch.isclose(quant_rnn(inp)[1], float_rnn(inp)[1]).all().item()"
"assert torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL)"
]
},
{
Expand Down

0 comments on commit 966de12

Please sign in to comment.