Skip to content

Commit

Permalink
Improve prediction button behavior:
Browse files Browse the repository at this point in the history
- Set the prediction button to default green.
- When prediction is running, change it to a stop button with red.
- Clicking stop halts prediction, moving to the last predicted frame.
- Display a pop-up confirming successful stop and readiness of predictions.

Additionally:
- Properly close video-related resources when closing a file to work on a new video.
- Log messages for missing prediction instances based on provided labeled instances.
- Explain that occluded instances might be recovered by the model automatically later.
  • Loading branch information
healthonrails committed Feb 29, 2024
1 parent de35f2a commit 1040a1f
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 96 deletions.
123 changes: 83 additions & 40 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,30 @@ class FlexibleWorker(QtCore.QObject):
start = QtCore.Signal()
finished = QtCore.Signal()
return_value = QtCore.Signal(object)
stop_signal = QtCore.Signal()

def __init__(self, function, *args, **kwargs):
super(FlexibleWorker, self).__init__()

self.function = function
self.args = args
self.kwargs = kwargs
self.stopped = False

self.stop_signal.connect(self.stop)

def run(self):
self.stopped = False
result = self.function(*self.args, **self.kwargs)
self.return_value.emit(result)
self.finished.emit()

def stop(self):
self.stopped = True

def is_stopped(self):
return self.stopped


