From 8ec2f77342fbfffc80d8feff0d6bd6777c049383 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 Nov 2023 15:44:36 +0000 Subject: [PATCH] Fix (tests): fix TruncAvgPoolTest --- .../brevitas/test_brevitas_avg_pool_export.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index c5bff8c57..aa0d28faf 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -5,17 +5,16 @@ import numpy as np import pytest -from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper import qonnx.core.onnx_exec as oxe from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import gen_finn_dt_tensor import torch from brevitas.export import export_qonnx from brevitas.nn import TruncAvgPool2d -from brevitas.quant_tensor import QuantTensor +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_activation import QuantReLU export_onnx_path = "test_brevitas_avg_pool_export.onnx" @@ -29,36 +28,38 @@ @pytest.mark.parametrize("idim", [7, 8]) def test_brevitas_avg_pool_export( kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request): + if signed: + quant_node = QuantIdentity( + bit_width=input_bit_width, + return_quant_tensor=True, + ) + else: + quant_node = QuantReLU( + bit_width=input_bit_width, + return_quant_tensor=True, + ) quant_avgpool = TruncAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor') - quant_avgpool.eval() + model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool) + model_brevitas.eval() # determine input - prefix = 'INT' if signed else 'UINT' - dt_name = prefix + str(input_bit_width) - dtype = DataType[dt_name] input_shape = (1, channels, idim, idim) - input_array = gen_finn_dt_tensor(dtype, input_shape) - scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32) - input_tensor = torch.from_numpy(input_array * scale_array).float() - scale_tensor = torch.from_numpy(scale_array).float() - zp = torch.tensor(0.) - input_quant_tensor = QuantTensor( - input_tensor, scale_tensor, zp, input_bit_width, signed, training=False) + inp = torch.randn(input_shape) # export test_id = request.node.callspec.id export_path = test_id + '_' + export_onnx_path - export_qonnx(quant_avgpool, export_path=export_path, input_t=input_quant_tensor) + export_qonnx(model_brevitas, export_path=export_path, input_t=inp) model = ModelWrapper(export_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # reference brevitas output - ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy() + ref_output_array = model_brevitas(inp).tensor.detach().numpy() # finn output - idict = {model.graph.input[0].name: input_array} + idict = {model.graph.input[0].name: inp.detach().numpy()} odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs