Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface for box prompt in SAM 2 video predictor #174

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```

Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.

## Load from 🤗 Hugging Face

Expand Down Expand Up @@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
Expand Down
480 changes: 400 additions & 80 deletions notebooks/video_predictor_example.ipynb

Large diffs are not rendered by default.

54 changes: 48 additions & 6 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from collections import OrderedDict

import torch
Expand Down Expand Up @@ -163,29 +164,66 @@ def _get_obj_num(self, inference_state):
return len(inference_state["obj_idx_to_id"])

@torch.inference_mode()
def add_new_points(
def add_new_points_or_box(
self,
inference_state,
frame_idx,
obj_id,
points,
labels,
points=None,
labels=None,
clear_old_points=True,
normalize_coords=True,
box=None,
):
"""Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]

if not isinstance(points, torch.Tensor):
if (points is not None) != (labels is not None):
raise ValueError("points and labels must be provided together")
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")

if points is None:
points = torch.zeros(0, 2, dtype=torch.float32)
elif not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32)
if not isinstance(labels, torch.Tensor):
if labels is None:
labels = torch.zeros(0, dtype=torch.int32)
elif not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension

# If `box` is provided, we add it as the first two points with labels 2 and 3
# along with the user-provided points (consistent with how SAM 2 is trained).
if box is not None:
if not clear_old_points:
raise ValueError(
"cannot add box without clearing old points, since "
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
box_labels = box_labels.reshape(1, 2)
points = torch.cat([box_coords, points], dim=1)
labels = torch.cat([box_labels, labels], dim=1)

if normalize_coords:
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
Expand Down Expand Up @@ -268,6 +306,10 @@ def add_new_points(
)
return frame_idx, obj_ids, video_res_masks

def add_new_points(self, *args, **kwargs):
"""Deprecated method. Please use `add_new_points_or_box` instead."""
return self.add_new_points_or_box(*args, **kwargs)

@torch.inference_mode()
def add_new_mask(
self,
Expand Down Expand Up @@ -548,7 +590,7 @@ def propagate_in_video_preflight(self, inference_state):
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points` or `add_new_mask`)
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
Expand Down
Loading