class LoadFrameThread(QtCore.QObject):
"""Thread for loading video frames.
Expand Down Expand Up @@ -214,6 +225,9 @@ def __init__(self,
self.step_size = 1
self.stepSizeWidget = StepSizeWidget()
self.prev_shapes = None
self.pred_worker = None
# Initialize a flag to control thread termination
self.stop_prediction_flag = False

self.canvas = self.labelList.canvas = Canvas(
epsilon=self._config["epsilon"],
Expand Down Expand Up @@ -678,6 +692,16 @@ def closeFile(self, _value=False):
self.saveButton = None
self.playButton = None
self.timer = None
self.filename = None
self.canvas.pixmap = None
self.event_type = None
self.highlighted_mark = None
self.stepSizeWidget = StepSizeWidget()
self.prev_shapes = None
self.pred_worker = None
self.stop_prediction_flag = False
self.imageData = None
self.frame_loader = LoadFrameThread()

def toolbar(self, title, actions=None):
toolbar = ToolBar(title)
Expand Down Expand Up @@ -907,6 +931,16 @@ def _select_sam_model_name(self):

return model_name

def stop_prediction(self):
# Emit the stop signal to signal the prediction thread to stop
self.pred_worker.stop_signal.emit()
self.stepSizeWidget.predict_button.setText(
"Pred") # Change button text
self.stepSizeWidget.predict_button.setStyleSheet(
"background-color: green; color: white;")

self.stop_prediction_flag = False

def predict_from_next_frame(self,
to_frame=60):
if len(self.canvas.shapes) <= 0:
Expand All @@ -915,52 +949,61 @@ def predict_from_next_frame(self,
"No Shapes or Labeled Frames",
f"Please label this frame")
return

model_name = self._select_sam_model_name()

if self.video_file:
self.video_processor = VideoProcessor(
self.video_file,
model_name=model_name,
save_image_to_disk=False
)
self.seg_pred_thread.start()
if self.step_size < 0:
end_frame = self.num_frames + self.step_size
else:
end_frame = self.frame_number + to_frame * self.step_size
if end_frame >= self.num_frames:
end_frame = self.num_frames - 1
if self.step_size < 0:
self.step_size = -self.step_size
self.pred_worker = FlexibleWorker(
function=self.video_processor.process_video_frames,
start_frame=self.frame_number+1,
end_frame=end_frame,
step=self.step_size,
is_cutie=model_name == "Cutie_VOS",
mem_every=self.step_size if self.step_size > 1 else 5
)
self.frame_number += 1
self.stepSizeWidget.predict_button.setEnabled(False)
self.pred_worker.moveToThread(self.seg_pred_thread)
self.pred_worker.start.connect(self.pred_worker.run)
self.seg_pred_thread.started.connect(self.pred_worker.start)
self.pred_worker.return_value.connect(self.lost_tracking_instance)
self.pred_worker.finished.connect(self.predict_is_ready)
self.seg_pred_thread.finished.connect(
self.seg_pred_thread.quit)
self.pred_worker.start.emit()
if self.pred_worker and self.stop_prediction_flag:
# If prediction is running, stop the prediction
self.stop_prediction()
else:
model_name = self._select_sam_model_name()
if self.video_file:
self.video_processor = VideoProcessor(
self.video_file,
model_name=model_name,
save_image_to_disk=False
)
self.seg_pred_thread.start()
if self.step_size < 0:
end_frame = self.num_frames + self.step_size
else:
end_frame = self.frame_number + to_frame * self.step_size
if end_frame >= self.num_frames:
end_frame = self.num_frames - 1
if self.step_size < 0:
self.step_size = -self.step_size
self.pred_worker = FlexibleWorker(
function=self.video_processor.process_video_frames,
start_frame=self.frame_number+1,
end_frame=end_frame,
step=self.step_size,
is_cutie=True,
mem_every=self.step_size
)
self.video_processor.set_pred_worker(self.pred_worker)
self.frame_number += 1
self.stepSizeWidget.predict_button.setText(
"Stop") # Change button text
self.stepSizeWidget.predict_button.setStyleSheet(
"background-color: red; color: white;")
self.stop_prediction_flag = True
self.pred_worker.moveToThread(self.seg_pred_thread)
self.pred_worker.start.connect(self.pred_worker.run)
self.seg_pred_thread.started.connect(self.pred_worker.start)
self.pred_worker.return_value.connect(
self.lost_tracking_instance)
self.pred_worker.finished.connect(self.predict_is_ready)
self.seg_pred_thread.finished.connect(
self.seg_pred_thread.quit)
self.pred_worker.start.emit()

def lost_tracking_instance(self, message):
if message is None:
return
message, current_frame_index = message.split("#")
current_frame_index = int(current_frame_index)
QtWidgets.QMessageBox.information(
self, "Stop early",
message
)
if "missing instance(s)" in message:
QtWidgets.QMessageBox.information(
self, "Stop early",
message
)
self.stepSizeWidget.predict_button.setEnabled(True)
self.set_frame_number(current_frame_index)

Expand Down
2 changes: 2 additions & 0 deletions annolid/gui/widgets/step_size_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, value=1):

# Predict Button
self.predict_button = QtWidgets.QPushButton("Pred")
self.predict_button.setStyleSheet(
"background-color: green; color: white;")

# Connect valueChanged signal of QSpinBox to self.valueChanged
self.step_size_spin_box.valueChanged.connect(self.emit_value_changed)
Expand Down
34 changes: 22 additions & 12 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,20 @@ def __init__(self,
self.num_center_points = num_center_points
self.center_points_dict = defaultdict()
self.save_image_to_disk = save_image_to_disk
self.pred_worker = None

def set_pred_worker(self, pred_worker):
self.pred_worker = pred_worker

def load_shapes(self, label_json_file):
with open(label_json_file, 'r') as json_file:
data = json.load(json_file)
shapes = data.get('shapes', [])
return shapes

def process_video_with_cutite(self, frames_to_propagate=100, mem_every=5):
def process_video_with_cutite(self, frames_to_propagate=100,
mem_every=5
):
self.most_recent_file = self.get_most_recent_file()
label_name_to_value = {"_background_": 0}
frame_number = int(
Expand All @@ -215,7 +221,8 @@ def process_video_with_cutite(self, frames_to_propagate=100, mem_every=5):
mask,
frames_to_propagate=frames_to_propagate,
visualize_every=20,
labels_dict=label_name_to_value
labels_dict=label_name_to_value,
pred_worker=self.pred_worker
)
return message

Expand Down Expand Up @@ -422,17 +429,20 @@ def process_video_frames(self,
- end_frame (int): Ending frame number.
- step (int): Step between frames.
"""
if is_cutie:
# always predict to the end of the video
end_frame = self.num_frames
message = self.process_video_with_cutite(
frames_to_propagate=end_frame, mem_every=mem_every)
return message
else:
if end_frame is None:
while not self.pred_worker.is_stopped():
if is_cutie:
# always predict to the end of the video
end_frame = self.num_frames
for i in range(start_frame, end_frame + 1, step):
self.process_frame(i)
message = self.process_video_with_cutite(
frames_to_propagate=end_frame,
mem_every=mem_every
)
return message
else:
if end_frame is None:
end_frame = self.num_frames
for i in range(start_frame, end_frame + 1, step):
self.process_frame(i)

