Skip to content

Commit

Permalink
Introduce 'visible' Attribute to Shape Type for Improved Visualization
Browse files Browse the repository at this point in the history
This commit adds a 'visible' attribute to the Shape type, allowing points to be marked as visible or invisible. When a point is not visible, it is displayed as a circle with filled colors. Additionally, this commit removes a line of the segmentation prediction thread to prevent the background worker from starting twice.
  • Loading branch information
healthonrails committed Mar 11, 2024
1 parent c95e685 commit d2b0a3e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 55 deletions.
3 changes: 2 additions & 1 deletion annolid/annotation/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cv2
import numpy as np
import pandas as pd
from labelme import label_file
from annolid.gui import label_file
from labelme.shape import Shape


Expand Down Expand Up @@ -41,6 +41,7 @@ def format_shape(s):
group_id=s.group_id,
shape_type=s.shape_type,
flags=s.flags,
visible=s.visible
)
)
return data
Expand Down
123 changes: 70 additions & 53 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def __init__(self,

self._selectAiModelComboBox.clear()
self.custom_ai_model_names = [
'SAM_HQ', 'Cutie_VOS', "EfficientVit_SAM"]
'SAM_HQ', 'Cutie_VOS', "EfficientVit_SAM", "CoTracker"]
model_names = [model.name for model in MODELS] + \
self.custom_ai_model_names
self._selectAiModelComboBox.addItems(model_names)
Expand Down Expand Up @@ -674,48 +674,49 @@ def closeFile(self, _value=False):
self.uniqLabelList.clear()
# clear the file list
self.fileListWidget.clear()
if self.video_loader is not None:
self.video_loader = None
self.num_frames = None
self.video_file = None
if self.audio_widget:
self.audio_widget.close()
self.audio_widget = None
if self.audio_dock:
self.audio_dock.close()
self.audio_dock = None
self.annotation_dir = None
self.statusBar().removeWidget(self.seekbar)
self.statusBar().removeWidget(self.saveButton)
self.statusBar().removeWidget(self.playButton)
self.seekbar = None
self._df = None
self.label_stats = {}
self.shape_hash_ids = {}
self.changed_json_stats = {}
self._pred_res_folder_suffix = '_tracking_results_labelme'
self.frame_number = 0
self.step_size = 5
self.video_results_folder = None
self.timestamp_dict = dict()
self.isPlaying = False
self._time_stamp = ''
self.saveButton = None
self.playButton = None
self.timer = None
self.filename = None
self.canvas.pixmap = None
self.event_type = None
self.highlighted_mark = None
self.stepSizeWidget = StepSizeWidget()
self.prev_shapes = None
self.pred_worker = None
self.stop_prediction_flag = False
self.imageData = None
self.frame_loader = LoadFrameThread()
if self.video_processor is not None:
self.video_processor.cutie_processor = None
self.video_processor = None
# if self.video_loader is not None:
self.video_loader = None
self.num_frames = None
self.video_file = None
if self.audio_widget:
self.audio_widget.close()
self.audio_widget = None
if self.audio_dock:
self.audio_dock.close()
self.audio_dock = None
self.annotation_dir = None
self.statusBar().removeWidget(self.seekbar)
self.statusBar().removeWidget(self.saveButton)
self.statusBar().removeWidget(self.playButton)
self.seekbar = None
self._df = None
self.label_stats = {}
self.shape_hash_ids = {}
self.changed_json_stats = {}
self._pred_res_folder_suffix = '_tracking_results_labelme'
self.frame_number = 0
self.step_size = 5
self.video_results_folder = None
self.timestamp_dict = dict()
self.isPlaying = False
self._time_stamp = ''
self.saveButton = None
self.playButton = None
self.timer = None
self.filename = None
self.canvas.pixmap = None
self.event_type = None
self.highlighted_mark = None
self.stepSizeWidget = StepSizeWidget()
self.prev_shapes = None
self.pred_worker = None
self.stop_prediction_flag = False
self.imageData = None
self.frame_loader = LoadFrameThread()
if self.video_processor is not None:
self.video_processor.cutie_processor = None
self.video_processor = None
self.fps = None

def toolbar(self, title, actions=None):
toolbar = ToolBar(title)
Expand Down Expand Up @@ -830,7 +831,10 @@ def _update_shape_color(self, shape):
self.uniqLabelList.setItemLabel(item, shape.label, rgb)
r, g, b = self._get_rgb_by_label(shape.label)
shape.line_color = QtGui.QColor(r, g, b)
shape.vertex_fill_color = QtGui.QColor(r, g, b)
if not shape.visible:
shape.vertex_fill_color = QtGui.QColor(r, g, b, 0)
else:
shape.vertex_fill_color = QtGui.QColor(r, g, b)
shape.hvertex_fill_color = QtGui.QColor(255, 255, 255)
shape.fill_color = QtGui.QColor(r, g, b, 128)
shape.select_line_color = QtGui.QColor(255, 255, 255)
Expand Down Expand Up @@ -936,7 +940,8 @@ def _select_sam_model_name(self):
model_names = {
"SAM_HQ": "sam_hq",
"EfficientVit_SAM": "efficientvit_sam",
"Cutie_VOS": "Cutie_VOS"
"Cutie_VOS": "Cutie_VOS",
"CoTracker": "CoTracker"
}
default_model_name = "Segment-Anything (Edge)"

