Skip to content

Commit

Permalink
Enhancement: Integrate CoTracker Offline Version for Backward Tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Mar 11, 2024
1 parent f7072da commit 58fe7b5
Showing 1 changed file with 74 additions and 32 deletions.
106 changes: 74 additions & 32 deletions annolid/tracker/cotracker/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from annolid.gui.shape import Shape
from annolid.tracker.cotracker.visualizer import Visualizer
from annolid.utils.logger import logger
from annolid.tracker.cotracker.visualizer import read_video_from_path

"""
@article{karaev2023cotracker,
title={CoTracker: It is Better to Track Together},
Expand All @@ -21,9 +23,9 @@


class CoTrackerProcessor:
def __init__(self, video_path, json_path=None):
def __init__(self, video_path, json_path=None, is_online=True):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.load_model()
self.model = self.load_model(is_online=is_online)
self.video_path = video_path
self.video_result_folder = Path(self.video_path).with_suffix('')
if not self.video_result_folder.exists():
Expand All @@ -32,6 +34,7 @@ def __init__(self, video_path, json_path=None):
self.queries = self.load_queries(json_path)
self.video_height = None
self.video_width = None
self.is_online = is_online

def get_frame_number(self, json_file):
# assume json file name pattern as
Expand All @@ -44,9 +47,13 @@ def get_frame_number(self, json_file):
frame_number = int(frame_number_str)
return frame_number

def load_model(self):
return torch.hub.load("facebookresearch/co-tracker",
"cotracker2_online").to(self.device)
def load_model(self, is_online=True):
if is_online:
return torch.hub.load("facebookresearch/co-tracker",
"cotracker2_online").to(self.device)
else:
return torch.hub.load("facebookresearch/co-tracker",
"cotracker2").to(self.device)

def load_queries(self, json_path):
if json_path is None:
Expand Down Expand Up @@ -91,42 +98,76 @@ def process_step(self,
grid_size=grid_size,
grid_query_frame=grid_query_frame)

def process_video(self, grid_size=10, grid_query_frame=0, need_visualize=True):
def process_video(self,
grid_size=10,
grid_query_frame=0,
need_visualize=True):
if not os.path.isfile(self.video_path):
raise ValueError("Video file does not exist")

window_frames = []
is_first_step = True

for i, frame in enumerate(iio.imiter(self.video_path, plugin="FFMPEG")):
if self.video_height is None or self.video_width is None:
self.video_height, self.video_width, _ = frame.shape
if i % self.model.step == 0 and i != 0:
pred_tracks, pred_visibility = self.process_step(
window_frames, is_first_step, grid_size, grid_query_frame)
if pred_tracks is not None:
logger.info(
f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}")
is_first_step = False
window_frames.append(frame)

pred_tracks, pred_visibility = self.process_step(
window_frames[-(i % self.model.step) - self.model.step - 1:],
is_first_step, grid_size, grid_query_frame)
if self.is_online:
window_frames = []
is_first_step = True

for i, frame in enumerate(iio.imiter(self.video_path, plugin="FFMPEG")):
if self.video_height is None or self.video_width is None:
self.video_height, self.video_width, _ = frame.shape
if i % self.model.step == 0 and i != 0:
pred_tracks, pred_visibility = self.process_step(
window_frames, is_first_step, grid_size, grid_query_frame)
if pred_tracks is not None:
logger.info(
f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}")
is_first_step = False
window_frames.append(frame)

pred_tracks, pred_visibility = self.process_step(
window_frames[-(i % self.model.step) - self.model.step - 1:],
is_first_step, grid_size, grid_query_frame)
else:
pred_tracks, pred_visibility, video = self._process_video_bidrection()

logger.info("Tracks are computed")
message = self.extract_frame_points(
pred_tracks, pred_visibility, query_frame=0)

if need_visualize:
video = torch.tensor(np.stack(window_frames),
device=self.device).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120,
linewidth=3, tracks_leave_trace=-1)
vis.visualize(video, pred_tracks, pred_visibility,
query_frame=grid_query_frame)
vis_video_name = f'{self.video_result_folder.name}_tracked'
vis = Visualizer(
save_dir=str(self.video_result_folder.parent),
linewidth=6,
mode='cool',
tracks_leave_trace=-1
)
if self.is_online:
video = torch.tensor(np.stack(window_frames),
device=self.device).permute(0, 3, 1, 2)[None]
vis.visualize(video, pred_tracks, pred_visibility,
query_frame=grid_query_frame,
filename=vis_video_name
)

else:
vis.visualize(
video=video,
tracks=pred_tracks,
visibility=pred_visibility,
filename=vis_video_name
)
return message

def _process_video_bidrection(self,
grid_size=10,
grid_query_frame=0):
logger.info(
f"grid_size: {grid_size}, grid_query_frame: {grid_query_frame}")

video = read_video_from_path(self.video_path)
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
pred_tracks, pred_visibility = self.model(
video, queries=self.queries[None],
backward_tracking=True)
return pred_tracks, pred_visibility, video

def save_current_frame_tracked_points_to_json(self, frame_number, points):
json_file_path = self.video_result_folder / \
(self.video_result_folder.name +
Expand Down Expand Up @@ -187,5 +228,6 @@ def extract_frame_points(
help="Compute dense and grid tracks starting from this frame")
args = parser.parse_args()

tracker_processor = CoTrackerProcessor(args.video_path, args.json_path)
tracker_processor = CoTrackerProcessor(
args.video_path, args.json_path)
tracker_processor.process_video(args.grid_size, args.grid_query_frame)

0 comments on commit 58fe7b5

Please sign in to comment.