def get_most_recent_file(self):
"""
Expand Down
112 changes: 68 additions & 44 deletions annolid/segmentation/cutie_vos/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pathlib import Path
import gdown
from annolid.utils.devices import get_device
from labelme.logger import logger

"""
References:
Expand Down Expand Up @@ -67,6 +68,8 @@ def _initialize_model(self):
with open_dict(cfg):
cfg['weights'] = model_path
cfg['mem_every'] = self.mem_every
logger.info(
f"Saving into working memeory for every: {self.mem_every}.")
cutie_model = CUTIE(cfg).to(self.device).eval()
model_weights = torch.load(
cfg.weights, map_location=self.device)
Expand All @@ -93,14 +96,18 @@ def process_video_with_mask(self, frame_number=0,
mask=None,
frames_to_propagate=60,
visualize_every=30,
labels_dict=None):
labels_dict=None,
pred_worker=None,
):
if mask is not None:
num_objects = len(np.unique(mask)) - 1
self.num_tracking_instances = num_objects
self.processor = InferenceCore(self.cutie, cfg=self.cfg)
cap = cv2.VideoCapture(self.video_name)
value_to_label_names = {
v: k for k, v in labels_dict.items()} if labels_dict else {}
instance_names = set(labels_dict.keys())
instance_names.remove('_background_')
# Get the total number of frames
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_number == total_frames - 1:
Expand All @@ -110,52 +117,69 @@ def process_video_with_mask(self, frame_number=0,
end_frame_number = frame_number + frames_to_propagate
current_frame_index = frame_number

delimiter = '#'

with torch.inference_mode():
with torch.cuda.amp.autocast(enabled=self.device == 'cuda'):
while cap.isOpened():
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_index)
_, frame = cap.read()
if frame is None or current_frame_index > end_frame_number:
break
frame_torch = image_to_torch(frame, device=self.device)
if (current_frame_index == 0 or
(current_frame_index == frame_number == 1) or
(frame_number > 1 and
current_frame_index % frame_number == 0)):
mask_torch = index_numpy_to_one_hot_torch(
mask, num_objects + 1).to(self.device)
prediction = self.processor.step(
frame_torch, mask_torch[1:], idx_mask=False)
else:
prediction = self.processor.step(frame_torch)
prediction = torch_prob_to_numpy_mask(prediction)
filename = self.video_folder / \
(self.video_folder.name +
f"_{current_frame_index:0>{9}}.json")
mask_dict = {value_to_label_names.get(label_id, str(label_id)): (prediction == label_id)
for label_id in np.unique(prediction)[1:]}
self._save_annotation(filename, mask_dict, frame.shape)
# if we lost tracking one of the instances, return the current frame number
num_instances_in_current_frame = mask_dict.keys()
if len(num_instances_in_current_frame) < self.num_tracking_instances:
delimiter = '#'
message = (
f"There are {self.num_tracking_instances - len(num_instances_in_current_frame)} "
f"missing instance(s) in the current frame ({current_frame_index}).\n\n"
f"Here is the list of instances detected in the current frame:\n"
f"{', '.join(str(instance) for instance in num_instances_in_current_frame)}"
)
message_with_index = message + \
delimiter + str(current_frame_index)
return message_with_index
if self.debug and current_frame_index % visualize_every == 0:
visualization = overlay_davis(frame, prediction)
plt.imshow(
cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB))
plt.title(str(current_frame_index))
plt.axis('off')
plt.show()
current_frame_index += 1
while not pred_worker.is_stopped():
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_index)
_, frame = cap.read()
if frame is None or current_frame_index > end_frame_number:
break
frame_torch = image_to_torch(frame, device=self.device)
if (current_frame_index == 0 or
(current_frame_index == frame_number == 1) or
(frame_number > 1 and
current_frame_index % frame_number == 0)):
mask_torch = index_numpy_to_one_hot_torch(
mask, num_objects + 1).to(self.device)
prediction = self.processor.step(
frame_torch, mask_torch[1:], idx_mask=False)
else:
prediction = self.processor.step(frame_torch)
prediction = torch_prob_to_numpy_mask(prediction)
filename = self.video_folder / \
(self.video_folder.name +
f"_{current_frame_index:0>{9}}.json")
mask_dict = {value_to_label_names.get(label_id, str(label_id)): (prediction == label_id)
for label_id in np.unique(prediction)[1:]}
self._save_annotation(filename, mask_dict, frame.shape)
# if we lost tracking one of the instances, return the current frame number
num_instances_in_current_frame = mask_dict.keys()
if len(num_instances_in_current_frame) < self.num_tracking_instances:
missing_instances = instance_names - \
set(num_instances_in_current_frame)
num_missing_instances = self.num_tracking_instances - \
len(num_instances_in_current_frame)
message = (
f"There are {num_missing_instances} missing instance(s) in the current frame ({current_frame_index}).\n\n"
f"Here is the list of instances missing or occluded in the current frame:\n"
f"Some occluded instances will be recovered automatically in the later frame:\n"
f"{', '.join(str(instance) for instance in missing_instances)}"
)
message_with_index = message + \
delimiter + str(current_frame_index)
logger.info(message)
# If half of the instances are missing, then stop the prediction.
if len(num_instances_in_current_frame) < self.num_tracking_instances / 2:
return message_with_index

if self.debug and current_frame_index % visualize_every == 0:
visualization = overlay_davis(frame, prediction)
plt.imshow(
cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB))
plt.title(str(current_frame_index))
plt.axis('off')
plt.show()
current_frame_index += 1
break

message = ("Stop at frame:\n") + \
delimiter + str(current_frame_index-1)
# Release the video capture object
cap.release()
return message


if __name__ == '__main__':
Expand Down

0 comments on commit 1040a1f

Please sign in to comment.