-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: support sagemaker batch transform for clip (#6171)
- Loading branch information
Showing
7 changed files
with
152 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 2 additions & 0 deletions
2
tests/integration/docarray_v2/csp/SampleClipExecutor/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# SampleClipExecutor | ||
|
8 changes: 8 additions & 0 deletions
8
tests/integration/docarray_v2/csp/SampleClipExecutor/config.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
jtype: SampleClipExecutor | ||
py_modules: | ||
- executor.py | ||
metas: | ||
name: SampleClipExecutor | ||
description: | ||
url: | ||
keywords: [] |
43 changes: 43 additions & 0 deletions
43
tests/integration/docarray_v2/csp/SampleClipExecutor/executor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from docarray import BaseDoc, DocList | ||
from docarray.typing import NdArray | ||
from docarray.typing.bytes import ImageBytes | ||
from docarray.typing.url import AnyUrl | ||
from jina import Executor, requests | ||
from pydantic import Field | ||
|
||
|
||
class TextAndImageDoc(BaseDoc): | ||
text: Optional[str] = None | ||
url: Optional[AnyUrl] = None | ||
bytes: Optional[ImageBytes] = None | ||
|
||
|
||
class EmbeddingResponseModel(TextAndImageDoc): | ||
embeddings: NdArray = Field(description="The embedding of the texts", default=[]) | ||
|
||
class Config(BaseDoc.Config): | ||
allow_population_by_field_name = True | ||
arbitrary_types_allowed = True | ||
json_encoders = {NdArray: lambda v: v.tolist()} | ||
|
||
|
||
class SampleClipExecutor(Executor): | ||
@requests(on="/encode") | ||
def foo( | ||
self, docs: DocList[TextAndImageDoc], **kwargs | ||
) -> DocList[EmbeddingResponseModel]: | ||
ret = [] | ||
for doc in docs: | ||
ret.append( | ||
EmbeddingResponseModel( | ||
id=doc.id, | ||
text=doc.text, | ||
url=doc.url, | ||
bytes=doc.bytes, | ||
embeddings=np.random.random((1, 64)), | ||
) | ||
) | ||
return DocList[EmbeddingResponseModel](ret) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import csv | ||
import io | ||
import os | ||
|
||
import requests | ||
from jina.orchestrate.pods import Pod | ||
from jina.parsers import set_pod_parser | ||
|
||
sagemaker_port = 8080 | ||
|
||
|
||
def test_provider_sagemaker_pod_rank(): | ||
args, _ = set_pod_parser().parse_known_args( | ||
[ | ||
"--uses", | ||
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"), | ||
"--provider", | ||
"sagemaker", | ||
"--provider-endpoint", | ||
"encode", | ||
"serve", # This is added by sagemaker | ||
] | ||
) | ||
with Pod(args): | ||
# Test the `GET /ping` endpoint (added by jina for sagemaker) | ||
resp = requests.get(f"http://localhost:{sagemaker_port}/ping") | ||
assert resp.status_code == 200 | ||
assert resp.json() == {} | ||
|
||
# Test the `POST /invocations` endpoint for inference | ||
# Note: this endpoint is not implemented in the sample executor | ||
resp = requests.post( | ||
f"http://localhost:{sagemaker_port}/invocations", | ||
json={ | ||
"data": [ | ||
{"url": "http://google.com"}, | ||
] | ||
}, | ||
) | ||
assert resp.status_code == 200 | ||
resp_json = resp.json() | ||
assert len(resp_json["data"]) == 1 | ||
assert len(resp_json["data"][0]["embeddings"][0]) == 64 | ||
assert resp_json["data"][0]["url"] == "http://google.com" | ||
|
||
|
||
def test_provider_sagemaker_pod_batch_transform_valid(): | ||
args, _ = set_pod_parser().parse_known_args( | ||
[ | ||
"--uses", | ||
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"), | ||
"--provider", | ||
"sagemaker", | ||
"serve", # This is added by sagemaker | ||
] | ||
) | ||
with Pod(args): | ||
# Test `POST /invocations` endpoint for batch-transform with valid input | ||
with open( | ||
os.path.join(os.path.dirname(__file__), "valid_clip_input.csv"), "r" | ||
) as f: | ||
csv_data = f.read() | ||
|
||
text = [] | ||
for line in csv.reader( | ||
io.StringIO(csv_data), | ||
delimiter=",", | ||
quoting=csv.QUOTE_NONE, | ||
escapechar="\\", | ||
): | ||
text.append(line) | ||
|
||
resp = requests.post( | ||
f"http://localhost:{sagemaker_port}/invocations", | ||
headers={ | ||
"accept": "application/json", | ||
"content-type": "text/csv", | ||
}, | ||
data=csv_data, | ||
) | ||
assert resp.status_code == 200 | ||
resp_json = resp.json() | ||
assert len(resp_json["data"]) == 3 | ||
assert resp_json["data"][0]["text"] == "the cat is in my house" | ||
assert ( | ||
resp_json["data"][1]["url"] | ||
== "https://dummyimage3.com/333/000/fff.jpg&text=embed+this" | ||
) | ||
assert "hWjj1RNtNftP" in resp_json["data"][2]["bytes"] |
Large diffs are not rendered by default.
Oops, something went wrong.