Skip to content

Commit

Permalink
Merge pull request #1057 from mikel-brostrom/fix-tflite-batched-infer…
Browse files Browse the repository at this point in the history
…ence

fix tflite dynamic batch inference
  • Loading branch information
mikel-brostrom authored Aug 4, 2023
2 parents 65c2186 + 35aeff4 commit eef475f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ jobs:
- name: Pytest tests # after tracking options as this does not download models
shell: bash # for Windows compatibility
run: |
# needed in TFLite export
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.7.3/flatc.tar.gz
tar -zxvf flatc.tar.gz
sudo chmod +x flatc
sudo mv flatc /usr/bin/
pytest --cov=$PACKAGE_DIR --cov-report=html -v tests
coverage report --fail-under=$COVERAGE_FAIL_UNDER
Expand All @@ -83,7 +90,7 @@ jobs:
# test exported reid model
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.torchscript --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.onnx --source $IMG --imgsz 320
# python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_openvino_model --source $IMG --imgsz 320
- name: Test tracking with seg models
Expand Down
25 changes: 10 additions & 15 deletions boxmot/appearance/reid_multibackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,20 @@ def __init__(
network = ie.read_model(model=w, weights=Path(w).with_suffix(".bin"))
if network.get_parameters()[0].get_layout().empty:
network.get_parameters()[0].set_layout(Layout("NCWH"))
# batch_dim = get_batch(network)
# if batch_dim.is_static:
# batch_size = batch_dim.get_length()
self.executable_network = ie.compile_model(
network, device_name="CPU"
) # device_name="MYRIAD" for Intel NCS2
self.output_layer = next(iter(self.executable_network.outputs))

elif self.tflite:
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")

import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path=str(w))
print(interpreter.get_signature_list())
self.tf_lite_model = interpreter.get_signature_runner()
inputs = {
'images': np.ones([5, 256, 128, 3], dtype=np.float32),
}
tf_lite_output = self.tf_lite_model(**inputs)
print(f"[TFLite] Model Predictions shape: {tf_lite_output['output'].shape}")
print("[TFLite] Model Predictions:")
try:
self.tf_lite_model = interpreter.get_signature_runner()
except Exception as e:
LOGGER.error(f'{e}. If SignatureDef error. Export you model with the official onn2tf docker')
exit()
else:
LOGGER.error("This model framework is not supported yet!")
exit()
Expand Down Expand Up @@ -239,10 +232,12 @@ def forward(self, im_batch):
{self.session.get_inputs()[0].name: im_batch},
)[0]
elif self.tflite:
print(im_batch.shape)
im_batch = im_batch.cpu().numpy()
features = self.tf_lite_model(im_batch)
print(features)
inputs = {
'images': im_batch,
}
tf_lite_output = self.tf_lite_model(**inputs)
features = tf_lite_output['output']

elif self.engine: # TensorRT
if True and im_batch.shape != self.bindings["images"].shape:
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def get_version():
'nvidia-pyindex', # TensorRT export
'nvidia-tensorrt', # TensorRT export
'openvino-dev>=2022.3', # OpenVINO export
'onnx2tf>=1.10.0' # TFLite export
'onnx2tf>=1.10.0', # TFLite export
'onnx_graphsurgeon', # TFLite export
'sng4onnx', # TFLite export
],
'evolve': [
'optuna', # ONNX export
Expand Down
12 changes: 6 additions & 6 deletions tests/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from boxmot.appearance.backbones import build_model
from boxmot.appearance.reid_export import (export_onnx, export_openvino,
export_torchscript)
export_tflite, export_torchscript)
from boxmot.appearance.reid_model_factory import (get_model_name,
load_pretrained_weights)
from boxmot.utils import WEIGHTS
Expand Down Expand Up @@ -55,8 +55,8 @@ def test_export_openvino():
assert f is not None


# def test_export_tflite(enabled=False):
# f = export_tflite(
# file=ONNX_WEIGHTS,
# )
# assert f is not None
def test_export_tflite(enabled=False):
f = export_tflite(
file=ONNX_WEIGHTS,
)
assert f is not None

0 comments on commit eef475f

Please sign in to comment.