Skip to content

Commit

Permalink
Set SAM 2.1 as Default Model
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Nov 20, 2024
1 parent 19a6362 commit aa7849e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def predict_from_next_frame(self,
task_function=process_video,
video_path=self.video_file,
frame_idx=self.frame_number,
model_config='sam2_hiera_l.yaml' if 'hiera_l' in model_name else "sam2_hiera_s.yaml",
model_config='sam2.1_hiera_l.yaml' if 'hiera_l' in model_name else "sam2.1_hiera_s.yaml",
)
else:
self.pred_worker = FlexibleWorker(
Expand Down
9 changes: 5 additions & 4 deletions annolid/segmentation/SAM/sam_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SAM2VideoProcessor:
def __init__(self, video_dir, id_to_labels,
checkpoint_path=None,
model_config="sam2_hiera_s.yaml",
model_config="sam2.1_hiera_s.yaml",
epsilon_for_polygon=2.0):
"""
Initializes the SAM2VideoProcessor with the given parameters.
Expand All @@ -32,14 +32,14 @@ def __init__(self, video_dir, id_to_labels,
# Set default checkpoint path if not provided
if checkpoint_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
checkpoint = "sam2_hiera_small.pt" if 'hiera_s' in model_config else "sam2_hiera_large.pt"
checkpoint = "sam2.1_hiera_small.pt" if 'hiera_s' in model_config else "sam2.1_hiera_large.pt"
checkpoint_path = os.path.join(current_dir,
"segment-anything-2",
"checkpoints",
checkpoint
)

self.BASE_URL = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
self.BASE_URL = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/"

self.video_dir = video_dir
self.checkpoint_path = checkpoint_path
Expand All @@ -66,6 +66,7 @@ def _initialize_predictor(self):
sam2_checkpoint_url = f"{self.BASE_URL}{os.path.basename(self.checkpoint_path)}"
download_file(sam2_checkpoint_url, self.checkpoint_path)

print(self.model_config, self.checkpoint_path)
return build_sam2_video_predictor(self.model_config,
self.checkpoint_path,
device=self.device)
Expand Down Expand Up @@ -251,7 +252,7 @@ def run(self, annotations, frame_idx):
def process_video(video_path,
frame_idx=0,
checkpoint_path=None,
model_config="sam2_hiera_s.yaml",
model_config="sam2.1_hiera_s.yaml",
epsilon_for_polygon=2.0):
"""
Processes a video by extracting frames, loading annotations from multiple JSON files, and running analysis.
Expand Down
2 changes: 1 addition & 1 deletion annolid/segmentation/SAM/segment-anything-2

0 comments on commit aa7849e

Please sign in to comment.