Skip to content

Commit

Permalink
Use dictionary unpacking to pass trainer function arguments
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Jan 9, 2025
1 parent be2e29e commit a1054ae
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sdk_v2/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,17 @@ def get_args_using_train_func(

# Wrap function code to execute it from the file. For example:
# TODO (andreyvelich): Find a better way to run users' scripts.
# def train(parameters):
# def train(lr=0.001):
# print('Start Training...')
# train({'lr': 0.01})
if train_func_parameters is None:
func_code = f"{func_code}\n{train_func.__name__}()\n"
else:
func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n"
func_code = (
f"{func_code}\n"
f"kwargs={train_func_parameters}\n"
f"{train_func.__name__}(**kwargs)\n"
)

# Prepare the template to execute script.
# Currently, we override the file where the training function is defined.
Expand Down

0 comments on commit a1054ae

Please sign in to comment.