Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to select the openclip model #284

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ clip_inference turn a set of text+image into clip embeddings
* **write_batch_size** Write batch size (default *10**6*)
* **wds_image_key** Key to use for images in webdataset. (default *jpg*)
* **wds_caption_key** Key to use for captions in webdataset. (default *txt*)
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip).
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip). You can also specify a checkpoint of openclip model that you need to download like this: `"open_clip:ViT-L-14 | datacomp_xl_s13b_b90k"`. To see a list of available openclip models, you can use this code:```import open_clip; print(open_clip.list_pretrained())```
* **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*)
* **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*)
* **use_jit** uses jit for the clip model (default *True*)
Expand Down Expand Up @@ -223,7 +223,7 @@ The API is very similar to `clip-retrieval inference` with some minor changes:
* **enable_metadata** Enable metadata processing (default *False*)
* **wds_image_key** Key to use for images in webdataset. (default *jpg*)
* **wds_caption_key** Key to use for captions in webdataset. (default *txt*)
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip).
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip). You can also specify a checkpoint of openclip model that you need to download like this: `"open_clip:ViT-L-14 | datacomp_xl_s13b_b90k"`. To see a list of available openclip models, you can use this code:```import open_clip; print(open_clip.list_pretrained())```
* **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*)
* **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*)
* **use_jit** uses jit for the clip model (default *True*)
Expand Down Expand Up @@ -304,7 +304,7 @@ clip-retrieval back --port 1234 --indices-paths indices_paths.json

Options:
* `--use_jit True` uses jit for the clip model
* `--clip_model "ViT-B/32"` allows choosing the clip model to use. Prefix with `"open_clip:"` to use an [open_clip](https://github.com/mlfoundations/open_clip) model.
* `--clip_model "ViT-B/32"` allows choosing the clip model to use. Prefix with `"open_clip:"` to use an [open_clip](https://github.com/mlfoundations/open_clip) model. You can also specify a checkpoint of openclip model that you need to download like this: `"open_clip:ViT-L-14 | datacomp_xl_s13b_b90k"`. To see a list of available openclip models, you can use this code:```import open_clip; print(open_clip.list_pretrained())```
* `--enable_mclip_option True` loads the mclip model, making it possible to search in any language.
* `--columns_to_return='["url", "image_path", "caption", "NSFW"]` allows you to specify which columns should be fetched from the metadata and returned by the backend. It's useful to specify less in case of hdf5 caching to speed up the queries.
* `--enable_faiss_memory_mapping=True` option can be passed to use an index with memory mapping.
Expand Down
10 changes: 8 additions & 2 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None

torch.backends.cuda.matmul.allow_tf32 = True

pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
clip_model = clip_model.split(" | ")
checkpoint = dict(open_clip.list_pretrained())[clip_model[0]] if len(clip_model)<2 else clip_model[1]
clip_model = clip_model[0]
print(f"Loading OpenClip model {clip_model} with {checkpoint} checkpoint")
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path
)
Expand All @@ -61,6 +63,8 @@ def get_tokenizer(clip_model):
import open_clip # pylint: disable=import-outside-toplevel

clip_model = clip_model[len("open_clip:") :]
clip_model = clip_model.split(" | ")
clip_model = clip_model[0]
return open_clip.get_tokenizer(clip_model)
else:
return lambda t: clip.tokenize(t, truncate=True)
Expand All @@ -71,6 +75,8 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):
"""Load clip"""
if clip_model.startswith("open_clip:"):
clip_model = clip_model[len("open_clip:") :]
clip_model = clip_model.split(" | ")
clip_model = clip_model[0]
model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path)
else:
model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path)
Expand Down
Loading