Skip to content

Commit

Permalink
New aesthetic option (#147)
Browse files Browse the repository at this point in the history
* New aesthetic option

* add requests to deps

* add aiohttp to deps

* disable jit in mapper test

to work around mlfoundations/open_clip#95

* disable mclip in end2end test to work around hf outage
  • Loading branch information
rom1504 authored May 21, 2022
1 parent 376f13f commit 3d8f650
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ HDF5 caching makes it possible to use the metadata with almost no memory usage.
* `--reorder_metadata_by_ivf_index True` option takes advantage of the data locality property of results of a knn ivf indices: it orders the metadata collection in order of the IVF clusters. That makes it possible to have much faster metadata retrieval as the reads are then accessing a few mostly sequential parts of the metadata instead of many non sequential parts. In practice that means being able to retrieve 1M items in 1s whereas only 1000 items can be retrieved in 1s without this method. This will order the metadata using the first image index.
* `--provide_safety_model True` will automatically download and load a [safety model](https://github.com/LAION-AI/CLIP-based-NSFW-Detector). You need to `pip install autokeras` optional dependency for this to work.
* `--provide_violence_detector True` will load a [violence detector](https://github.com/ml-research/OffImgDetectionCLIP), [paper](https://arxiv.org/abs/2202.06675.pdf)
* `--provide_aesthetic_embeddings True` will load the [aesthetic embeddings](https://github.com/LAION-AI/aesthetic-predictor) and allow users to make the query move towards a nicer point of the clip space

These options can also be provided in the config file to have different options for each index. Example:
```json
Expand Down
59 changes: 58 additions & 1 deletion clip_retrieval/clip_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flask_cors import CORS
import faiss
from collections import defaultdict
from multiprocessing.pool import ThreadPool
import json
from io import BytesIO
from PIL import Image
Expand All @@ -22,6 +23,7 @@
from functools import lru_cache
from werkzeug.middleware.dispatcher import DispatcherMiddleware
import pyarrow as pa
import fsspec

import h5py
from tqdm import tqdm
Expand Down Expand Up @@ -178,7 +180,9 @@ def __init__(self, **kwargs):
super().__init__()
self.clip_resources = kwargs["clip_resources"]

def compute_query(self, clip_resource, text_input, image_input, image_url_input, use_mclip):
def compute_query(
self, clip_resource, text_input, image_input, image_url_input, use_mclip, aesthetic_score, aesthetic_weight
):
"""compute the query embedding"""
import torch # pylint: disable=import-outside-toplevel
import clip # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -210,6 +214,11 @@ def compute_query(self, clip_resource, text_input, image_input, image_url_input,
image_features /= image_features.norm(dim=-1, keepdim=True)
query = image_features.cpu().detach().numpy().astype("float32")

if clip_resource.aesthetic_embeddings is not None and aesthetic_score is not None:
aesthetic_embedding = clip_resource.aesthetic_embeddings[aesthetic_score]
query = query + aesthetic_embedding * aesthetic_weight
query = query / np.linalg.norm(query)

return query

def hash_based_dedup(self, embeddings):
Expand Down Expand Up @@ -394,6 +403,8 @@ def query(
deduplicate=True,
use_safety_model=False,
use_violence_detector=False,
aesthetic_score=None,
aesthetic_weight=None,
):
"""implement the querying functionality of the knn service: from text and image to nearest neighbors"""

Expand All @@ -410,6 +421,8 @@ def query(
image_input=image_input,
image_url_input=image_url_input,
use_mclip=use_mclip,
aesthetic_score=aesthetic_score,
aesthetic_weight=aesthetic_weight,
)
distances, indices = self.knn_search(
query,
Expand Down Expand Up @@ -443,6 +456,10 @@ def post(self):
deduplicate = json_data.get("deduplicate", False)
use_safety_model = json_data.get("use_safety_model", False)
use_violence_detector = json_data.get("use_violence_detector", False)
aesthetic_score = json_data.get("aesthetic_score", "")
aesthetic_score = int(aesthetic_score) if aesthetic_score != "" else None
aesthetic_weight = json_data.get("aesthetic_weight", "")
aesthetic_weight = float(aesthetic_weight) if aesthetic_weight != "" else None
return self.query(
text_input,
image_input,
Expand All @@ -455,6 +472,8 @@ def post(self):
deduplicate,
use_safety_model,
use_violence_detector,
aesthetic_score,
aesthetic_weight,
)


Expand Down Expand Up @@ -618,6 +637,33 @@ def get_cache_folder(clip_model):
return cache_folder


# needs to do this at load time
@lru_cache(maxsize=None)
def get_aesthetic_embedding(model_type):
"""get aesthetic embedding"""
if model_type == "ViT-B/32":
model_type = "vit_b_32"
elif model_type == "ViT-L/14":
model_type = "vit_l_14"

fs, _ = fsspec.core.url_to_fs(
f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating0.npy?raw=true"
)
embs = {}
with ThreadPool(10) as pool:

def get(k):
with fs.open(
f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating{k}.npy?raw=true",
"rb",
) as f:
embs[k] = np.load(f)

for _ in pool.imap_unordered(get, range(10)):
pass
return embs


@lru_cache(maxsize=None)
def load_violence_detector(clip_model):
"""load violence detector for this clip model"""
Expand Down Expand Up @@ -701,6 +747,7 @@ class ClipResource:
ivf_old_to_new_mapping: Any
columns_to_return: List[str]
metadata_is_ordered_by_ivf: bool
aesthetic_embeddings: Any


@dataclass
Expand All @@ -718,6 +765,7 @@ class ClipOptions:
use_arrow: bool
provide_safety_model: bool
provide_violence_detector: bool
provide_aesthetic_embeddings: bool


def dict_to_clip_options(d, clip_options):
Expand All @@ -743,6 +791,9 @@ def dict_to_clip_options(d, clip_options):
provide_violence_detector=d["provide_violence_detector"]
if "provide_violence_detector" in d
else clip_options.provide_violence_detector,
provide_aesthetic_embeddings=d["provide_aesthetic_embeddings"]
if "provide_aesthetic_embeddings" in d
else clip_options.provide_aesthetic_embeddings,
)


Expand Down Expand Up @@ -780,6 +831,9 @@ def load_clip_index(clip_options):
violence_detector = (
load_violence_detector(clip_options.clip_model) if clip_options.provide_violence_detector else None
)
aesthetic_embeddings = (
get_aesthetic_embedding(clip_options.clip_model) if clip_options.provide_aesthetic_embeddings else None
)

image_present = os.path.exists(clip_options.indice_folder + "/image.index")
text_present = os.path.exists(clip_options.indice_folder + "/text.index")
Expand Down Expand Up @@ -820,6 +874,7 @@ def load_clip_index(clip_options):
ivf_old_to_new_mapping=ivf_old_to_new_mapping if clip_options.reorder_metadata_by_ivf_index else None,
columns_to_return=clip_options.columns_to_return,
metadata_is_ordered_by_ivf=clip_options.reorder_metadata_by_ivf_index,
aesthetic_embeddings=aesthetic_embeddings,
)


Expand Down Expand Up @@ -864,6 +919,7 @@ def clip_back(
use_arrow=False,
provide_safety_model=False,
provide_violence_detector=False,
provide_aesthetic_embeddings=True,
):
"""main entry point of clip back, start the endpoints"""
print("starting boot of clip back")
Expand All @@ -883,6 +939,7 @@ def clip_back(
use_arrow=use_arrow,
provide_safety_model=provide_safety_model,
provide_violence_detector=provide_violence_detector,
provide_aesthetic_embeddings=provide_aesthetic_embeddings,
),
)
print("indices loaded")
Expand Down
20 changes: 14 additions & 6 deletions front/src/clip-front.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ClipFront extends LitElement {
this.imageUrl = imageUrl === null ? undefined : imageUrl
this.hideDuplicateUrls = true
this.hideDuplicateImages = true
this.aestheticScore = '9'
this.aestheticWeight = '0.5'
this.initIndices()
}

Expand Down Expand Up @@ -99,7 +101,9 @@ class ClipFront extends LitElement {
removeViolence: { type: Boolean },
hideDuplicateUrls: { type: Boolean },
hideDuplicateImages: { type: Boolean },
useMclip: { type: Boolean }
useMclip: { type: Boolean },
aestheticWeight: { type: String },
aestheticScore: { type: String }
}
}

Expand Down Expand Up @@ -149,7 +153,8 @@ class ClipFront extends LitElement {
}
}
if (_changedProperties.has('useMclip') || _changedProperties.has('modality') || _changedProperties.has('currentIndex') ||
_changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') || _changedProperties.has('removeViolence')) {
_changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') ||
_changedProperties.has('removeViolence') || _changedProperties.has('aestheticScore') || _changedProperties.has('aestheticWeight')) {
if (this.image !== undefined || this.text !== '' || this.imageUrl !== undefined) {
this.redoSearch()
}
Expand Down Expand Up @@ -233,7 +238,7 @@ class ClipFront extends LitElement {
const imageUrl = this.imageUrl === undefined ? null : this.imageUrl
const count = this.modality === 'image' && this.currentIndex === this.indices[0] ? 10000 : 100
const results = await this.service.callClipService(text, image, imageUrl, this.modality, count,
this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
downloadFile('clipsubset.json', JSON.stringify(results, null, 2))
}

Expand All @@ -244,7 +249,7 @@ class ClipFront extends LitElement {
this.image = undefined
this.imageUrl = undefined
const results = await this.service.callClipService(this.text, null, null, this.modality, this.numImages,
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
Expand All @@ -257,7 +262,7 @@ class ClipFront extends LitElement {
this.text = ''
this.imageUrl = undefined
const results = await this.service.callClipService(null, this.image, null, this.modality, this.numImages,
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
Expand All @@ -270,7 +275,7 @@ class ClipFront extends LitElement {
this.text = ''
this.image = undefined
const results = await this.service.callClipService(null, null, this.imageUrl, this.modality, this.numImages,
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
Expand Down Expand Up @@ -547,6 +552,9 @@ class ClipFront extends LitElement {
<label>Remove violence<input type="checkbox" ?checked="${this.removeViolence}" @click=${() => { this.removeViolence = !this.removeViolence }} /></label><br />
<label>Hide duplicate urls<input type="checkbox" ?checked="${this.hideDuplicateUrls}" @click=${() => { this.hideDuplicateUrls = !this.hideDuplicateUrls }} /></label><br />
<label>Hide (near) duplicate images<input type="checkbox" ?checked="${this.hideDuplicateImages}" @click=${() => { this.hideDuplicateImages = !this.hideDuplicateImages }} /></label><br />
<label>Aesthetic score <select @input=${(e) => { this.aestheticScore = e.target.value }}>
${[...Array(10).keys()].map(i => html`<option ?selected="${this.aestheticScore === i.toString()}" value=${i}>${i}</option>`)}</select></label><br />
<label>Aesthetic weight<input type="input" value="${this.aestheticWeight}" @input=${(e) => { this.aestheticWeight = e.target.value }} /></label><br />
<label>Search over <select @input=${e => { this.modality = e.target.value }}>${['image', 'text'].map(modality =>
html`<option value=${modality} ?selected=${modality === this.modality}>${modality}</option>`)}</select><br />
<label>Search with multilingual clip <input type="checkbox" ?checked="${this.useMclip}" @click=${() => { this.useMclip = !this.useMclip }} /></label><br />
Expand Down
6 changes: 4 additions & 2 deletions front/src/clip-service.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export default class ClipService {
return result
}

async callClipService (text, image, imageUrl, modality, numImages, indexName, numResultIds, useMclip, hideDuplicateImages, useSafetyModel, useViolenceDetector) {
async callClipService (text, image, imageUrl, modality, numImages, indexName, numResultIds, useMclip, hideDuplicateImages, useSafetyModel, useViolenceDetector, aestheticScore, aestheticWeight) {
console.log('calling', text, numImages)
const result = JsonBigint.parse(await (await fetch(this.backend + `/knn-service`, {
method: 'POST',
Expand All @@ -29,7 +29,9 @@ export default class ClipService {
'use_mclip': useMclip,
'deduplicate': hideDuplicateImages,
'use_safety_model': useSafetyModel,
'use_violence_detector': useViolenceDetector
'use_violence_detector': useViolenceDetector,
'aesthetic_score': aestheticScore,
'aesthetic_weight': aestheticWeight
})
})).text())

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ fsspec==2022.1.0
sentence-transformers>=2.2.0,<3
wandb>=0.12.10,<0.13
open-clip-torch>=1.0.1,<2.0.0
requests>=2.27.1,<3
aiohttp>=3.8.1,<4
2 changes: 1 addition & 1 deletion tests/test_clip_inference/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_mapper(model):
enable_metadata=False,
use_mclip=False,
clip_model=model,
use_jit=True,
use_jit=False,
mclip_model="",
)
current_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down
4 changes: 3 additions & 1 deletion tests/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def test_end2end():
f.write('{"example_index": "' + index_folder + '"}')

p = subprocess.Popen(
f"clip-retrieval back --port=1239 --indices_paths='{indice_path}'", shell=True, stdout=subprocess.PIPE
f"clip-retrieval back --port=1239 --indices_paths='{indice_path}' --enable_mclip_option=False",
shell=True,
stdout=subprocess.PIPE,
)
for i in range(8):
try:
Expand Down

0 comments on commit 3d8f650

Please sign in to comment.