Skip to content

Commit

Permalink
Fix (tests): fix TruncAvgPoolTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 14, 2023
1 parent 809fc62 commit 8ec2f77
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

import numpy as np
import pytest
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
import qonnx.core.onnx_exec as oxe
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import gen_finn_dt_tensor
import torch

from brevitas.export import export_qonnx
from brevitas.nn import TruncAvgPool2d
from brevitas.quant_tensor import QuantTensor
from brevitas.nn.quant_activation import QuantIdentity
from brevitas.nn.quant_activation import QuantReLU

export_onnx_path = "test_brevitas_avg_pool_export.onnx"

Expand All @@ -29,36 +28,38 @@
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request):
if signed:
quant_node = QuantIdentity(
bit_width=input_bit_width,
return_quant_tensor=True,
)
else:
quant_node = QuantReLU(
bit_width=input_bit_width,
return_quant_tensor=True,
)

quant_avgpool = TruncAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor')
quant_avgpool.eval()
model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool)
model_brevitas.eval()

# determine input
prefix = 'INT' if signed else 'UINT'
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.)
input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed, training=False)
inp = torch.randn(input_shape)

# export
test_id = request.node.callspec.id
export_path = test_id + '_' + export_onnx_path
export_qonnx(quant_avgpool, export_path=export_path, input_t=input_quant_tensor)
export_qonnx(model_brevitas, export_path=export_path, input_t=inp)
model = ModelWrapper(export_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
ref_output_array = model_brevitas(inp).tensor.detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
idict = {model.graph.input[0].name: inp.detach().numpy()}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
Expand Down

0 comments on commit 8ec2f77

Please sign in to comment.