Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better support for non-CUDA devices (CPU, MPS) #192

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions notebooks/automatic_mask_generator_example.ipynb

Large diffs are not rendered by default.

61 changes: 45 additions & 16 deletions notebooks/image_predictor_example.ipynb

Large diffs are not rendered by default.

395 changes: 70 additions & 325 deletions notebooks/video_predictor_example.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def _process_batch(
orig_h, orig_w = orig_size

# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
Expand Down
7 changes: 6 additions & 1 deletion sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
4 changes: 2 additions & 2 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def _prepare_memory_conditioned_features(
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
Expand Down
12 changes: 8 additions & 4 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def init_state(
async_loading_frames=False,
):
"""Initialize a inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
Expand All @@ -65,11 +67,11 @@ def init_state(
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
Expand Down Expand Up @@ -270,7 +272,8 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -793,7 +796,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
27 changes: 21 additions & 6 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
Expand All @@ -119,6 +127,7 @@ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_st
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
Expand Down Expand Up @@ -152,7 +161,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img

Expand All @@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
Expand Down Expand Up @@ -196,17 +206,22 @@ def load_video_frames(

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
Expand Down
Loading