Skip to content

Commit

Permalink
[Feature] Support ONNX and TensorRT exportation of RTMO models (#2597)
Browse files Browse the repository at this point in the history
* support ONNX&TensorRT exportation of RTMO

* add configs for rtmo

* replace bbox expansion factor with parameter bbox_padding

* refine code

* refine comment

* apply model.switch_to_deploy in BaseTask.build_pytorch_model

* fix lint

* add rtmo into regression test

* add rtmo with trt backend into regression test

* add rtmo into supported model list
  • Loading branch information
Ben-Louis authored Dec 14, 2023
1 parent 6ff3c93 commit 1e3d06d
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 2 deletions.
25 changes: 25 additions & 0 deletions configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']

onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})

codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=2000,
keep_top_k=50,
background_label_id=-1,
))
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt-fp16.py']

onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])

codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=2000,
keep_top_k=50,
background_label_id=-1,
))
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ TODO
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N |
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ task_processor.visualize(
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N |
5 changes: 5 additions & 0 deletions mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def build_pytorch_model(self,
if hasattr(model, 'backbone') and hasattr(model.backbone,
'switch_to_deploy'):
model.backbone.switch_to_deploy()

if hasattr(model, 'switch_to_deploy') and callable(
model.switch_to_deploy):
model.switch_to_deploy()

model = model.to(self.device)
model.eval()
return model
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403
from . import mspn_head, rtmo_head, simcc_head, yolox_pose_head

__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head']
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head']
100 changes: 100 additions & 0 deletions mmdeploy/codebase/mmpose/models/heads/rtmo_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import torch
from mmpose.structures.bbox import bbox_xyxy2cs
from torch import Tensor

from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms
from mmdeploy.utils import Backend, get_backend


@FUNCTION_REWRITER.register_rewriter(
func_name='mmpose.models.heads.hybrid_heads.'
'rtmo_head.RTMOHead.forward')
def predict(self,
x: Tuple[Tensor],
batch_data_samples: List = [],
test_cfg: Optional[dict] = None):
"""Get predictions and transform to bbox and keypoints results.
Args:
x (Tuple[Tensor]): The input tensor from upstream network.
batch_data_samples: Batch image meta info. Defaults to None.
test_cfg: The runtime config for testing process.
Returns:
Tuple[Tensor]: Predict bbox and keypoint results.
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
has shape (batch_size, num_instances, 5), the last dimension 5
arrange as (x1, y1, x2, y2, score).
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
the last dimension 3 arrange as (x, y, score).
"""

# deploy context
ctx = FUNCTION_REWRITER.get_context()
backend = get_backend(ctx.cfg)
deploy_cfg = ctx.cfg

cfg = self.test_cfg if test_cfg is None else test_cfg

# get predictions
cls_scores, bbox_preds, _, kpt_vis, pose_vecs = self.head_module(x)[:5]
assert len(cls_scores) == len(bbox_preds)
num_imgs = cls_scores[0].shape[0]

# flatten and concat predictions
scores = self._flatten_predictions(cls_scores).sigmoid()
flatten_bbox_preds = self._flatten_predictions(bbox_preds)
flatten_pose_vecs = self._flatten_predictions(pose_vecs)
flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid()
bboxes = self.decode_bbox(flatten_bbox_preds, self.flatten_priors,
self.flatten_stride)

if backend == Backend.TENSORRT:
# pad for batched_nms because its output index is filled with -1
bboxes = torch.cat(
[bboxes,
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
dim=1)

scores = torch.cat(
[scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1)

# nms parameters
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.get('nms_thr', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.get('pre_top_k', -1)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

# do nms
_, _, nms_indices = multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True)

batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1)

# filter predictions
dets = torch.cat([bboxes, scores], dim=2)
dets = dets[batch_inds, nms_indices, ...]
pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...]
kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...]
grids = self.flatten_priors[nms_indices, ...]

# decode keypoints
bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1)
keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids)
pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1)

return dets, pred_kpts
14 changes: 14 additions & 0 deletions tests/regression/mmpose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,17 @@ models:
input_img: *img_human_pose
test_img: *img_human_pose
deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py

- name: RTMO
metafile: configs/body_2d_keypoint/rtmo/body7/rtmo_body7.yml
model_configs:
- configs/body_2d_keypoint/rtmo/body7/rtmo-s_8xb32-600e_body7-640x640.py
pipelines:
- convert_image:
input_img: *img_human_pose
test_img: *img_human_pose
deploy_config: configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py
- convert_image:
input_img: *img_human_pose
test_img: *img_human_pose
deploy_config: configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py

0 comments on commit 1e3d06d

Please sign in to comment.