Skip to content

Commit

Permalink
Add: ONNX and DPO tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 28, 2024
1 parent e671fe4 commit 658c937
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 56 deletions.
8 changes: 7 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
{
"cSpell.words": [
"dtype",
"embs",
"huggingface",
"keepdim",
"logits",
"multimodal",
"ndarray",
"onnxruntime",
"preprocess",
"pretrained",
"probs",
"pypi",
"reranker",
"reranking",
"softmax",
"transfromers",
"uform",
Expand All @@ -18,4 +24,4 @@
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
}
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,20 @@ Results for VQAv2 evaluation.
> ² Lacking a broad enough evaluation dataset, we translated the [COCO Karpathy test split](https://www.kaggle.com/datasets/shtvkumar/karpathy-splits) with multiple public and proprietary translation services, averaging the scores across all sets, and breaking them down in the bottom section. <br/>
> ³ We used `apple/DFN5B-CLIP-ViT-H-14-378` CLIP model.
## Size

Torch is a heavy dependency and most models are too large to run on edge and on IoT devices.
Using the ONNX runtime, one can significantly reduce memory consumption and deployment latency.

```sh
$ conda create -n env_torch python=3.10 -y
$ conda create -n env_onnx python=3.10 -y
$ conda activate env_torch && pip install -e ".[torch]" && conda deactivate
$ conda activate env_onnx && pip install -e ".[onnx-gpu]" && conda deactivate
du -sh $(conda info --envs | grep 'env_torch' | awk '{print $2}')
du -sh $(conda info --envs | grep 'env_onnx' | awk '{print $2}')
```

## Speed

On Nvidia RTX 3090, the following performance is expected on text encoding.
Expand Down
85 changes: 79 additions & 6 deletions python/scripts/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import pytest
from PIL import Image
import uform
Expand All @@ -8,13 +10,17 @@
]

onnx_models_and_providers = [
("unum-cloud/uform-vl-english", "cpu"),
("unum-cloud/uform-vl-multilingual-v2", "cpu"),
("unum-cloud/uform-vl-english-large", "cpu", "fp32"),
("unum-cloud/uform-vl-english-small", "cpu", "fp32"),
("unum-cloud/uform-vl-english-large", "gpu", "fp32"),
("unum-cloud/uform-vl-english-small", "gpu", "fp32"),
("unum-cloud/uform-vl-english-large", "gpu", "fp16"),
("unum-cloud/uform-vl-english-small", "gpu", "fp16"),
]


@pytest.mark.parametrize("model_name", torch_models)
def test_one_embedding(model_name: str):
def test_torch_one_embedding(model_name: str):
model, processor = uform.get_model(model_name)
text = "a small red panda in a zoo"
image_path = "assets/unum.png"
Expand All @@ -23,16 +29,23 @@ def test_one_embedding(model_name: str):
image_data = processor.preprocess_image(image)
text_data = processor.preprocess_text(text)

_, image_embedding = model.encode_image(image_data, return_features=True)
_, text_embedding = model.encode_text(text_data, return_features=True)
image_features, image_embedding = model.encode_image(image_data, return_features=True)
text_features, text_embedding = model.encode_text(text_data, return_features=True)

assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"

# Test reranking
joint_embedding = model.encode_multimodal(
image_features=image_features, text_features=text_features, attention_mask=text_data["attention_mask"]
)
score = model.get_matching_scores(joint_embedding)
assert score.shape[0] == 1, "Matching score batch size is not 1"


@pytest.mark.parametrize("model_name", torch_models)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_many_embeddings(model_name: str, batch_size: int):
def test_torch_many_embeddings(model_name: str, batch_size: int):
model, processor = uform.get_model(model_name)
texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size
Expand All @@ -46,3 +59,63 @@ def test_many_embeddings(model_name: str, batch_size: int):

assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"


@pytest.mark.parametrize("model_specs", onnx_models_and_providers)
def test_onnx_one_embedding(model_specs: Tuple[str, str, str]):

from uform.onnx_models import ExecutionProviderError

try:

model, processor = uform.get_model_onnx(*model_specs)
text = "a small red panda in a zoo"
image_path = "assets/unum.png"

image = Image.open(image_path)
image_data = processor.preprocess_image(image)
text_data = processor.preprocess_text(text)

image_features, image_embedding = model.encode_image(image_data, return_features=True)
text_features, text_embedding = model.encode_text(text_data, return_features=True)

assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"

score, joint_embedding = model.encode_multimodal(
image_features=image_features,
text_features=text_features,
attention_mask=text_data["attention_mask"],
return_scores=True,
)
assert score.shape[0] == 1, "Matching score batch size is not 1"
assert joint_embedding.shape[0] == 1, "Joint embedding batch size is not 1"

except ExecutionProviderError as e:
pytest.skip(f"Execution provider error: {e}")


@pytest.mark.parametrize("model_specs", onnx_models_and_providers)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_onnx_many_embeddings(model_specs: Tuple[str, str, str], batch_size: int):

from uform.onnx_models import ExecutionProviderError

try:

model, processor = uform.get_model_onnx(*model_specs)
texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size

images = [Image.open(path) for path in image_paths]
image_data = processor.preprocess_image(images)
text_data = processor.preprocess_text(texts)

image_embeddings = model.encode_image(image_data, return_features=False)
text_embeddings = model.encode_text(text_data, return_features=False)

assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"

except ExecutionProviderError as e:
pytest.skip(f"Execution provider error: {e}")
1 change: 1 addition & 0 deletions python/scripts/test_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

torch_hf_models = [
"unum-cloud/uform-gen2-qwen-500m",
"unum-cloud/uform-gen2-dpo",
]


Expand Down
4 changes: 2 additions & 2 deletions python/uform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_model(model_name: str, token: Optional[str] = None):

def get_model_onnx(model_name: str, device: str, dtype: str, token: Optional[str] = None):
from uform.onnx_models import VLM_ONNX
from uform.numpy_preprocessor import NumpyProcessor
from uform.numpy_preprocessor import NumPyProcessor

assert device in (
"cpu",
Expand All @@ -54,6 +54,6 @@ def get_model_onnx(model_name: str, device: str, dtype: str, token: Optional[str
config = load(f)

model = VLM_ONNX(model_path, config, device, dtype)
processor = NumpyProcessor(config, join(model_path, "tokenizer.json"))
processor = NumPyProcessor(config, join(model_path, "tokenizer.json"))

return model, processor
2 changes: 1 addition & 1 deletion python/uform/numpy_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np


class NumpyProcessor:
class NumPyProcessor:
def __init__(self, config: Dict, tokenizer_path: PathLike):
"""
:param config: model config
Expand Down
84 changes: 38 additions & 46 deletions python/uform/onnx_models.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
from os.path import join
from typing import Dict, Optional, Tuple, Union

import onnxruntime
import onnxruntime as ort
from numpy import ndarray


class ExecutionProviderError(Exception):
"""Exception raised when a requested execution provider is not available."""


def available_providers(device: str) -> Tuple[str, ...]:
available = ort.get_available_providers()
if device == "gpu":
if "CUDAExecutionProvider" not in available:
raise ExecutionProviderError(
"CUDAExecutionProvider is not available, consider installing `onnxruntime-gpu` and make sure the CUDA is available on your system."
)
return ("CUDAExecutionProvider",)

return ("CPUExecutionProvider", "CoreMLExecutionProvider")


class VisualEncoderONNX:
def __init__(self, model_path: str, device: str):
"""
:param model_path: Path to onnx model
:param device: Device name, either cpu or gpu
"""

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

self.session = onnxruntime.InferenceSession(
self.session = ort.InferenceSession(
model_path,
sess_options=sess_options,
providers=[
"CUDAExecutionProvider" if device == "gpu" else "CPUExecutionProvider"
],
sess_options=session_options,
providers=available_providers(device),
)

def __call__(self, images: ndarray) -> Tuple[ndarray, ndarray]:
Expand All @@ -37,33 +49,23 @@ def __init__(self, text_encoder_path: str, reranker_path: str, device: str):
:param device: Device name, either cpu or gpu
"""

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

self.text_encoder_session = onnxruntime.InferenceSession(
self.text_encoder_session = ort.InferenceSession(
text_encoder_path,
sess_options=sess_options,
providers=[
"CUDAExecutionProvider" if device == "gpu" else "CPUExecutionProvider"
],
sess_options=session_options,
providers=available_providers(device),
)

self.reranker_session = onnxruntime.InferenceSession(
self.reranker_session = ort.InferenceSession(
reranker_path,
sess_options=sess_options,
providers=[
"CUDAExecutionProvider" if device == "gpu" else "CPUExecutionProvider"
],
sess_options=session_options,
providers=available_providers(device),
)

def __call__(
self, input_ids: ndarray, attention_mask: ndarray
) -> Tuple[ndarray, ndarray]:
return self.text_encoder_session.run(
None, {"input_ids": input_ids, "attention_mask": attention_mask}
)
def __call__(self, input_ids: ndarray, attention_mask: ndarray) -> Tuple[ndarray, ndarray]:
return self.text_encoder_session.run(None, {"input_ids": input_ids, "attention_mask": attention_mask})

def forward_multimodal(
self, text_features: ndarray, attention_mask: ndarray, image_features: ndarray
Expand All @@ -90,9 +92,7 @@ def __init__(self, checkpoint_path: str, config: Dict, device: str, dtype: str):
), f"Invalid `dtype`: {dtype}. Must be either `fp32` or `fp16` (only for gpu)"
assert (
device == "cpu" and dtype == "fp32"
) or device == "gpu", (
"Combination `device`=`cpu` & `dtype=fp16` is not supported"
)
) or device == "gpu", "Combination `device`=`cpu` & `dtype=fp16` is not supported"

self.device = device
self.dtype = dtype
Expand All @@ -107,9 +107,7 @@ def __init__(self, checkpoint_path: str, config: Dict, device: str, dtype: str):
device,
)

self.image_encoder = VisualEncoderONNX(
join(checkpoint_path, f"image_encoder.onnx"), device
)
self.image_encoder = VisualEncoderONNX(join(checkpoint_path, f"image_encoder.onnx"), device)

def encode_image(
self,
Expand Down Expand Up @@ -160,23 +158,17 @@ def encode_multimodal(
preprocessed images (or precomputed images features) through multimodal encoded to produce matching scores and optionally multimodal joint embeddings.
:param image: Preprocessed images
:param text: Preprocesses texts
:param text: Preprocessed texts
:param image_features: Precomputed images features
:param text_features: Precomputed text features
:param attention_mask: Attention masks, not required if pass `text` instead of text_features
"""

assert (
image is not None or image_features is not None
), "Either `image` or `image_features` should be non None"
assert (
text is not None or text_features is not None
), "Either `text_data` or `text_features` should be non None"
assert image is not None or image_features is not None, "Either `image` or `image_features` should be non None"
assert text is not None or text_features is not None, "Either `text_data` or `text_features` should be non None"

if text_features is not None:
assert (
attention_mask is not None
), "if `text_features` is not None, then you should pass `attention_mask`"
assert attention_mask is not None, "if `text_features` is not None, then you should pass `attention_mask`"

if image_features is None:
image_features = self.image_encoder(image)
Expand Down

0 comments on commit 658c937

Please sign in to comment.