From 2fc7d490ba9e8ed9f4fa110e0f90ad04d6d9c549 Mon Sep 17 00:00:00 2001 From: Ankush Bhatia Date: Fri, 17 Jan 2025 11:00:06 +0530 Subject: [PATCH] Update Distributed prediction component (#3767) Co-authored-by: Ankush Bhatia Co-authored-by: vizhur --- .../distributed_model_prediction/spec.yaml | 2 +- .../src_distributed/model_prediction.py | 50 +++++++++++++------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/assets/training/model_evaluation/components/distributed_model_prediction/spec.yaml b/assets/training/model_evaluation/components/distributed_model_prediction/spec.yaml index 215f391200..e489b69cec 100644 --- a/assets/training/model_evaluation/components/distributed_model_prediction/spec.yaml +++ b/assets/training/model_evaluation/components/distributed_model_prediction/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: model_prediction_with_container -version: 0.0.4 +version: 0.0.5 type: command display_name: Distributed Model Prediction description: "Optimized Distributed inference component for LLMs." diff --git a/assets/training/model_evaluation/src_distributed/model_prediction.py b/assets/training/model_evaluation/src_distributed/model_prediction.py index 7d819b9f6a..c2e8d69f2b 100644 --- a/assets/training/model_evaluation/src_distributed/model_prediction.py +++ b/assets/training/model_evaluation/src_distributed/model_prediction.py @@ -89,6 +89,7 @@ def postprocess(self, result): """ y_pred_df, y_test_df, perf_df, y_pred_proba_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() for y_pred, y_test, perf, pred_probas in result: + logger.info(f"Type here as well: {type(y_test)}") y_pred_df = pd.concat([y_pred_df, y_pred], axis=0) y_test_df = pd.concat([y_test_df, y_test], axis=0) perf_df = pd.concat([perf_df, perf], axis=0) @@ -121,8 +122,8 @@ def _make_chat_completion_data(self, input_df, last_chats, col_name): input_rows = input_df.values.tolist() for ind, datarow in enumerate(input_rows): conversation = datarow[0] - conversation.append({"role":"assistant", "content":last_chats[ind]}) - appended_data[col_name].append(conversation) + updated_conversation = conversation + [{"role":"assistant", "content":last_chats[ind]}] + appended_data[col_name].append(updated_conversation) return pd.DataFrame(appended_data) @@ -153,7 +154,10 @@ def predict_single(self, data): else: input_texts = [i[0] if len(i) == 1 else [j.strip() for j in i] for i in input_texts] if self.task_type == SupportedTask.TEXT_GENERATION: - self.extra_params.update({"return_full_text": False}) + if "return_full_text" not in self.extra_params: + self.extra_params["return_full_text"] = False + if self.task_type == SupportedTask.QnA: + self.extra_params.update({"truncation":"longest_first"}) data = { "input_data": { "input_string": input_texts, @@ -162,6 +166,17 @@ def predict_single(self, data): } payload = MIRPayload.from_dict(data) payload.update_params(get_generator_params(payload.params)) + try: + inference_results = self.engine.run(payload) + except: + try: + logger.info("Failed with longest_first") + payload.params["truncation"] = "only_second" + inference_results = self.engine.run(payload) + except: + logger.info("Failed with only first") + payload.params["truncation"] = "only_first" + inference_results = self.engine.run(payload) @@ -181,8 +196,19 @@ def predict_single(self, data): start_ms = time.time() * 1000 inference_results = self.engine.run(payload) end_ms = time.time() * 1000 - outputs = [res.response for i, res in enumerate(inference_results)] + if self.task_type == SupportedTask.TEXT_GENERATION: + outputs = [] + for gt, res in zip(input_texts, inference_results): + if gt in res.response: + outputs.append(res.response[len(gt):]) + else: + outputs.append(res.response) + else: + outputs = [res.response for i, res in enumerate(inference_results)] pred_probas = [res.scores for res in inference_results] + + + perf_data = [{ PerformanceColumns.BATCH_SIZE_COLUMN_NAME: len(input_texts), PerformanceColumns.START_TIME_COLUMN_NAME: datetime.fromtimestamp(start_ms / 1000, timezone.utc).isoformat(), @@ -195,12 +221,15 @@ def predict_single(self, data): } for gt, pred in zip(input_texts, outputs)] pred_proba_df = pd.DataFrame(pred_probas, index=X_test.index) perf_data = pd.DataFrame(perf_data) + if self.task_type == SupportedTask.CHAT_COMPLETION or self.task_type == TaskType.CONVERSATIONAL: - pred_df = self._make_chat_completion_data(X_test, outputs, + pred_df = self._make_chat_completion_data(X_test.copy(deep=True), outputs, col_name=ChatCompletionConstants.OUTPUT_FULL_CONVERSATION) pred_df[ChatCompletionConstants.OUTPUT] = outputs - y_test = self._make_chat_completion_data(X_test, y_test, col_name="ground_truth") + y_test = pd.DataFrame(y_test, columns=["ground_truth"], index=X_test.index) + # y_test = self._make_chat_completion_data(X_test.copy(deep=True), y_test, col_name="ground_truth") return pred_df, y_test, perf_data, pred_proba_df + pred_df = pd.DataFrame(outputs, index=X_test.index, columns=["prediction"]) if isinstance(y_test, pd.Series): y_test = y_test.to_frame() @@ -460,15 +489,11 @@ def main(): data_path = args.data logger.info(f"Torch Current Device Count:{torch.cuda.device_count()}") - logger.info(f"Got Params: {args.parameters}") - logger.info(f"Params type: {type(args.parameters)}") - #logger.info(f"Evaled params: {eval(args.parameters)}") extra_params.update(json.loads(args.parameters)) + logger.info(f"Got Model Path: {args.mlflow_model}") - task_type = args.task - input_column_names, label_column_name, extra_y_test_cols = validate_and_get_columns(vars(args)) try: @@ -555,15 +580,12 @@ def main(): predictor = Predictor(g_fmscorer, task_type, extra_params, num_replicas, label_column_name, tokenizer, extra_y_test_cols) collated_res = [{} for i in range(distributed_state.num_processes)] with distributed_state.split_between_processes(full_data) as proc_data: - #indices = proc_data[0].index y_pred_proc, y_test_proc, y_perf_proc, y_pred_proba = predictor.predict(proc_data) - #logger.info(f"Indices: {indices}") proc_res = {"predictions": y_pred_proc, "ground_truth": y_test_proc, "perf": y_perf_proc, "pred_probas": y_pred_proba} dist.all_gather_object(object_list=collated_res, obj=proc_res) logger.info("Waiting for all processes.....") distributed_state.wait_for_everyone() logger.info(f"Collated Results Lengths: {[len(i) for i in collated_res]}") - logger.info(f"Type of each key: {[(k, type(v), len(v)) for k, v in collated_res[0].items()]}") y_pred_df, y_test_df, y_perf_df, y_pred_proba_df = _gather_predictions(collated_res) if task_type != SupportedTask.CHAT_COMPLETION and task_type != TaskType.CONVERSATIONAL: