Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 14, 2023
1 parent d27799c commit 809fc62
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
FC_INPUT_SIZE = (1, 1, 28, 28)
CNV_INPUT_SIZE = (1, 3, 32, 32)

MIN_INP_VAL = 0
MAX_INP_VAL = 255

MAX_WBITS = 2
MAX_ABITS = 2

Expand All @@ -48,12 +45,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
fc.eval()
# load a random int test vector
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
# input_qt = QuantTensor(
# input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_qonnx(fc, export_path=finn_onnx, input_t=input_t, input_names=['input'])
input = torch.randn(FC_INPUT_SIZE)

export_qonnx(fc, export_path=finn_onnx, input_t=input, input_names=['input'])
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand All @@ -62,11 +56,11 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
model = model.transform(RemoveStaticGraphInputs())

# run using FINN-based execution
input_dict = {'input': input_t}
input_dict = {'input': input.detach().numpy()}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
expected = fc.forward(input_t).detach().numpy()
expected = fc.forward(input).detach().numpy()
assert np.isclose(produced, expected, atol=ATOL).all()


Expand All @@ -84,13 +78,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
cnv, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
cnv.eval()
# load a random int test vector
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=CNV_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
# QONNX Export does not expect QuantTensor, only Tensor
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_qonnx(cnv, export_path=finn_onnx, input_t=input_qt, input_names=['input'])
input = torch.randn(CNV_INPUT_SIZE)

export_qonnx(cnv, export_path=finn_onnx, input_t=input, input_names=['input'])
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand All @@ -99,9 +89,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
model = model.transform(RemoveStaticGraphInputs())

# run using FINN-based execution
input_dict = {"input": input_a}
input_dict = {"input": input.detach().numpy()}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
expected = cnv(input_t).detach().numpy()
expected = cnv(input).detach().numpy()
assert np.isclose(produced, expected, atol=ATOL).all()

0 comments on commit 809fc62

Please sign in to comment.