diff --git a/annolid/tracker/cotracker/track.py b/annolid/tracker/cotracker/track.py index 54c9ea7..db4e41a 100644 --- a/annolid/tracker/cotracker/track.py +++ b/annolid/tracker/cotracker/track.py @@ -9,6 +9,8 @@ from annolid.gui.shape import Shape from annolid.tracker.cotracker.visualizer import Visualizer from annolid.utils.logger import logger +from annolid.tracker.cotracker.visualizer import read_video_from_path + """ @article{karaev2023cotracker, title={CoTracker: It is Better to Track Together}, @@ -21,9 +23,9 @@ class CoTrackerProcessor: - def __init__(self, video_path, json_path=None): + def __init__(self, video_path, json_path=None, is_online=True): self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = self.load_model() + self.model = self.load_model(is_online=is_online) self.video_path = video_path self.video_result_folder = Path(self.video_path).with_suffix('') if not self.video_result_folder.exists(): @@ -32,6 +34,7 @@ def __init__(self, video_path, json_path=None): self.queries = self.load_queries(json_path) self.video_height = None self.video_width = None + self.is_online = is_online def get_frame_number(self, json_file): # assume json file name pattern as @@ -44,9 +47,13 @@ def get_frame_number(self, json_file): frame_number = int(frame_number_str) return frame_number - def load_model(self): - return torch.hub.load("facebookresearch/co-tracker", - "cotracker2_online").to(self.device) + def load_model(self, is_online=True): + if is_online: + return torch.hub.load("facebookresearch/co-tracker", + "cotracker2_online").to(self.device) + else: + return torch.hub.load("facebookresearch/co-tracker", + "cotracker2").to(self.device) def load_queries(self, json_path): if json_path is None: @@ -91,42 +98,76 @@ def process_step(self, grid_size=grid_size, grid_query_frame=grid_query_frame) - def process_video(self, grid_size=10, grid_query_frame=0, need_visualize=True): + def process_video(self, + grid_size=10, + grid_query_frame=0, + need_visualize=True): if not os.path.isfile(self.video_path): raise ValueError("Video file does not exist") - - window_frames = [] - is_first_step = True - - for i, frame in enumerate(iio.imiter(self.video_path, plugin="FFMPEG")): - if self.video_height is None or self.video_width is None: - self.video_height, self.video_width, _ = frame.shape - if i % self.model.step == 0 and i != 0: - pred_tracks, pred_visibility = self.process_step( - window_frames, is_first_step, grid_size, grid_query_frame) - if pred_tracks is not None: - logger.info( - f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}") - is_first_step = False - window_frames.append(frame) - - pred_tracks, pred_visibility = self.process_step( - window_frames[-(i % self.model.step) - self.model.step - 1:], - is_first_step, grid_size, grid_query_frame) + if self.is_online: + window_frames = [] + is_first_step = True + + for i, frame in enumerate(iio.imiter(self.video_path, plugin="FFMPEG")): + if self.video_height is None or self.video_width is None: + self.video_height, self.video_width, _ = frame.shape + if i % self.model.step == 0 and i != 0: + pred_tracks, pred_visibility = self.process_step( + window_frames, is_first_step, grid_size, grid_query_frame) + if pred_tracks is not None: + logger.info( + f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}") + is_first_step = False + window_frames.append(frame) + + pred_tracks, pred_visibility = self.process_step( + window_frames[-(i % self.model.step) - self.model.step - 1:], + is_first_step, grid_size, grid_query_frame) + else: + pred_tracks, pred_visibility, video = self._process_video_bidrection() logger.info("Tracks are computed") message = self.extract_frame_points( pred_tracks, pred_visibility, query_frame=0) if need_visualize: - video = torch.tensor(np.stack(window_frames), - device=self.device).permute(0, 3, 1, 2)[None] - vis = Visualizer(save_dir="./saved_videos", pad_value=120, - linewidth=3, tracks_leave_trace=-1) - vis.visualize(video, pred_tracks, pred_visibility, - query_frame=grid_query_frame) + vis_video_name = f'{self.video_result_folder.name}_tracked' + vis = Visualizer( + save_dir=str(self.video_result_folder.parent), + linewidth=6, + mode='cool', + tracks_leave_trace=-1 + ) + if self.is_online: + video = torch.tensor(np.stack(window_frames), + device=self.device).permute(0, 3, 1, 2)[None] + vis.visualize(video, pred_tracks, pred_visibility, + query_frame=grid_query_frame, + filename=vis_video_name + ) + + else: + vis.visualize( + video=video, + tracks=pred_tracks, + visibility=pred_visibility, + filename=vis_video_name + ) return message + def _process_video_bidrection(self, + grid_size=10, + grid_query_frame=0): + logger.info( + f"grid_size: {grid_size}, grid_query_frame: {grid_query_frame}") + + video = read_video_from_path(self.video_path) + video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() + pred_tracks, pred_visibility = self.model( + video, queries=self.queries[None], + backward_tracking=True) + return pred_tracks, pred_visibility, video + def save_current_frame_tracked_points_to_json(self, frame_number, points): json_file_path = self.video_result_folder / \ (self.video_result_folder.name + @@ -187,5 +228,6 @@ def extract_frame_points( help="Compute dense and grid tracks starting from this frame") args = parser.parse_args() - tracker_processor = CoTrackerProcessor(args.video_path, args.json_path) + tracker_processor = CoTrackerProcessor( + args.video_path, args.json_path) tracker_processor.process_video(args.grid_size, args.grid_query_frame)