Skip to content

Commit

Permalink
added normalized centroid association (#1247)
Browse files Browse the repository at this point in the history
Add centroid-based cost option
  • Loading branch information
mikel-brostrom authored Jan 12, 2024
1 parent e325453 commit 0363b48
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 15 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ jobs:
IMG: ./assets/MOT17-mini/train/MOT17-05-FRCNN/img1/000001.jpg
run: |
# deepocsort fro all supported yolo models
python examples/track.py --tracking-method deepocsort --source $IMG --imgsz 320 --reid-model examples/weights/clip_market1501.pt
python examples/track.py --tracking-method deepocsort --source $IMG --imgsz 320
python examples/track.py --yolo-model yolo_nas_s --tracking-method deepocsort --source $IMG --imgsz 320
python examples/track.py --yolo-model yolox_n --tracking-method deepocsort --source $IMG --imgsz 320
# python examples/track.py --yolo-model yolox_n --tracking-method deepocsort --source $IMG --imgsz 320
# hybridsort
python examples/track.py --tracking-method hybridsort --source $IMG --imgsz 320
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
# test exported reid model
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.torchscript --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.onnx --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
#python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_openvino_model --source $IMG --imgsz 320
- name: Test tracking with seg models
Expand Down
7 changes: 4 additions & 3 deletions boxmot/appearance/reid_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def export_torchscript(model, im, file, optimize):
def export_onnx(model, im, file, opset, dynamic, fp16, simplify):
# ONNX export
try:
__tr.check_packages(("onnx",))
# required by onnx2tf
__tr.check_packages(("onnx==1.14.0",))
import onnx

f = file.with_suffix(".onnx")
Expand Down Expand Up @@ -107,7 +108,7 @@ def export_onnx(model, im, file, opset, dynamic, fp16, simplify):

def export_openvino(file, half):
__tr.check_packages(
("openvino-dev",)
("openvino-dev>=2023.0",)
) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa
Expand Down Expand Up @@ -258,7 +259,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
parser.add_argument(
"--weights",
type=Path,
default=WEIGHTS / "mobilenetv2_x1_4_dukemtmcreid.pt",
default=WEIGHTS / "osnet_x0_25_msmt17.pt",
help="model.pt path(s)",
)
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion boxmot/appearance/reid_multibackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(
elif self.onnx: # ONNX Runtime
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
cuda = torch.cuda.is_available() and device.type != "cpu"
tr.check_packages(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime", ))
# https://onnxruntime.ai/docs/reference/compatibility.html
tr.check_packages(("onnx", "onnxruntime-gpu==1.16.3" if cuda else "onnxruntime==1.16.3", ))
import onnxruntime

providers = (
Expand Down
28 changes: 28 additions & 0 deletions boxmot/utils/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,40 @@ def ciou_batch(bboxes1, bboxes2):
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)


def centroid_batch(bboxes1, bboxes2, w, h):
"""
Computes the normalized centroid distance between two sets of bounding boxes.
Bounding boxes are in the format [x1, y1, x2, y2].
`normalize_scale` is a tuple (width, height) to normalize the distance.
"""

# Calculate centroids
centroids1 = np.stack(((bboxes1[..., 0] + bboxes1[..., 2]) / 2,
(bboxes1[..., 1] + bboxes1[..., 3]) / 2), axis=-1)
centroids2 = np.stack(((bboxes2[..., 0] + bboxes2[..., 2]) / 2,
(bboxes2[..., 1] + bboxes2[..., 3]) / 2), axis=-1)

# Expand dimensions for broadcasting
centroids1 = np.expand_dims(centroids1, 1)
centroids2 = np.expand_dims(centroids2, 0)

# Calculate Euclidean distances
distances = np.sqrt(np.sum((centroids1 - centroids2) ** 2, axis=-1))

# Normalize distances
norm_factor = np.sqrt(w**2 + h**2)
normalized_distances = distances / norm_factor

return 1 - normalized_distances


def get_asso_func(asso_mode):
ASSO_FUNCS = {
"iou": iou_batch,
"giou": giou_batch,
"ciou": ciou_batch,
"diou": diou_batch,
"centroid": centroid_batch
}

return ASSO_FUNCS[asso_mode]
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

filterpy>=1.4.5 # OCSORT & DeepOCSORT

ftfy>=6.1.1 #clip
gdown>=4.7.1 # google drive model download
gdown==4.6.3 # google drive model download
GitPython>=3.1.0 # track eval cloning
lapx>=0.5.4
loguru>=0.7.0
Expand All @@ -13,7 +12,8 @@ opencv-python>=4.6.0
pandas>=1.1.4 # export matrix
pre-commit>=3.3.3
PyYAML>=5.3.1 # read tracker configs
regex>=2023.6.3 #clip
regex>=2023.6.3 # clip
ftfy>=6.1.1 # clip

scikit-learn>=1.3.0 # gsi
tensorboard>=2.13.0
Expand Down
11 changes: 6 additions & 5 deletions tests/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def test_export_openvino():
assert f is not None


def test_export_tflite(enabled=False):
f = export_tflite(
file=ONNX_WEIGHTS,
)
assert f is not None
# def test_export_tflite():
# f = export_tflite(
# file=ONNX_WEIGHTS,
# )
# print(f)
# assert f is not None

0 comments on commit 0363b48

Please sign in to comment.