From 2d8c01f7c1de0d042b0b68900ee318e77e0d103d Mon Sep 17 00:00:00 2001 From: healthonrails Date: Thu, 21 Nov 2024 14:03:16 -0500 Subject: [PATCH] feat: Support for YOLO11n and YOLO11x models in GUI --- annolid/gui/app.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/annolid/gui/app.py b/annolid/gui/app.py index ec4be1a..0ddb15d 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -509,7 +509,8 @@ def __init__(self, "CoTracker", "sam2_hiera_s", "sam2_hiera_l", - "YOLO", + "YOLO11n", + "YOLO11x", ] model_names = [model.name for model in MODELS] + \ self.custom_ai_model_names @@ -1027,7 +1028,8 @@ def _select_sam_model_name(self): "CoTracker": "CoTracker", "sam2_hiera_s": "sam2_hiera_s", "sam2_hiera_l": "sam2_hiera_l", - "YOLO": "YOLO", + "YOLO11n": "yolo11n.pt", + "YOLO11x": "yolo11x.pt", } default_model_name = "Segment-Anything (Edge)" @@ -1055,7 +1057,7 @@ def predict_from_next_frame(self, if self.pred_worker and self.stop_prediction_flag: # If prediction is running, stop the prediction self.stop_prediction() - elif len(self.canvas.shapes) <= 0 and not "YOLO" in model_name: + elif len(self.canvas.shapes) <= 0 and not "yolo" in model_name: QtWidgets.QMessageBox.about(self, "No Shapes or Labeled Frames", f"Please label this frame") @@ -1065,9 +1067,9 @@ def predict_from_next_frame(self, if "sam2_hiera" in model_name: from annolid.segmentation.SAM.sam_v2 import process_video self.video_processor = process_video - elif "YOLO" in model_name: + elif "yolo" in model_name: from annolid.segmentation.yolos import InferenceProcessor - self.video_processor = InferenceProcessor(model_name="yolo11n.pt", + self.video_processor = InferenceProcessor(model_name=model_name, model_type="yolo" ) else: @@ -1102,7 +1104,7 @@ def predict_from_next_frame(self, frame_idx=self.frame_number, model_config='sam2.1_hiera_l.yaml' if 'hiera_l' in model_name else "sam2.1_hiera_s.yaml", ) - elif 'YOLO' in model_name: + elif 'yolo' in model_name: self.pred_worker = FlexibleWorker( task_function=self.video_processor.run_inference, source=self.video_file