Code changes for moving to Batch Scoring for Teacher Model (#3507)
* Code changes for moving to Batch Scoring for Teacher Model * Addition of copy right header * Lengthy line formating fix * Fix for unterminated f string * Remove unwanted imports * File formatting fixes * File formatting fixes * Fix for noqa * Doc String for copy file contents function * Fix for closing the connection in post processing * Making changes for fixing pipeline changes * Version changes * Merge conflict fixes * Environmental version fixes
This commit is contained in:
Родитель
4f983518a6
Коммит
04eff530e2
|
@ -0,0 +1,61 @@
|
||||||
|
# OSS Distillation Batch Score Data Generation Pipeline
|
||||||
|
|
||||||
|
This component generates data from a teacher model endpoint by invoking it in batch mode. It is part of the OSS Distillation pipeline.
|
||||||
|
|
||||||
|
## Component Details
|
||||||
|
|
||||||
|
- **Name**: `oss_distillation_batchscoring_datagen_pipeline`
|
||||||
|
- **Version**: `0.0.1`
|
||||||
|
- **Type**: `pipeline`
|
||||||
|
- **Display Name**: `OSS Distillation Batch Score Data Generation Pipeline`
|
||||||
|
- **Description**: Component to generate data from teacher model endpoint by invoking it in batch.
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
|
||||||
|
| Name | Type | Optional | Default | Description |
|
||||||
|
|----------------------------------|----------|----------|----------------------------------|-------------------------------------------------------------------------------------------------------------|
|
||||||
|
| instance_type_pipeline_validation| string | True | | Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used. |
|
||||||
|
| instance_type_data_generation | string | True | Standard_D4as_v4 | Instance type to be used for finetune component in case of virtual cluster compute. |
|
||||||
|
| instance_type_data_import | string | True | Singularity.ND96amrs_A100_v4 | Instance type to be used for data_import component in case of virtual cluster compute. |
|
||||||
|
| instance_type_finetune | string | True | Singularity.ND96amrs_A100_v4 | Instance type to be used for finetune component in case of virtual cluster compute. |
|
||||||
|
| compute_pipeline_validation | string | True | serverless | Compute to be used for validation component. |
|
||||||
|
| compute_data_generation | string | True | serverless | Compute to be used for model_import. |
|
||||||
|
| compute_data_import | string | True | serverless | Compute to be used for model_import. |
|
||||||
|
| compute_finetune | string | True | serverless | Compute to be used for finetune. |
|
||||||
|
| train_file_path | uri_file | False | | Path to the registered training data asset. |
|
||||||
|
| validation_file_path | uri_file | True | | Path to the registered validation data asset. |
|
||||||
|
| teacher_model_endpoint_url | string | True | | Teacher model endpoint URL. |
|
||||||
|
| teacher_model_asset_id | string | True | | Teacher model Asset Id. |
|
||||||
|
| teacher_model_endpoint_name | string | True | | Teacher model endpoint name. |
|
||||||
|
| teacher_model_max_new_tokens | integer | True | 128 | Teacher model max_new_tokens inference parameter. |
|
||||||
|
| teacher_model_temperature | number | True | 0.2 | Teacher model temperature inference parameter. |
|
||||||
|
| teacher_model_top_p | number | True | 0.1 | Teacher model top_p inference parameter. |
|
||||||
|
| teacher_model_frequency_penalty | number | True | 0.0 | Teacher model frequency penalty inference parameter. |
|
||||||
|
| teacher_model_presence_penalty | number | True | 0.0 | Teacher model presence penalty inference parameter. |
|
||||||
|
| teacher_model_stop | string | True | | Teacher model stop inference parameter. |
|
||||||
|
| min_endpoint_success_ratio | number | True | 0.7 | Minimum value of (successful_requests / total_requests) required for classifying inference as successful. |
|
||||||
|
| enable_chain_of_thought | string | True | false | Enable Chain of thought for data generation. |
|
||||||
|
| enable_chain_of_density | string | True | false | Enable Chain of density for text summarization. |
|
||||||
|
| max_len_summary | integer | True | 80 | Maximum Length Summary for text summarization. |
|
||||||
|
| data_generation_task_type | string | False | | Data generation task type. Supported values: NLI, CONVERSATION, NLU_QA, MATH, SUMMARIZATION. |
|
||||||
|
| num_train_epochs | integer | True | 1 | Training epochs. |
|
||||||
|
| per_device_train_batch_size | integer | True | 1 | Train batch size. |
|
||||||
|
| learning_rate | number | True | 3e-04 | Start learning rate. |
|
||||||
|
| authentication_type | string | False | azureml_workspace_connection | Authentication type for endpoint. Supported values: azureml_workspace_connection, managed_identity. |
|
||||||
|
| connection_name | string | True | | Connection name to be used for authentication. |
|
||||||
|
| additional_headers | string | True | | JSON serialized string expressing additional headers to be added to each request. |
|
||||||
|
| debug_mode | boolean | False | False | Enable debug mode to print all the debug logs in the score step. |
|
||||||
|
| ensure_ascii | boolean | False | False | If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. |
|
||||||
|
| max_retry_time_interval | integer | True | | The maximum time (in seconds) spent retrying a payload. |
|
||||||
|
| initial_worker_count | integer | False | 5 | The initial number of workers to use for scoring. |
|
||||||
|
| max_worker_count | integer | False | 200 | Overrides `initial_worker_count` if necessary. |
|
||||||
|
| instance_count | integer | False | 1 | Number of nodes in a compute cluster we will run the batch score step on. |
|
||||||
|
| max_concurrency_per_instance | integer | False | 1 | Number of processes that will be run concurrently on any given node. |
|
||||||
|
| mini_batch_size | string | True | 100KB | The mini batch size for parallel run. |
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|----------------------------------|----------|-------------------------------------------------------------------------------------------------------------|
|
||||||
|
| generated_batch_train_file_path | uri_file | Generated train data |
|
||||||
|
| generated_batch_validation_file_path | uri_file | Generated validation data |
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring"]
|
|
@ -0,0 +1,417 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
|
||||||
|
name: oss_distillation_batchscoring_datagen_pipeline
|
||||||
|
version: 0.0.1
|
||||||
|
type: pipeline
|
||||||
|
|
||||||
|
|
||||||
|
display_name: OSS Distillation Batch Score Data Generation Pipeline
|
||||||
|
description: Component to generate data from teacher model endpoint by invoking it in batch.
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Compute parameters
|
||||||
|
instance_type_pipeline_validation:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
|
||||||
|
instance_type_data_generation:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Standard_D4as_v4
|
||||||
|
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
|
||||||
|
instance_type_data_import:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Singularity.ND96amrs_A100_v4
|
||||||
|
description: Instance type to be used for data_import component in case of virtual cluster compute, eg. Singularity.D8_v3. The parameter compute_data_import must be set to 'serverless' for instance_type to be used
|
||||||
|
instance_type_finetune:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Singularity.ND96amrs_A100_v4
|
||||||
|
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
|
||||||
|
|
||||||
|
compute_pipeline_validation:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
default: 'serverless'
|
||||||
|
description: compute to be used for validation component
|
||||||
|
|
||||||
|
compute_data_generation:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for model_import eg. provide 'FT-Cluster' if
|
||||||
|
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
compute_data_import:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for model_import eg. provide 'FT-Cluster' if
|
||||||
|
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
compute_finetune:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for finetune eg. provide 'FT-Cluster' if your
|
||||||
|
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
|
||||||
|
# ########################### Data Generator Component ########################### #
|
||||||
|
|
||||||
|
train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
teacher_model_endpoint_url:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint URL
|
||||||
|
|
||||||
|
teacher_model_endpoint_name:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint name
|
||||||
|
|
||||||
|
teacher_model_endpoint_key:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint key
|
||||||
|
|
||||||
|
teacher_model_max_new_tokens:
|
||||||
|
type: integer
|
||||||
|
default: 128
|
||||||
|
description: Teacher model max_new_tokens inference parameter
|
||||||
|
|
||||||
|
teacher_model_temperature:
|
||||||
|
type: number
|
||||||
|
default: 0.2
|
||||||
|
description: Teacher model temperature inference parameter
|
||||||
|
|
||||||
|
teacher_model_top_p:
|
||||||
|
type: number
|
||||||
|
default: 0.1
|
||||||
|
description: Teacher model top_p inference parameter
|
||||||
|
|
||||||
|
teacher_model_frequency_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model frequency penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_presence_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model presence penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_stop:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model stop inference parameter
|
||||||
|
|
||||||
|
min_endpoint_success_ratio:
|
||||||
|
type: number
|
||||||
|
default: 0.7
|
||||||
|
description: >
|
||||||
|
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
|
||||||
|
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
|
||||||
|
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
|
||||||
|
|
||||||
|
enable_chain_of_thought:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of thought for data generation
|
||||||
|
|
||||||
|
enable_chain_of_density:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of density for text summarization
|
||||||
|
|
||||||
|
max_len_summary:
|
||||||
|
type: integer
|
||||||
|
optional: true
|
||||||
|
default: 80
|
||||||
|
description: Maximum Length Summary for text summarization
|
||||||
|
|
||||||
|
data_generation_task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- NLI
|
||||||
|
- CONVERSATION
|
||||||
|
- NLU_QA
|
||||||
|
- MATH
|
||||||
|
- SUMMARIZATION
|
||||||
|
description: >
|
||||||
|
Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
|
||||||
|
# Output of validation component.
|
||||||
|
validation_info:
|
||||||
|
type: uri_file
|
||||||
|
description: Validation status.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
num_train_epochs:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
optional: true
|
||||||
|
description: training epochs
|
||||||
|
|
||||||
|
per_device_train_batch_size:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
optional: true
|
||||||
|
description: Train batch size
|
||||||
|
|
||||||
|
learning_rate:
|
||||||
|
type: number
|
||||||
|
default: 3e-04
|
||||||
|
optional: true
|
||||||
|
description: Start learning rate.
|
||||||
|
|
||||||
|
# ########################### Batch Score Component ########################### #
|
||||||
|
authentication_type:
|
||||||
|
type: string
|
||||||
|
optional: False
|
||||||
|
description: Authentication type for endpoint. Either `azureml_workspace_connection` or `managed_identity`.
|
||||||
|
default: azureml_workspace_connection
|
||||||
|
enum:
|
||||||
|
- azureml_workspace_connection
|
||||||
|
- managed_identity
|
||||||
|
configuration_file:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Config file path that contains deployment configurations
|
||||||
|
additional_headers:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
description: JSON serialized string expressing additional headers to be added to each request.
|
||||||
|
debug_mode:
|
||||||
|
type: boolean
|
||||||
|
optional: False
|
||||||
|
default: False
|
||||||
|
description: Enable debug mode to print all the debug logs in the score step.
|
||||||
|
ensure_ascii:
|
||||||
|
type: boolean
|
||||||
|
optional: False
|
||||||
|
default: False
|
||||||
|
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
|
||||||
|
max_retry_time_interval:
|
||||||
|
type: integer
|
||||||
|
optional: True
|
||||||
|
description: The maximum time (in seconds) spent retrying a payload. If unspecified, payloads are retried for unlimited time.
|
||||||
|
initial_worker_count:
|
||||||
|
type: integer
|
||||||
|
optional: False
|
||||||
|
default: 5
|
||||||
|
description: The initial number of workers to use for scoring.
|
||||||
|
max_worker_count:
|
||||||
|
type: integer
|
||||||
|
optional: False
|
||||||
|
default: 200
|
||||||
|
description: Overrides `initial_worker_count` if necessary.
|
||||||
|
instance_count:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
description: Number of nodes in a compute cluster we will run the batch score step on.
|
||||||
|
max_concurrency_per_instance:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
description: Number of processes that will be run concurrently on any given node. This number should not be larger than 1/2 of the number of cores in an individual node in the specified cluster.
|
||||||
|
mini_batch_size:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 100KB
|
||||||
|
description: The mini batch size for parallel run.
|
||||||
|
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
generated_batch_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated train data
|
||||||
|
mode: rw_mount
|
||||||
|
generated_batch_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated validation data
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
oss_distillation_generate_data_batch_preprocess:
|
||||||
|
type: command
|
||||||
|
component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_data_generation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||||
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
|
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
||||||
|
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
||||||
|
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
||||||
|
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||||
|
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
||||||
|
max_len_summary: '${{parent.inputs.max_len_summary}}'
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
|
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
||||||
|
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
||||||
|
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
||||||
|
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
||||||
|
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||||
|
validation_info: '${{parent.inputs.validation_info}}'
|
||||||
|
outputs:
|
||||||
|
generated_train_payload_path:
|
||||||
|
type: mltable
|
||||||
|
generated_validation_payload_path:
|
||||||
|
type: mltable
|
||||||
|
hash_train_data:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
hash_validation_data:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
batch_config_connection:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
|
||||||
|
# Config generator job
|
||||||
|
oss_distillation_generate_data_config_generator:
|
||||||
|
type: command
|
||||||
|
component: azureml:batch_benchmark_config_generator:0.0.8
|
||||||
|
compute: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
scoring_url: ${{parent.inputs.teacher_model_endpoint_url}}
|
||||||
|
deployment_name: ${{parent.inputs.teacher_model_endpoint_name}}
|
||||||
|
authentication_type: ${{parent.inputs.authentication_type}}
|
||||||
|
configuration_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
|
||||||
|
additional_headers: ${{parent.inputs.additional_headers}}
|
||||||
|
debug_mode: ${{parent.inputs.debug_mode}}
|
||||||
|
ensure_ascii: ${{parent.inputs.ensure_ascii}}
|
||||||
|
max_retry_time_interval: ${{parent.inputs.max_retry_time_interval}}
|
||||||
|
initial_worker_count: ${{parent.inputs.initial_worker_count}}
|
||||||
|
max_worker_count: ${{parent.inputs.max_worker_count}}
|
||||||
|
model_type: oss
|
||||||
|
outputs:
|
||||||
|
batch_score_config:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
|
||||||
|
|
||||||
|
# Batch score job
|
||||||
|
oss_distillation_train_data_batch_score:
|
||||||
|
type: parallel
|
||||||
|
component: azureml:batch_score_oss:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
async_mode: False
|
||||||
|
data_input_table: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.generated_train_payload_path}}
|
||||||
|
configuration_file: ${{parent.jobs.oss_distillation_generate_data_config_generator.outputs.batch_score_config}}
|
||||||
|
outputs:
|
||||||
|
job_output_path:
|
||||||
|
type: uri_file
|
||||||
|
mini_batch_results_output_directory:
|
||||||
|
type: uri_folder
|
||||||
|
resources:
|
||||||
|
instance_count: ${{parent.inputs.instance_count}}
|
||||||
|
max_concurrency_per_instance: ${{parent.inputs.max_concurrency_per_instance}}
|
||||||
|
mini_batch_size: ${{parent.inputs.mini_batch_size}}
|
||||||
|
retry_settings:
|
||||||
|
timeout: 6000
|
||||||
|
max_retries: 10
|
||||||
|
environment_variables:
|
||||||
|
BATCH_SCORE_INITIAL_REQUEST_TIMEOUT: '180'
|
||||||
|
BATCH_SCORE_DELAY_AFTER_SUCCESSFUL_REQUEST: 'False'
|
||||||
|
BATCH_SCORE_MAX_REQUEST_TIMEOUT: '300'
|
||||||
|
|
||||||
|
validation_file_path_exists:
|
||||||
|
type: command
|
||||||
|
component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
|
|
||||||
|
validation_succeeded:
|
||||||
|
type: if_else
|
||||||
|
condition: ${{parent.jobs.validation_file_path_exists.outputs.output}}
|
||||||
|
true_block: ${{parent.jobs.oss_distillation_validation_data_batch_score}}
|
||||||
|
|
||||||
|
# Batch score job
|
||||||
|
oss_distillation_validation_data_batch_score:
|
||||||
|
type: parallel
|
||||||
|
component: azureml:batch_score_oss:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
async_mode: False
|
||||||
|
data_input_table: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.generated_validation_payload_path}}
|
||||||
|
configuration_file: ${{parent.jobs.oss_distillation_generate_data_config_generator.outputs.batch_score_config}}
|
||||||
|
outputs:
|
||||||
|
job_output_path:
|
||||||
|
type: uri_file
|
||||||
|
mini_batch_results_output_directory:
|
||||||
|
type: uri_folder
|
||||||
|
resources:
|
||||||
|
instance_count: ${{parent.inputs.instance_count}}
|
||||||
|
max_concurrency_per_instance: ${{parent.inputs.max_concurrency_per_instance}}
|
||||||
|
mini_batch_size: ${{parent.inputs.mini_batch_size}}
|
||||||
|
retry_settings:
|
||||||
|
timeout: 6000
|
||||||
|
max_retries: 10
|
||||||
|
environment_variables:
|
||||||
|
BATCH_SCORE_INITIAL_REQUEST_TIMEOUT: '180'
|
||||||
|
BATCH_SCORE_DELAY_AFTER_SUCCESSFUL_REQUEST: 'False'
|
||||||
|
BATCH_SCORE_MAX_REQUEST_TIMEOUT: '300'
|
||||||
|
|
||||||
|
oss_distillation_generate_data_batch_postprocess:
|
||||||
|
type: command
|
||||||
|
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_data_generation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||||
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
|
batch_score_train_result: '${{parent.jobs.oss_distillation_train_data_batch_score.outputs.mini_batch_results_output_directory}}'
|
||||||
|
batch_score_validation_result: '${{parent.jobs.oss_distillation_validation_data_batch_score.outputs.mini_batch_results_output_directory}}'
|
||||||
|
hash_train_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_train_data}}'
|
||||||
|
hash_validation_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_validation_data}}'
|
||||||
|
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||||
|
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
|
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||||
|
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
|
||||||
|
outputs:
|
||||||
|
generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}'
|
||||||
|
generated_batch_validation_file_path: '${{parent.outputs.generated_batch_validation_file_path}}'
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# OSS Distillation Generate Data Batch Scoring Preprocess
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This component prepares data to invoke the teacher model endpoint in batch. It supports various data formats such as `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
The component uses the following environment:
|
||||||
|
- `azureml://registries/azureml/environments/acft-hf-nlp-gpu/labels/latest`
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
The component accepts the following inputs:
|
||||||
|
|
||||||
|
| Name | Type | Description | Required | Default |
|
||||||
|
|-------------------------------|-----------|-------------------------------------------------------------------------------------------------|----------|---------|
|
||||||
|
| `train_file_path` | uri_file | Path to the registered training data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | Yes | |
|
||||||
|
| `validation_file_path` | uri_file | Path to the registered validation data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | No | |
|
||||||
|
| `teacher_model_endpoint_url` | string | The URL of the teacher model endpoint. | Yes | |
|
||||||
|
| `teacher_model_asset_id` | string | The asset ID of the teacher model. | Yes | |
|
||||||
|
| `teacher_model_max_new_tokens`| integer | Teacher model max_new_tokens inference parameter. | Yes | 128 |
|
||||||
|
| `teacher_model_temperature` | number | Teacher model temperature inference parameter. | Yes | 0.2 |
|
||||||
|
| `teacher_model_top_p` | number | Teacher model top_p inference parameter. | Yes | 0.1 |
|
||||||
|
| `teacher_model_frequency_penalty` | number | Teacher model frequency penalty inference parameter. | Yes | 0.0 |
|
||||||
|
| `teacher_model_presence_penalty` | number | Teacher model presence penalty inference parameter. | Yes | 0.0 |
|
||||||
|
| `teacher_model_stop` | string | Teacher model stop inference parameter. | No | |
|
||||||
|
| `enable_chain_of_thought` | string | Enable Chain of thought for data generation. | No | "false" |
|
||||||
|
| `enable_chain_of_density` | string | Enable Chain of density for text summarization. | No | "false" |
|
||||||
|
| `max_len_summary` | integer | Maximum Length Summary for text summarization. | No | 80 |
|
||||||
|
| `data_generation_task_type` | string | Specifies the type of data generation task. Supported values: `NLI`, `CONVERSATION`, `NLU_QA`, `MATH`, `SUMMARIZATION`. | Yes | |
|
||||||
|
| `validation_output` | uri_file | Validation status. | Yes | |
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
The component produces the following outputs:
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|----------------------------------|-----------|---------------------------------------------------------------|
|
||||||
|
| `generated_train_payload_path` | mltable | Directory containing the payload to be sent to the model. |
|
||||||
|
| `generated_validation_payload_path` | mltable | Directory containing the payload to be sent to the model. |
|
||||||
|
| `hash_train_data` | uri_file | JSONL file containing the hash for each payload. |
|
||||||
|
| `hash_validation_data` | uri_file | JSONL file containing the hash for each payload. |
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring Postprocess"]
|
|
@ -0,0 +1,110 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||||
|
name: oss_distillation_generate_data_batch_postprocess
|
||||||
|
version: 0.0.1
|
||||||
|
type: command
|
||||||
|
|
||||||
|
is_deterministic: False
|
||||||
|
|
||||||
|
display_name: OSS Distillation Generate Data Postprocess Batch Scoring
|
||||||
|
description: Component to prepare data returned from teacher model enpoint in batch
|
||||||
|
|
||||||
|
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Inputs
|
||||||
|
train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
hash_train_data:
|
||||||
|
type: uri_file
|
||||||
|
optional: false
|
||||||
|
description: jsonl file containing the hash for each payload.
|
||||||
|
|
||||||
|
hash_validation_data:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: jsonl file containing the hash for each payload.
|
||||||
|
|
||||||
|
batch_score_train_result:
|
||||||
|
type: uri_folder
|
||||||
|
description: Path to the directory containing jsonl file(s) that have the result for each payload.
|
||||||
|
|
||||||
|
batch_score_validation_result:
|
||||||
|
type: uri_folder
|
||||||
|
optional: true
|
||||||
|
description: Path to the directory containing jsonl file(s) that have the result for each payload.
|
||||||
|
|
||||||
|
min_endpoint_success_ratio:
|
||||||
|
type: number
|
||||||
|
default: 0.7
|
||||||
|
description: >
|
||||||
|
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
|
||||||
|
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
|
||||||
|
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
|
||||||
|
|
||||||
|
enable_chain_of_thought:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of thought for data generation
|
||||||
|
|
||||||
|
enable_chain_of_density:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of density for text summarization
|
||||||
|
|
||||||
|
data_generation_task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- NLI
|
||||||
|
- CONVERSATION
|
||||||
|
- NLU_QA
|
||||||
|
- MATH
|
||||||
|
- SUMMARIZATION
|
||||||
|
description: >
|
||||||
|
Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
connection_config_file:
|
||||||
|
type: uri_file
|
||||||
|
description: Connection config file for batch scoring
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
generated_batch_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated train data
|
||||||
|
mode: rw_mount
|
||||||
|
generated_batch_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated validation data
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
code: ../../src
|
||||||
|
command: >-
|
||||||
|
python generate_data_postprocess.py
|
||||||
|
--train_file_path ${{inputs.train_file_path}}
|
||||||
|
$[[--validation_file_path ${{inputs.validation_file_path}}]]
|
||||||
|
--hash_train_data ${{inputs.hash_train_data}}
|
||||||
|
$[[--hash_validation_data ${{inputs.hash_validation_data}}]]
|
||||||
|
--batch_score_train_result ${{inputs.batch_score_train_result}}
|
||||||
|
$[[--batch_score_validation_result ${{inputs.batch_score_validation_result}}]]
|
||||||
|
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
|
||||||
|
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
|
||||||
|
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
|
||||||
|
--data_generation_task_type ${{inputs.data_generation_task_type}}
|
||||||
|
--connection_config_file ${{inputs.connection_config_file}}
|
||||||
|
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
|
||||||
|
--generated_batch_validation_file_path ${{outputs.generated_batch_validation_file_path}}
|
|
@ -0,0 +1,39 @@
|
||||||
|
# OSS Distillation Generate Data Batch Scoring Preprocess
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This component prepares data to invoke the teacher model endpoint in batch. It supports various data formats such as `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
The component uses the following environment:
|
||||||
|
- `azureml://registries/azureml/environments/acft-hf-nlp-gpu/labels/latest`
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
The component accepts the following inputs:
|
||||||
|
|
||||||
|
| Name | Type | Description | Required | Default |
|
||||||
|
|-------------------------------|-----------|-------------------------------------------------------------------------------------------------|----------|---------|
|
||||||
|
| `train_file_path` | uri_file | Path to the registered training data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | Yes | |
|
||||||
|
| `validation_file_path` | uri_file | Path to the registered validation data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | No | |
|
||||||
|
| `teacher_model_endpoint_url` | string | The URL of the teacher model endpoint. | Yes | |
|
||||||
|
| `teacher_model_asset_id` | string | The asset ID of the teacher model. | Yes | |
|
||||||
|
| `teacher_model_max_new_tokens`| integer | Teacher model max_new_tokens inference parameter. | Yes | 128 |
|
||||||
|
| `teacher_model_temperature` | number | Teacher model temperature inference parameter. | Yes | 0.2 |
|
||||||
|
| `teacher_model_top_p` | number | Teacher model top_p inference parameter. | Yes | 0.1 |
|
||||||
|
| `teacher_model_frequency_penalty` | number | Teacher model frequency penalty inference parameter. | Yes | 0.0 |
|
||||||
|
| `teacher_model_presence_penalty` | number | Teacher model presence penalty inference parameter. | Yes | 0.0 |
|
||||||
|
| `teacher_model_stop` | string | Teacher model stop inference parameter. | No | |
|
||||||
|
| `enable_chain_of_thought` | string | Enable Chain of thought for data generation. | No | "false" |
|
||||||
|
| `enable_chain_of_density` | string | Enable Chain of density for text summarization. | No | "false" |
|
||||||
|
| `max_len_summary` | integer | Maximum Length Summary for text summarization. | No | 80 |
|
||||||
|
| `data_generation_task_type` | string | Specifies the type of data generation task. Supported values: `NLI`, `CONVERSATION`, `NLU_QA`, `MATH`, `SUMMARIZATION`. | Yes | |
|
||||||
|
| `validation_output` | uri_file | Validation status. | Yes | |
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
The component produces the following outputs:
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|----------------------------------|-----------|---------------------------------------------------------------|
|
||||||
|
| `generated_train_payload_path` | mltable | Directory containing the payload to be sent to the model. |
|
||||||
|
| `generated_validation_payload_path` | mltable | Directory containing the payload to be sent to the model. |
|
||||||
|
| `hash_train_data` | uri_file | JSONL file containing the hash for each payload. |
|
||||||
|
| `hash_validation_data` | uri_file | JSONL file containing the hash for each payload. |
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring Preprocess"]
|
|
@ -0,0 +1,152 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||||
|
name: oss_distillation_generate_data_batch_preprocess
|
||||||
|
version: 0.0.1
|
||||||
|
type: command
|
||||||
|
|
||||||
|
is_deterministic: False
|
||||||
|
|
||||||
|
display_name: OSS Distillation Generate Data Batch Scoring Preprocess
|
||||||
|
description: Component to prepare data to invoke teacher model enpoint in batch
|
||||||
|
|
||||||
|
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Inputs
|
||||||
|
train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
teacher_model_endpoint_name:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint name
|
||||||
|
|
||||||
|
teacher_model_endpoint_url:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint url
|
||||||
|
|
||||||
|
teacher_model_endpoint_key:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint key
|
||||||
|
|
||||||
|
teacher_model_max_new_tokens:
|
||||||
|
type: integer
|
||||||
|
default: 128
|
||||||
|
description: Teacher model max_new_tokens inference parameter
|
||||||
|
|
||||||
|
teacher_model_temperature:
|
||||||
|
type: number
|
||||||
|
default: 0.2
|
||||||
|
description: Teacher model temperature inference parameter
|
||||||
|
|
||||||
|
teacher_model_top_p:
|
||||||
|
type: number
|
||||||
|
default: 0.1
|
||||||
|
description: Teacher model top_p inference parameter
|
||||||
|
|
||||||
|
teacher_model_frequency_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model frequency penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_presence_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model presence penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_stop:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model stop inference parameter
|
||||||
|
|
||||||
|
enable_chain_of_thought:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of thought for data generation
|
||||||
|
|
||||||
|
enable_chain_of_density:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of density for text summarization
|
||||||
|
|
||||||
|
max_len_summary:
|
||||||
|
type: integer
|
||||||
|
optional: true
|
||||||
|
default: 80
|
||||||
|
description: Maximum Length Summary for text summarization
|
||||||
|
|
||||||
|
data_generation_task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- NLI
|
||||||
|
- CONVERSATION
|
||||||
|
- NLU_QA
|
||||||
|
- MATH
|
||||||
|
- SUMMARIZATION
|
||||||
|
description: >
|
||||||
|
Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
|
||||||
|
# Output of validation component.
|
||||||
|
validation_info:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Validation status.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
generated_train_payload_path:
|
||||||
|
type: mltable
|
||||||
|
description: directory containing the payload to be sent to the model.
|
||||||
|
generated_validation_payload_path:
|
||||||
|
type: mltable
|
||||||
|
description: directory containing the payload to be sent to the model.
|
||||||
|
hash_train_data:
|
||||||
|
type: uri_file
|
||||||
|
description: jsonl file containing the hash for each payload.
|
||||||
|
hash_validation_data:
|
||||||
|
type: uri_file
|
||||||
|
description: jsonl file containing the hash for each payload.
|
||||||
|
batch_config_connection:
|
||||||
|
type: uri_file
|
||||||
|
description: Config file path that contains deployment configurations
|
||||||
|
|
||||||
|
code: ../../src
|
||||||
|
command: >-
|
||||||
|
python generate_data_preprocess.py
|
||||||
|
--train_file_path ${{inputs.train_file_path}}
|
||||||
|
$[[--validation_file_path ${{inputs.validation_file_path}}]]
|
||||||
|
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
|
||||||
|
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
|
||||||
|
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
|
||||||
|
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
|
||||||
|
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
|
||||||
|
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
|
||||||
|
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
|
||||||
|
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
|
||||||
|
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
|
||||||
|
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
|
||||||
|
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
|
||||||
|
$[[--max_len_summary ${{inputs.max_len_summary}}]]
|
||||||
|
--data_generation_task_type ${{inputs.data_generation_task_type}}
|
||||||
|
--generated_train_payload_path ${{outputs.generated_train_payload_path}}
|
||||||
|
--generated_validation_payload_path ${{outputs.generated_validation_payload_path}}
|
||||||
|
--hash_train_data ${{outputs.hash_train_data}}
|
||||||
|
--hash_validation_data ${{outputs.hash_validation_data}}
|
||||||
|
--batch_config_connection ${{outputs.batch_config_connection}}
|
|
@ -0,0 +1,23 @@
|
||||||
|
# OSS Distillation Generate Data Path Selection
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This component selects the path for data generation based on the task type. It supports various data generation tasks such as Natural Language Inference (NLI), Conversation, Natural Language Understanding for Question Answering (NLU_QA), Math, and Summarization.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
The component uses the following environment:
|
||||||
|
- `azureml://registries/azureml/environments/model-evaluation/labels/latest`
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
The component accepts the following inputs:
|
||||||
|
|
||||||
|
- `data_generation_task_type` (string): Specifies the type of data generation task. Supported values are:
|
||||||
|
- `NLI`: Generate Natural Language Inference data
|
||||||
|
- `CONVERSATION`: Generate conversational data (multi/single turn)
|
||||||
|
- `NLU_QA`: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
- `MATH`: Generate Math data for numerical responses
|
||||||
|
- `SUMMARIZATION`: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
The component produces the following output:
|
||||||
|
|
||||||
|
- `output` (boolean): A control output indicating the success or failure of the data path selection.
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation","Batch or Sequence Scoring Selector"]
|
|
@ -0,0 +1,41 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||||
|
name: oss_distillation_data_generation_batch_scoring_selector
|
||||||
|
version: 0.0.1
|
||||||
|
type: command
|
||||||
|
|
||||||
|
is_deterministic: True
|
||||||
|
|
||||||
|
display_name: OSS Distillation Batch Scoring Selector Component
|
||||||
|
description: Component to select the Batch Scoring Selector based on the task type
|
||||||
|
|
||||||
|
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Inputs
|
||||||
|
|
||||||
|
data_generation_task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- NLI
|
||||||
|
- CONVERSATION
|
||||||
|
- NLU_QA
|
||||||
|
- MATH
|
||||||
|
- SUMMARIZATION
|
||||||
|
description: >
|
||||||
|
Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
outputs:
|
||||||
|
output:
|
||||||
|
type: boolean
|
||||||
|
is_control: true
|
||||||
|
|
||||||
|
|
||||||
|
code: ../../src
|
||||||
|
command: >-
|
||||||
|
mldesigner execute --source generate_data_batch_scoring_selection.py --name validate
|
||||||
|
--inputs data_generation_task_type=${{inputs.data_generation_task_type}}
|
||||||
|
--outputs output='${{outputs.output}}'
|
|
@ -0,0 +1,23 @@
|
||||||
|
# OSS Distillation Generate Data Path Selection
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This component selects the path for data generation based on the task type. It supports various data generation tasks such as Natural Language Inference (NLI), Conversation, Natural Language Understanding for Question Answering (NLU_QA), Math, and Summarization.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
The component uses the following environment:
|
||||||
|
- `azureml://registries/azureml/environments/model-evaluation/labels/latest`
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
The component accepts the following inputs:
|
||||||
|
|
||||||
|
- `data_generation_task_type` (string): Specifies the type of data generation task. Supported values are:
|
||||||
|
- `NLI`: Generate Natural Language Inference data
|
||||||
|
- `CONVERSATION`: Generate conversational data (multi/single turn)
|
||||||
|
- `NLU_QA`: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
- `MATH`: Generate Math data for numerical responses
|
||||||
|
- `SUMMARIZATION`: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
The component produces the following output:
|
||||||
|
|
||||||
|
- `output` (boolean): A control output indicating the success or failure of the data path selection.
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation"]
|
|
@ -0,0 +1,54 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||||
|
name: oss_distillation_data_generation_file_selector
|
||||||
|
version: 0.0.1
|
||||||
|
type: command
|
||||||
|
|
||||||
|
is_deterministic: True
|
||||||
|
tags:
|
||||||
|
codegenBy: dsl.condition_output
|
||||||
|
|
||||||
|
display_name: OSS Distillation Fine-Tuning Input File Selector
|
||||||
|
description: Component to select the Batch Scoring Selector based on the task type
|
||||||
|
|
||||||
|
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
generated_batch_train_file_path:
|
||||||
|
type: uri_folder
|
||||||
|
optional: true
|
||||||
|
mode: ro_mount
|
||||||
|
generated_batch_validation_file_path:
|
||||||
|
type: uri_folder
|
||||||
|
optional: true
|
||||||
|
mode: ro_mount
|
||||||
|
generated_train_file_path:
|
||||||
|
type: uri_folder
|
||||||
|
optional: true
|
||||||
|
mode: ro_mount
|
||||||
|
generated_validation_file_path:
|
||||||
|
type: uri_folder
|
||||||
|
optional: true
|
||||||
|
mode: ro_mount
|
||||||
|
condition:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
optional: false
|
||||||
|
outputs:
|
||||||
|
ft_input_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
mode: rw_mount
|
||||||
|
ft_input_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
|
||||||
|
code: ../../src
|
||||||
|
command: >-
|
||||||
|
python dsl_condition_output.py
|
||||||
|
$[[--generated_batch_train_file_path ${{inputs.generated_batch_train_file_path}}]]
|
||||||
|
$[[--generated_batch_validation_file_path ${{inputs.generated_batch_validation_file_path}}]]
|
||||||
|
$[[--generated_train_file_path ${{inputs.generated_train_file_path}}]]
|
||||||
|
$[[--generated_validation_file_path ${{inputs.generated_validation_file_path}}]]
|
||||||
|
--condition ${{inputs.condition}}
|
||||||
|
--ft_input_train_file_path ${{outputs.ft_input_train_file_path}}
|
||||||
|
--ft_input_validation_file_path ${{outputs.ft_input_validation_file_path}}
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation"]
|
|
@ -0,0 +1,241 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
|
||||||
|
name: oss_distillation_seq_scoring_pipeline
|
||||||
|
version: 0.0.1
|
||||||
|
type: pipeline
|
||||||
|
|
||||||
|
|
||||||
|
display_name: OSS Distillation Sequence Scoring Pipeline
|
||||||
|
description: Component to generate data from teacher model enpoint(sequentially) and finetune student model on generated dataset
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Compute parameters
|
||||||
|
instance_type_pipeline_validation:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
|
||||||
|
instance_type_data_generation:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Standard_D4as_v4
|
||||||
|
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
|
||||||
|
instance_type_data_import:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Singularity.ND96amrs_A100_v4
|
||||||
|
description: Instance type to be used for data_import component in case of virtual cluster compute, eg. Singularity.D8_v3. The parameter compute_data_import must be set to 'serverless' for instance_type to be used
|
||||||
|
instance_type_finetune:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: Singularity.ND96amrs_A100_v4
|
||||||
|
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
|
||||||
|
|
||||||
|
compute_pipeline_validation:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
default: 'serverless'
|
||||||
|
description: compute to be used for validation component
|
||||||
|
|
||||||
|
compute_data_generation:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for model_import eg. provide 'FT-Cluster' if
|
||||||
|
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
compute_data_import:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for model_import eg. provide 'FT-Cluster' if
|
||||||
|
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
compute_finetune:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 'serverless'
|
||||||
|
description: >-
|
||||||
|
compute to be used for finetune eg. provide 'FT-Cluster' if your
|
||||||
|
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||||
|
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||||
|
|
||||||
|
# ########################### Data Generator Component ########################### #
|
||||||
|
|
||||||
|
train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
|
||||||
|
validation_info:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Validation status.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
teacher_model_endpoint_name:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint name
|
||||||
|
|
||||||
|
teacher_model_endpoint_url:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint URL
|
||||||
|
|
||||||
|
teacher_model_endpoint_key:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model endpoint key
|
||||||
|
|
||||||
|
teacher_model_max_new_tokens:
|
||||||
|
type: integer
|
||||||
|
default: 128
|
||||||
|
description: Teacher model max_new_tokens inference parameter
|
||||||
|
|
||||||
|
teacher_model_temperature:
|
||||||
|
type: number
|
||||||
|
default: 0.2
|
||||||
|
description: Teacher model temperature inference parameter
|
||||||
|
|
||||||
|
teacher_model_top_p:
|
||||||
|
type: number
|
||||||
|
default: 0.1
|
||||||
|
description: Teacher model top_p inference parameter
|
||||||
|
|
||||||
|
teacher_model_frequency_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model frequency penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_presence_penalty:
|
||||||
|
type: number
|
||||||
|
default: 0.0
|
||||||
|
description: Teacher model presence penalty inference parameter
|
||||||
|
|
||||||
|
teacher_model_stop:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
description: Teacher model stop inference parameter
|
||||||
|
|
||||||
|
request_batch_size:
|
||||||
|
type: integer
|
||||||
|
default: 10
|
||||||
|
description: No of data records to hit teacher model endpoint in one go
|
||||||
|
|
||||||
|
min_endpoint_success_ratio:
|
||||||
|
type: number
|
||||||
|
default: 0.7
|
||||||
|
description: >
|
||||||
|
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
|
||||||
|
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
|
||||||
|
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
|
||||||
|
|
||||||
|
enable_chain_of_thought:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of thought for data generation
|
||||||
|
|
||||||
|
enable_chain_of_density:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: "false"
|
||||||
|
description: Enable Chain of density for text summarization
|
||||||
|
|
||||||
|
max_len_summary:
|
||||||
|
type: integer
|
||||||
|
optional: true
|
||||||
|
default: 80
|
||||||
|
description: Maximum Length Summary for text summarization
|
||||||
|
|
||||||
|
data_generation_task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- NLI
|
||||||
|
- CONVERSATION
|
||||||
|
- NLU_QA
|
||||||
|
- MATH
|
||||||
|
- SUMMARIZATION
|
||||||
|
description: >
|
||||||
|
Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
num_train_epochs:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
optional: true
|
||||||
|
description: training epochs
|
||||||
|
|
||||||
|
per_device_train_batch_size:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
optional: true
|
||||||
|
description: Train batch size
|
||||||
|
|
||||||
|
learning_rate:
|
||||||
|
type: number
|
||||||
|
default: 3e-04
|
||||||
|
optional: true
|
||||||
|
description: Start learning rate.
|
||||||
|
|
||||||
|
# Output of validation component.
|
||||||
|
validation_output:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Validation status.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
generated_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated train data
|
||||||
|
mode: rw_mount
|
||||||
|
generated_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
description: Generated validation data
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
oss_distillation_generate_data:
|
||||||
|
type: command
|
||||||
|
component: azureml:oss_distillation_generate_data:0.0.8
|
||||||
|
compute: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_data_generation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||||
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
|
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
||||||
|
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
||||||
|
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
||||||
|
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||||
|
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
||||||
|
max_len_summary: '${{parent.inputs.max_len_summary}}'
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
|
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
||||||
|
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
||||||
|
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
||||||
|
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
||||||
|
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||||
|
request_batch_size: '${{parent.inputs.request_batch_size}}'
|
||||||
|
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||||
|
validation_output: '${{parent.inputs.validation_output}}'
|
||||||
|
outputs:
|
||||||
|
generated_train_file_path: '${{parent.outputs.generated_train_file_path}}'
|
||||||
|
generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}'
|
|
@ -0,0 +1,23 @@
|
||||||
|
# OSS Distillation Data Generation Batch Scoring Selector
|
||||||
|
|
||||||
|
This component is designed to check if the validation file is present or not in the OSS Distillation pipeline. It ensures that the provided validation file path is valid and accessible.
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
The OSS Distillation Data Generation Batch Scoring Selector is a command component that verifies the presence of a validation file. It supports various data formats including `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
|
||||||
|
- **Environment**: `azureml://registries/azureml/environments/model-evaluation/labels/latest`
|
||||||
|
|
||||||
|
## Inputs
|
||||||
|
|
||||||
|
| Name | Type | Optional | Description |
|
||||||
|
|-----------------------|----------|----------|-------------------------------------------------------------------------------------------------------------|
|
||||||
|
| validation_file_path | uri_file | Yes | Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv`, and `parquet`. |
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|--------|---------|------------------------------------|
|
||||||
|
| output | boolean | Indicates if the validation file is present or not. |
|
|
@ -0,0 +1,3 @@
|
||||||
|
type: component
|
||||||
|
spec: spec.yaml
|
||||||
|
categories: ["Foundational Models", "Finetune", "Distillation","Validation File Checker"]
|
|
@ -0,0 +1,31 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||||
|
name: oss_distillation_data_generation_validation_file_checker
|
||||||
|
version: 0.0.1
|
||||||
|
type: command
|
||||||
|
|
||||||
|
is_deterministic: True
|
||||||
|
|
||||||
|
display_name: OSS Distillation Validation File Checker Component
|
||||||
|
description: Component to Check if the validation file is present or not
|
||||||
|
|
||||||
|
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
# Inputs
|
||||||
|
validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
output:
|
||||||
|
type: boolean
|
||||||
|
is_control: true
|
||||||
|
|
||||||
|
|
||||||
|
code: ../../src
|
||||||
|
command: >-
|
||||||
|
mldesigner execute --source generate_data_validation_file_check.py --name validate
|
||||||
|
--inputs $[[validation_file_path=${{inputs.validation_file_path}}]]
|
||||||
|
--outputs output='${{outputs.output}}'
|
|
@ -1,6 +1,6 @@
|
||||||
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
|
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
|
||||||
name: oss_distillation_pipeline
|
name: oss_distillation_pipeline
|
||||||
version: 0.0.9
|
version: 0.0.10
|
||||||
type: pipeline
|
type: pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,6 +171,58 @@ inputs:
|
||||||
4. MATH: Generate Math data for numerical responses
|
4. MATH: Generate Math data for numerical responses
|
||||||
5. SUMMARIZATION: Generate Key Summary for an Article
|
5. SUMMARIZATION: Generate Key Summary for an Article
|
||||||
|
|
||||||
|
|
||||||
|
# ########################### Batch Score Component ########################### #
|
||||||
|
authentication_type:
|
||||||
|
type: string
|
||||||
|
optional: False
|
||||||
|
description: Authentication type for endpoint. Either `azureml_workspace_connection` or `managed_identity`.
|
||||||
|
default: azureml_workspace_connection
|
||||||
|
enum:
|
||||||
|
- azureml_workspace_connection
|
||||||
|
- managed_identity
|
||||||
|
additional_headers:
|
||||||
|
type: string
|
||||||
|
optional: True
|
||||||
|
description: JSON serialized string expressing additional headers to be added to each request.
|
||||||
|
debug_mode:
|
||||||
|
type: boolean
|
||||||
|
optional: False
|
||||||
|
default: False
|
||||||
|
description: Enable debug mode to print all the debug logs in the score step.
|
||||||
|
ensure_ascii:
|
||||||
|
type: boolean
|
||||||
|
optional: False
|
||||||
|
default: False
|
||||||
|
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
|
||||||
|
max_retry_time_interval:
|
||||||
|
type: integer
|
||||||
|
optional: True
|
||||||
|
description: The maximum time (in seconds) spent retrying a payload. If unspecified, payloads are retried for unlimited time.
|
||||||
|
initial_worker_count:
|
||||||
|
type: integer
|
||||||
|
optional: False
|
||||||
|
default: 5
|
||||||
|
description: The initial number of workers to use for scoring.
|
||||||
|
max_worker_count:
|
||||||
|
type: integer
|
||||||
|
optional: False
|
||||||
|
default: 200
|
||||||
|
description: Overrides `initial_worker_count` if necessary.
|
||||||
|
instance_count:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
description: Number of nodes in a compute cluster we will run the batch score step on.
|
||||||
|
max_concurrency_per_instance:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
description: Number of processes that will be run concurrently on any given node. This number should not be larger than 1/2 of the number of cores in an individual node in the specified cluster.
|
||||||
|
mini_batch_size:
|
||||||
|
type: string
|
||||||
|
optional: true
|
||||||
|
default: 100KB
|
||||||
|
description: The mini batch size for parallel run.
|
||||||
|
|
||||||
# ########################### Finetuning Component ########################### #
|
# ########################### Finetuning Component ########################### #
|
||||||
|
|
||||||
number_of_gpu_to_use_finetuning:
|
number_of_gpu_to_use_finetuning:
|
||||||
|
@ -230,6 +282,12 @@ inputs:
|
||||||
optional: true
|
optional: true
|
||||||
description: Name of the registered model
|
description: Name of the registered model
|
||||||
|
|
||||||
|
validation_info:
|
||||||
|
type: uri_file
|
||||||
|
optional: true
|
||||||
|
description: Validation status.
|
||||||
|
mode: rw_mount
|
||||||
|
|
||||||
outputs:
|
outputs:
|
||||||
output_model:
|
output_model:
|
||||||
type: uri_folder
|
type: uri_folder
|
||||||
|
@ -270,31 +328,103 @@ jobs:
|
||||||
type: uri_file
|
type: uri_file
|
||||||
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
|
||||||
|
|
||||||
oss_distillation_generate_data:
|
data_generation_batch_scoring_selector:
|
||||||
type: command
|
type: command
|
||||||
component: azureml:oss_distillation_generate_data:0.0.8
|
component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1
|
||||||
compute: '${{parent.inputs.compute_data_generation}}'
|
compute: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
resources:
|
resources:
|
||||||
instance_type: '${{parent.inputs.instance_type_data_generation}}'
|
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
identity:
|
identity:
|
||||||
type: user_identity
|
type: user_identity
|
||||||
inputs:
|
inputs:
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
|
|
||||||
|
validation_succeeded:
|
||||||
|
type: if_else
|
||||||
|
condition: ${{parent.jobs.data_generation_batch_scoring_selector.outputs.output}}
|
||||||
|
true_block: ${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline}}
|
||||||
|
false_block: ${{parent.jobs.oss_distillation_seq_scoring_pipeline}}
|
||||||
|
|
||||||
|
oss_distillation_batchscoring_datagen_pipeline:
|
||||||
|
type: pipeline
|
||||||
|
component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1
|
||||||
|
inputs:
|
||||||
|
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
|
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
|
||||||
|
instance_type_data_import: '${{parent.inputs.instance_type_data_import}}'
|
||||||
|
instance_type_finetune: '${{parent.inputs.instance_type_finetune}}'
|
||||||
|
compute_pipeline_validation: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
|
compute_data_generation: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
compute_data_import: '${{parent.inputs.compute_data_import}}'
|
||||||
|
compute_finetune: '${{parent.inputs.compute_finetune}}'
|
||||||
train_file_path: '${{parent.inputs.train_file_path}}'
|
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||||
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
|
||||||
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
||||||
|
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
||||||
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
||||||
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
|
||||||
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
|
||||||
max_len_summary: '${{parent.inputs.max_len_summary}}'
|
|
||||||
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
|
||||||
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
||||||
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
||||||
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
||||||
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
||||||
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||||
|
teacher_model_stop: '${{parent.inputs.teacher_model_stop}}'
|
||||||
|
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||||
|
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||||
|
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
||||||
|
max_len_summary: '${{parent.inputs.max_len_summary}}'
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
|
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
|
||||||
|
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
|
||||||
|
learning_rate: '${{parent.inputs.learning_rate}}'
|
||||||
|
authentication_type: '${{parent.inputs.authentication_type}}'
|
||||||
|
additional_headers: '${{parent.inputs.additional_headers}}'
|
||||||
|
debug_mode: '${{parent.inputs.debug_mode}}'
|
||||||
|
ensure_ascii: '${{parent.inputs.ensure_ascii}}'
|
||||||
|
max_retry_time_interval: '${{parent.inputs.max_retry_time_interval}}'
|
||||||
|
initial_worker_count: '${{parent.inputs.initial_worker_count}}'
|
||||||
|
max_worker_count: '${{parent.inputs.max_worker_count}}'
|
||||||
|
instance_count: '${{parent.inputs.instance_count}}'
|
||||||
|
max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}'
|
||||||
|
mini_batch_size: '${{parent.inputs.mini_batch_size}}'
|
||||||
|
validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
generated_batch_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
generated_batch_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
|
||||||
|
oss_distillation_seq_scoring_pipeline:
|
||||||
|
type: pipeline
|
||||||
|
component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1
|
||||||
|
inputs:
|
||||||
|
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
|
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
|
||||||
|
instance_type_data_import: '${{parent.inputs.instance_type_data_import}}'
|
||||||
|
instance_type_finetune: '${{parent.inputs.instance_type_finetune}}'
|
||||||
|
compute_pipeline_validation: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
|
compute_data_generation: '${{parent.inputs.compute_data_generation}}'
|
||||||
|
compute_data_import: '${{parent.inputs.compute_data_import}}'
|
||||||
|
compute_finetune: '${{parent.inputs.compute_finetune}}'
|
||||||
|
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||||
|
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||||
|
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
||||||
|
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
||||||
|
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
||||||
|
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
||||||
|
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
||||||
|
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
||||||
|
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
||||||
|
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||||
|
teacher_model_stop: '${{parent.inputs.teacher_model_stop}}'
|
||||||
request_batch_size: '${{parent.inputs.request_batch_size}}'
|
request_batch_size: '${{parent.inputs.request_batch_size}}'
|
||||||
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||||
|
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||||
|
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
|
||||||
|
max_len_summary: '${{parent.inputs.max_len_summary}}'
|
||||||
|
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||||
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
|
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
|
||||||
outputs:
|
outputs:
|
||||||
generated_train_file_path:
|
generated_train_file_path:
|
||||||
|
@ -304,6 +434,30 @@ jobs:
|
||||||
type: uri_file
|
type: uri_file
|
||||||
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
|
||||||
|
|
||||||
|
oss_distillation_train_data_generation_file_selector:
|
||||||
|
type: command
|
||||||
|
component: azureml:oss_distillation_data_generation_file_selector:0.0.1
|
||||||
|
compute: '${{parent.inputs.compute_pipeline_validation}}'
|
||||||
|
resources:
|
||||||
|
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||||
|
identity:
|
||||||
|
type: user_identity
|
||||||
|
inputs:
|
||||||
|
generated_batch_train_file_path: '${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline.outputs.generated_batch_train_file_path}}'
|
||||||
|
generated_batch_validation_file_path: '${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline.outputs.generated_batch_validation_file_path}}'
|
||||||
|
generated_train_file_path: '${{parent.jobs.oss_distillation_seq_scoring_pipeline.outputs.generated_train_file_path}}'
|
||||||
|
generated_validation_file_path: '${{parent.jobs.oss_distillation_seq_scoring_pipeline.outputs.generated_validation_file_path}}'
|
||||||
|
condition: '${{parent.jobs.data_generation_batch_scoring_selector.outputs.output}}'
|
||||||
|
outputs:
|
||||||
|
ft_input_train_file_path:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
ft_input_validation_file_path:
|
||||||
|
type: uri_file
|
||||||
|
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
|
||||||
|
|
||||||
|
|
||||||
oss_text_generation_data_import:
|
oss_text_generation_data_import:
|
||||||
type: command
|
type: command
|
||||||
component: azureml:oss_text_generation_data_import:0.0.24
|
component: azureml:oss_text_generation_data_import:0.0.24
|
||||||
|
@ -318,8 +472,8 @@ jobs:
|
||||||
environment_variables:
|
environment_variables:
|
||||||
_AZUREML_CR_ENABLE_ITP_CAP: "false"
|
_AZUREML_CR_ENABLE_ITP_CAP: "false"
|
||||||
inputs:
|
inputs:
|
||||||
train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}'
|
train_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_train_file_path}}'
|
||||||
validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}'
|
validation_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_validation_file_path}}'
|
||||||
system_properties: '${{parent.inputs.system_properties}}'
|
system_properties: '${{parent.inputs.system_properties}}'
|
||||||
|
|
||||||
oss_chat_completion_finetune:
|
oss_chat_completion_finetune:
|
||||||
|
|
|
@ -74,6 +74,9 @@ DEFAULT_TEMPERATURE = 0.2
|
||||||
# TEXT SUMMARIZATION DEFAULT OUTPUT WORD COUNT
|
# TEXT SUMMARIZATION DEFAULT OUTPUT WORD COUNT
|
||||||
DEFAULT_MAX_LEN_SUMMARY = 80
|
DEFAULT_MAX_LEN_SUMMARY = 80
|
||||||
|
|
||||||
|
STATUS_SUCCESS = "SUCCESS"
|
||||||
|
FINISH_REASON_STOP = "stop"
|
||||||
|
|
||||||
|
|
||||||
class InferenceMode:
|
class InferenceMode:
|
||||||
"""Supported inference modes."""
|
"""Supported inference modes."""
|
||||||
|
@ -112,6 +115,10 @@ class TelemetryConstants:
|
||||||
INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint"
|
INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint"
|
||||||
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
|
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
|
||||||
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
|
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
|
||||||
|
PRE_PROCESS_TRAINING_DATA = "pre_process_training_data"
|
||||||
|
PRE_PROCESS_VALIDATION_DATA = "pre_process_validation_data"
|
||||||
|
POST_PROCESS_TRAINING_DATA = "post_process_training_data"
|
||||||
|
POST_PROCESS_VALIDATION_DATA = "post_process_validation_data"
|
||||||
PROCESS_DATASET_RECORD = "process_dataset_record"
|
PROCESS_DATASET_RECORD = "process_dataset_record"
|
||||||
|
|
||||||
VALIDATOR = "validator"
|
VALIDATOR = "validator"
|
||||||
|
@ -123,6 +130,28 @@ class TelemetryConstants:
|
||||||
VALIDATE_TRAINING_DATA = "validate_training_data"
|
VALIDATE_TRAINING_DATA = "validate_training_data"
|
||||||
VALIDATE_VALIDATION_DATA = "validate_validation_data"
|
VALIDATE_VALIDATION_DATA = "validate_validation_data"
|
||||||
VALIDATE_MODEL_INFERENCE = "validate_model_inference"
|
VALIDATE_MODEL_INFERENCE = "validate_model_inference"
|
||||||
|
VERSION_SELECTION = "version_selection"
|
||||||
|
|
||||||
|
|
||||||
|
class PayloadField:
|
||||||
|
"""Payload fields."""
|
||||||
|
|
||||||
|
# Payload fields that will be sent to the model.
|
||||||
|
MESSAGES = "messages"
|
||||||
|
ROLE = "role"
|
||||||
|
CONTENT = "content"
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
# Payload fields that will be received from the model.
|
||||||
|
RESPONSE = "response"
|
||||||
|
REQUEST = "request"
|
||||||
|
|
||||||
|
|
||||||
|
class HashField:
|
||||||
|
"""Hash fields."""
|
||||||
|
|
||||||
|
HASH = "hash"
|
||||||
|
|
||||||
|
|
||||||
class BackoffConstants:
|
class BackoffConstants:
|
||||||
|
@ -160,13 +189,16 @@ class SystemPrompt:
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_cot_prompt(cls):
|
def default_cot_prompt(cls):
|
||||||
"""Get the default chain of thought prompt."""
|
"""Get the default chain of thought prompt."""
|
||||||
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(keys=cls.DEFAULT_KEYS, additional_instructions="")
|
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(
|
||||||
|
keys=cls.DEFAULT_KEYS, additional_instructions=""
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def math_cot_prompt(cls):
|
def math_cot_prompt(cls):
|
||||||
"""Get the math chain of thought prompt for datasets expecting numeric answers."""
|
"""Get the math chain of thought prompt for datasets expecting numeric answers."""
|
||||||
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(keys=cls.MATH_NUMERICAL_KEYS,
|
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(
|
||||||
additional_instructions=cls.MATH_ADDITIONAL_INSTRUCTIONS
|
keys=cls.MATH_NUMERICAL_KEYS,
|
||||||
|
additional_instructions=cls.MATH_ADDITIONAL_INSTRUCTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Contains helper functions for input/output operations."""
|
||||||
|
|
||||||
|
from typing import Any, List, Dict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from azureml.acft.common_components.utils.error_handling.exceptions import (
|
||||||
|
ACFTValidationException,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.error_definitions import (
|
||||||
|
ACFTUserError,
|
||||||
|
)
|
||||||
|
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_files_with_given_extension(
|
||||||
|
file_paths: List[str], extension: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""Filter and return the list of files with given extension."""
|
||||||
|
return [file_path for file_path in file_paths if file_path.endswith(extension)]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_paths_from_dir(dir: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get sorted file paths from directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir (str): Directory path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file_paths (List[str]): List of sorted file paths.
|
||||||
|
"""
|
||||||
|
file_paths = []
|
||||||
|
|
||||||
|
for root, _, files in os.walk(dir):
|
||||||
|
for file in files:
|
||||||
|
file_paths.append(os.path.join(root, file))
|
||||||
|
|
||||||
|
file_paths.sort()
|
||||||
|
return file_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_io_path(path: str) -> List[str]:
|
||||||
|
"""Resolve input/output path as a list of file paths.
|
||||||
|
|
||||||
|
It can handle the following cases for `path` argument:
|
||||||
|
- `uri_file`: `path` points to a single file.
|
||||||
|
- `uri_folder`: `path` points to a directory containing multiple files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the file or folder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paths (List[str]): List of file paths.
|
||||||
|
"""
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
return _get_file_paths_from_dir(path)
|
||||||
|
|
||||||
|
return [path]
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl_files(path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Read jsonl file/files and return a list of dictionaries.
|
||||||
|
|
||||||
|
If `path` points to a file without extension, try to read it as \
|
||||||
|
a jsonl file. This is done to support `uri_file` without extension scenario.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to a jsonl file or a directory containing jsonl files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data (List[Dict[str, Any]]): List of dictionaries.
|
||||||
|
"""
|
||||||
|
file_paths: List[str] = _resolve_io_path(path)
|
||||||
|
if len(file_paths) > 1:
|
||||||
|
file_paths = _filter_files_with_given_extension(file_paths, ".jsonl")
|
||||||
|
data_dicts = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
for i, line in enumerate(file):
|
||||||
|
try:
|
||||||
|
data_dicts.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
mssg = f"Invalid JSON format in line {i + 1} of file '{file_path}'."
|
||||||
|
raise ACFTValidationException._with_error(
|
||||||
|
AzureMLError.create(ACFTUserError, pii_safe_message=mssg)
|
||||||
|
)
|
||||||
|
if not data_dicts:
|
||||||
|
mssg = f"No data found in {file_paths}."
|
||||||
|
raise ACFTValidationException._with_error(
|
||||||
|
AzureMLError.create(ACFTUserError, pii_safe_message=mssg)
|
||||||
|
)
|
||||||
|
return data_dicts
|
||||||
|
|
||||||
|
|
||||||
|
def write_jsonl_file(file_path: str, data: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Write data to a `.jsonl` file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the file.
|
||||||
|
data (List[Dict[str, Any]]): Data to be written to the file.
|
||||||
|
"""
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
for line in data:
|
||||||
|
file.write(json.dumps(line) + "\n")
|
|
@ -5,32 +5,43 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple, Union, Optional, Callable
|
from typing import List, Tuple, Union, Optional, Callable, Any, Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from azure.ai.ml import MLClient
|
from azure.ai.ml import MLClient
|
||||||
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
|
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
|
||||||
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, ServerlessEndpoint
|
from azure.ai.ml.entities import (
|
||||||
|
ManagedOnlineEndpoint,
|
||||||
|
ManagedOnlineDeployment,
|
||||||
|
ServerlessEndpoint,
|
||||||
|
)
|
||||||
from azure.identity import AzureCliCredential, ManagedIdentityCredential
|
from azure.identity import AzureCliCredential, ManagedIdentityCredential
|
||||||
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException
|
from azureml.acft.common_components.utils.error_handling.exceptions import (
|
||||||
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError
|
ACFTValidationException,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.error_definitions import (
|
||||||
|
ACFTUserError,
|
||||||
|
)
|
||||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||||
from azureml.acft.common_components import get_logger_app
|
from azureml.acft.common_components import get_logger_app
|
||||||
from azureml.core import Run, Workspace
|
from azureml.core import Run, Workspace
|
||||||
from azureml.core.run import _OfflineRun
|
from azureml.core.run import _OfflineRun
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
from common.constants import (
|
from common.constants import (
|
||||||
REQUESTS_RETRY_DELAY,
|
REQUESTS_RETRY_DELAY,
|
||||||
REGISTRY_MODEL_PATTERN,
|
REGISTRY_MODEL_PATTERN,
|
||||||
SUPPORTED_STUDENT_MODEL_MAP,
|
SUPPORTED_STUDENT_MODEL_MAP,
|
||||||
SUPPORTED_TEACHER_MODEL_MAP,
|
SUPPORTED_TEACHER_MODEL_MAP,
|
||||||
BackoffConstants
|
BackoffConstants,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import")
|
logger = get_logger_app(
|
||||||
|
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||||
|
)
|
||||||
current_run: Run = Run.get_context()
|
current_run: Run = Run.get_context()
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +68,9 @@ def retry(times: int):
|
||||||
time.sleep(REQUESTS_RETRY_DELAY)
|
time.sleep(REQUESTS_RETRY_DELAY)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Retried {} times when calling {}, now giving up!".format(times, func.__name__)
|
"Retried {} times when calling {}, now giving up!".format(
|
||||||
|
times, func.__name__
|
||||||
|
)
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -114,7 +127,7 @@ def get_workspace_mlclient(workspace: Workspace = None) -> MLClient:
|
||||||
credential,
|
credential,
|
||||||
subscription_id=workspace.subscription_id,
|
subscription_id=workspace.subscription_id,
|
||||||
resource_group_name=workspace.resource_group,
|
resource_group_name=workspace.resource_group,
|
||||||
workspace_name=workspace.name
|
workspace_name=workspace.name,
|
||||||
)
|
)
|
||||||
raise Exception("Error creating MLClient. No credentials or workspace found")
|
raise Exception("Error creating MLClient. No credentials or workspace found")
|
||||||
|
|
||||||
|
@ -153,9 +166,13 @@ class ServerlessEndpointDetails(EndpointDetails):
|
||||||
"""
|
"""
|
||||||
self._mlclient: MLClient = mlclient_ws
|
self._mlclient: MLClient = mlclient_ws
|
||||||
try:
|
try:
|
||||||
self._endpoint: ServerlessEndpoint = self._mlclient.serverless_endpoints.get(endpoint_name)
|
self._endpoint: ServerlessEndpoint = (
|
||||||
|
self._mlclient.serverless_endpoints.get(endpoint_name)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Serverless endpoint fetch details failed with exception: {e}")
|
raise Exception(
|
||||||
|
f"Serverless endpoint fetch details failed with exception: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# ensure endpoint is healthy
|
# ensure endpoint is healthy
|
||||||
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
|
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
|
||||||
|
@ -172,9 +189,13 @@ class ServerlessEndpointDetails(EndpointDetails):
|
||||||
str: endpoint primary key for serverless deployment.
|
str: endpoint primary key for serverless deployment.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self._mlclient.serverless_endpoints.get_keys(self._endpoint.name).primary_key
|
return self._mlclient.serverless_endpoints.get_keys(
|
||||||
|
self._endpoint.name
|
||||||
|
).primary_key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}")
|
raise Exception(
|
||||||
|
f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_endpoint_url(self) -> str:
|
def get_endpoint_url(self) -> str:
|
||||||
"""Get URL for managed online endpoint."""
|
"""Get URL for managed online endpoint."""
|
||||||
|
@ -202,7 +223,9 @@ class OnlineEndpointDetails(EndpointDetails):
|
||||||
"""
|
"""
|
||||||
self._mlclient: MLClient = mlclient_ws
|
self._mlclient: MLClient = mlclient_ws
|
||||||
try:
|
try:
|
||||||
self._endpoint: ManagedOnlineEndpoint = self._mlclient.online_endpoints.get(endpoint_name)
|
self._endpoint: ManagedOnlineEndpoint = self._mlclient.online_endpoints.get(
|
||||||
|
endpoint_name
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Online endpoint fetch details failed with exception: {e}")
|
raise Exception(f"Online endpoint fetch details failed with exception: {e}")
|
||||||
|
|
||||||
|
@ -210,9 +233,12 @@ class OnlineEndpointDetails(EndpointDetails):
|
||||||
# fetch deployment with 100% traffic
|
# fetch deployment with 100% traffic
|
||||||
deployments = [
|
deployments = [
|
||||||
deployment
|
deployment
|
||||||
for deployment in all_deployments if deployment.name in [
|
for deployment in all_deployments
|
||||||
|
if deployment.name
|
||||||
|
in [
|
||||||
deployment_name
|
deployment_name
|
||||||
for deployment_name, traffic_percent in self._endpoint.traffic.items() if traffic_percent == 100
|
for deployment_name, traffic_percent in self._endpoint.traffic.items()
|
||||||
|
if traffic_percent == 100
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -226,10 +252,16 @@ class OnlineEndpointDetails(EndpointDetails):
|
||||||
self._deployment = deployments[0]
|
self._deployment = deployments[0]
|
||||||
# ensure endpoint and deployment is healthy
|
# ensure endpoint and deployment is healthy
|
||||||
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
|
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
|
||||||
logger.info(f"Deployment provisioning state: {self._deployment.provisioning_state}")
|
logger.info(
|
||||||
if not (self._endpoint.provisioning_state.lower() == "succeeded"
|
f"Deployment provisioning state: {self._deployment.provisioning_state}"
|
||||||
and self._deployment.provisioning_state.lower() == "succeeded"):
|
)
|
||||||
raise Exception(f"Endpoint {self._endpoint.name} or deployment {self._deployment.name} is unhealthy.")
|
if not (
|
||||||
|
self._endpoint.provisioning_state.lower() == "succeeded"
|
||||||
|
and self._deployment.provisioning_state.lower() == "succeeded"
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Endpoint {self._endpoint.name} or deployment {self._deployment.name} is unhealthy."
|
||||||
|
)
|
||||||
|
|
||||||
def get_endpoint_key(self):
|
def get_endpoint_key(self):
|
||||||
"""Get endpoint primary key for managed online deployment.
|
"""Get endpoint primary key for managed online deployment.
|
||||||
|
@ -241,9 +273,13 @@ class OnlineEndpointDetails(EndpointDetails):
|
||||||
str: endpoint primary key for managed online deployment.
|
str: endpoint primary key for managed online deployment.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self._mlclient.online_endpoints.get_keys(self._endpoint.name).primary_key
|
return self._mlclient.online_endpoints.get_keys(
|
||||||
|
self._endpoint.name
|
||||||
|
).primary_key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}")
|
raise Exception(
|
||||||
|
f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_endpoint_url(self) -> str:
|
def get_endpoint_url(self) -> str:
|
||||||
"""Get URL for managed online endpoint."""
|
"""Get URL for managed online endpoint."""
|
||||||
|
@ -260,11 +296,14 @@ class OnlineEndpointDetails(EndpointDetails):
|
||||||
List[ManagedOnlineDeployment]: List of deployments
|
List[ManagedOnlineDeployment]: List of deployments
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._deployments: List[ManagedOnlineDeployment] = self._mlclient.online_deployments.list(
|
self._deployments: List[ManagedOnlineDeployment] = (
|
||||||
self._endpoint.name)
|
self._mlclient.online_deployments.list(self._endpoint.name)
|
||||||
|
)
|
||||||
return self._deployments
|
return self._deployments
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Could not fetch deployments for endpoint: {self._endpoint.name}. Exception => {e}")
|
logger.error(
|
||||||
|
f"Could not fetch deployments for endpoint: {self._endpoint.name}. Exception => {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -305,7 +344,11 @@ def _get_model_id_from_run_details():
|
||||||
def _get_model_details(model_asset_id, supported_model_map) -> Tuple[str, str, str]:
|
def _get_model_details(model_asset_id, supported_model_map) -> Tuple[str, str, str]:
|
||||||
# try matching registry model pattern
|
# try matching registry model pattern
|
||||||
if match := REGISTRY_MODEL_PATTERN.match(model_asset_id):
|
if match := REGISTRY_MODEL_PATTERN.match(model_asset_id):
|
||||||
registry, model_name, model_version = match.group("registry"), match.group("model"), match.group("version")
|
registry, model_name, model_version = (
|
||||||
|
match.group("registry"),
|
||||||
|
match.group("model"),
|
||||||
|
match.group("version"),
|
||||||
|
)
|
||||||
# check if model_name exists in supported list
|
# check if model_name exists in supported list
|
||||||
if model_name not in supported_model_map:
|
if model_name not in supported_model_map:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -412,7 +455,7 @@ def exponential_backoff(
|
||||||
ACFTUserError,
|
ACFTUserError,
|
||||||
pii_safe_message=(
|
pii_safe_message=(
|
||||||
f"Encountered unknown status code: {status_code}. ",
|
f"Encountered unknown status code: {status_code}. ",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -435,10 +478,25 @@ def exponential_backoff(
|
||||||
ACFTUserError,
|
ACFTUserError,
|
||||||
pii_safe_message=(
|
pii_safe_message=(
|
||||||
f"Request failed after multiple tries with status code: {status_code}. ",
|
f"Request failed after multiple tries with status code: {status_code}. ",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash_value(data: Union[Dict[str, Any], str]) -> str:
|
||||||
|
"""
|
||||||
|
Get hash value for the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Dict[str, Any]): Data for which hash value needs to be computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hash_value (str): Hash value.
|
||||||
|
"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
return hashlib.sha256(data.encode()).hexdigest()
|
||||||
|
return hashlib.sha256(json.dumps(data).encode()).hexdigest()
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module link output on run conditionally.
|
||||||
|
|
||||||
|
If condition is `true`, link `output` to `generated_batch_train_file_path` and `generated_batch_validation_file_path`.
|
||||||
|
If condition is `false`, link `output` to `generated_train_file_path` and `generated_validation_file_path`.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def copy_file_contents(input_src1, ft_input_train_file_path):
|
||||||
|
"""
|
||||||
|
Copy the contents of one file to another.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_src1 (str): The path to the source file.
|
||||||
|
ft_input_train_file_path (str): The path to the destination file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Read the contents of input_src1
|
||||||
|
with open(input_src1, "r") as src_file:
|
||||||
|
contents = src_file.read()
|
||||||
|
|
||||||
|
# Write the contents to ft_input_train_file_path
|
||||||
|
with open(ft_input_train_file_path, "w") as dest_file:
|
||||||
|
dest_file.write(contents)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--generated_batch_train_file_path", type=str)
|
||||||
|
parser.add_argument("--generated_batch_validation_file_path", type=str)
|
||||||
|
parser.add_argument("--generated_train_file_path", type=str)
|
||||||
|
parser.add_argument("--generated_validation_file_path", type=str)
|
||||||
|
parser.add_argument("--condition", type=str)
|
||||||
|
parser.add_argument("--ft_input_train_file_path", type=str)
|
||||||
|
parser.add_argument("--ft_input_validation_file_path", type=str)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
print(f"Condition output component received args: {args}.")
|
||||||
|
if (
|
||||||
|
args.generated_batch_train_file_path is None
|
||||||
|
and args.generated_train_file_path is None
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
"Got 'generated_batch_train_file_path' and 'generated_train_file_path' both be None."
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = args.condition.lower() == "true"
|
||||||
|
input_src1 = args.generated_train_file_path
|
||||||
|
input_src2 = args.generated_validation_file_path
|
||||||
|
ft_input_train_file_path = args.ft_input_train_file_path
|
||||||
|
ft_input_validation_file_path = args.ft_input_validation_file_path
|
||||||
|
if condition:
|
||||||
|
input_src1 = args.generated_batch_train_file_path
|
||||||
|
input_src2 = args.generated_batch_validation_file_path
|
||||||
|
copy_file_contents(input_src1, ft_input_train_file_path)
|
||||||
|
copy_file_contents(input_src2, ft_input_validation_file_path)
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""File containing function for FTaaS data import component."""
|
||||||
|
|
||||||
|
from azureml.acft.common_components import (
|
||||||
|
get_logger_app,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||||
|
swallow_all_exceptions,
|
||||||
|
)
|
||||||
|
from azureml.telemetry.activity import log_activity
|
||||||
|
|
||||||
|
from azure.ai.ml import Input
|
||||||
|
from mldesigner import Output, command_component
|
||||||
|
|
||||||
|
from common.constants import (
|
||||||
|
DataGenerationTaskType,
|
||||||
|
TelemetryConstants,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger_app(
|
||||||
|
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@command_component
|
||||||
|
@swallow_all_exceptions(logger)
|
||||||
|
def validate(
|
||||||
|
data_generation_task_type: Input(type="string", optional=False), # noqa: F821
|
||||||
|
) -> Output(type="boolean", is_control=True): # noqa: F821
|
||||||
|
"""Entry function of model validation script."""
|
||||||
|
with log_activity(
|
||||||
|
logger,
|
||||||
|
TelemetryConstants.VERSION_SELECTION,
|
||||||
|
{"data_generation_task_type": data_generation_task_type},
|
||||||
|
):
|
||||||
|
logger.info("Validating arguments: " + repr(data_generation_task_type))
|
||||||
|
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
|
@ -0,0 +1,383 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""File containing function for FTaaS data import component."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
|
||||||
|
from azureml.acft.contrib.hf.nlp.constants.constants import (
|
||||||
|
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components import (
|
||||||
|
get_logger_app,
|
||||||
|
set_logging_parameters,
|
||||||
|
LoggingLiterals,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||||
|
swallow_all_exceptions,
|
||||||
|
)
|
||||||
|
from azureml.telemetry.activity import log_activity
|
||||||
|
from common.io import read_jsonl_files
|
||||||
|
from common.constants import (
|
||||||
|
COMPONENT_NAME,
|
||||||
|
DEFAULT_SUCCESS_RATIO,
|
||||||
|
DataGenerationTaskType,
|
||||||
|
HashField,
|
||||||
|
TelemetryConstants,
|
||||||
|
SystemPrompt,
|
||||||
|
PayloadField,
|
||||||
|
STATUS_SUCCESS,
|
||||||
|
FINISH_REASON_STOP,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common.utils import (
|
||||||
|
get_hash_value,
|
||||||
|
get_workspace_mlclient,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger_app(
|
||||||
|
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
"""
|
||||||
|
Add arguments and returns the parser. Here we add all the arguments for all the tasks.
|
||||||
|
|
||||||
|
Those arguments that are not relevant for the input task should be ignored.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Model selector for hugging face models", allow_abbrev=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# File I/O
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_score_train_result",
|
||||||
|
type=str,
|
||||||
|
help="Path to the directory containing jsonl file(s) that have the result for each payload.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_score_validation_result",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Path to the directory containing jsonl file(s) that have the result for each payload.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hash_train_data",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path tho the jsonl file containing the hash for each payload.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hash_validation_data",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path tho the jsonl file containing the hash for each payload.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_batch_train_file_path",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="file to save the generated training data",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_batch_validation_file_path",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="file to save the generated validation data",
|
||||||
|
)
|
||||||
|
|
||||||
|
# File I/O
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_file_path",
|
||||||
|
type=str,
|
||||||
|
help="Input train file path",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_file_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Input validation file path",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_chain_of_thought",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="false",
|
||||||
|
help="This enables Chain of Thought",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_chain_of_density",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="false",
|
||||||
|
help="This enables Chain of Density for Summarization",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_generation_task_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Text Summary for Article
|
||||||
|
""",
|
||||||
|
choices=[v.value for v in DataGenerationTaskType],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_endpoint_success_ratio",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
default=DEFAULT_SUCCESS_RATIO,
|
||||||
|
help=(
|
||||||
|
f"The minimum value of "
|
||||||
|
"(successful_requests / total_requests) required for classifying inference as successful. "
|
||||||
|
"If (successful_requests / total_requests) < min_endpoint_success_ratio, "
|
||||||
|
"the experiment will be marked as failed. "
|
||||||
|
f"By default it is {DEFAULT_SUCCESS_RATIO}. "
|
||||||
|
"(0 means all requests are allowed to fail while 1 means no request should fail.)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--connection_config_file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="A config file path that contains deployment configurations.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def delete_connection(config_file: str):
|
||||||
|
"""Delete the connection configuration file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file (str): The path to the connection configuration file.
|
||||||
|
"""
|
||||||
|
if config_file:
|
||||||
|
try:
|
||||||
|
config_from_file = {}
|
||||||
|
with open(config_file) as file:
|
||||||
|
config_from_file = json.load(file)
|
||||||
|
batch_connection_name = config_from_file.get("connection_name", None)
|
||||||
|
if batch_connection_name:
|
||||||
|
mlclient_ws = get_workspace_mlclient()
|
||||||
|
if not mlclient_ws:
|
||||||
|
raise Exception("Could not create MLClient for current workspace")
|
||||||
|
mlclient_ws.connections.delete(batch_connection_name)
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"Error deleting connection: {e}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_data(
|
||||||
|
batch_score_res_path: str,
|
||||||
|
input_file_path: str,
|
||||||
|
enable_cot: bool,
|
||||||
|
enable_cod: bool,
|
||||||
|
data_generation_task_type: str,
|
||||||
|
min_endpoint_success_ratio: float,
|
||||||
|
output_file_path: str,
|
||||||
|
hash_data: str,
|
||||||
|
):
|
||||||
|
"""Generate and save synthentic data under output_dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_score_res_path (str): Path containing jsonl file(s) that have the result for each payload.
|
||||||
|
input_file_path (str): Input JSONL file path.
|
||||||
|
enable_cot (bool): Enable Chain of Thought
|
||||||
|
enable_cod (bool): Enable Chain of Density
|
||||||
|
data_generation_task_type (str): Data generation task type
|
||||||
|
min_endpoint_success_ratio (float): Minimum success ratio below which run will be considered a failure
|
||||||
|
output_file_path (str): Output JSONL file path.
|
||||||
|
hash_data (str): Path to the jsonl file containing the hash for each payload.
|
||||||
|
"""
|
||||||
|
error_count = 0
|
||||||
|
output_data = []
|
||||||
|
if input_file_path is None:
|
||||||
|
logger.info(
|
||||||
|
f"No input file path provided. Skipping data postprocessing for {input_file_path}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
hash_data_list: List[Dict[str, Any]] = read_jsonl_files(hash_data)
|
||||||
|
# Recreate hash_data_list respecting the order of batch_score_res_list
|
||||||
|
hash_data_dict = {
|
||||||
|
hash_data[HashField.HASH]: hash_data for hash_data in hash_data_list
|
||||||
|
}
|
||||||
|
total_rows = len(hash_data_list)
|
||||||
|
|
||||||
|
batch_score_res_list: List[Dict[str, Any]] = read_jsonl_files(batch_score_res_path)
|
||||||
|
try:
|
||||||
|
for idx, batch_score_dict in enumerate(batch_score_res_list):
|
||||||
|
status = batch_score_dict.get("status", "")
|
||||||
|
if status == STATUS_SUCCESS:
|
||||||
|
request_dict = batch_score_dict.get("request", {})
|
||||||
|
if request_dict:
|
||||||
|
synthetic_responses = []
|
||||||
|
messages = request_dict.pop("messages", [])
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
if role == PayloadField.USER:
|
||||||
|
hash_val = get_hash_value(message)
|
||||||
|
system_message = hash_data_dict[hash_val].get(
|
||||||
|
PayloadField.SYSTEM, {}
|
||||||
|
)
|
||||||
|
synthetic_responses.append(system_message)
|
||||||
|
synthetic_responses.append(message)
|
||||||
|
break
|
||||||
|
response_data = batch_score_dict.get("response", {})
|
||||||
|
finish_reason = response_data["choices"][0]["finish_reason"]
|
||||||
|
if finish_reason == FINISH_REASON_STOP:
|
||||||
|
prediction_result = response_data["choices"][0]["message"][
|
||||||
|
"content"
|
||||||
|
].strip()
|
||||||
|
# For CoT prompts, need to remove the reasoning and only use the answer
|
||||||
|
if (
|
||||||
|
enable_cot
|
||||||
|
and data_generation_task_type
|
||||||
|
!= DataGenerationTaskType.CONVERSATION
|
||||||
|
):
|
||||||
|
key = SystemPrompt.get_response_key(
|
||||||
|
data_generation_task_type
|
||||||
|
)
|
||||||
|
prediction_result = json.loads(prediction_result)[key]
|
||||||
|
|
||||||
|
if (
|
||||||
|
enable_cod
|
||||||
|
and data_generation_task_type
|
||||||
|
== DataGenerationTaskType.SUMMARIZATION
|
||||||
|
):
|
||||||
|
result = json.loads(prediction_result)
|
||||||
|
prediction_result = result[-1]["Denser_Summary"]
|
||||||
|
|
||||||
|
synthetic_responses.append(
|
||||||
|
{"role": "assistant", "content": str(prediction_result)}
|
||||||
|
)
|
||||||
|
output_data.append({"messages": synthetic_responses})
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in postprocessing {idx} data: {e}")
|
||||||
|
raise e
|
||||||
|
success_ratio = float(total_rows - error_count) / total_rows
|
||||||
|
print(success_ratio)
|
||||||
|
if success_ratio < min_endpoint_success_ratio:
|
||||||
|
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
|
||||||
|
raise Exception(msg)
|
||||||
|
with open(output_file_path, "w") as f:
|
||||||
|
for record in output_data:
|
||||||
|
f.write(json.dumps(record) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def data_import(args: Namespace):
|
||||||
|
"""Copy the user data to output dir."""
|
||||||
|
train_file_path = args.train_file_path
|
||||||
|
validation_file_path = args.validation_file_path
|
||||||
|
batch_score_train_path = args.batch_score_train_result
|
||||||
|
batch_score_validation_path = args.batch_score_validation_result
|
||||||
|
generated_batch_train_file_path = args.generated_batch_train_file_path
|
||||||
|
generated_batch_validation_file_path = args.generated_batch_validation_file_path
|
||||||
|
enable_cot_str = args.enable_chain_of_thought
|
||||||
|
enable_cod_str = args.enable_chain_of_density
|
||||||
|
min_endpoint_success_ratio = args.min_endpoint_success_ratio
|
||||||
|
data_generation_task_type = args.data_generation_task_type
|
||||||
|
hash_train_data = args.hash_train_data
|
||||||
|
hash_validation_data = args.hash_validation_data
|
||||||
|
connection_config_file = args.connection_config_file
|
||||||
|
|
||||||
|
enable_cot = True if enable_cot_str.lower() == "true" else False
|
||||||
|
enable_cod = True if enable_cod_str.lower() == "true" else False
|
||||||
|
|
||||||
|
with log_activity(
|
||||||
|
logger=logger, activity_name=TelemetryConstants.POST_PROCESS_TRAINING_DATA
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Deleting batch configuration connection used for teacher model invocation."
|
||||||
|
)
|
||||||
|
delete_connection(connection_config_file)
|
||||||
|
logger.info(
|
||||||
|
"Running data postprocessing for train file path: %s", train_file_path
|
||||||
|
)
|
||||||
|
postprocess_data(
|
||||||
|
batch_score_res_path=batch_score_train_path,
|
||||||
|
input_file_path=train_file_path,
|
||||||
|
enable_cot=enable_cot,
|
||||||
|
enable_cod=enable_cod,
|
||||||
|
data_generation_task_type=data_generation_task_type,
|
||||||
|
min_endpoint_success_ratio=min_endpoint_success_ratio,
|
||||||
|
output_file_path=generated_batch_train_file_path,
|
||||||
|
hash_data=hash_train_data,
|
||||||
|
)
|
||||||
|
if validation_file_path:
|
||||||
|
with log_activity(
|
||||||
|
logger=logger,
|
||||||
|
activity_name=TelemetryConstants.POST_PROCESS_VALIDATION_DATA,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Running data postprocessing for validation file path: %s",
|
||||||
|
validation_file_path,
|
||||||
|
)
|
||||||
|
postprocess_data(
|
||||||
|
batch_score_res_path=batch_score_validation_path,
|
||||||
|
input_file_path=validation_file_path,
|
||||||
|
enable_cot=enable_cot,
|
||||||
|
enable_cod=enable_cod,
|
||||||
|
data_generation_task_type=data_generation_task_type,
|
||||||
|
min_endpoint_success_ratio=min_endpoint_success_ratio,
|
||||||
|
output_file_path=generated_batch_validation_file_path,
|
||||||
|
hash_data=hash_validation_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Path(generated_batch_validation_file_path.parent).mkdir(
|
||||||
|
exist_ok=True, parents=True
|
||||||
|
)
|
||||||
|
# create an empty file if validation file is not provided
|
||||||
|
open(generated_batch_validation_file_path, "w").close()
|
||||||
|
|
||||||
|
|
||||||
|
@swallow_all_exceptions(time_delay=5)
|
||||||
|
def main():
|
||||||
|
"""Parse args and import model."""
|
||||||
|
parser = get_parser()
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
set_logging_parameters(
|
||||||
|
task_type="ChatCompletion",
|
||||||
|
acft_custom_dimensions={
|
||||||
|
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
|
||||||
|
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
|
||||||
|
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
|
||||||
|
},
|
||||||
|
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||||
|
log_level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_import(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,540 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""File containing function for FTaaS data import component."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
|
||||||
|
from azureml.acft.contrib.hf.nlp.constants.constants import (
|
||||||
|
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components import (
|
||||||
|
get_logger_app,
|
||||||
|
set_logging_parameters,
|
||||||
|
LoggingLiterals,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||||
|
swallow_all_exceptions,
|
||||||
|
)
|
||||||
|
from azureml.telemetry.activity import log_activity
|
||||||
|
from azure.ai.ml.entities import ServerlessConnection
|
||||||
|
|
||||||
|
from common.io import read_jsonl_files
|
||||||
|
|
||||||
|
from mltable import from_json_lines_files
|
||||||
|
|
||||||
|
from common.constants import (
|
||||||
|
COMPONENT_NAME,
|
||||||
|
DEFAULT_MAX_NEW_TOKENS,
|
||||||
|
DEFAULT_SUMMARY_MAX_NEW_TOKENS,
|
||||||
|
DEFAULT_TEMPERATURE,
|
||||||
|
DEFAULT_TOP_P,
|
||||||
|
FREQUENCY_PENALTY,
|
||||||
|
PRESENCE_PENALTY,
|
||||||
|
MAX_NEW_TOKENS,
|
||||||
|
TEMPERATURE,
|
||||||
|
TOP_P,
|
||||||
|
STOP_TOKEN,
|
||||||
|
VLLM_CHAT_SCORE_PATH,
|
||||||
|
DataGenerationTaskType,
|
||||||
|
HashField,
|
||||||
|
TelemetryConstants,
|
||||||
|
SystemPrompt,
|
||||||
|
PayloadField,
|
||||||
|
DEFAULT_MAX_LEN_SUMMARY,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common.utils import (
|
||||||
|
get_workspace_mlclient,
|
||||||
|
get_endpoint_details,
|
||||||
|
get_hash_value,
|
||||||
|
validate_teacher_model_details,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common.validation import validate_file_paths_with_supported_formats
|
||||||
|
|
||||||
|
logger = get_logger_app(
|
||||||
|
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
"""
|
||||||
|
Add arguments and returns the parser. Here we add all the arguments for all the tasks.
|
||||||
|
|
||||||
|
Those arguments that are not relevant for the input task should be ignored.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Model selector for hugging face models", allow_abbrev=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# File I/O
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_file_path",
|
||||||
|
type=str,
|
||||||
|
help="Input train file path",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_file_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Input validation file path",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_train_file_path",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="file to save the generated training data",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_validation_file_path",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="file to save the generated validation data",
|
||||||
|
)
|
||||||
|
|
||||||
|
# add optional data-generator params
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_endpoint_name",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Teacher model endpoint name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_endpoint_key",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Teacher model endpoint key",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_endpoint_url",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Teacher model endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_max_new_tokens",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=DEFAULT_MAX_NEW_TOKENS,
|
||||||
|
help="Teacher model max_tokens parameter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_temperature",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
default=DEFAULT_TEMPERATURE,
|
||||||
|
help="Teacher model temperature parameter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_top_p",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
default=DEFAULT_TOP_P,
|
||||||
|
help="Teacher model top-p parameter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_frequency_penalty",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
help="Teacher model frequency parameter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_presence_penalty",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
help="Teacher model presense penalty",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacher_model_stop", type=str, required=False, help="Teacher model stop "
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_chain_of_thought",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="false",
|
||||||
|
help="This enables Chain of Thought",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_chain_of_density",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="false",
|
||||||
|
help="This enables Chain of Density for Summarization",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_len_summary",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=DEFAULT_MAX_LEN_SUMMARY,
|
||||||
|
help="Maximum word count for text summarization ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_generation_task_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""Data generation task type. Supported values are:
|
||||||
|
1. NLI: Generate Natural Language Inference data
|
||||||
|
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||||
|
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||||
|
4. MATH: Generate Math data for numerical responses
|
||||||
|
5. SUMMARIZATION: Generate Text Summary for Article
|
||||||
|
""",
|
||||||
|
choices=[v.value for v in DataGenerationTaskType],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_train_payload_path",
|
||||||
|
type=str,
|
||||||
|
help="file to save the generated training payload data",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_validation_payload_path",
|
||||||
|
type=str,
|
||||||
|
help="file to save the generated validation payload data",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hash_train_data",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path tho the jsonl file where the hash for each payload will be dumped.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hash_validation_data",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path tho the jsonl file where the hash for each payload will be dumped.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_config_connection",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(
|
||||||
|
inference_params: dict,
|
||||||
|
enable_cot: bool,
|
||||||
|
enable_cod: bool,
|
||||||
|
max_len_summary: int,
|
||||||
|
data_generation_task_type: str,
|
||||||
|
generated_train_payload_path: str,
|
||||||
|
generated_validation_payload_path: str,
|
||||||
|
hash_train_data: str,
|
||||||
|
hash_validation_data: str,
|
||||||
|
train_file_path: Path,
|
||||||
|
validation_file_path: Path = None,
|
||||||
|
):
|
||||||
|
"""Generate and save synthentic data under output_dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_params (dict): Inference params to hit endpoint with
|
||||||
|
enable_cot (bool): Enable Chain of Thought processing
|
||||||
|
enable_cod (bool): Enable Chain of Density processing for text summarization task
|
||||||
|
max_len_summary (int): Maximum word count for text summarization
|
||||||
|
data_generation_task_type (str): Data generation task type
|
||||||
|
generated_train_payload_path (str): Path to save the generated training payload data
|
||||||
|
generated_validation_payload_path (str): Path to save the generated validation payload data
|
||||||
|
hash_train_data (str): Path to the jsonl file where the hash for each payload will be dumped.
|
||||||
|
hash_validation_data (str): Path to the jsonl file where the hash for each payload will be dumped.
|
||||||
|
train_file_path (Path): Train JSONL file path
|
||||||
|
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process_system_prompt(message: dict) -> dict:
|
||||||
|
"""Update the system prompt depending on the task type and the flag enable_cot.
|
||||||
|
|
||||||
|
The original message unchanged if enable_cot is False or task type is conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict): System message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
message (dict): System message with updated content
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
enable_cot
|
||||||
|
and data_generation_task_type != DataGenerationTaskType.CONVERSATION
|
||||||
|
):
|
||||||
|
cot_prompt = SystemPrompt.get_cot_prompt(data_generation_task_type)
|
||||||
|
cot_system_message = {"role": "system", "content": cot_prompt}
|
||||||
|
return cot_system_message
|
||||||
|
elif (
|
||||||
|
enable_cod
|
||||||
|
and data_generation_task_type == DataGenerationTaskType.SUMMARIZATION
|
||||||
|
):
|
||||||
|
cod_prompt = SystemPrompt.get_cod_prompt(max_len_summary)
|
||||||
|
cod_system_message = {"role": "system", "content": cod_prompt}
|
||||||
|
return cod_system_message
|
||||||
|
else:
|
||||||
|
return message
|
||||||
|
|
||||||
|
def pre_process_data(
|
||||||
|
input_file_path: Path,
|
||||||
|
output_file_path: Path,
|
||||||
|
hash_file_path: str,
|
||||||
|
) -> None:
|
||||||
|
"""Batch process data and do a bulk request to teacher model endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file_path (Path): Input data file path
|
||||||
|
output_file_path (Path): Path to output directory
|
||||||
|
hash_file_path (str): Path to the jsonl file where the hash for each payload will be dumped.
|
||||||
|
Raises:
|
||||||
|
Exception: if success ratio is less than min_endpoint_success_ratio
|
||||||
|
"""
|
||||||
|
input_data = read_jsonl_files(input_file_path)
|
||||||
|
output_data = []
|
||||||
|
output_hash_data = []
|
||||||
|
try:
|
||||||
|
for idx, record in enumerate(input_data):
|
||||||
|
# Basic validation for the input data
|
||||||
|
messages = record.pop("messages", [])
|
||||||
|
if not messages: # empty messages
|
||||||
|
logger.error(f"Failed with exception:{idx} Empty messages")
|
||||||
|
return
|
||||||
|
first_message = messages[0]
|
||||||
|
if first_message["role"] != PayloadField.SYSTEM:
|
||||||
|
logger.error(
|
||||||
|
f"row {idx} failed with exception: First message should be system, "
|
||||||
|
f"but got {first_message['role']}"
|
||||||
|
)
|
||||||
|
for message in messages[1:]:
|
||||||
|
role = message["role"]
|
||||||
|
if role not in ("assistant", "user"):
|
||||||
|
logger.error(
|
||||||
|
f"row {idx} failed with exception: role should be system or user, but got {role}"
|
||||||
|
)
|
||||||
|
inference_data = []
|
||||||
|
system_message = {}
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
if role == PayloadField.SYSTEM:
|
||||||
|
system_message[PayloadField.SYSTEM] = message
|
||||||
|
inference_data.append(process_system_prompt(message))
|
||||||
|
elif role == PayloadField.USER:
|
||||||
|
inference_data.append(message)
|
||||||
|
hash_data = {
|
||||||
|
HashField.HASH: get_hash_value(message),
|
||||||
|
**system_message,
|
||||||
|
}
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"messages": inference_data,
|
||||||
|
**inference_params,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_hash_data.append(hash_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"idx: {idx}. exception: {e}")
|
||||||
|
payload_jsonl_path = os.path.join(output_file_path, "payload.jsonl")
|
||||||
|
logger.info("payload_jsonl_path: %s", payload_jsonl_path)
|
||||||
|
with open(payload_jsonl_path, "w") as payload_file:
|
||||||
|
for entry in output_data:
|
||||||
|
payload_file.write(json.dumps(entry) + "\n")
|
||||||
|
logger.info("hash_file_path: %s", hash_file_path)
|
||||||
|
with open(hash_file_path, "w") as hash_file:
|
||||||
|
for entry in output_hash_data:
|
||||||
|
hash_file.write(json.dumps(entry) + "\n")
|
||||||
|
|
||||||
|
output_file_path = str(output_file_path)
|
||||||
|
mltable = from_json_lines_files(paths=[{"file": payload_jsonl_path}])
|
||||||
|
logger.info("output_file_path type before saving: %s", output_file_path)
|
||||||
|
mltable.save(output_file_path)
|
||||||
|
|
||||||
|
with log_activity(
|
||||||
|
logger=logger, activity_name=TelemetryConstants.PRE_PROCESS_TRAINING_DATA
|
||||||
|
):
|
||||||
|
logger.info("PreProcessing train file")
|
||||||
|
pre_process_data(train_file_path, generated_train_payload_path, hash_train_data)
|
||||||
|
logger.info("Data generated and saved for train file")
|
||||||
|
|
||||||
|
if validation_file_path:
|
||||||
|
with log_activity(
|
||||||
|
logger=logger,
|
||||||
|
activity_name=TelemetryConstants.PRE_PROCESS_VALIDATION_DATA,
|
||||||
|
):
|
||||||
|
logger.info("PreProcessing validation file")
|
||||||
|
pre_process_data(
|
||||||
|
validation_file_path,
|
||||||
|
generated_validation_payload_path,
|
||||||
|
hash_validation_data,
|
||||||
|
)
|
||||||
|
logger.info("Data generated and saved for validation file")
|
||||||
|
else:
|
||||||
|
hash_validation_data = Path(hash_validation_data)
|
||||||
|
Path(hash_validation_data.parent).mkdir(exist_ok=True, parents=True)
|
||||||
|
# create an empty file if validation file is not provided
|
||||||
|
open(hash_validation_data, "w").close()
|
||||||
|
|
||||||
|
|
||||||
|
def data_import(args: Namespace):
|
||||||
|
"""Copy the user data to output dir."""
|
||||||
|
train_file_path = args.train_file_path
|
||||||
|
validation_file_path = args.validation_file_path
|
||||||
|
generated_train_payload_path = args.generated_train_payload_path
|
||||||
|
generated_validation_payload_path = args.generated_validation_payload_path
|
||||||
|
teacher_model_endpoint_name = args.teacher_model_endpoint_name
|
||||||
|
teacher_model_endpoint_url = args.teacher_model_endpoint_url
|
||||||
|
teacher_model_endpoint_key = args.teacher_model_endpoint_key
|
||||||
|
# add optional data-generator params
|
||||||
|
teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
|
||||||
|
teacher_model_temperature = args.teacher_model_temperature
|
||||||
|
teacher_model_top_p = args.teacher_model_top_p
|
||||||
|
teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
|
||||||
|
teacher_model_presence_penalty = args.teacher_model_presence_penalty
|
||||||
|
teacher_model_stop = args.teacher_model_stop
|
||||||
|
enable_cot_str = args.enable_chain_of_thought
|
||||||
|
enable_cod_str = args.enable_chain_of_density
|
||||||
|
max_len_summary = args.max_len_summary
|
||||||
|
data_generation_task_type = args.data_generation_task_type
|
||||||
|
hash_train_data = args.hash_train_data
|
||||||
|
hash_validation_data = args.hash_validation_data
|
||||||
|
batch_config_connection = args.batch_config_connection
|
||||||
|
|
||||||
|
# validate file formats
|
||||||
|
validate_file_paths_with_supported_formats(
|
||||||
|
[args.train_file_path, args.validation_file_path]
|
||||||
|
)
|
||||||
|
logger.info("File format validation successful.")
|
||||||
|
|
||||||
|
enable_cot = True if enable_cot_str.lower() == "true" else False
|
||||||
|
enable_cod = True if enable_cod_str.lower() == "true" else False
|
||||||
|
|
||||||
|
mlclient_ws = get_workspace_mlclient()
|
||||||
|
if not mlclient_ws:
|
||||||
|
raise Exception("Could not create MLClient for current workspace")
|
||||||
|
|
||||||
|
if teacher_model_endpoint_name:
|
||||||
|
endpoint_details = get_endpoint_details(
|
||||||
|
mlclient_ws, teacher_model_endpoint_name
|
||||||
|
)
|
||||||
|
teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
|
||||||
|
teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
|
||||||
|
teacher_model_asset_id = endpoint_details.get_deployed_model_id()
|
||||||
|
validate_teacher_model_details(teacher_model_asset_id)
|
||||||
|
|
||||||
|
if not teacher_model_endpoint_url:
|
||||||
|
raise Exception("Endpoint URL is a requried parameter for data generation")
|
||||||
|
|
||||||
|
if not teacher_model_endpoint_key:
|
||||||
|
raise Exception("Endpoint key is a requried parameter for data generation")
|
||||||
|
|
||||||
|
if teacher_model_top_p < 0 or teacher_model_top_p > 1:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
|
||||||
|
)
|
||||||
|
if teacher_model_temperature < 0 or teacher_model_temperature > 1:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
|
||||||
|
)
|
||||||
|
inference_params = {
|
||||||
|
MAX_NEW_TOKENS: (
|
||||||
|
DEFAULT_SUMMARY_MAX_NEW_TOKENS
|
||||||
|
if data_generation_task_type == "SUMMARIZATION"
|
||||||
|
and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
|
||||||
|
else teacher_model_max_new_tokens
|
||||||
|
),
|
||||||
|
TEMPERATURE: teacher_model_temperature,
|
||||||
|
TOP_P: teacher_model_top_p,
|
||||||
|
}
|
||||||
|
|
||||||
|
if teacher_model_frequency_penalty:
|
||||||
|
inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty
|
||||||
|
|
||||||
|
if teacher_model_presence_penalty:
|
||||||
|
inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty
|
||||||
|
|
||||||
|
if teacher_model_stop:
|
||||||
|
inference_params[STOP_TOKEN] = teacher_model_stop
|
||||||
|
|
||||||
|
if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
|
||||||
|
teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH
|
||||||
|
|
||||||
|
logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
guid = uuid.uuid4()
|
||||||
|
short_guid = str(guid)[:8]
|
||||||
|
connection_name = f"distillation-ws-connection-{short_guid}"
|
||||||
|
mlclient_ws.connections.create_or_update(
|
||||||
|
ServerlessConnection(
|
||||||
|
name=connection_name,
|
||||||
|
endpoint=teacher_model_endpoint_url,
|
||||||
|
api_key=teacher_model_endpoint_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"Connection created with name: {connection_name}")
|
||||||
|
config = {}
|
||||||
|
config["scoring_url"] = teacher_model_endpoint_url
|
||||||
|
config["connection_name"] = connection_name
|
||||||
|
with open(batch_config_connection, "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create connection for teacher model batch score invocation : {e}"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
"Failed to create workspace connection for teacher model batch score invocation "
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Running data preprocessing")
|
||||||
|
preprocess_data(
|
||||||
|
inference_params=inference_params,
|
||||||
|
enable_cot=enable_cot,
|
||||||
|
enable_cod=enable_cod,
|
||||||
|
max_len_summary=max_len_summary,
|
||||||
|
generated_train_payload_path=generated_train_payload_path,
|
||||||
|
generated_validation_payload_path=generated_validation_payload_path,
|
||||||
|
train_file_path=train_file_path,
|
||||||
|
data_generation_task_type=data_generation_task_type,
|
||||||
|
validation_file_path=validation_file_path,
|
||||||
|
hash_train_data=hash_train_data,
|
||||||
|
hash_validation_data=hash_validation_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@swallow_all_exceptions(time_delay=5)
|
||||||
|
def main():
|
||||||
|
"""Parse args and import model."""
|
||||||
|
parser = get_parser()
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
set_logging_parameters(
|
||||||
|
task_type="ChatCompletion",
|
||||||
|
acft_custom_dimensions={
|
||||||
|
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
|
||||||
|
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
|
||||||
|
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
|
||||||
|
},
|
||||||
|
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||||
|
log_level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_import(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""File containing function for FTaaS data import component."""
|
||||||
|
|
||||||
|
from azureml.acft.common_components import (
|
||||||
|
get_logger_app,
|
||||||
|
)
|
||||||
|
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||||
|
swallow_all_exceptions,
|
||||||
|
)
|
||||||
|
from azureml.telemetry.activity import log_activity
|
||||||
|
|
||||||
|
from azure.ai.ml import Input
|
||||||
|
from mldesigner import Output, command_component
|
||||||
|
|
||||||
|
from common.constants import (
|
||||||
|
TelemetryConstants,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger_app(
|
||||||
|
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@command_component
|
||||||
|
@swallow_all_exceptions(logger)
|
||||||
|
def validate(
|
||||||
|
validation_file_path: Input(type="string", optional=True), # noqa: F821
|
||||||
|
) -> Output(type="boolean", is_control=True): # noqa: F821
|
||||||
|
"""Entry function of model validation script."""
|
||||||
|
with log_activity(
|
||||||
|
logger,
|
||||||
|
TelemetryConstants.VERSION_SELECTION,
|
||||||
|
{"validation_file_path": validation_file_path},
|
||||||
|
):
|
||||||
|
logger.info("Validating arguments: " + repr(validation_file_path))
|
||||||
|
if validation_file_path:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
Загрузка…
Ссылка в новой задаче