Skip to content

Commit

Permalink
Initial changes for phi medium support
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchez-alex committed Jan 31, 2025
1 parent 180c597 commit 49cd760
Show file tree
Hide file tree
Showing 16 changed files with 237 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.8
version: 0.0.9.test1
type: command

is_deterministic: True
Expand Down Expand Up @@ -121,6 +121,10 @@ inputs:
type: uri_file
description: Validation status.
mode: rw_mount

model_asset_id:
type: string
description: Student model to use

outputs:
generated_train_file_path:
Expand Down Expand Up @@ -152,5 +156,6 @@ command: >-
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--max_len_summary ${{inputs.max_len_summary}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--model_asset_id ${{inputs.model_asset_id}}
--generated_train_file_path ${{outputs.generated_train_file_path}}
--generated_validation_file_path ${{outputs.generated_validation_file_path}}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_batchscoring_datagen_pipeline
version: 0.0.1
version: 0.0.1.test1
type: pipeline


Expand Down Expand Up @@ -160,6 +160,9 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
model_asset_id:
type: string
description: The student model to finetune

# Output of validation component.
validation_info:
Expand Down Expand Up @@ -256,7 +259,7 @@ outputs:
jobs:
oss_distillation_generate_data_batch_preprocess:
type: command
component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1
component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1.test1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand Down Expand Up @@ -296,7 +299,7 @@ jobs:
# Config generator job
oss_distillation_generate_data_config_generator:
type: command
component: azureml:batch_benchmark_config_generator:0.0.9
component: azureml://registries/azureml/components/batch_benchmark_config_generator/versions/0.0.9
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
Expand All @@ -322,7 +325,7 @@ jobs:
# Batch score job
oss_distillation_train_data_batch_score:
type: parallel
component: azureml:batch_score_oss:0.0.1
component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
identity:
type: user_identity
Expand All @@ -349,7 +352,7 @@ jobs:

validation_file_path_exists:
type: command
component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1
component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1.test1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
Expand All @@ -366,7 +369,7 @@ jobs:
# Batch score job
oss_distillation_validation_data_batch_score:
type: parallel
component: azureml:batch_score_oss:0.0.1
component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
identity:
type: user_identity
Expand All @@ -393,7 +396,7 @@ jobs:

oss_distillation_generate_data_batch_postprocess:
type: command
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1.test1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -410,6 +413,7 @@ jobs:
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
outputs:
generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data_batch_postprocess
version: 0.0.1
version: 0.0.1.test1
type: command

is_deterministic: False
Expand Down Expand Up @@ -82,6 +82,10 @@ inputs:
type: uri_file
description: Connection config file for batch scoring

model_asset_id:
type: string
description: The student model to finetune

outputs:
generated_batch_train_file_path:
type: uri_file
Expand All @@ -104,6 +108,7 @@ command: >-
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
--model_asset_id ${{inputs.model_asset_id}}
--data_generation_task_type ${{inputs.data_generation_task_type}}
--connection_config_file ${{inputs.connection_config_file}}
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data_batch_preprocess
version: 0.0.1
version: 0.0.1.test1
type: command

is_deterministic: False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_batch_scoring_selector
version: 0.0.1
version: 0.0.1.test1
type: command

is_deterministic: True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_file_selector
version: 0.0.1
version: 0.0.1.test1
type: command

is_deterministic: True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_seq_scoring_pipeline
version: 0.0.1
version: 0.0.1.test1
type: pipeline


Expand Down Expand Up @@ -172,6 +172,10 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
model_asset_id:
type: string
description: The student model asset id
optional: false

# Training parameters
num_train_epochs:
Expand Down Expand Up @@ -212,7 +216,7 @@ outputs:
jobs:
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.8
component: azureml:oss_distillation_generate_data:0.0.9.test1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -236,6 +240,7 @@ jobs:
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.inputs.validation_output}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path: '${{parent.outputs.generated_train_file_path}}'
generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}'
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_validation_file_checker
version: 0.0.1
version: 0.0.1.test1
type: command

is_deterministic: True
Expand Down
27 changes: 15 additions & 12 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_pipeline
version: 0.0.10
version: 0.0.10.test1
type: pipeline


