From 9b1ba5b39b2bba884afeea869c5e6c1d7fc5f559 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:48:18 -0800 Subject: [PATCH] [0.31.0-dlc] Fix parsing adapters list in inference request (#2600) --- .../python/setup/djl_python/input_parser.py | 62 +++++++++---------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index d38add852..6b16775ce 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -203,56 +203,50 @@ def parse_adapters(request_input: TextInput, input_item: Input, if adapter_registry: input_len = len(request_input.input_text) if isinstance( request_input.input_text, list) else 1 - adapters_per_item = _fetch_adapters_from_input(input_map, input_item) - if adapters_per_item: - _validate_adapters(adapters_per_item, - kwargs.get("adapter_registry")) - else: - # inference with just base model. - adapters_per_item = [("", "")] * input_len - - if input_len != len(adapters_per_item): - raise ValueError( - f"Number of adapters is not equal to the number of inputs") - # lookup the adapter registry to get the adapter details of the registered adapter. - adapters_data = [ - kwargs.get("adapter_registry").get(adapter[0], None) - for adapter in adapters_per_item - ] - if len(adapters_data) == 1: - adapters_data = adapters_data[0] - - request_input.adapters = adapters_data - - -def _fetch_adapters_from_input(input_map: dict, input_item: Input): + adapters_data = _fetch_adapters_from_input(input_map, input_item, + adapter_registry) + if adapters_data: + if input_len != len(adapters_data): + raise ValueError( + f"Number of adapters is not equal to the number of inputs") + if len(adapters_data) == 1: + adapters_data = adapters_data[0] + request_input.adapters = adapters_data + + +def _fetch_adapters_from_input(input_map: dict, input_item: Input, + adapter_registry): adapters_per_item = [] if "adapters" in input_map: - adapters_per_item = (input_map.pop("adapters"), None) + adapters_per_item = input_map.pop("adapters", []) # check content, possible in workflow approach if input_item.contains_key("adapter"): - adapters_per_item = (input_item.get_as_string("adapter"), None) + adapters_per_item = input_item.get_as_string("adapter") # check properties, possible from header + adapter_alias = None if "X-Amzn-SageMaker-Adapter-Identifier" in input_item.get_properties(): - adapters_per_item = ( - input_item.get_property("X-Amzn-SageMaker-Adapter-Identifier"), - input_item.get_property("X-Amzn-SageMaker-Adapter-Alias")) + adapters_per_item = input_item.get_property( + "X-Amzn-SageMaker-Adapter-Identifier") + adapter_alias = input_item.get_property( + "X-Amzn-SageMaker-Adapter-Alias") + + logging.debug(f"Using adapter {adapter_alias or adapters_per_item}") if not isinstance(adapters_per_item, list): adapters_per_item = [adapters_per_item] - logging.debug(f"Using adapter {adapters_per_item}") - return adapters_per_item - - -def _validate_adapters(adapters_per_item, adapter_registry): - for adapter_name, adapter_alias in adapters_per_item: + adapters_data = [] + for adapter_name in adapters_per_item: if adapter_name and adapter_name not in adapter_registry: raise ValueError( f"Adapter {adapter_alias or adapter_name} is not registered") + # lookup the adapter registry to get the adapter details of the registered adapter. + adapters_data.append(adapter_registry.get(adapter_name)) + return adapters_data + def parse_lmi_default_request_rolling_batch(payload): if not isinstance(payload, dict):