Skip to content

Commit

Permalink
Enhance tracking accuracy by utilizing the past three polygon centers…
Browse files Browse the repository at this point in the history
… as prompts.
  • Loading branch information
healthonrails committed Jan 4, 2024
1 parent 54f237d commit ae66e7f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
Binary file added annolid/segmentation/SAM/edge_sam_3x_decoder.onnx
Binary file not shown.
Binary file added annolid/segmentation/SAM/edge_sam_3x_encoder.onnx
Binary file not shown.
44 changes: 31 additions & 13 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from annolid.gui.shape import Shape
from annolid.annotation.keypoints import save_labels
import numpy as np
from collections import deque


class MaxSizeQueue(deque):
def __init__(self, max_size):
super().__init__(maxlen=max_size)

def enqueue(self, item):
self.append(item)

def to_numpy(self):
return np.array(list(self))


def calculate_polygon_center(polygon_vertices):
Expand All @@ -21,33 +33,40 @@ class VideoProcessor:
A class for processing video frames using the Segment-Anything model.
"""

def __init__(self, video_path, encoder_path, decoder_path):
def __init__(self,
video_path,
num_center_points=3
):
"""
Initialize the VideoProcessor.
Parameters:
- video_path (str): Path to the video file.
- encoder_path (str): Path to the encoder model file.
- decoder_path (str): Path to the decoder model file.
- num_center_points (int): number of center points for prompt.
"""
self.video_path = video_path
self.video_folder = Path(video_path).with_suffix("")
self.video_loader = CV2Video(video_path)
self.edge_sam = self.get_model(encoder_path, decoder_path)
self.edge_sam = self.get_model()
self.num_frames = self.video_loader.total_frames()
self.center_points = MaxSizeQueue(max_size=num_center_points)

def get_model(self, encoder_path, decoder_path):
def get_model(self,
encoder_path="edge_sam_3x_encoder.onnx",
decoder_path="edge_sam_3x_decoder.onnx",
name="Segment-Anything (Edge)"
):
"""
Load the Segment-Anything model.
Parameters:
- encoder_path (str): Path to the encoder model file.
- decoder_path (str): Path to the decoder model file.
- name (str): name of the SAM model
Returns:
- SegmentAnythingModel: The loaded model.
"""
name = "Segment-Anything (Edge)"
model = SegmentAnythingModel(name, encoder_path, decoder_path)
return model

Expand Down Expand Up @@ -96,6 +115,8 @@ def process_frame(self, frame_number):
for label, points in points_dict.items():
self.edge_sam.set_image(cur_frame)
points = calculate_polygon_center(points)
self.center_points.enqueue(points[0])
points = self.center_points.to_numpy()
point_labels = [1] * len(points)
polygon = self.edge_sam.predict_polygon_from_points(
points, point_labels)
Expand All @@ -116,7 +137,7 @@ def process_frame(self, frame_number):
save_labels(filename=filename, imagePath=img_filename, label_list=label_list,
height=height, width=width)

def process_video_frames(self, start_frame, end_frame=None, step=10):
def process_video_frames(self, start_frame=0, end_frame=None, step=10):
"""
Process multiple frames of the video.
Expand All @@ -142,9 +163,6 @@ def get_most_recent_file(self):

if __name__ == '__main__':
# Usage
video_path = "animal.mp4"
encoder_path = "edge_sam_3x_encoder.onnx"
decoder_path = "edge_sam_3x_decoder.onnx"

video_processor = VideoProcessor(video_path, encoder_path, decoder_path)
video_processor.process_video_frames(0, 500, 1)
video_path = "squirrel.mp4"
video_processor = VideoProcessor(video_path)
video_processor.process_video_frames()

0 comments on commit ae66e7f

Please sign in to comment.