Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeepSparse backend for CLIP inference #323

Merged
merged 8 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ clip_inference turn a set of text+image into clip embeddings
* **slurm_cache_path**, cache path to use for slurm-related tasks. (default *None*)
* **slurm_verbose_wait=False**, wether to print the status of your slurm job (default *False*)

#### DeepSparse Backend

[DeepSparse](https://github.com/neuralmagic/deepsparse) is an inference runtime for fast sparse model inference on CPUs. There is a backend available within clip-retrieval by installing it with `pip install deepsparse-nightly[clip]`, and specifying a `clip_model` with a prepended `"nm:"`, such as [`"nm:neuralmagic/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K-quant-ds"`](https://huggingface.co/neuralmagic/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K-quant-ds) or [`"nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds"`](https://huggingface.co/mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds).

### Inference Worker

If you wish to have more control over how inference is run, you can create and call workers directly using `clip-retrieval inference.worker`
Expand Down
79 changes: 79 additions & 0 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import clip
from PIL import Image
import time
import numpy as np


class HFClipWrapper(nn.Module):
Expand Down Expand Up @@ -95,6 +96,81 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None
return model, preprocess


class DeepSparseWrapper(nn.Module):
"""
Wrap DeepSparse for managing input types
"""

def __init__(self, model_path):
super().__init__()

import deepsparse # pylint: disable=import-outside-toplevel

##### Fix for two-input models
from deepsparse.clip import CLIPTextPipeline # pylint: disable=import-outside-toplevel

def custom_process_inputs(self, inputs):
if not isinstance(inputs.text, list):
# Always wrap in a list
inputs.text = [inputs.text]
if not isinstance(inputs.text[0], str):
# If not a string, assume it's already been tokenized
tokens = np.stack(inputs.text, axis=0, dtype=np.int32)
return [tokens, np.array(tokens.shape[0] * [tokens.shape[1] - 1])]
else:
tokens = [np.array(t).astype(np.int32) for t in self.tokenizer(inputs.text)]
tokens = np.stack(tokens, axis=0)
return [tokens, np.array(tokens.shape[0] * [tokens.shape[1] - 1])]

# This overrides the process_inputs function globally for all CLIPTextPipeline classes
CLIPTextPipeline.process_inputs = custom_process_inputs
####

self.textual_model_path = model_path + "/textual.onnx"
self.visual_model_path = model_path + "/visual.onnx"

self.textual_model = deepsparse.Pipeline.create(task="clip_text", model_path=self.textual_model_path)
self.visual_model = deepsparse.Pipeline.create(task="clip_visual", model_path=self.visual_model_path)

def encode_image(self, image):
image = [np.array(image)]
embeddings = self.visual_model(images=image).image_embeddings[0]
return torch.from_numpy(embeddings)

def encode_text(self, text):
text = [t.numpy() for t in text]
embeddings = self.textual_model(text=text).text_embeddings[0]
return torch.from_numpy(embeddings)

def forward(self, *args, **kwargs): # pylint: disable=unused-argument
return NotImplemented


def load_deepsparse(clip_model):
"""load deepsparse"""

from huggingface_hub import snapshot_download # pylint: disable=import-outside-toplevel

# Download the model from HF
model_folder = snapshot_download(repo_id=clip_model)
# Compile the model with DeepSparse
model = DeepSparseWrapper(model_path=model_folder)

from deepsparse.clip.constants import CLIP_RGB_MEANS, CLIP_RGB_STDS # pylint: disable=import-outside-toplevel

def process_image(image):
image = model.visual_model._preprocess_transforms(image.convert("RGB")) # pylint: disable=protected-access
image_array = np.array(image)
image_array = image_array.transpose(2, 0, 1).astype("float32")
image_array /= 255.0
image_array = (image_array - np.array(CLIP_RGB_MEANS).reshape((3, 1, 1))) / np.array(CLIP_RGB_STDS).reshape(
(3, 1, 1)
)
return torch.from_numpy(np.ascontiguousarray(image_array, dtype=np.float32))

return model, process_image


@lru_cache(maxsize=None)
def get_tokenizer(clip_model):
"""Load clip"""
Expand All @@ -116,6 +192,9 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):
elif clip_model.startswith("hf_clip:"):
clip_model = clip_model[len("hf_clip:") :]
model, preprocess = load_hf_clip(clip_model, device)
elif clip_model.startswith("nm:"):
clip_model = clip_model[len("nm:") :]
model, preprocess = load_deepsparse(clip_model)
else:
model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path)
return model, preprocess
Expand Down
3 changes: 2 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ pytest==7.0.1
types-setuptools
types-requests
types-certifi
pyspark
pyspark
deepsparse-nightly[clip]
10 changes: 9 additions & 1 deletion tests/test_clip_inference/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from clip_retrieval.clip_inference.mapper import ClipMapper


@pytest.mark.parametrize("model", ["ViT-B/32", "open_clip:ViT-B-32-quickgelu", "hf_clip:patrickjohncyh/fashion-clip"])
@pytest.mark.parametrize(
"model",
[
"ViT-B/32",
"open_clip:ViT-B-32-quickgelu",
"hf_clip:patrickjohncyh/fashion-clip",
"nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds",
],
)
def test_mapper(model):
os.environ["CUDA_VISIBLE_DEVICES"] = ""

Expand Down
Loading