Expand Down Expand Up @@ -270,11 +270,11 @@ inputs:
optional: true
description: Validation parameters propagated from pipeline.

# Model parameters
# Student Model parameters
model_asset_id:
type: string
optional: false
description: Asset id of model
description: Asset id of the student model

# Model registration
registered_model_name:
Expand All @@ -297,7 +297,7 @@ outputs:
jobs:
oss_distillation_validate_pipeline:
type: command
component: azureml:oss_distillation_validate_pipeline:0.0.5
component: azureml:oss_distillation_validate_pipeline:0.0.5.test1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
Expand All @@ -323,14 +323,15 @@ jobs:
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
validation_info:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json

data_generation_batch_scoring_selector:
type: command
component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1
component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1.test1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
Expand All @@ -347,7 +348,7 @@ jobs:

oss_distillation_batchscoring_datagen_pipeline:
type: pipeline
component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1
component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1.test1
inputs:
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
Expand Down Expand Up @@ -387,6 +388,7 @@ jobs:
max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}'
mini_batch_size: '${{parent.inputs.mini_batch_size}}'
validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'

outputs:
generated_batch_train_file_path:
Expand All @@ -398,7 +400,7 @@ jobs:

oss_distillation_seq_scoring_pipeline:
type: pipeline
component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1
component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1.test1
inputs:
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
Expand Down Expand Up @@ -426,6 +428,7 @@ jobs:
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path:
type: uri_file
Expand All @@ -437,7 +440,7 @@ jobs:

oss_distillation_train_data_generation_file_selector:
type: command
component: azureml:oss_distillation_data_generation_file_selector:0.0.1
component: azureml:oss_distillation_data_generation_file_selector:0.0.1.test1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
Expand All @@ -460,7 +463,7 @@ jobs:

oss_text_generation_data_import:
type: command
component: azureml:oss_text_generation_data_import:0.0.25
component: azureml://registries/azureml/components/oss_text_generation_data_import/versions/0.0.26
compute: '${{parent.inputs.compute_data_import}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_import}}'
Expand All @@ -472,13 +475,13 @@ jobs:
environment_variables:
_AZUREML_CR_ENABLE_ITP_CAP: "false"
inputs:
train_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_train_file_path}}'
validation_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_validation_file_path}}'
train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}'
validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}'
system_properties: '${{parent.inputs.system_properties}}'

oss_chat_completion_finetune:
type: command
component: azureml:oss_chat_completion_finetune:0.0.25
component: azureml://registries/azureml/components/oss_chat_completion_finetune/versions/0.0.26
compute: '${{parent.inputs.compute_finetune}}'
resources:
instance_type: '${{parent.inputs.instance_type_finetune}}'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_validate_pipeline
version: 0.0.5
version: 0.0.5.test1
type: command

is_deterministic: true
Expand Down Expand Up @@ -135,6 +135,11 @@ inputs:
optional: true
description: Start learning rate.

model_asset_id:
type: string
optional: false
description: The student model to finetune

outputs:
validation_info:
type: uri_file
Expand Down Expand Up @@ -163,4 +168,5 @@ command: >-
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
--model_asset_id '${{inputs.model_asset_id}}'
--validation_info ${{outputs.validation_info}}
11 changes: 1 addition & 10 deletions assets/training/distillation/src/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generatior constants."""
"""Data generation constants."""

import re
from enum import EnumMeta, Enum
Expand Down Expand Up @@ -36,15 +36,6 @@
}
}

# SUPPORTED STUDENT MODEL
# MAP keys are model name in registry, which maps to specific model details like registry and supported versions
SUPPORTED_STUDENT_MODEL_MAP = {
"Meta-Llama-3.1-8B-Instruct": {
"supported_registries": ["azureml-meta"],
"supported_version_pattern": re.compile(r"\d+"),
}
}

# Scoring paths
VLLM_CHAT_SCORE_PATH = "/v1/chat/completions"
HFTV2_TEXT_GEN_SCORE_PATH = "/score"
Expand Down
Loading

0 comments on commit 49cd760

Please sign in to comment.