Skip to content

Commit

Permalink
Update Distributed prediction component (#3767)
Browse files Browse the repository at this point in the history
Co-authored-by: Ankush Bhatia <[email protected]>
Co-authored-by: vizhur <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent 9f1ba24 commit 2fc7d49
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: model_prediction_with_container
version: 0.0.4
version: 0.0.5
type: command
display_name: Distributed Model Prediction
description: "Optimized Distributed inference component for LLMs."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def postprocess(self, result):
"""
y_pred_df, y_test_df, perf_df, y_pred_proba_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
for y_pred, y_test, perf, pred_probas in result:
logger.info(f"Type here as well: {type(y_test)}")
y_pred_df = pd.concat([y_pred_df, y_pred], axis=0)
y_test_df = pd.concat([y_test_df, y_test], axis=0)
perf_df = pd.concat([perf_df, perf], axis=0)
Expand Down Expand Up @@ -121,8 +122,8 @@ def _make_chat_completion_data(self, input_df, last_chats, col_name):
input_rows = input_df.values.tolist()
for ind, datarow in enumerate(input_rows):
conversation = datarow[0]
conversation.append({"role":"assistant", "content":last_chats[ind]})
appended_data[col_name].append(conversation)
updated_conversation = conversation + [{"role":"assistant", "content":last_chats[ind]}]
appended_data[col_name].append(updated_conversation)
return pd.DataFrame(appended_data)


Expand Down Expand Up @@ -153,7 +154,10 @@ def predict_single(self, data):
else:
input_texts = [i[0] if len(i) == 1 else [j.strip() for j in i] for i in input_texts]
if self.task_type == SupportedTask.TEXT_GENERATION:
self.extra_params.update({"return_full_text": False})
if "return_full_text" not in self.extra_params:
self.extra_params["return_full_text"] = False
if self.task_type == SupportedTask.QnA:
self.extra_params.update({"truncation":"longest_first"})
data = {
"input_data": {
"input_string": input_texts,
Expand All @@ -162,6 +166,17 @@ def predict_single(self, data):
}
payload = MIRPayload.from_dict(data)
payload.update_params(get_generator_params(payload.params))
try:
inference_results = self.engine.run(payload)
except:
try:
logger.info("Failed with longest_first")
payload.params["truncation"] = "only_second"
inference_results = self.engine.run(payload)
except:
logger.info("Failed with only first")
payload.params["truncation"] = "only_first"
inference_results = self.engine.run(payload)



Expand All @@ -181,8 +196,19 @@ def predict_single(self, data):
start_ms = time.time() * 1000
inference_results = self.engine.run(payload)
end_ms = time.time() * 1000
outputs = [res.response for i, res in enumerate(inference_results)]
if self.task_type == SupportedTask.TEXT_GENERATION:
outputs = []
for gt, res in zip(input_texts, inference_results):
if gt in res.response:
outputs.append(res.response[len(gt):])
else:
outputs.append(res.response)
else:
outputs = [res.response for i, res in enumerate(inference_results)]
pred_probas = [res.scores for res in inference_results]



perf_data = [{
PerformanceColumns.BATCH_SIZE_COLUMN_NAME: len(input_texts),
PerformanceColumns.START_TIME_COLUMN_NAME: datetime.fromtimestamp(start_ms / 1000, timezone.utc).isoformat(),
Expand All @@ -195,12 +221,15 @@ def predict_single(self, data):
} for gt, pred in zip(input_texts, outputs)]
pred_proba_df = pd.DataFrame(pred_probas, index=X_test.index)
perf_data = pd.DataFrame(perf_data)

if self.task_type == SupportedTask.CHAT_COMPLETION or self.task_type == TaskType.CONVERSATIONAL:
pred_df = self._make_chat_completion_data(X_test, outputs,
pred_df = self._make_chat_completion_data(X_test.copy(deep=True), outputs,
col_name=ChatCompletionConstants.OUTPUT_FULL_CONVERSATION)
pred_df[ChatCompletionConstants.OUTPUT] = outputs
y_test = self._make_chat_completion_data(X_test, y_test, col_name="ground_truth")
y_test = pd.DataFrame(y_test, columns=["ground_truth"], index=X_test.index)
# y_test = self._make_chat_completion_data(X_test.copy(deep=True), y_test, col_name="ground_truth")
return pred_df, y_test, perf_data, pred_proba_df

pred_df = pd.DataFrame(outputs, index=X_test.index, columns=["prediction"])
if isinstance(y_test, pd.Series):
y_test = y_test.to_frame()
Expand Down Expand Up @@ -460,15 +489,11 @@ def main():
data_path = args.data

logger.info(f"Torch Current Device Count:{torch.cuda.device_count()}")

logger.info(f"Got Params: {args.parameters}")
logger.info(f"Params type: {type(args.parameters)}")
#logger.info(f"Evaled params: {eval(args.parameters)}")
extra_params.update(json.loads(args.parameters))

logger.info(f"Got Model Path: {args.mlflow_model}")

task_type = args.task

input_column_names, label_column_name, extra_y_test_cols = validate_and_get_columns(vars(args))

try:
Expand Down Expand Up @@ -555,15 +580,12 @@ def main():
predictor = Predictor(g_fmscorer, task_type, extra_params, num_replicas, label_column_name, tokenizer, extra_y_test_cols)
collated_res = [{} for i in range(distributed_state.num_processes)]
with distributed_state.split_between_processes(full_data) as proc_data:
#indices = proc_data[0].index
y_pred_proc, y_test_proc, y_perf_proc, y_pred_proba = predictor.predict(proc_data)
#logger.info(f"Indices: {indices}")
proc_res = {"predictions": y_pred_proc, "ground_truth": y_test_proc, "perf": y_perf_proc, "pred_probas": y_pred_proba}
dist.all_gather_object(object_list=collated_res, obj=proc_res)
logger.info("Waiting for all processes.....")
distributed_state.wait_for_everyone()
logger.info(f"Collated Results Lengths: {[len(i) for i in collated_res]}")
logger.info(f"Type of each key: {[(k, type(v), len(v)) for k, v in collated_res[0].items()]}")
y_pred_df, y_test_df, y_perf_df, y_pred_proba_df = _gather_predictions(collated_res)

if task_type != SupportedTask.CHAT_COMPLETION and task_type != TaskType.CONVERSATIONAL:
Expand Down

0 comments on commit 2fc7d49

Please sign in to comment.