Skip to content

Commit

Permalink
feat: Add YOLO video inference to GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Nov 20, 2024
1 parent d5469b0 commit 03c3bf7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
16 changes: 14 additions & 2 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def __init__(self,
"CoTracker",
"sam2_hiera_s",
"sam2_hiera_l",
"YOLO",
]
model_names = [model.name for model in MODELS] + \
self.custom_ai_model_names
Expand Down Expand Up @@ -1026,6 +1027,7 @@ def _select_sam_model_name(self):
"CoTracker": "CoTracker",
"sam2_hiera_s": "sam2_hiera_s",
"sam2_hiera_l": "sam2_hiera_l",
"YOLO": "YOLO",
}
default_model_name = "Segment-Anything (Edge)"

Expand All @@ -1049,20 +1051,25 @@ def stop_prediction(self):

def predict_from_next_frame(self,
to_frame=60):
model_name = self._select_sam_model_name()
if self.pred_worker and self.stop_prediction_flag:
# If prediction is running, stop the prediction
self.stop_prediction()
elif len(self.canvas.shapes) <= 0:
elif len(self.canvas.shapes) <= 0 and not "YOLO" in model_name:
QtWidgets.QMessageBox.about(self,
"No Shapes or Labeled Frames",
f"Please label this frame")
return
else:
model_name = self._select_sam_model_name()
if self.video_file:
if "sam2_hiera" in model_name:
from annolid.segmentation.SAM.sam_v2 import process_video
self.video_processor = process_video
elif "YOLO" in model_name:
from annolid.segmentation.yolos import InferenceProcessor
self.video_processor = InferenceProcessor(model_name="yolo11n.pt",
model_type="yolo"
)
else:
self.video_processor = VideoProcessor(
self.video_file,
Expand Down Expand Up @@ -1095,6 +1102,11 @@ def predict_from_next_frame(self,
frame_idx=self.frame_number,
model_config='sam2.1_hiera_l.yaml' if 'hiera_l' in model_name else "sam2.1_hiera_s.yaml",
)
elif 'YOLO' in model_name:
self.pred_worker = FlexibleWorker(
task_function=self.video_processor.run_inference,
source=self.video_file
)
else:
self.pred_worker = FlexibleWorker(
task_function=self.video_processor.process_video_frames,
Expand Down
2 changes: 2 additions & 0 deletions annolid/postprocessing/tracking_results_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def load_zone_json(self):
if not os.path.exists(self.zone_file):
json_files = sorted(find_manual_labeled_json_files(
str(self.tracking_csv).replace('_tracking.csv', '')))
if len(json_files) < 1:
return
# assume the first file has the Zone or place info
self.zone_file = json_files[0]
with open(self.zone_file, 'r') as f:
Expand Down
5 changes: 3 additions & 2 deletions annolid/segmentation/yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def run_inference(self, source):
self.save_yolo_to_labelme(
yolo_results, id_to_labels, frame_shape, output_directory
)
return f"Done#{self.frame_count}"

def extract_yolo_results(self, result):
"""
Expand Down Expand Up @@ -140,9 +141,9 @@ def save_yolo_to_labelme(self, yolo_results, id_to_labels, frame_shape,

# Example usage
if __name__ == "__main__":
video_path = "~/Downloads/mouse.mp4"
video_path = os.path.expanduser("~/Downloads/IMG_0769.MOV")

yolo_processor = InferenceProcessor(
"yolo11n.pt", model_type="yolo", class_names=["mouse", "teaball"]
"yolo11n.pt", model_type="yolo"
)
yolo_processor.run_inference(video_path)

0 comments on commit 03c3bf7

Please sign in to comment.