Expand All @@ -956,6 +961,7 @@ def stop_prediction(self):
"background-color: green; color: white;")

self.stop_prediction_flag = False
logger.info(f"Prediction was stopped.")

def predict_from_next_frame(self,
to_frame=60):
Expand Down Expand Up @@ -991,19 +997,21 @@ def predict_from_next_frame(self,
start_frame=self.frame_number+1,
end_frame=end_frame,
step=self.step_size,
is_cutie=True,
mem_every=self.step_size
is_cutie=False if model_name == "CoTracker" else True,
mem_every=self.step_size,
point_tracking=model_name == "CoTracker"
)
self.video_processor.set_pred_worker(self.pred_worker)
self.frame_number += 1
logger.info(
f"Prediction started from frame number: {self.frame_number}.")
self.stepSizeWidget.predict_button.setText(
"Stop") # Change button text
self.stepSizeWidget.predict_button.setStyleSheet(
"background-color: red; color: white;")
self.stop_prediction_flag = True
self.pred_worker.moveToThread(self.seg_pred_thread)
self.pred_worker.start.connect(self.pred_worker.run)
self.seg_pred_thread.started.connect(self.pred_worker.start)
self.pred_worker.return_value.connect(
self.lost_tracking_instance)
self.pred_worker.finished.connect(self.predict_is_ready)
Expand All @@ -1030,9 +1038,6 @@ def lost_tracking_instance(self, message):
self.stop_prediction_flag = False

def predict_is_ready(self):
QtWidgets.QMessageBox.information(
self, "Prediction Ready",
"Predictions for the video frames have been generated!")
self.stepSizeWidget.predict_button.setText(
"Pred") # Change button text
self.stepSizeWidget.predict_button.setStyleSheet(
Expand All @@ -1041,9 +1046,14 @@ def predict_is_ready(self):
self.stop_prediction_flag = False
if self.video_loader is not None:
num_json_files = count_json_files(self.video_results_folder)
logger.info(
f"Number of predicted frames: {num_json_files} in total {self.num_frames}")
if num_json_files == self.num_frames:
# convert json labels to csv file
self.convert_json_to_tracked_csv()
QtWidgets.QMessageBox.information(
self, "Prediction Ready",
"Predictions for the video frames have been generated!")

def saveLabels(self, filename):
lf = LabelFile()
Expand Down Expand Up @@ -1088,6 +1098,7 @@ def format_shape(s):
otherData=self.otherData,
flags=flags,
)
logger.info(f"Saved image and label json file: {filename}")

self.labelFile = lf
items = self.fileListWidget.findItems(
Expand Down Expand Up @@ -1738,8 +1749,12 @@ def load_tracking_results(self, cur_video_folder, video_filename):
self._load_labels(tr)

if tracking_csv_file:
self._df = pd.read_csv(tracking_csv_file)
self._df = self._df.drop(columns=['Unnamed: 0'], errors='ignore')
try:
self._df = pd.read_csv(tracking_csv_file)
self._df = self._df.drop(
columns=['Unnamed: 0'], errors='ignore')
except:
logger.info(f"Error loading file {tracking_csv_file}")

def _load_timestamps(self, timestamp_csv_file):
"""Load timestamps from the given CSV file and update timestamp_dict."""
Expand Down Expand Up @@ -1956,6 +1971,7 @@ def loadLabels(self, shapes):
group_id = shape["group_id"]
description = shape.get("description", "")
other_data = shape["other_data"]
visible = shape["visible"]

if not points:
# skip point-empty shape
Expand All @@ -1967,6 +1983,7 @@ def loadLabels(self, shapes):
group_id=group_id,
description=description,
mask=shape["mask"],
visible=visible
)
for x, y in points:
shape.addPoint(QtCore.QPointF(x, y))
Expand Down
2 changes: 2 additions & 0 deletions annolid/gui/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def load(self, filename):
"flags",
"description",
"mask",
"visible",
]
try:
with open(filename, "r") as f:
Expand Down Expand Up @@ -119,6 +120,7 @@ def load(self, filename):
group_id=s.get("group_id"),
mask=utils.img_b64_to_arr(
s["mask"]) if s.get("mask") else None,
visible=s.get("visible"),
other_data={k: v for k,
v in s.items() if k not in shape_keys},
)
Expand Down
2 changes: 2 additions & 0 deletions annolid/gui/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
group_id=None,
description=None,
mask=None,
visible=True,
):
self.label = label
self.group_id = group_id
Expand All @@ -70,6 +71,7 @@ def __init__(
self.description = description
self.other_data = {}
self.mask = mask
self.visible = visible

self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX
Expand Down
6 changes: 5 additions & 1 deletion annolid/gui/widgets/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,11 @@ def calculateOffsets(self, point):

def boundedMoveVertex(self, pos):
index, shape = self.hVertex, self.hShape
point = shape[index]
try:
point = shape[index]
except IndexError as e:
logger.inf(e)
return
if self.outOfPixmap(pos):
pos = self.intersectionPoint(point, pos)
shape.moveVertexBy(index, pos - point)
Expand Down

0 comments on commit d2b0a3e

Please sign in to comment.