From 23bf3753966aba96773bb7d7be7848d18c2f71dc Mon Sep 17 00:00:00 2001 From: healthonrails Date: Tue, 5 Mar 2024 10:24:06 -0500 Subject: [PATCH] Add 'has_occlusion' parameter with default value True This commit introduces a new parameter 'has_occlusion' with a default value of True. The tracking mechanism now stops early under two conditions: when more than half of the tracking instances are lost, or when there is no occlusion detected in the video and at least one instance is lost during tracking. In other cases, the system will log a message indicating the event. --- annolid/segmentation/cutie_vos/predict.py | 24 ++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/annolid/segmentation/cutie_vos/predict.py b/annolid/segmentation/cutie_vos/predict.py index 255771a..068e60c 100644 --- a/annolid/segmentation/cutie_vos/predict.py +++ b/annolid/segmentation/cutie_vos/predict.py @@ -16,7 +16,6 @@ from hydra import compose, initialize from annolid.segmentation.cutie_vos.model.cutie import CUTIE from annolid.segmentation.cutie_vos.inference.inference_core import InferenceCore -from annolid.segmentation.cutie_vos.inference.utils.args_utils import get_dataset_cfg from pathlib import Path import gdown from annolid.utils.devices import get_device @@ -108,7 +107,8 @@ def process_video_with_mask(self, frame_number=0, labels_dict=None, pred_worker=None, recording=True, - output_video_path=None + output_video_path=None, + has_occlusion=True, ): if mask is not None: num_objects = len(np.unique(mask)) - 1 @@ -183,16 +183,18 @@ def process_video_with_mask(self, frame_number=0, message_with_index = message + \ delimiter + str(current_frame_index) logger.info(message) - # If half of the instances are missing, then stop the prediction. - if len(num_instances_in_current_frame) < self.num_tracking_instances / 2: + # Stop the prediction if more than half of the instances are missing, + # or when there is no occlusion in the video and one instance loses tracking. + if (not has_occlusion or + len(num_instances_in_current_frame) < self.num_tracking_instances / 2 + ): + pred_worker.stop_signal.emit() + # Release the video capture object + cap.release() + # Release the video writer if recording is set to True + if recording: + self.video_writer.release() return message_with_index - pred_worker.stop_signal.emit() - # Release the video capture object - cap.release() - # Release the video writer if recording is set to True - if recording: - self.video_writer.release() - return message_with_index if recording: visualization = overlay_davis(frame, prediction) # Write the frame to the video file