Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fewshot / Image-Conditioned Detection #28

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,22 @@ live-edited text prompts. To run the example



### Example 4 - Fewshot prediction

This example replicates the Image-Conditioned Detection example in the original OwlVit repo. To run the example

```bash
cd examples
python3 fewshot_predict.py \
--threshold="0.7,0.1" \
--image_encoder_engine=../data/owl_image_encoder_patch32.engine \
--query-image ../assets/cat_query_image.jpg ../assets/frog.jpg \
--query-label "a cat" "a frog"
```

By default the output will be saved to ``data/fewshot_predict_out.jpg``.


<a id="acknowledgement"></a>
## 👏 Acknowledgement

Expand Down
102 changes: 102 additions & 0 deletions examples/fewshot_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os.path
import time

import numpy as np
import PIL.Image
import torch
from nanoowl.fewshot_predictor import FewshotPredictor
from nanoowl.owl_drawing import draw_owl_output
from nanoowl.owl_predictor import OwlPredictor

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default="../assets/cat_image.jpg")
parser.add_argument(
"--query-image",
metavar="N",
type=str,
nargs="+",
help="an example of what to look for in the image",
default=["../assets/frog.jpg", "../assets/cat_query_image.jpg"],
)
parser.add_argument(
"--query-label",
metavar="N",
type=str,
nargs="+",
help="a text label for each query image",
default=["a frog", "a cat"],
)
parser.add_argument("--threshold", type=str, default="0.1,0.7")
parser.add_argument("--output", type=str, default="../data/fewshot_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
parser.add_argument(
"--image_encoder_engine",
type=str,
default="../data/owl_image_encoder_patch32.engine",
)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num_profiling_runs", type=int, default=30)
args = parser.parse_args()

image = PIL.Image.open(args.image)

query_images = []
for image_file in args.query_image:
if not os.path.isfile(image_file):
raise FileNotFoundError(f"File missing from {os.path.abspath(image_file)}")
else:
query_images.append(PIL.Image.open(image_file))

query_labels = args.query_label

thresholds = args.threshold.strip("][()")
thresholds = thresholds.split(",")
if len(thresholds) == 1:
thresholds = float(thresholds[0])
else:
thresholds = [float(x) for x in thresholds]

predictor = FewshotPredictor(
owl_predictor=OwlPredictor(
args.model, image_encoder_engine=args.image_encoder_engine
)
)

query_embeddings = [
predictor.encode_query_image(image=query_image, text=query_label)
for query_image, query_label in zip(query_images, query_labels)
]

output = predictor.predict(image, query_embeddings, threshold=thresholds)

if args.profile:
torch.cuda.current_stream().synchronize()
t0 = time.perf_counter_ns()
for i in range(args.num_profiling_runs):
output = predictor.predict(image, query_embeddings, threshold=thresholds)
torch.cuda.current_stream().synchronize()
t1 = time.perf_counter_ns()
dt = (t1 - t0) / 1e9
print(f"PROFILING FPS: {args.num_profiling_runs/dt}")

image = draw_owl_output(image, output, text=query_labels, draw_text=True)

image.save(args.output)
166 changes: 166 additions & 0 deletions nanoowl/fewshot_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union

import PIL.Image
import torch

from .image_preprocessor import ImagePreprocessor
from .owl_predictor import (
OwlDecodeOutput,
OwlEncodeImageOutput,
OwlEncodeTextOutput,
OwlPredictor,
)


class FewshotPredictor(torch.nn.Module):
def __init__(
self,
owl_predictor: Optional[OwlPredictor] = None,
image_preprocessor: Optional[ImagePreprocessor] = None,
device: str = "cuda",
):
super().__init__()
self.owl_predictor = OwlPredictor() if owl_predictor is None else owl_predictor
self.image_preprocessor = (
ImagePreprocessor().to(device).eval()
if image_preprocessor is None
else image_preprocessor
)

@torch.no_grad()
def predict(
self,
image: PIL.Image,
query_embeddings: List,
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
pad_square: bool = True,
) -> OwlDecodeOutput:
image_tensor = self.image_preprocessor.preprocess_pil_image(image)

rois = torch.tensor(
[[0, 0, image.width, image.height]],
dtype=image_tensor.dtype,
device=image_tensor.device,
)

image_encodings = self.owl_predictor.encode_rois(
image_tensor, rois, pad_square=pad_square
)

return self.decode(image_encodings, query_embeddings, threshold)

def decode(
self,
image_output: OwlEncodeImageOutput,
query_embeds,
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
) -> OwlDecodeOutput:
num_input_images = image_output.image_class_embeds.shape[0]
print(f"{num_input_images=}")

image_class_embeds = image_output.image_class_embeds
image_class_embeds = image_class_embeds / (
torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
)

if isinstance(threshold, (int, float)):
threshold = [threshold] * len(
query_embeds
) # apply single threshold to all labels

query_embeds = torch.concat(query_embeds, dim=0)
query_embeds = query_embeds / (
torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
)
logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
logits = (logits + image_output.logit_shift) * image_output.logit_scale

scores_sigmoid = torch.sigmoid(logits)
scores_max = scores_sigmoid.max(dim=-1)
labels = scores_max.indices
scores = scores_max.values
masks = []
for i, thresh in enumerate(threshold):
label_mask = labels == i
score_mask = scores > thresh
obj_mask = torch.logical_and(label_mask, score_mask)
masks.append(obj_mask)
mask = masks[0]
for mask_t in masks[1:]:
mask = torch.logical_or(mask, mask_t)

input_indices = torch.arange(
0, num_input_images, dtype=labels.dtype, device=labels.device
)
input_indices = input_indices[:, None].repeat(1, self.owl_predictor.num_patches)

return OwlDecodeOutput(
labels=labels[mask],
scores=scores[mask],
boxes=image_output.pred_boxes[mask],
input_indices=input_indices[mask],
)

def encode_query_image(
self,
image: PIL.Image,
text: str,
pad_square: bool = True,
) -> torch.Tensor:
image_tensor = self.image_preprocessor.preprocess_pil_image(image)

text_encodings = self.encode_text([text])

rois = torch.tensor(
[[0, 0, image.width, image.height]],
dtype=image_tensor.dtype,
device=image_tensor.device,
)

image_encodings = self.owl_predictor.encode_rois(
image_tensor, rois, pad_square=pad_square
)

return self.find_best_encoding(image_encodings, text_encodings)

def encode_text(self, text) -> OwlEncodeTextOutput:
return self.owl_predictor.encode_text(text)

@staticmethod
def find_best_encoding(
image_output: OwlEncodeImageOutput,
text_output: OwlEncodeTextOutput,
) -> torch.Tensor:
image_class_embeds = image_output.image_class_embeds
image_class_embeds = image_class_embeds / (
torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
)
query_embeds = text_output.text_embeds
query_embeds = query_embeds / (
torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
)
logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
logits = (logits + image_output.logit_shift) * image_output.logit_scale

scores_sigmoid = torch.sigmoid(logits)
scores_max = scores_sigmoid.max(dim=-1)
scores = scores_max.values
best = torch.argmax(scores).item()
best_embed = image_class_embeds[:, best]
return best_embed
64 changes: 64 additions & 0 deletions test/test_fewshot_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import PIL.Image
import pytest
import torch
from nanoowl.fewshot_predictor import FewshotPredictor
from nanoowl.tree import Tree
from nanoowl.tree_predictor import TreePredictor


def test_encode_query_images():
predictor = FewshotPredictor()

query_image = PIL.Image.open("assets/frog.jpg")

query_encoding = predictor.encode_query_image(query_image, "a frog")

assert len(query_encoding.shape) == 2
assert query_encoding.shape[0] == 1
assert query_encoding.shape[1] == 512


def test_encode_labels():
predictor = FewshotPredictor()

labels = ["a frog", "an owl", "mice", "405943069245", ""]

text_encodings = predictor.encode_text(labels).text_embeds

assert len(text_encodings.shape) == 2
assert text_encodings.shape[0] == len(labels)
assert text_encodings.shape[1] == 512


def test_fewshot_predictor_predict():
predictor = FewshotPredictor()

image = PIL.Image.open("assets/cat_query_image.jpg")

query_image = PIL.Image.open("assets/cat_image.jpg")

query_label = "a cat"

thresholds = 0.7

query_embedding = predictor.encode_query_image(image=query_image, text=query_label)

detections = predictor.predict(image, [query_embedding], threshold=thresholds)

print(detections)