Skip to content

Commit

Permalink
Add example bbox input usage to EfficientViT-SAM
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Feb 8, 2024
1 parent 60dae3d commit d781580
Showing 1 changed file with 61 additions and 53 deletions.
114 changes: 61 additions & 53 deletions annolid/segmentation/SAM/efficientvit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
--return-single-mask
```
Example usage for bboxes input
```python efficientvit_sam.py --model xl1 --encoder_model xl1_encoder.onnx \
--decoder_model xl1_decoder.onnx --mode boxes \
--boxes "[[16,8,220,180],[230,190,440,400]]"
```
"""


Expand Down Expand Up @@ -119,6 +125,56 @@ def show_box(box, ax):
facecolor=(0, 0, 0, 0), lw=2))


def preprocess(x, img_size):
"""
Preprocess the input image.
"""
pixel_mean = [123.675 / 255, 116.28 / 255, 103.53 / 255]
pixel_std = [58.395 / 255, 57.12 / 255, 57.375 / 255]

x = torch.tensor(x)
resize_transform = SamResize(img_size)
x = resize_transform(x).float() / 255
x = transforms.Normalize(mean=pixel_mean, std=pixel_std)(x)

h, w = x.shape[-2:]
th, tw = img_size, img_size
assert th >= h and tw >= w
x = F.pad(x, (0, tw - w, 0, th - h), value=0).unsqueeze(0).numpy()

return x


def resize_longest_image_size(input_image_size: torch.Tensor,
longest_side: int) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size


def mask_postprocessing(masks: torch.Tensor,
orig_im_size: torch.Tensor) -> torch.Tensor:
img_size = 1024
masks = torch.tensor(masks)
orig_im_size = torch.tensor(orig_im_size)
masks = F.interpolate(
masks,
size=(img_size, img_size),
mode="bilinear",
align_corners=False,
)

prepadded_size = resize_longest_image_size(orig_im_size, img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(
h, w), mode="bilinear", align_corners=False)
return masks


class SamEncoder:
"""
Encoder for EfficientViTSAM model.
Expand Down Expand Up @@ -241,56 +297,6 @@ def apply_boxes(self, boxes, original_size, new_size):
return boxes


def preprocess(x, img_size):
"""
Preprocess the input image.
"""
pixel_mean = [123.675 / 255, 116.28 / 255, 103.53 / 255]
pixel_std = [58.395 / 255, 57.12 / 255, 57.375 / 255]

x = torch.tensor(x)
resize_transform = SamResize(img_size)
x = resize_transform(x).float() / 255
x = transforms.Normalize(mean=pixel_mean, std=pixel_std)(x)

h, w = x.shape[-2:]
th, tw = img_size, img_size
assert th >= h and tw >= w
x = F.pad(x, (0, tw - w, 0, th - h), value=0).unsqueeze(0).numpy()

return x


def resize_longest_image_size(input_image_size: torch.Tensor,
longest_side: int) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size


def mask_postprocessing(masks: torch.Tensor,
orig_im_size: torch.Tensor) -> torch.Tensor:
img_size = 1024
masks = torch.tensor(masks)
orig_im_size = torch.tensor(orig_im_size)
masks = F.interpolate(
masks,
size=(img_size, img_size),
mode="bilinear",
align_corners=False,
)

prepadded_size = resize_longest_image_size(orig_im_size, img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(
h, w), mode="bilinear", align_corners=False)
return masks


class EfficientViTSAM:
"""
EfficientViTSAM model for image segmentation.
Expand Down Expand Up @@ -340,7 +346,7 @@ def run_inference(self, cv_image, bboxes, point=None):
point_coords=point_coords,
point_labels=point_labels,
)
return masks.squeeze().cpu().numpy()
return masks.cpu().numpy()

elif self.mode == "boxes":
boxes = np.array(bboxes, dtype=np.float32)
Expand All @@ -349,7 +355,7 @@ def run_inference(self, cv_image, bboxes, point=None):
origin_image_size=origin_image_size,
boxes=boxes,
)
return masks.squeeze().cpu().numpy()
return masks.cpu().numpy()
else:
return []

Expand Down Expand Up @@ -377,6 +383,7 @@ def run_inference(self, cv_image, bboxes, point=None):

encoder = SamEncoder(model_path=args.encoder_model)
decoder = SamDecoder(model_path=args.decoder_model)
eff_sam = EfficientViTSAM()

raw_img = cv2.cvtColor(cv2.imread(args.img_path), cv2.COLOR_BGR2RGB)
origin_image_size = raw_img.shape[:2]
Expand Down Expand Up @@ -429,6 +436,7 @@ def run_inference(self, cv_image, bboxes, point=None):
plt.savefig(args.out_path, bbox_inches="tight",
dpi=300, pad_inches=0.0)
print(f"Result saved in {args.out_path}")

_masks = eff_sam.run_inference(raw_img, boxes)
print(_masks.shape, _masks)
else:
raise NotImplementedError

0 comments on commit d781580

Please sign in to comment.