From 4b4a86a8560b36fc2bdb5aae3c5d2f941897ad33 Mon Sep 17 00:00:00 2001 From: Aaron Date: Sun, 24 Mar 2024 13:34:42 -0400 Subject: [PATCH 01/14] add fewshot predictor with tests and cli --- examples/fewshot_predict.py | 102 ++++++++++++++++++++ nanoowl/fewshot_predictor.py | 166 +++++++++++++++++++++++++++++++++ test/test_fewshot_predictor.py | 64 +++++++++++++ 3 files changed, 332 insertions(+) create mode 100644 examples/fewshot_predict.py create mode 100644 nanoowl/fewshot_predictor.py create mode 100644 test/test_fewshot_predictor.py diff --git a/examples/fewshot_predict.py b/examples/fewshot_predict.py new file mode 100644 index 0000000..5fe6429 --- /dev/null +++ b/examples/fewshot_predict.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os.path +import time + +import numpy as np +import PIL.Image +import torch +from nanoowl.fewshot_predictor import FewshotPredictor +from nanoowl.owl_drawing import draw_owl_output +from nanoowl.owl_predictor import OwlPredictor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, default="../assets/cat_image.jpg") + parser.add_argument( + "--query-image", + metavar="N", + type=str, + nargs="+", + help="an example of what to look for in the image", + default=["../assets/frog.jpg", "../assets/cat_query_image.jpg"], + ) + parser.add_argument( + "--query-label", + metavar="N", + type=str, + nargs="+", + help="a text label for each query image", + default=["a frog", "a cat"], + ) + parser.add_argument("--threshold", type=str, default="0.1,0.7") + parser.add_argument("--output", type=str, default="../data/fewshot_predict_out.jpg") + parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") + parser.add_argument( + "--image_encoder_engine", + type=str, + default="../data/owl_image_encoder_patch32.engine", + ) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num_profiling_runs", type=int, default=30) + args = parser.parse_args() + + image = PIL.Image.open(args.image) + + query_images = [] + for image_file in args.query_image: + if not os.path.isfile(image_file): + raise FileNotFoundError(f"File missing from {os.path.abspath(image_file)}") + else: + query_images.append(PIL.Image.open(image_file)) + + query_labels = args.query_label + + thresholds = args.threshold.strip("][()") + thresholds = thresholds.split(",") + if len(thresholds) == 1: + thresholds = float(thresholds[0]) + else: + thresholds = [float(x) for x in thresholds] + + predictor = FewshotPredictor( + owl_predictor=OwlPredictor( + args.model, image_encoder_engine=args.image_encoder_engine + ) + ) + + query_embeddings = [ + predictor.encode_query_image(image=query_image, text=query_labels) + for query_image in query_images + ] + + output = predictor.predict(image, query_embeddings, threshold=thresholds) + + if args.profile: + torch.cuda.current_stream().synchronize() + t0 = time.perf_counter_ns() + for i in range(args.num_profiling_runs): + output = predictor.predict(image, query_embeddings, threshold=thresholds) + torch.cuda.current_stream().synchronize() + t1 = time.perf_counter_ns() + dt = (t1 - t0) / 1e9 + print(f"PROFILING FPS: {args.num_profiling_runs/dt}") + + image = draw_owl_output(image, output, text=query_labels, draw_text=True) + + image.save(args.output) diff --git a/nanoowl/fewshot_predictor.py b/nanoowl/fewshot_predictor.py new file mode 100644 index 0000000..ccc8f9b --- /dev/null +++ b/nanoowl/fewshot_predictor.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +import PIL.Image +import torch + +from .image_preprocessor import ImagePreprocessor +from .owl_predictor import ( + OwlDecodeOutput, + OwlEncodeImageOutput, + OwlEncodeTextOutput, + OwlPredictor, +) + + +class FewshotPredictor(torch.nn.Module): + def __init__( + self, + owl_predictor: Optional[OwlPredictor] = None, + image_preprocessor: Optional[ImagePreprocessor] = None, + device: str = "cuda", + ): + super().__init__() + self.owl_predictor = OwlPredictor() if owl_predictor is None else owl_predictor + self.image_preprocessor = ( + ImagePreprocessor().to(device).eval() + if image_preprocessor is None + else image_preprocessor + ) + + @torch.no_grad() + def predict( + self, + image: PIL.Image, + query_embeddings: List, + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + pad_square: bool = True, + ) -> OwlDecodeOutput: + image_tensor = self.image_preprocessor.preprocess_pil_image(image) + + rois = torch.tensor( + [[0, 0, image.width, image.height]], + dtype=image_tensor.dtype, + device=image_tensor.device, + ) + + image_encodings = self.owl_predictor.encode_rois( + image_tensor, rois, pad_square=pad_square + ) + + return self.decode(image_encodings, query_embeddings, threshold) + + def decode( + self, + image_output: OwlEncodeImageOutput, + query_embeds, + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + ) -> OwlDecodeOutput: + num_input_images = image_output.image_class_embeds.shape[0] + print(f"{num_input_images=}") + + image_class_embeds = image_output.image_class_embeds + image_class_embeds = image_class_embeds / ( + torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6 + ) + + if isinstance(threshold, (int, float)): + threshold = [threshold] * len( + query_embeds + ) # apply single threshold to all labels + + query_embeds = torch.concat(query_embeds, dim=0) + query_embeds = query_embeds / ( + torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6 + ) + logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + logits = (logits + image_output.logit_shift) * image_output.logit_scale + + scores_sigmoid = torch.sigmoid(logits) + scores_max = scores_sigmoid.max(dim=-1) + labels = scores_max.indices + scores = scores_max.values + masks = [] + for i, thresh in enumerate(threshold): + label_mask = labels == i + score_mask = scores > thresh + obj_mask = torch.logical_and(label_mask, score_mask) + masks.append(obj_mask) + mask = masks[0] + for mask_t in masks[1:]: + mask = torch.logical_or(mask, mask_t) + + input_indices = torch.arange( + 0, num_input_images, dtype=labels.dtype, device=labels.device + ) + input_indices = input_indices[:, None].repeat(1, self.owl_predictor.num_patches) + + return OwlDecodeOutput( + labels=labels[mask], + scores=scores[mask], + boxes=image_output.pred_boxes[mask], + input_indices=input_indices[mask], + ) + + def encode_query_image( + self, + image: PIL.Image, + text: str, + pad_square: bool = True, + ) -> torch.Tensor: + image_tensor = self.image_preprocessor.preprocess_pil_image(image) + + text_encodings = self.encode_text([text]) + + rois = torch.tensor( + [[0, 0, image.width, image.height]], + dtype=image_tensor.dtype, + device=image_tensor.device, + ) + + image_encodings = self.owl_predictor.encode_rois( + image_tensor, rois, pad_square=pad_square + ) + + return self.find_best_encoding(image_encodings, text_encodings) + + def encode_text(self, text) -> OwlEncodeTextOutput: + return self.owl_predictor.encode_text(text) + + @staticmethod + def find_best_encoding( + image_output: OwlEncodeImageOutput, + text_output: OwlEncodeTextOutput, + ) -> torch.Tensor: + image_class_embeds = image_output.image_class_embeds + image_class_embeds = image_class_embeds / ( + torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6 + ) + query_embeds = text_output.text_embeds + query_embeds = query_embeds / ( + torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6 + ) + logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + logits = (logits + image_output.logit_shift) * image_output.logit_scale + + scores_sigmoid = torch.sigmoid(logits) + scores_max = scores_sigmoid.max(dim=-1) + scores = scores_max.values + best = torch.argmax(scores).item() + best_embed = image_class_embeds[:, best] + return best_embed diff --git a/test/test_fewshot_predictor.py b/test/test_fewshot_predictor.py new file mode 100644 index 0000000..4571fdb --- /dev/null +++ b/test/test_fewshot_predictor.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import PIL.Image +import pytest +import torch +from nanoowl.fewshot_predictor import FewshotPredictor +from nanoowl.tree import Tree +from nanoowl.tree_predictor import TreePredictor + + +def test_encode_query_images(): + predictor = FewshotPredictor() + + query_image = PIL.Image.open("assets/frog.jpg") + + query_encoding = predictor.encode_query_image(query_image, "a frog") + + assert len(query_encoding.shape) == 2 + assert query_encoding.shape[0] == 1 + assert query_encoding.shape[1] == 512 + + +def test_encode_labels(): + predictor = FewshotPredictor() + + labels = ["a frog", "an owl", "mice", "405943069245", ""] + + text_encodings = predictor.encode_text(labels).text_embeds + + assert len(text_encodings.shape) == 2 + assert text_encodings.shape[0] == len(labels) + assert text_encodings.shape[1] == 512 + + +def test_fewshot_predictor_predict(): + predictor = FewshotPredictor() + + image = PIL.Image.open("assets/cat_query_image.jpg") + + query_image = PIL.Image.open("assets/cat_image.jpg") + + query_label = "a cat" + + thresholds = 0.7 + + query_embedding = predictor.encode_query_image(image=query_image, text=query_label) + + detections = predictor.predict(image, [query_embedding], threshold=thresholds) + + print(detections) From 91a401c21c228e0620de6c198edb7ee371853c4c Mon Sep 17 00:00:00 2001 From: Aaron Date: Mon, 25 Mar 2024 10:16:11 -0400 Subject: [PATCH 02/14] fix fewshot_predict script assigns 1 label per query --- examples/fewshot_predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fewshot_predict.py b/examples/fewshot_predict.py index 5fe6429..77b1360 100644 --- a/examples/fewshot_predict.py +++ b/examples/fewshot_predict.py @@ -81,8 +81,8 @@ ) query_embeddings = [ - predictor.encode_query_image(image=query_image, text=query_labels) - for query_image in query_images + predictor.encode_query_image(image=query_image, text=query_label) + for query_image, query_label in zip(query_images, query_labels) ] output = predictor.predict(image, query_embeddings, threshold=thresholds) From 0a1b765710d423d728fcd2bb59a1c9a245bed002 Mon Sep 17 00:00:00 2001 From: Aaron Date: Tue, 26 Mar 2024 10:56:10 -0400 Subject: [PATCH 03/14] add usage example to readme --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index d86482c..da7bf06 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,22 @@ live-edited text prompts. To run the example +### Example 4 - Fewshot prediction + +This example replicates the Image-Conditioned Detection example in the original OwlVit repo. To run the example + + ```bash + cd examples + python3 fewshot_predict.py \ + --threshold="0.7,0.1" \ + --image_encoder_engine=../data/owl_image_encoder_patch32.engine \ + --query-image ../assets/cat_query_image.jpg ../assets/frog.jpg \ + --query-label "a cat" "a frog" + ``` + +By default the output will be saved to ``data/fewshot_predict_out.jpg``. + + ## 👏 Acknowledgement From 1df24da87701d5d086459c3c0b694cf752eb7e68 Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Mon, 29 Jul 2024 07:03:59 -0700 Subject: [PATCH 04/14] added owlv2 support --- nanoowl/owl_predictor.py | 163 ++++++++++++++++++++++----------------- 1 file changed, 91 insertions(+), 72 deletions(-) diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index 1afb897..3fae698 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -21,6 +21,9 @@ import tempfile import os from torchvision.ops import roi_align +from transformers.models.owlv2.modeling_owlv2 import Owlv2ForObjectDetection +from transformers.models.owlv2.processing_owlv2 import Owlv2Processor + from transformers.models.owlvit.modeling_owlvit import OwlViTForObjectDetection from transformers.models.owlvit.processing_owlvit import OwlViTProcessor from dataclasses import dataclass @@ -39,9 +42,9 @@ def _owl_center_to_corners_format_torch(bboxes_center): center_x, center_y, width, height = bboxes_center.unbind(-1) bbox_corners = torch.stack( [ - (center_x - 0.5 * width), - (center_y - 0.5 * height), - (center_x + 0.5 * width), + (center_x - 0.5 * width), + (center_y - 0.5 * height), + (center_x + 0.5 * width), (center_y + 0.5 * height) ], dim=-1, @@ -50,11 +53,12 @@ def _owl_center_to_corners_format_torch(bboxes_center): def _owl_get_image_size(hf_name: str): - image_sizes = { "google/owlvit-base-patch32": 768, "google/owlvit-base-patch16": 768, "google/owlvit-large-patch14": 840, + "google/owlv2-base-patch16-ensemble": 960, + "google/owlv2-large-patch14-ensemble": 1008, } return image_sizes[hf_name] @@ -66,6 +70,8 @@ def _owl_get_patch_size(hf_name: str): "google/owlvit-base-patch32": 32, "google/owlvit-base-patch16": 16, "google/owlvit-large-patch14": 14, + "google/owlv2-base-patch16-ensemble": 16, + "google/owlv2-large-patch14-ensemble": 14, } return patch_sizes[hf_name] @@ -141,25 +147,35 @@ class OwlDecodeOutput: class OwlPredictor(torch.nn.Module): - + def __init__(self, - model_name: str = "google/owlvit-base-patch32", - device: str = "cuda", - image_encoder_engine: Optional[str] = None, - image_encoder_engine_max_batch_size: int = 1, - image_preprocessor: Optional[ImagePreprocessor] = None - ): + model_name: str = "google/owlvit-base-patch32", + device: str = "cuda", + image_encoder_engine: Optional[str] = None, + image_encoder_engine_max_batch_size: int = 1, + image_preprocessor: Optional[ImagePreprocessor] = None + ): super().__init__() self.image_size = _owl_get_image_size(model_name) self.device = device - self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval() - self.processor = OwlViTProcessor.from_pretrained(model_name) + + model_type = model_name.split("/")[1].split('-')[0] + if model_type == 'owlv2': + self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(self.device).eval() + self.processor = Owlv2Processor.from_pretrained(model_name) + self.base_model = self.model.owlv2 + + else: + self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval() + self.processor = OwlViTProcessor.from_pretrained(model_name) + self.base_model = self.model.owlvit + self.patch_size = _owl_get_patch_size(model_name) self.num_patches_per_side = self.image_size // self.patch_size self.box_bias = _owl_compute_box_bias(self.num_patches_per_side).to(self.device) - self.num_patches = (self.num_patches_per_side)**2 + self.num_patches = (self.num_patches_per_side) ** 2 self.mesh_grid = torch.stack( torch.meshgrid( torch.linspace(0., 1., self.image_size), @@ -168,33 +184,35 @@ def __init__(self, ).to(self.device).float() self.image_encoder_engine = None if image_encoder_engine is not None: - image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, image_encoder_engine_max_batch_size) + image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, + image_encoder_engine_max_batch_size) self.image_encoder_engine = image_encoder_engine - self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() + self.image_preprocessor = image_preprocessor.to( + self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() def get_num_patches(self): return self.num_patches def get_device(self): return self.device - + def get_image_size(self): return (self.image_size, self.image_size) - + def encode_text(self, text: List[str]) -> OwlEncodeTextOutput: text_input = self.processor(text=text, return_tensors="pt") input_ids = text_input['input_ids'].to(self.device) attention_mask = text_input['attention_mask'].to(self.device) - text_outputs = self.model.owlvit.text_model(input_ids, attention_mask) + text_outputs = self.base_model.text_model(input_ids, attention_mask) text_embeds = text_outputs[1] - text_embeds = self.model.owlvit.text_projection(text_embeds) + text_embeds = self.base_model.text_projection(text_embeds) return OwlEncodeTextOutput(text_embeds=text_embeds) def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput: - - vision_outputs = self.model.owlvit.vision_model(image) + + vision_outputs = self.base_model.vision_model(image) last_hidden_state = vision_outputs[0] - image_embeds = self.model.owlvit.vision_model.post_layernorm(last_hidden_state) + image_embeds = self.base_model.vision_model.post_layernorm(last_hidden_state) class_token_out = image_embeds[:, :1, :] image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.model.layer_norm(image_embeds) # 768 dim @@ -220,7 +238,7 @@ def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput: ) return output - + def encode_image_trt(self, image: torch.Tensor) -> OwlEncodeImageOutput: return self.image_encoder_engine(image) @@ -230,7 +248,8 @@ def encode_image(self, image: torch.Tensor) -> OwlEncodeImageOutput: else: return self.encode_image_torch(image) - def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): + def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, + padding_scale: float = 1.0): if len(rois) == 0: return torch.empty( (0, image.shape[1], self.image_size, self.image_size), @@ -244,13 +263,15 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool cx = (rois[..., 0] + rois[..., 2]) / 2 cy = (rois[..., 1] + rois[..., 3]) / 2 s = torch.max(w, h) - rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1) + rois = torch.stack([cx - s, cy - s, cx + s, cy + s], dim=-1) # compute mask pad_x = (s - w) / (2 * s) pad_y = (s - h) / (2 * s) - mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) - mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) + mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & ( + self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) + mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & ( + self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) mask = (mask_x & mask_y) # extract rois @@ -261,8 +282,8 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool roi_images = (roi_images * mask[:, None, :, :]) return roi_images, rois - - def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0): + + def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): # with torch_timeit_sync("extract rois"): roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale) # with torch_timeit_sync("encode images"): @@ -271,14 +292,14 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool output.pred_boxes = pred_boxes return output - def decode(self, - image_output: OwlEncodeImageOutput, - text_output: OwlEncodeTextOutput, - threshold: Union[int, float, List[Union[int, float]]] = 0.1, - ) -> OwlDecodeOutput: + def decode(self, + image_output: OwlEncodeImageOutput, + text_output: OwlEncodeTextOutput, + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + ) -> OwlDecodeOutput: if isinstance(threshold, (int, float)): - threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels + threshold = [threshold] * len(text_output.text_embeds) # apply single threshold to all labels num_input_images = image_output.image_class_embeds.shape[0] @@ -288,7 +309,7 @@ def decode(self, query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) logits = (logits + image_output.logit_shift) * image_output.logit_scale - + scores_sigmoid = torch.sigmoid(logits) scores_max = scores_sigmoid.max(dim=-1) labels = scores_max.indices @@ -297,9 +318,9 @@ def decode(self, for i, thresh in enumerate(threshold): label_mask = labels == i score_mask = scores > thresh - obj_mask = torch.logical_and(label_mask,score_mask) - masks.append(obj_mask) - + obj_mask = torch.logical_and(label_mask, score_mask) + masks.append(obj_mask) + mask = masks[0] for mask_t in masks[1:]: mask = torch.logical_or(mask, mask_t) @@ -329,18 +350,18 @@ def get_image_encoder_output_names(): ] return names + def export_image_encoder_onnx(self, + output_path: str, + use_dynamic_axes: bool = True, + batch_size: int = 1, + onnx_opset=17 + ): - def export_image_encoder_onnx(self, - output_path: str, - use_dynamic_axes: bool = True, - batch_size: int = 1, - onnx_opset=17 - ): - class TempModule(torch.nn.Module): def __init__(self, parent): super().__init__() self.parent = parent + def forward(self, image): output = self.parent.encode_image_torch(image) return ( @@ -354,13 +375,13 @@ def forward(self, image): data = torch.randn(batch_size, 3, self.image_size, self.image_size).to(self.device) if use_dynamic_axes: - dynamic_axes = { + dynamic_axes = { "image": {0: "batch"}, "image_embeds": {0: "batch"}, "image_class_embeds": {0: "batch"}, "logit_shift": {0: "batch"}, "logit_scale": {0: "batch"}, - "pred_boxes": {0: "batch"} + "pred_boxes": {0: "batch"} } else: dynamic_axes = {} @@ -368,15 +389,15 @@ def forward(self, image): model = TempModule(self) torch.onnx.export( - model, - data, - output_path, - input_names=self.get_image_encoder_input_names(), + model, + data, + output_path, + input_names=self.get_image_encoder_input_names(), output_names=self.get_image_encoder_output_names(), dynamic_axes=dynamic_axes, opset_version=onnx_opset ) - + @staticmethod def load_image_encoder_engine(engine_path: str, max_batch_size: int = 1): import tensorrt as trt @@ -401,7 +422,6 @@ def __init__(self, base_module: TRTModule, max_batch_size: int): @torch.no_grad() def forward(self, image): - b = image.shape[0] results = [] @@ -427,13 +447,13 @@ def forward(self, image): return image_encoder - def build_image_encoder_engine(self, - engine_path: str, - max_batch_size: int = 1, - fp16_mode = True, - onnx_path: Optional[str] = None, - onnx_opset: int = 17 - ): + def build_image_encoder_engine(self, + engine_path: str, + max_batch_size: int = 1, + fp16_mode=True, + onnx_path: Optional[str] = None, + onnx_opset: int = 17 + ): if onnx_path is None: onnx_dir = tempfile.mkdtemp() @@ -441,7 +461,7 @@ def build_image_encoder_engine(self, self.export_image_encoder_onnx(onnx_path, onnx_opset=onnx_opset) args = ["/usr/src/tensorrt/bin/trtexec"] - + args.append(f"--onnx={onnx_path}") args.append(f"--saveEngine={engine_path}") @@ -454,14 +474,14 @@ def build_image_encoder_engine(self, return self.load_image_encoder_engine(engine_path, max_batch_size) - def predict(self, - image: PIL.Image, - text: List[str], - text_encodings: Optional[OwlEncodeTextOutput], - threshold: Union[int, float, List[Union[int, float]]] = 0.1, - pad_square: bool = True, - - ) -> OwlDecodeOutput: + def predict(self, + image: PIL.Image, + text: List[str], + text_encodings: Optional[OwlEncodeTextOutput], + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + pad_square: bool = True, + + ) -> OwlDecodeOutput: image_tensor = self.image_preprocessor.preprocess_pil_image(image) @@ -473,4 +493,3 @@ def predict(self, image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) return self.decode(image_encodings, text_encodings, threshold) - From 91baacd619bcf8c15bff1a0e985fc074bd7a202b Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Mon, 29 Jul 2024 11:59:32 -0700 Subject: [PATCH 05/14] align_rois argument added --- nanoowl/build_image_encoder_engine.py | 4 +++- nanoowl/owl_predictor.py | 15 ++++++++++++--- setup.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nanoowl/build_image_encoder_engine.py b/nanoowl/build_image_encoder_engine.py index 6c5910b..8bcdceb 100644 --- a/nanoowl/build_image_encoder_engine.py +++ b/nanoowl/build_image_encoder_engine.py @@ -25,10 +25,12 @@ parser.add_argument("--model_name", type=str, default="google/owlvit-base-patch32") parser.add_argument("--fp16_mode", type=bool, default=True) parser.add_argument("--onnx_opset", type=int, default=16) + parser.add_argument("--align_rois", type=bool, default=True) args = parser.parse_args() predictor = OwlPredictor( - model_name=args.model_name + model_name=args.model_name, + align_rois =args.align_rois, ) predictor.build_image_encoder_engine( diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index 3fae698..594e7c2 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -65,7 +65,6 @@ def _owl_get_image_size(hf_name: str): def _owl_get_patch_size(hf_name: str): - patch_sizes = { "google/owlvit-base-patch32": 32, "google/owlvit-base-patch16": 16, @@ -153,11 +152,13 @@ def __init__(self, device: str = "cuda", image_encoder_engine: Optional[str] = None, image_encoder_engine_max_batch_size: int = 1, - image_preprocessor: Optional[ImagePreprocessor] = None + image_preprocessor: Optional[ImagePreprocessor] = None, + align_rois=True, ): super().__init__() + self.align_rois = align_rois self.image_size = _owl_get_image_size(model_name) self.device = device @@ -275,7 +276,15 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool mask = (mask_x & mask_y) # extract rois - roi_images = roi_align(image, [rois], output_size=self.get_image_size()) + if self.align_rois: + roi_images = roi_align(image, [rois], output_size=self.get_image_size()) + else: + # Crop the image for each object detected + roi_images = [] + for i in range(len(rois)): + bbox = tuple(rois[i]) + object_image = image.crop(bbox) + roi_images.append(object_image) # mask rois if pad_square: diff --git a/setup.py b/setup.py index 27230b0..a7de8dd 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,6 @@ setup( name="nanoowl", - version="0.0.0", + version="0.0.1", packages=find_packages() ) \ No newline at end of file From 967fbd392a1ddb48a7d789200511d1910bfd7743 Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Tue, 30 Jul 2024 12:04:22 -0700 Subject: [PATCH 06/14] setup version update --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a7de8dd..bba1abe 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,9 @@ setup( name="nanoowl", - version="0.0.1", + version="0.0.2", + description='NanoOWL is a project that optimizes OWL-ViT to run ' + '🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with ' + 'NVIDIA TensorRT', packages=find_packages() ) \ No newline at end of file From 403e20cabd1d572c5c34bc6e5c5e65214c6117aa Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 2 Aug 2024 09:44:54 -0400 Subject: [PATCH 07/14] support cpu if cuda not available --- nanoowl/fewshot_predictor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nanoowl/fewshot_predictor.py b/nanoowl/fewshot_predictor.py index ccc8f9b..23f7ca1 100644 --- a/nanoowl/fewshot_predictor.py +++ b/nanoowl/fewshot_predictor.py @@ -33,10 +33,13 @@ def __init__( self, owl_predictor: Optional[OwlPredictor] = None, image_preprocessor: Optional[ImagePreprocessor] = None, - device: str = "cuda", + device: str = None, ): super().__init__() - self.owl_predictor = OwlPredictor() if owl_predictor is None else owl_predictor + device = device or "cuda" if torch.cuda.is_available() else "cpu" + self.owl_predictor = ( + OwlPredictor(device=device) if owl_predictor is None else owl_predictor + ) self.image_preprocessor = ( ImagePreprocessor().to(device).eval() if image_preprocessor is None From 5f212114085b3532cc0b96ed1abfc1cae67038d4 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 13 Sep 2024 11:40:23 -0400 Subject: [PATCH 08/14] allow multiple text hints --- nanoowl/fewshot_predictor.py | 8 ++++---- test/test_fewshot_predictor.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/nanoowl/fewshot_predictor.py b/nanoowl/fewshot_predictor.py index ccc8f9b..41a05b9 100644 --- a/nanoowl/fewshot_predictor.py +++ b/nanoowl/fewshot_predictor.py @@ -120,12 +120,12 @@ def decode( def encode_query_image( self, image: PIL.Image, - text: str, + text_hints: List[str], pad_square: bool = True, ) -> torch.Tensor: image_tensor = self.image_preprocessor.preprocess_pil_image(image) - text_encodings = self.encode_text([text]) + text_encodings = self.encode_text(text_hints) rois = torch.tensor( [[0, 0, image.width, image.height]], @@ -139,8 +139,8 @@ def encode_query_image( return self.find_best_encoding(image_encodings, text_encodings) - def encode_text(self, text) -> OwlEncodeTextOutput: - return self.owl_predictor.encode_text(text) + def encode_text(self, texts: List[str]) -> OwlEncodeTextOutput: + return self.owl_predictor.encode_text(texts) @staticmethod def find_best_encoding( diff --git a/test/test_fewshot_predictor.py b/test/test_fewshot_predictor.py index 4571fdb..2f3b5fc 100644 --- a/test/test_fewshot_predictor.py +++ b/test/test_fewshot_predictor.py @@ -15,19 +15,15 @@ import PIL.Image -import pytest -import torch from nanoowl.fewshot_predictor import FewshotPredictor -from nanoowl.tree import Tree -from nanoowl.tree_predictor import TreePredictor def test_encode_query_images(): - predictor = FewshotPredictor() + predictor = FewshotPredictor(device="cpu") query_image = PIL.Image.open("assets/frog.jpg") - query_encoding = predictor.encode_query_image(query_image, "a frog") + query_encoding = predictor.encode_query_image(query_image, ["a frog"]) assert len(query_encoding.shape) == 2 assert query_encoding.shape[0] == 1 From 774c8220be546b40c144108780fe942bf96b7293 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 13 Sep 2024 11:41:28 -0400 Subject: [PATCH 09/14] allow multiple text hints --- examples/fewshot_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fewshot_predict.py b/examples/fewshot_predict.py index 77b1360..b78481a 100644 --- a/examples/fewshot_predict.py +++ b/examples/fewshot_predict.py @@ -81,7 +81,7 @@ ) query_embeddings = [ - predictor.encode_query_image(image=query_image, text=query_label) + predictor.encode_query_image(image=query_image, text_hints=[query_labels]) for query_image, query_label in zip(query_images, query_labels) ] From 713b9c8bea3514b905e16f162c36900fa4406b32 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 13 Sep 2024 12:05:11 -0400 Subject: [PATCH 10/14] minor fixes --- examples/fewshot_predict.py | 1 - nanoowl/owl_drawing.py | 26 +++++++++++--------------- setup.py | 13 ++++++------- test/test_fewshot_predictor.py | 8 +++++--- 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/examples/fewshot_predict.py b/examples/fewshot_predict.py index b78481a..017d9c1 100644 --- a/examples/fewshot_predict.py +++ b/examples/fewshot_predict.py @@ -18,7 +18,6 @@ import os.path import time -import numpy as np import PIL.Image import torch from nanoowl.fewshot_predictor import FewshotPredictor diff --git a/nanoowl/owl_drawing.py b/nanoowl/owl_drawing.py index d580398..47edf83 100644 --- a/nanoowl/owl_drawing.py +++ b/nanoowl/owl_drawing.py @@ -14,13 +14,15 @@ # limitations under the License. -import PIL.Image -import PIL.ImageDraw +from typing import List + import cv2 -from .owl_predictor import OwlDecodeOutput import matplotlib.pyplot as plt import numpy as np -from typing import List +import PIL.Image +import PIL.ImageDraw + +from .owl_predictor import OwlDecodeOutput def get_colors(count: int): @@ -36,7 +38,7 @@ def get_colors(count: int): def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=True): is_pil = not isinstance(image, np.ndarray) if is_pil: - image = np.asarray(image) + image = np.asarray(image).copy() font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.75 colors = get_colors(len(text)) @@ -48,13 +50,7 @@ def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=T box = [int(x) for x in box] pt0 = (box[0], box[1]) pt1 = (box[2], box[3]) - cv2.rectangle( - image, - pt0, - pt1, - colors[label_index], - 4 - ) + cv2.rectangle(image, pt0, pt1, colors[label_index], 4) if draw_text: offset_y = 12 offset_x = 0 @@ -66,9 +62,9 @@ def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=T font, font_scale, colors[label_index], - 2,# thickness - cv2.LINE_AA + 2, # thickness + cv2.LINE_AA, ) if is_pil: image = PIL.Image.fromarray(image) - return image \ No newline at end of file + return image diff --git a/setup.py b/setup.py index bba1abe..889239f 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,10 @@ from setuptools import find_packages, setup - setup( name="nanoowl", - version="0.0.2", - description='NanoOWL is a project that optimizes OWL-ViT to run ' - '🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with ' - 'NVIDIA TensorRT', - packages=find_packages() -) \ No newline at end of file + version="0.0.4", + description="NanoOWL is a project that optimizes OWL-ViT to run " + "🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with " + "NVIDIA TensorRT", + packages=find_packages(), +) diff --git a/test/test_fewshot_predictor.py b/test/test_fewshot_predictor.py index 2f3b5fc..1dc9432 100644 --- a/test/test_fewshot_predictor.py +++ b/test/test_fewshot_predictor.py @@ -45,15 +45,17 @@ def test_encode_labels(): def test_fewshot_predictor_predict(): predictor = FewshotPredictor() - image = PIL.Image.open("assets/cat_query_image.jpg") + image = PIL.Image.open("../assets/cat_query_image.jpg") - query_image = PIL.Image.open("assets/cat_image.jpg") + query_image = PIL.Image.open("../assets/cat_image.jpg") query_label = "a cat" thresholds = 0.7 - query_embedding = predictor.encode_query_image(image=query_image, text=query_label) + query_embedding = predictor.encode_query_image( + image=query_image, text_hints=[query_label] + ) detections = predictor.predict(image, [query_embedding], threshold=thresholds) From 08fb01535f6f15789f32edfed828e42ce0613ca1 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 13 Sep 2024 12:15:40 -0400 Subject: [PATCH 11/14] fewshot_predict example runs without trt engine --- examples/fewshot_predict.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/fewshot_predict.py b/examples/fewshot_predict.py index 017d9c1..2389c94 100644 --- a/examples/fewshot_predict.py +++ b/examples/fewshot_predict.py @@ -73,10 +73,18 @@ else: thresholds = [float(x) for x in thresholds] - predictor = FewshotPredictor( - owl_predictor=OwlPredictor( - args.model, image_encoder_engine=args.image_encoder_engine + engine_path = ( + args.image_encoder_engine if os.path.isfile(args.image_encoder_engine) else None + ) + if not os.path.isfile(args.image_encoder_engine): + print( + f"No image encoder engine found at", + "{os.path.abspath(args.image_encoder_engine)}.", + "Continuing without tensorrt...", ) + + predictor = FewshotPredictor( + owl_predictor=OwlPredictor(args.model, image_encoder_engine=engine_path) ) query_embeddings = [ From 73aacfe9514a6865bb5e9d0bee48e5c31e46ac19 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 8 Nov 2024 09:41:26 -0500 Subject: [PATCH 12/14] remove print statement from src code --- nanoowl/fewshot_predictor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nanoowl/fewshot_predictor.py b/nanoowl/fewshot_predictor.py index cf7701b..2a848b9 100644 --- a/nanoowl/fewshot_predictor.py +++ b/nanoowl/fewshot_predictor.py @@ -75,7 +75,6 @@ def decode( threshold: Union[int, float, List[Union[int, float]]] = 0.1, ) -> OwlDecodeOutput: num_input_images = image_output.image_class_embeds.shape[0] - print(f"{num_input_images=}") image_class_embeds = image_output.image_class_embeds image_class_embeds = image_class_embeds / ( From 0cceb0cca99b0f62c67cbae133683bdee9c54323 Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 8 Nov 2024 09:49:21 -0500 Subject: [PATCH 13/14] update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 889239f..c7c044b 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="nanoowl", - version="0.0.4", + version="0.0.4-73aacfe", description="NanoOWL is a project that optimizes OWL-ViT to run " "🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with " "NVIDIA TensorRT", From 380cfe8abd79c46a2622f4a6bdc374b65434ae5f Mon Sep 17 00:00:00 2001 From: Aaron Date: Fri, 8 Nov 2024 09:49:21 -0500 Subject: [PATCH 14/14] update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 889239f..c38eb04 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="nanoowl", - version="0.0.4", + version="0.0.4.dev0+73aacfe", description="NanoOWL is a project that optimizes OWL-ViT to run " "🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with " "NVIDIA TensorRT",