Skip to content

Commit

Permalink
Add the option to select the openclip model
Browse files Browse the repository at this point in the history
ability to select specific version of openclip

add option to select the opencli model
  • Loading branch information
barinov274 committed Jun 13, 2023
1 parent 049b4b3 commit 11fd256
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None

torch.backends.cuda.matmul.allow_tf32 = True

pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
clip_model = clip_model.split(" | ")
checkpoint = dict(open_clip.list_pretrained())[clip_model[0]] if len(clip_model)<2 else clip_model[1]
clip_model = clip_model[0]
print(f"Loading OpenClip model {clip_model} with {checkpoint} checkpoint")
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model, pretrained=checkpoint, device=device, jit=use_jit, cache_dir=clip_cache_path
)
Expand All @@ -61,6 +63,8 @@ def get_tokenizer(clip_model):
import open_clip # pylint: disable=import-outside-toplevel

clip_model = clip_model[len("open_clip:") :]
clip_model = clip_model.split(" | ")
clip_model = clip_model[0]
return open_clip.get_tokenizer(clip_model)
else:
return lambda t: clip.tokenize(t, truncate=True)
Expand All @@ -71,6 +75,8 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):
"""Load clip"""
if clip_model.startswith("open_clip:"):
clip_model = clip_model[len("open_clip:") :]
clip_model = clip_model.split(" | ")
clip_model = clip_model[0]
model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path)
else:
model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path)
Expand Down

0 comments on commit 11fd256

Please sign in to comment.