diff --git a/annolid/gui/app.py b/annolid/gui/app.py index b016573..f178a2e 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -70,6 +70,7 @@ class FlexibleWorker(QtCore.QObject): start = QtCore.Signal() finished = QtCore.Signal() return_value = QtCore.Signal(object) + stop_signal = QtCore.Signal() def __init__(self, function, *args, **kwargs): super(FlexibleWorker, self).__init__() @@ -77,12 +78,22 @@ def __init__(self, function, *args, **kwargs): self.function = function self.args = args self.kwargs = kwargs + self.stopped = False + + self.stop_signal.connect(self.stop) def run(self): + self.stopped = False result = self.function(*self.args, **self.kwargs) self.return_value.emit(result) self.finished.emit() + def stop(self): + self.stopped = True + + def is_stopped(self): + return self.stopped + class LoadFrameThread(QtCore.QObject): """Thread for loading video frames. @@ -214,6 +225,9 @@ def __init__(self, self.step_size = 1 self.stepSizeWidget = StepSizeWidget() self.prev_shapes = None + self.pred_worker = None + # Initialize a flag to control thread termination + self.stop_prediction_flag = False self.canvas = self.labelList.canvas = Canvas( epsilon=self._config["epsilon"], @@ -678,6 +692,16 @@ def closeFile(self, _value=False): self.saveButton = None self.playButton = None self.timer = None + self.filename = None + self.canvas.pixmap = None + self.event_type = None + self.highlighted_mark = None + self.stepSizeWidget = StepSizeWidget() + self.prev_shapes = None + self.pred_worker = None + self.stop_prediction_flag = False + self.imageData = None + self.frame_loader = LoadFrameThread() def toolbar(self, title, actions=None): toolbar = ToolBar(title) @@ -907,6 +931,16 @@ def _select_sam_model_name(self): return model_name + def stop_prediction(self): + # Emit the stop signal to signal the prediction thread to stop + self.pred_worker.stop_signal.emit() + self.stepSizeWidget.predict_button.setText( + "Pred") # Change button text + self.stepSizeWidget.predict_button.setStyleSheet( + "background-color: green; color: white;") + + self.stop_prediction_flag = False + def predict_from_next_frame(self, to_frame=60): if len(self.canvas.shapes) <= 0: @@ -915,52 +949,61 @@ def predict_from_next_frame(self, "No Shapes or Labeled Frames", f"Please label this frame") return - - model_name = self._select_sam_model_name() - - if self.video_file: - self.video_processor = VideoProcessor( - self.video_file, - model_name=model_name, - save_image_to_disk=False - ) - self.seg_pred_thread.start() - if self.step_size < 0: - end_frame = self.num_frames + self.step_size - else: - end_frame = self.frame_number + to_frame * self.step_size - if end_frame >= self.num_frames: - end_frame = self.num_frames - 1 - if self.step_size < 0: - self.step_size = -self.step_size - self.pred_worker = FlexibleWorker( - function=self.video_processor.process_video_frames, - start_frame=self.frame_number+1, - end_frame=end_frame, - step=self.step_size, - is_cutie=model_name == "Cutie_VOS", - mem_every=self.step_size if self.step_size > 1 else 5 - ) - self.frame_number += 1 - self.stepSizeWidget.predict_button.setEnabled(False) - self.pred_worker.moveToThread(self.seg_pred_thread) - self.pred_worker.start.connect(self.pred_worker.run) - self.seg_pred_thread.started.connect(self.pred_worker.start) - self.pred_worker.return_value.connect(self.lost_tracking_instance) - self.pred_worker.finished.connect(self.predict_is_ready) - self.seg_pred_thread.finished.connect( - self.seg_pred_thread.quit) - self.pred_worker.start.emit() + if self.pred_worker and self.stop_prediction_flag: + # If prediction is running, stop the prediction + self.stop_prediction() + else: + model_name = self._select_sam_model_name() + if self.video_file: + self.video_processor = VideoProcessor( + self.video_file, + model_name=model_name, + save_image_to_disk=False + ) + self.seg_pred_thread.start() + if self.step_size < 0: + end_frame = self.num_frames + self.step_size + else: + end_frame = self.frame_number + to_frame * self.step_size + if end_frame >= self.num_frames: + end_frame = self.num_frames - 1 + if self.step_size < 0: + self.step_size = -self.step_size + self.pred_worker = FlexibleWorker( + function=self.video_processor.process_video_frames, + start_frame=self.frame_number+1, + end_frame=end_frame, + step=self.step_size, + is_cutie=True, + mem_every=self.step_size + ) + self.video_processor.set_pred_worker(self.pred_worker) + self.frame_number += 1 + self.stepSizeWidget.predict_button.setText( + "Stop") # Change button text + self.stepSizeWidget.predict_button.setStyleSheet( + "background-color: red; color: white;") + self.stop_prediction_flag = True + self.pred_worker.moveToThread(self.seg_pred_thread) + self.pred_worker.start.connect(self.pred_worker.run) + self.seg_pred_thread.started.connect(self.pred_worker.start) + self.pred_worker.return_value.connect( + self.lost_tracking_instance) + self.pred_worker.finished.connect(self.predict_is_ready) + self.seg_pred_thread.finished.connect( + self.seg_pred_thread.quit) + self.pred_worker.start.emit() def lost_tracking_instance(self, message): if message is None: return message, current_frame_index = message.split("#") current_frame_index = int(current_frame_index) - QtWidgets.QMessageBox.information( - self, "Stop early", - message - ) + if "missing instance(s)" in message: + QtWidgets.QMessageBox.information( + self, "Stop early", + message + ) self.stepSizeWidget.predict_button.setEnabled(True) self.set_frame_number(current_frame_index) diff --git a/annolid/gui/widgets/step_size_widget.py b/annolid/gui/widgets/step_size_widget.py index b11d451..c1acbf2 100644 --- a/annolid/gui/widgets/step_size_widget.py +++ b/annolid/gui/widgets/step_size_widget.py @@ -19,6 +19,8 @@ def __init__(self, value=1): # Predict Button self.predict_button = QtWidgets.QPushButton("Pred") + self.predict_button.setStyleSheet( + "background-color: green; color: white;") # Connect valueChanged signal of QSpinBox to self.valueChanged self.step_size_spin_box.valueChanged.connect(self.emit_value_changed) diff --git a/annolid/segmentation/SAM/edge_sam_bg.py b/annolid/segmentation/SAM/edge_sam_bg.py index eebdbc4..fcb5048 100644 --- a/annolid/segmentation/SAM/edge_sam_bg.py +++ b/annolid/segmentation/SAM/edge_sam_bg.py @@ -184,6 +184,10 @@ def __init__(self, self.num_center_points = num_center_points self.center_points_dict = defaultdict() self.save_image_to_disk = save_image_to_disk + self.pred_worker = None + + def set_pred_worker(self, pred_worker): + self.pred_worker = pred_worker def load_shapes(self, label_json_file): with open(label_json_file, 'r') as json_file: @@ -191,7 +195,9 @@ def load_shapes(self, label_json_file): shapes = data.get('shapes', []) return shapes - def process_video_with_cutite(self, frames_to_propagate=100, mem_every=5): + def process_video_with_cutite(self, frames_to_propagate=100, + mem_every=5 + ): self.most_recent_file = self.get_most_recent_file() label_name_to_value = {"_background_": 0} frame_number = int( @@ -215,7 +221,8 @@ def process_video_with_cutite(self, frames_to_propagate=100, mem_every=5): mask, frames_to_propagate=frames_to_propagate, visualize_every=20, - labels_dict=label_name_to_value + labels_dict=label_name_to_value, + pred_worker=self.pred_worker ) return message @@ -422,17 +429,20 @@ def process_video_frames(self, - end_frame (int): Ending frame number. - step (int): Step between frames. """ - if is_cutie: - # always predict to the end of the video - end_frame = self.num_frames - message = self.process_video_with_cutite( - frames_to_propagate=end_frame, mem_every=mem_every) - return message - else: - if end_frame is None: + while not self.pred_worker.is_stopped(): + if is_cutie: + # always predict to the end of the video end_frame = self.num_frames - for i in range(start_frame, end_frame + 1, step): - self.process_frame(i) + message = self.process_video_with_cutite( + frames_to_propagate=end_frame, + mem_every=mem_every + ) + return message + else: + if end_frame is None: + end_frame = self.num_frames + for i in range(start_frame, end_frame + 1, step): + self.process_frame(i) def get_most_recent_file(self): """ diff --git a/annolid/segmentation/cutie_vos/predict.py b/annolid/segmentation/cutie_vos/predict.py index c60790e..ccaa9de 100644 --- a/annolid/segmentation/cutie_vos/predict.py +++ b/annolid/segmentation/cutie_vos/predict.py @@ -20,6 +20,7 @@ from pathlib import Path import gdown from annolid.utils.devices import get_device +from labelme.logger import logger """ References: @@ -67,6 +68,8 @@ def _initialize_model(self): with open_dict(cfg): cfg['weights'] = model_path cfg['mem_every'] = self.mem_every + logger.info( + f"Saving into working memeory for every: {self.mem_every}.") cutie_model = CUTIE(cfg).to(self.device).eval() model_weights = torch.load( cfg.weights, map_location=self.device) @@ -93,7 +96,9 @@ def process_video_with_mask(self, frame_number=0, mask=None, frames_to_propagate=60, visualize_every=30, - labels_dict=None): + labels_dict=None, + pred_worker=None, + ): if mask is not None: num_objects = len(np.unique(mask)) - 1 self.num_tracking_instances = num_objects @@ -101,6 +106,8 @@ def process_video_with_mask(self, frame_number=0, cap = cv2.VideoCapture(self.video_name) value_to_label_names = { v: k for k, v in labels_dict.items()} if labels_dict else {} + instance_names = set(labels_dict.keys()) + instance_names.remove('_background_') # Get the total number of frames total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if frame_number == total_frames - 1: @@ -110,52 +117,69 @@ def process_video_with_mask(self, frame_number=0, end_frame_number = frame_number + frames_to_propagate current_frame_index = frame_number + delimiter = '#' + with torch.inference_mode(): with torch.cuda.amp.autocast(enabled=self.device == 'cuda'): while cap.isOpened(): - cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_index) - _, frame = cap.read() - if frame is None or current_frame_index > end_frame_number: - break - frame_torch = image_to_torch(frame, device=self.device) - if (current_frame_index == 0 or - (current_frame_index == frame_number == 1) or - (frame_number > 1 and - current_frame_index % frame_number == 0)): - mask_torch = index_numpy_to_one_hot_torch( - mask, num_objects + 1).to(self.device) - prediction = self.processor.step( - frame_torch, mask_torch[1:], idx_mask=False) - else: - prediction = self.processor.step(frame_torch) - prediction = torch_prob_to_numpy_mask(prediction) - filename = self.video_folder / \ - (self.video_folder.name + - f"_{current_frame_index:0>{9}}.json") - mask_dict = {value_to_label_names.get(label_id, str(label_id)): (prediction == label_id) - for label_id in np.unique(prediction)[1:]} - self._save_annotation(filename, mask_dict, frame.shape) - # if we lost tracking one of the instances, return the current frame number - num_instances_in_current_frame = mask_dict.keys() - if len(num_instances_in_current_frame) < self.num_tracking_instances: - delimiter = '#' - message = ( - f"There are {self.num_tracking_instances - len(num_instances_in_current_frame)} " - f"missing instance(s) in the current frame ({current_frame_index}).\n\n" - f"Here is the list of instances detected in the current frame:\n" - f"{', '.join(str(instance) for instance in num_instances_in_current_frame)}" - ) - message_with_index = message + \ - delimiter + str(current_frame_index) - return message_with_index - if self.debug and current_frame_index % visualize_every == 0: - visualization = overlay_davis(frame, prediction) - plt.imshow( - cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)) - plt.title(str(current_frame_index)) - plt.axis('off') - plt.show() - current_frame_index += 1 + while not pred_worker.is_stopped(): + cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_index) + _, frame = cap.read() + if frame is None or current_frame_index > end_frame_number: + break + frame_torch = image_to_torch(frame, device=self.device) + if (current_frame_index == 0 or + (current_frame_index == frame_number == 1) or + (frame_number > 1 and + current_frame_index % frame_number == 0)): + mask_torch = index_numpy_to_one_hot_torch( + mask, num_objects + 1).to(self.device) + prediction = self.processor.step( + frame_torch, mask_torch[1:], idx_mask=False) + else: + prediction = self.processor.step(frame_torch) + prediction = torch_prob_to_numpy_mask(prediction) + filename = self.video_folder / \ + (self.video_folder.name + + f"_{current_frame_index:0>{9}}.json") + mask_dict = {value_to_label_names.get(label_id, str(label_id)): (prediction == label_id) + for label_id in np.unique(prediction)[1:]} + self._save_annotation(filename, mask_dict, frame.shape) + # if we lost tracking one of the instances, return the current frame number + num_instances_in_current_frame = mask_dict.keys() + if len(num_instances_in_current_frame) < self.num_tracking_instances: + missing_instances = instance_names - \ + set(num_instances_in_current_frame) + num_missing_instances = self.num_tracking_instances - \ + len(num_instances_in_current_frame) + message = ( + f"There are {num_missing_instances} missing instance(s) in the current frame ({current_frame_index}).\n\n" + f"Here is the list of instances missing or occluded in the current frame:\n" + f"Some occluded instances will be recovered automatically in the later frame:\n" + f"{', '.join(str(instance) for instance in missing_instances)}" + ) + message_with_index = message + \ + delimiter + str(current_frame_index) + logger.info(message) + # If half of the instances are missing, then stop the prediction. + if len(num_instances_in_current_frame) < self.num_tracking_instances / 2: + return message_with_index + + if self.debug and current_frame_index % visualize_every == 0: + visualization = overlay_davis(frame, prediction) + plt.imshow( + cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)) + plt.title(str(current_frame_index)) + plt.axis('off') + plt.show() + current_frame_index += 1 + break + + message = ("Stop at frame:\n") + \ + delimiter + str(current_frame_index-1) + # Release the video capture object + cap.release() + return message if __name__ == '__main__':