Skip to content

Commit

Permalink
[0.31.0-dlc] Fix parsing adapters list in inference request (#2600)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Nov 22, 2024
1 parent 014608a commit 9b1ba5b
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,56 +203,50 @@ def parse_adapters(request_input: TextInput, input_item: Input,
if adapter_registry:
input_len = len(request_input.input_text) if isinstance(
request_input.input_text, list) else 1
adapters_per_item = _fetch_adapters_from_input(input_map, input_item)
if adapters_per_item:
_validate_adapters(adapters_per_item,
kwargs.get("adapter_registry"))
else:
# inference with just base model.
adapters_per_item = [("", "")] * input_len

if input_len != len(adapters_per_item):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
# lookup the adapter registry to get the adapter details of the registered adapter.
adapters_data = [
kwargs.get("adapter_registry").get(adapter[0], None)
for adapter in adapters_per_item
]
if len(adapters_data) == 1:
adapters_data = adapters_data[0]

request_input.adapters = adapters_data


def _fetch_adapters_from_input(input_map: dict, input_item: Input):
adapters_data = _fetch_adapters_from_input(input_map, input_item,
adapter_registry)
if adapters_data:
if input_len != len(adapters_data):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
if len(adapters_data) == 1:
adapters_data = adapters_data[0]
request_input.adapters = adapters_data


def _fetch_adapters_from_input(input_map: dict, input_item: Input,
adapter_registry):
adapters_per_item = []
if "adapters" in input_map:
adapters_per_item = (input_map.pop("adapters"), None)
adapters_per_item = input_map.pop("adapters", [])

# check content, possible in workflow approach
if input_item.contains_key("adapter"):
adapters_per_item = (input_item.get_as_string("adapter"), None)
adapters_per_item = input_item.get_as_string("adapter")

# check properties, possible from header
adapter_alias = None
if "X-Amzn-SageMaker-Adapter-Identifier" in input_item.get_properties():
adapters_per_item = (
input_item.get_property("X-Amzn-SageMaker-Adapter-Identifier"),
input_item.get_property("X-Amzn-SageMaker-Adapter-Alias"))
adapters_per_item = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Identifier")
adapter_alias = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Alias")

logging.debug(f"Using adapter {adapter_alias or adapters_per_item}")

if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]

logging.debug(f"Using adapter {adapters_per_item}")
return adapters_per_item


def _validate_adapters(adapters_per_item, adapter_registry):
for adapter_name, adapter_alias in adapters_per_item:
adapters_data = []
for adapter_name in adapters_per_item:
if adapter_name and adapter_name not in adapter_registry:
raise ValueError(
f"Adapter {adapter_alias or adapter_name} is not registered")

# lookup the adapter registry to get the adapter details of the registered adapter.
adapters_data.append(adapter_registry.get(adapter_name))
return adapters_data


def parse_lmi_default_request_rolling_batch(payload):
if not isinstance(payload, dict):
Expand Down

0 comments on commit 9b1ba5b

Please sign in to comment.