Skip to content

Commit

Permalink
Add DeepSparse backend for CLIP inference (#323)
Browse files Browse the repository at this point in the history
* Make an inference backend for DeepSparse

* Final fixes for MVP

* Update for test

* Change test

* Update README

* Lint

* lint

* add deepsparse nightly to test deps

---------

Co-authored-by: Romain Beaumont <[email protected]>
  • Loading branch information
mgoin and rom1504 authored Jan 6, 2024
1 parent c14bf7a commit 36a4060
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ clip_inference turn a set of text+image into clip embeddings
* **slurm_cache_path**, cache path to use for slurm-related tasks. (default *None*)
* **slurm_verbose_wait=False**, wether to print the status of your slurm job (default *False*)

#### DeepSparse Backend

[DeepSparse](https://github.com/neuralmagic/deepsparse) is an inference runtime for fast sparse model inference on CPUs. There is a backend available within clip-retrieval by installing it with `pip install deepsparse-nightly[clip]`, and specifying a `clip_model` with a prepended `"nm:"`, such as [`"nm:neuralmagic/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K-quant-ds"`](https://huggingface.co/neuralmagic/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K-quant-ds) or [`"nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds"`](https://huggingface.co/mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds).

### Inference Worker

If you wish to have more control over how inference is run, you can create and call workers directly using `clip-retrieval inference.worker`
Expand Down
79 changes: 79 additions & 0 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import clip
from PIL import Image
import time
import numpy as np


class HFClipWrapper(nn.Module):
Expand Down Expand Up @@ -95,6 +96,81 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None
return model, preprocess


class DeepSparseWrapper(nn.Module):
"""
Wrap DeepSparse for managing input types
"""

def __init__(self, model_path):
super().__init__()

import deepsparse # pylint: disable=import-outside-toplevel

##### Fix for two-input models
from deepsparse.clip import CLIPTextPipeline # pylint: disable=import-outside-toplevel

def custom_process_inputs(self, inputs):
if not isinstance(inputs.text, list):
# Always wrap in a list
inputs.text = [inputs.text]
if not isinstance(inputs.text[0], str):
# If not a string, assume it's already been tokenized
tokens = np.stack(inputs.text, axis=0, dtype=np.int32)
return [tokens, np.array(tokens.shape[0] * [tokens.shape[1] - 1])]
else:
tokens = [np.array(t).astype(np.int32) for t in self.tokenizer(inputs.text)]
tokens = np.stack(tokens, axis=0)
return [tokens, np.array(tokens.shape[0] * [tokens.shape[1] - 1])]

# This overrides the process_inputs function globally for all CLIPTextPipeline classes
CLIPTextPipeline.process_inputs = custom_process_inputs
####

self.textual_model_path = model_path + "/textual.onnx"
self.visual_model_path = model_path + "/visual.onnx"

self.textual_model = deepsparse.Pipeline.create(task="clip_text", model_path=self.textual_model_path)
self.visual_model = deepsparse.Pipeline.create(task="clip_visual", model_path=self.visual_model_path)

def encode_image(self, image):
image = [np.array(image)]
embeddings = self.visual_model(images=image).image_embeddings[0]
return torch.from_numpy(embeddings)

def encode_text(self, text):
text = [t.numpy() for t in text]
embeddings = self.textual_model(text=text).text_embeddings[0]
return torch.from_numpy(embeddings)

def forward(self, *args, **kwargs): # pylint: disable=unused-argument
return NotImplemented


def load_deepsparse(clip_model):
"""load deepsparse"""

from huggingface_hub import snapshot_download # pylint: disable=import-outside-toplevel

# Download the model from HF
model_folder = snapshot_download(repo_id=clip_model)
# Compile the model with DeepSparse
model = DeepSparseWrapper(model_path=model_folder)

from deepsparse.clip.constants import CLIP_RGB_MEANS, CLIP_RGB_STDS # pylint: disable=import-outside-toplevel

def process_image(image):
image = model.visual_model._preprocess_transforms(image.convert("RGB")) # pylint: disable=protected-access
image_array = np.array(image)
image_array = image_array.transpose(2, 0, 1).astype("float32")
image_array /= 255.0
image_array = (image_array - np.array(CLIP_RGB_MEANS).reshape((3, 1, 1))) / np.array(CLIP_RGB_STDS).reshape(
(3, 1, 1)
)
return torch.from_numpy(np.ascontiguousarray(image_array, dtype=np.float32))

return model, process_image


@lru_cache(maxsize=None)
def get_tokenizer(clip_model):
"""Load clip"""
Expand All @@ -116,6 +192,9 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):
elif clip_model.startswith("hf_clip:"):
clip_model = clip_model[len("hf_clip:") :]
model, preprocess = load_hf_clip(clip_model, device)
elif clip_model.startswith("nm:"):
clip_model = clip_model[len("nm:") :]
model, preprocess = load_deepsparse(clip_model)
else:
model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path)
return model, preprocess
Expand Down
3 changes: 2 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ pytest==7.0.1
types-setuptools
types-requests
types-certifi
pyspark
pyspark
deepsparse-nightly[clip]
10 changes: 9 additions & 1 deletion tests/test_clip_inference/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from clip_retrieval.clip_inference.mapper import ClipMapper


@pytest.mark.parametrize("model", ["ViT-B/32", "open_clip:ViT-B-32-quickgelu", "hf_clip:patrickjohncyh/fashion-clip"])
@pytest.mark.parametrize(
"model",
[
"ViT-B/32",
"open_clip:ViT-B-32-quickgelu",
"hf_clip:patrickjohncyh/fashion-clip",
"nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds",
],
)
def test_mapper(model):
os.environ["CUDA_VISIBLE_DEVICES"] = ""

Expand Down

0 comments on commit 36a4060

Please sign in to comment.