Skip to content

Commit

Permalink
Predict the polygons for the next 90 frames based on the current frame
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Jan 4, 2024
1 parent ae66e7f commit b52595b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
18 changes: 18 additions & 0 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from annolid.postprocessing.quality_control import pred_dict_to_labelme
from annolid.annotation.keypoints import save_labels
from annolid.annotation.timestamps import convert_frame_number_to_time
from annolid.segmentation.SAM.edge_sam_bg import VideoProcessor
from labelme.ai import MODELS
__appname__ = 'Annolid'
__version__ = "1.1.3"
Expand Down Expand Up @@ -905,6 +906,23 @@ def format_shape(s):
otherData=self.otherData,
flags=flags,
)

if self.video_file:
self.video_processor = VideoProcessor(self.video_file)
self.seg_pred_thread.start()
end_frame = self.frame_number + 90 * self.step_size
if end_frame >= self.num_frames:
end_frame = self.num_frames - 1
self.pred_worker = FlexibleWorker(
function=self.video_processor.process_video_frames,
start_frame=self.frame_number,
end_frame=end_frame,
step=self.step_size
)
self.pred_worker.moveToThread(self.seg_pred_thread)
self.pred_worker.start.connect(self.pred_worker.run)
self.pred_worker.start.emit()

self.labelFile = lf
items = self.fileListWidget.findItems(
self.imagePath, Qt.MatchExactly
Expand Down
14 changes: 10 additions & 4 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import cv2
from pathlib import Path
from segment_anything import SegmentAnythingModel
from annolid.segmentation.SAM.segment_anything import SegmentAnythingModel
from annolid.data.videos import CV2Video
from annolid.utils.files import find_most_recent_file
import json
Expand Down Expand Up @@ -28,7 +28,7 @@ def calculate_polygon_center(polygon_vertices):
return np.array([(center_x, center_y)])


class VideoProcessor:
class VideoProcessor():
"""
A class for processing video frames using the Segment-Anything model.
"""
Expand All @@ -44,12 +44,14 @@ def __init__(self,
- video_path (str): Path to the video file.
- num_center_points (int): number of center points for prompt.
"""
super(VideoProcessor, self).__init__()
self.video_path = video_path
self.video_folder = Path(video_path).with_suffix("")
self.video_loader = CV2Video(video_path)
self.edge_sam = self.get_model()
self.num_frames = self.video_loader.total_frames()
self.center_points = MaxSizeQueue(max_size=num_center_points)
self.most_recent_file = self.get_most_recent_file()

def get_model(self,
encoder_path="edge_sam_3x_encoder.onnx",
Expand Down Expand Up @@ -107,8 +109,10 @@ def process_frame(self, frame_number):
cur_frame = self.video_loader.load_frame(frame_number)

height, width, _ = cur_frame.shape
if self.most_recent_file is None:
return

points_dict, _ = self.load_json_file(self.get_most_recent_file())
points_dict, _ = self.load_json_file(self.most_recent_file)
label_list = []

# Example usage of predict_polygon_from_points
Expand All @@ -131,6 +135,7 @@ def process_frame(self, frame_number):

filename = self.video_folder / \
(self.video_folder.name + f"_{frame_number:0>{9}}.json")
self.most_recent_file = filename
img_filename = str(filename.with_suffix('.png'))
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
cv2.imwrite(img_filename, cur_frame)
Expand Down Expand Up @@ -158,7 +163,8 @@ def get_most_recent_file(self):
Returns:
- str: Path to the most recent file.
"""
return find_most_recent_file(self.video_folder)
_recent_file = find_most_recent_file(self.video_folder)
return _recent_file


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions annolid/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

def find_most_recent_file(folder_path, file_ext=".json"):
# List all files in the folder
if not os.path.exists(folder_path):
return
all_files = os.listdir(folder_path)

# Filter out directories and get file paths
Expand Down

0 comments on commit b52595b

Please sign in to comment.