Skip to content

Commit

Permalink
fix: support sagemaker batch transform for clip (#6171)
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li authored Jun 13, 2024
1 parent 12e2a94 commit 98429b0
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jina/serve/runtimes/worker/http_csp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,14 @@ def construct_model_from_line(
)
else:
parsed_fields[field_name] = parsed_list
# Handle direct assignment for basic types
# General parsing attempt for other types
else:
parsed_fields[field_name] = field_info.type_(field_str)
if field_str:
try:
parsed_fields[field_name] = field_info.type_(field_str)
except (ValueError, TypeError):
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
parsed_fields[field_name] = parse_obj_as(field_info.type_, field_str)

return model(**parsed_fields)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SampleClipExecutor

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
jtype: SampleClipExecutor
py_modules:
- executor.py
metas:
name: SampleClipExecutor
description:
url:
keywords: []
43 changes: 43 additions & 0 deletions tests/integration/docarray_v2/csp/SampleClipExecutor/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional

import numpy as np
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from docarray.typing.bytes import ImageBytes
from docarray.typing.url import AnyUrl
from jina import Executor, requests
from pydantic import Field


class TextAndImageDoc(BaseDoc):
text: Optional[str] = None
url: Optional[AnyUrl] = None
bytes: Optional[ImageBytes] = None


class EmbeddingResponseModel(TextAndImageDoc):
embeddings: NdArray = Field(description="The embedding of the texts", default=[])

class Config(BaseDoc.Config):
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {NdArray: lambda v: v.tolist()}


class SampleClipExecutor(Executor):
@requests(on="/encode")
def foo(
self, docs: DocList[TextAndImageDoc], **kwargs
) -> DocList[EmbeddingResponseModel]:
ret = []
for doc in docs:
ret.append(
EmbeddingResponseModel(
id=doc.id,
text=doc.text,
url=doc.url,
bytes=doc.bytes,
embeddings=np.random.random((1, 64)),
)
)
return DocList[EmbeddingResponseModel](ret)
Empty file.
89 changes: 89 additions & 0 deletions tests/integration/docarray_v2/csp/test_sagemaker_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import csv
import io
import os

import requests
from jina.orchestrate.pods import Pod
from jina.parsers import set_pod_parser

sagemaker_port = 8080


def test_provider_sagemaker_pod_rank():
args, _ = set_pod_parser().parse_known_args(
[
"--uses",
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"),
"--provider",
"sagemaker",
"--provider-endpoint",
"encode",
"serve", # This is added by sagemaker
]
)
with Pod(args):
# Test the `GET /ping` endpoint (added by jina for sagemaker)
resp = requests.get(f"http://localhost:{sagemaker_port}/ping")
assert resp.status_code == 200
assert resp.json() == {}

# Test the `POST /invocations` endpoint for inference
# Note: this endpoint is not implemented in the sample executor
resp = requests.post(
f"http://localhost:{sagemaker_port}/invocations",
json={
"data": [
{"url": "http://google.com"},
]
},
)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["data"]) == 1
assert len(resp_json["data"][0]["embeddings"][0]) == 64
assert resp_json["data"][0]["url"] == "http://google.com"


def test_provider_sagemaker_pod_batch_transform_valid():
args, _ = set_pod_parser().parse_known_args(
[
"--uses",
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"),
"--provider",
"sagemaker",
"serve", # This is added by sagemaker
]
)
with Pod(args):
# Test `POST /invocations` endpoint for batch-transform with valid input
with open(
os.path.join(os.path.dirname(__file__), "valid_clip_input.csv"), "r"
) as f:
csv_data = f.read()

text = []
for line in csv.reader(
io.StringIO(csv_data),
delimiter=",",
quoting=csv.QUOTE_NONE,
escapechar="\\",
):
text.append(line)

resp = requests.post(
f"http://localhost:{sagemaker_port}/invocations",
headers={
"accept": "application/json",
"content-type": "text/csv",
},
data=csv_data,
)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["data"]) == 3
assert resp_json["data"][0]["text"] == "the cat is in my house"
assert (
resp_json["data"][1]["url"]
== "https://dummyimage3.com/333/000/fff.jpg&text=embed+this"
)
assert "hWjj1RNtNftP" in resp_json["data"][2]["bytes"]
3 changes: 3 additions & 0 deletions tests/integration/docarray_v2/csp/valid_clip_input.csv

Large diffs are not rendered by default.

0 comments on commit 98429b0

Please sign in to comment.