Code changes for moving to Batch Scoring for Teacher Model (#3507)

* Code changes for moving to Batch Scoring for Teacher Model

* Addition of copy right header

* Lengthy line formating fix

* Fix for unterminated f string

* Remove unwanted imports

* File formatting fixes

* File formatting fixes

* Fix for noqa

* Doc String for copy file contents function

* Fix for closing the connection in post processing

* Making changes for fixing pipeline changes

* Version changes

* Merge conflict fixes

* Environmental version fixes
This commit is contained in:
visahan-24 2024-11-04 18:41:37 +05:30 коммит произвёл GitHub
Родитель 4f983518a6
Коммит 04eff530e2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
29 изменённых файлов: 2745 добавлений и 49 удалений

Просмотреть файл

@ -0,0 +1,61 @@
# OSS Distillation Batch Score Data Generation Pipeline
This component generates data from a teacher model endpoint by invoking it in batch mode. It is part of the OSS Distillation pipeline.
## Component Details
- **Name**: `oss_distillation_batchscoring_datagen_pipeline`
- **Version**: `0.0.1`
- **Type**: `pipeline`
- **Display Name**: `OSS Distillation Batch Score Data Generation Pipeline`
- **Description**: Component to generate data from teacher model endpoint by invoking it in batch.
## Inputs
| Name | Type | Optional | Default | Description |
|----------------------------------|----------|----------|----------------------------------|-------------------------------------------------------------------------------------------------------------|
| instance_type_pipeline_validation| string | True | | Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used. |
| instance_type_data_generation | string | True | Standard_D4as_v4 | Instance type to be used for finetune component in case of virtual cluster compute. |
| instance_type_data_import | string | True | Singularity.ND96amrs_A100_v4 | Instance type to be used for data_import component in case of virtual cluster compute. |
| instance_type_finetune | string | True | Singularity.ND96amrs_A100_v4 | Instance type to be used for finetune component in case of virtual cluster compute. |
| compute_pipeline_validation | string | True | serverless | Compute to be used for validation component. |
| compute_data_generation | string | True | serverless | Compute to be used for model_import. |
| compute_data_import | string | True | serverless | Compute to be used for model_import. |
| compute_finetune | string | True | serverless | Compute to be used for finetune. |
| train_file_path | uri_file | False | | Path to the registered training data asset. |
| validation_file_path | uri_file | True | | Path to the registered validation data asset. |
| teacher_model_endpoint_url | string | True | | Teacher model endpoint URL. |
| teacher_model_asset_id | string | True | | Teacher model Asset Id. |
| teacher_model_endpoint_name | string | True | | Teacher model endpoint name. |
| teacher_model_max_new_tokens | integer | True | 128 | Teacher model max_new_tokens inference parameter. |
| teacher_model_temperature | number | True | 0.2 | Teacher model temperature inference parameter. |
| teacher_model_top_p | number | True | 0.1 | Teacher model top_p inference parameter. |
| teacher_model_frequency_penalty | number | True | 0.0 | Teacher model frequency penalty inference parameter. |
| teacher_model_presence_penalty | number | True | 0.0 | Teacher model presence penalty inference parameter. |
| teacher_model_stop | string | True | | Teacher model stop inference parameter. |
| min_endpoint_success_ratio | number | True | 0.7 | Minimum value of (successful_requests / total_requests) required for classifying inference as successful. |
| enable_chain_of_thought | string | True | false | Enable Chain of thought for data generation. |
| enable_chain_of_density | string | True | false | Enable Chain of density for text summarization. |
| max_len_summary | integer | True | 80 | Maximum Length Summary for text summarization. |
| data_generation_task_type | string | False | | Data generation task type. Supported values: NLI, CONVERSATION, NLU_QA, MATH, SUMMARIZATION. |
| num_train_epochs | integer | True | 1 | Training epochs. |
| per_device_train_batch_size | integer | True | 1 | Train batch size. |
| learning_rate | number | True | 3e-04 | Start learning rate. |
| authentication_type | string | False | azureml_workspace_connection | Authentication type for endpoint. Supported values: azureml_workspace_connection, managed_identity. |
| connection_name | string | True | | Connection name to be used for authentication. |
| additional_headers | string | True | | JSON serialized string expressing additional headers to be added to each request. |
| debug_mode | boolean | False | False | Enable debug mode to print all the debug logs in the score step. |
| ensure_ascii | boolean | False | False | If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. |
| max_retry_time_interval | integer | True | | The maximum time (in seconds) spent retrying a payload. |
| initial_worker_count | integer | False | 5 | The initial number of workers to use for scoring. |
| max_worker_count | integer | False | 200 | Overrides `initial_worker_count` if necessary. |
| instance_count | integer | False | 1 | Number of nodes in a compute cluster we will run the batch score step on. |
| max_concurrency_per_instance | integer | False | 1 | Number of processes that will be run concurrently on any given node. |
| mini_batch_size | string | True | 100KB | The mini batch size for parallel run. |
## Outputs
| Name | Type | Description |
|----------------------------------|----------|-------------------------------------------------------------------------------------------------------------|
| generated_batch_train_file_path | uri_file | Generated train data |
| generated_batch_validation_file_path | uri_file | Generated validation data |

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring"]

Просмотреть файл

@ -0,0 +1,417 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_batchscoring_datagen_pipeline
version: 0.0.1
type: pipeline
display_name: OSS Distillation Batch Score Data Generation Pipeline
description: Component to generate data from teacher model endpoint by invoking it in batch.
inputs:
# Compute parameters
instance_type_pipeline_validation:
type: string
optional: True
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
instance_type_data_generation:
type: string
optional: true
default: Standard_D4as_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
instance_type_data_import:
type: string
optional: true
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for data_import component in case of virtual cluster compute, eg. Singularity.D8_v3. The parameter compute_data_import must be set to 'serverless' for instance_type to be used
instance_type_finetune:
type: string
optional: true
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
compute_pipeline_validation:
type: string
optional: True
default: 'serverless'
description: compute to be used for validation component
compute_data_generation:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for model_import eg. provide 'FT-Cluster' if
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
compute_data_import:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for model_import eg. provide 'FT-Cluster' if
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
compute_finetune:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for finetune eg. provide 'FT-Cluster' if your
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
# ########################### Data Generator Component ########################### #
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL
teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name
teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key
teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter
teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter
teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter
teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter
teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter
teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter
min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
optional: true
default: "false"
description: Enable Chain of thought for data generation
enable_chain_of_density:
type: string
optional: true
default: "false"
description: Enable Chain of density for text summarization
max_len_summary:
type: integer
optional: true
default: 80
description: Maximum Length Summary for text summarization
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
- MATH
- SUMMARIZATION
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
# Output of validation component.
validation_info:
type: uri_file
description: Validation status.
mode: rw_mount
# Training parameters
num_train_epochs:
type: integer
default: 1
optional: true
description: training epochs
per_device_train_batch_size:
type: integer
default: 1
optional: true
description: Train batch size
learning_rate:
type: number
default: 3e-04
optional: true
description: Start learning rate.
# ########################### Batch Score Component ########################### #
authentication_type:
type: string
optional: False
description: Authentication type for endpoint. Either `azureml_workspace_connection` or `managed_identity`.
default: azureml_workspace_connection
enum:
- azureml_workspace_connection
- managed_identity
configuration_file:
type: string
optional: true
description: Config file path that contains deployment configurations
additional_headers:
type: string
optional: True
description: JSON serialized string expressing additional headers to be added to each request.
debug_mode:
type: boolean
optional: False
default: False
description: Enable debug mode to print all the debug logs in the score step.
ensure_ascii:
type: boolean
optional: False
default: False
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
max_retry_time_interval:
type: integer
optional: True
description: The maximum time (in seconds) spent retrying a payload. If unspecified, payloads are retried for unlimited time.
initial_worker_count:
type: integer
optional: False
default: 5
description: The initial number of workers to use for scoring.
max_worker_count:
type: integer
optional: False
default: 200
description: Overrides `initial_worker_count` if necessary.
instance_count:
type: integer
default: 1
description: Number of nodes in a compute cluster we will run the batch score step on.
max_concurrency_per_instance:
type: integer
default: 1
description: Number of processes that will be run concurrently on any given node. This number should not be larger than 1/2 of the number of cores in an individual node in the specified cluster.
mini_batch_size:
type: string
optional: true
default: 100KB
description: The mini batch size for parallel run.
outputs:
generated_batch_train_file_path:
type: uri_file
description: Generated train data
mode: rw_mount
generated_batch_validation_file_path:
type: uri_file
description: Generated validation data
mode: rw_mount
jobs:
oss_distillation_generate_data_batch_preprocess:
type: command
component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
identity:
type: user_identity
inputs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
validation_info: '${{parent.inputs.validation_info}}'
outputs:
generated_train_payload_path:
type: mltable
generated_validation_payload_path:
type: mltable
hash_train_data:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
hash_validation_data:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
batch_config_connection:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
# Config generator job
oss_distillation_generate_data_config_generator:
type: command
component: azureml:batch_benchmark_config_generator:0.0.8
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity:
type: user_identity
inputs:
scoring_url: ${{parent.inputs.teacher_model_endpoint_url}}
deployment_name: ${{parent.inputs.teacher_model_endpoint_name}}
authentication_type: ${{parent.inputs.authentication_type}}
configuration_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
additional_headers: ${{parent.inputs.additional_headers}}
debug_mode: ${{parent.inputs.debug_mode}}
ensure_ascii: ${{parent.inputs.ensure_ascii}}
max_retry_time_interval: ${{parent.inputs.max_retry_time_interval}}
initial_worker_count: ${{parent.inputs.initial_worker_count}}
max_worker_count: ${{parent.inputs.max_worker_count}}
model_type: oss
outputs:
batch_score_config:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
# Batch score job
oss_distillation_train_data_batch_score:
type: parallel
component: azureml:batch_score_oss:0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
identity:
type: user_identity
inputs:
async_mode: False
data_input_table: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.generated_train_payload_path}}
configuration_file: ${{parent.jobs.oss_distillation_generate_data_config_generator.outputs.batch_score_config}}
outputs:
job_output_path:
type: uri_file
mini_batch_results_output_directory:
type: uri_folder
resources:
instance_count: ${{parent.inputs.instance_count}}
max_concurrency_per_instance: ${{parent.inputs.max_concurrency_per_instance}}
mini_batch_size: ${{parent.inputs.mini_batch_size}}
retry_settings:
timeout: 6000
max_retries: 10
environment_variables:
BATCH_SCORE_INITIAL_REQUEST_TIMEOUT: '180'
BATCH_SCORE_DELAY_AFTER_SUCCESSFUL_REQUEST: 'False'
BATCH_SCORE_MAX_REQUEST_TIMEOUT: '300'
validation_file_path_exists:
type: command
component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity:
type: user_identity
inputs:
validation_file_path: '${{parent.inputs.validation_file_path}}'
validation_succeeded:
type: if_else
condition: ${{parent.jobs.validation_file_path_exists.outputs.output}}
true_block: ${{parent.jobs.oss_distillation_validation_data_batch_score}}
# Batch score job
oss_distillation_validation_data_batch_score:
type: parallel
component: azureml:batch_score_oss:0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
identity:
type: user_identity
inputs:
async_mode: False
data_input_table: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.generated_validation_payload_path}}
configuration_file: ${{parent.jobs.oss_distillation_generate_data_config_generator.outputs.batch_score_config}}
outputs:
job_output_path:
type: uri_file
mini_batch_results_output_directory:
type: uri_folder
resources:
instance_count: ${{parent.inputs.instance_count}}
max_concurrency_per_instance: ${{parent.inputs.max_concurrency_per_instance}}
mini_batch_size: ${{parent.inputs.mini_batch_size}}
retry_settings:
timeout: 6000
max_retries: 10
environment_variables:
BATCH_SCORE_INITIAL_REQUEST_TIMEOUT: '180'
BATCH_SCORE_DELAY_AFTER_SUCCESSFUL_REQUEST: 'False'
BATCH_SCORE_MAX_REQUEST_TIMEOUT: '300'
oss_distillation_generate_data_batch_postprocess:
type: command
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
identity:
type: user_identity
inputs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
batch_score_train_result: '${{parent.jobs.oss_distillation_train_data_batch_score.outputs.mini_batch_results_output_directory}}'
batch_score_validation_result: '${{parent.jobs.oss_distillation_validation_data_batch_score.outputs.mini_batch_results_output_directory}}'
hash_train_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_train_data}}'
hash_validation_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_validation_data}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
outputs:
generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}'
generated_batch_validation_file_path: '${{parent.outputs.generated_batch_validation_file_path}}'

Просмотреть файл

@ -0,0 +1,39 @@
# OSS Distillation Generate Data Batch Scoring Preprocess
## Description
This component prepares data to invoke the teacher model endpoint in batch. It supports various data formats such as `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
## Environment
The component uses the following environment:
- `azureml://registries/azureml/environments/acft-hf-nlp-gpu/labels/latest`
## Inputs
The component accepts the following inputs:
| Name | Type | Description | Required | Default |
|-------------------------------|-----------|-------------------------------------------------------------------------------------------------|----------|---------|
| `train_file_path` | uri_file | Path to the registered training data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | Yes | |
| `validation_file_path` | uri_file | Path to the registered validation data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | No | |
| `teacher_model_endpoint_url` | string | The URL of the teacher model endpoint. | Yes | |
| `teacher_model_asset_id` | string | The asset ID of the teacher model. | Yes | |
| `teacher_model_max_new_tokens`| integer | Teacher model max_new_tokens inference parameter. | Yes | 128 |
| `teacher_model_temperature` | number | Teacher model temperature inference parameter. | Yes | 0.2 |
| `teacher_model_top_p` | number | Teacher model top_p inference parameter. | Yes | 0.1 |
| `teacher_model_frequency_penalty` | number | Teacher model frequency penalty inference parameter. | Yes | 0.0 |
| `teacher_model_presence_penalty` | number | Teacher model presence penalty inference parameter. | Yes | 0.0 |
| `teacher_model_stop` | string | Teacher model stop inference parameter. | No | |
| `enable_chain_of_thought` | string | Enable Chain of thought for data generation. | No | "false" |
| `enable_chain_of_density` | string | Enable Chain of density for text summarization. | No | "false" |
| `max_len_summary` | integer | Maximum Length Summary for text summarization. | No | 80 |
| `data_generation_task_type` | string | Specifies the type of data generation task. Supported values: `NLI`, `CONVERSATION`, `NLU_QA`, `MATH`, `SUMMARIZATION`. | Yes | |
| `validation_output` | uri_file | Validation status. | Yes | |
## Outputs
The component produces the following outputs:
| Name | Type | Description |
|----------------------------------|-----------|---------------------------------------------------------------|
| `generated_train_payload_path` | mltable | Directory containing the payload to be sent to the model. |
| `generated_validation_payload_path` | mltable | Directory containing the payload to be sent to the model. |
| `hash_train_data` | uri_file | JSONL file containing the hash for each payload. |
| `hash_validation_data` | uri_file | JSONL file containing the hash for each payload. |

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring Postprocess"]

Просмотреть файл

@ -0,0 +1,110 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data_batch_postprocess
version: 0.0.1
type: command
is_deterministic: False
display_name: OSS Distillation Generate Data Postprocess Batch Scoring
description: Component to prepare data returned from teacher model enpoint in batch
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
inputs:
# Inputs
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
hash_train_data:
type: uri_file
optional: false
description: jsonl file containing the hash for each payload.
hash_validation_data:
type: uri_file
optional: true
description: jsonl file containing the hash for each payload.
batch_score_train_result:
type: uri_folder
description: Path to the directory containing jsonl file(s) that have the result for each payload.
batch_score_validation_result:
type: uri_folder
optional: true
description: Path to the directory containing jsonl file(s) that have the result for each payload.
min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
optional: true
default: "false"
description: Enable Chain of thought for data generation
enable_chain_of_density:
type: string
optional: true
default: "false"
description: Enable Chain of density for text summarization
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
- MATH
- SUMMARIZATION
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
connection_config_file:
type: uri_file
description: Connection config file for batch scoring
outputs:
generated_batch_train_file_path:
type: uri_file
description: Generated train data
mode: rw_mount
generated_batch_validation_file_path:
type: uri_file
description: Generated validation data
mode: rw_mount
code: ../../src
command: >-
python generate_data_postprocess.py
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
--hash_train_data ${{inputs.hash_train_data}}
$[[--hash_validation_data ${{inputs.hash_validation_data}}]]
--batch_score_train_result ${{inputs.batch_score_train_result}}
$[[--batch_score_validation_result ${{inputs.batch_score_validation_result}}]]
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--connection_config_file ${{inputs.connection_config_file}}
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
--generated_batch_validation_file_path ${{outputs.generated_batch_validation_file_path}}

Просмотреть файл

@ -0,0 +1,39 @@
# OSS Distillation Generate Data Batch Scoring Preprocess
## Description
This component prepares data to invoke the teacher model endpoint in batch. It supports various data formats such as `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
## Environment
The component uses the following environment:
- `azureml://registries/azureml/environments/acft-hf-nlp-gpu/labels/latest`
## Inputs
The component accepts the following inputs:
| Name | Type | Description | Required | Default |
|-------------------------------|-----------|-------------------------------------------------------------------------------------------------|----------|---------|
| `train_file_path` | uri_file | Path to the registered training data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | Yes | |
| `validation_file_path` | uri_file | Path to the registered validation data asset. Supported formats: `jsonl`, `json`, `csv`, `tsv`, `parquet`. | No | |
| `teacher_model_endpoint_url` | string | The URL of the teacher model endpoint. | Yes | |
| `teacher_model_asset_id` | string | The asset ID of the teacher model. | Yes | |
| `teacher_model_max_new_tokens`| integer | Teacher model max_new_tokens inference parameter. | Yes | 128 |
| `teacher_model_temperature` | number | Teacher model temperature inference parameter. | Yes | 0.2 |
| `teacher_model_top_p` | number | Teacher model top_p inference parameter. | Yes | 0.1 |
| `teacher_model_frequency_penalty` | number | Teacher model frequency penalty inference parameter. | Yes | 0.0 |
| `teacher_model_presence_penalty` | number | Teacher model presence penalty inference parameter. | Yes | 0.0 |
| `teacher_model_stop` | string | Teacher model stop inference parameter. | No | |
| `enable_chain_of_thought` | string | Enable Chain of thought for data generation. | No | "false" |
| `enable_chain_of_density` | string | Enable Chain of density for text summarization. | No | "false" |
| `max_len_summary` | integer | Maximum Length Summary for text summarization. | No | 80 |
| `data_generation_task_type` | string | Specifies the type of data generation task. Supported values: `NLI`, `CONVERSATION`, `NLU_QA`, `MATH`, `SUMMARIZATION`. | Yes | |
| `validation_output` | uri_file | Validation status. | Yes | |
## Outputs
The component produces the following outputs:
| Name | Type | Description |
|----------------------------------|-----------|---------------------------------------------------------------|
| `generated_train_payload_path` | mltable | Directory containing the payload to be sent to the model. |
| `generated_validation_payload_path` | mltable | Directory containing the payload to be sent to the model. |
| `hash_train_data` | uri_file | JSONL file containing the hash for each payload. |
| `hash_validation_data` | uri_file | JSONL file containing the hash for each payload. |

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation","Batch Scoring Preprocess"]

Просмотреть файл

@ -0,0 +1,152 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data_batch_preprocess
version: 0.0.1
type: command
is_deterministic: False
display_name: OSS Distillation Generate Data Batch Scoring Preprocess
description: Component to prepare data to invoke teacher model enpoint in batch
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
inputs:
# Inputs
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name
teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint url
teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key
teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter
teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter
teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter
teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter
teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter
teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter
enable_chain_of_thought:
type: string
optional: true
default: "false"
description: Enable Chain of thought for data generation
enable_chain_of_density:
type: string
optional: true
default: "false"
description: Enable Chain of density for text summarization
max_len_summary:
type: integer
optional: true
default: 80
description: Maximum Length Summary for text summarization
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
- MATH
- SUMMARIZATION
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
# Output of validation component.
validation_info:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount
outputs:
generated_train_payload_path:
type: mltable
description: directory containing the payload to be sent to the model.
generated_validation_payload_path:
type: mltable
description: directory containing the payload to be sent to the model.
hash_train_data:
type: uri_file
description: jsonl file containing the hash for each payload.
hash_validation_data:
type: uri_file
description: jsonl file containing the hash for each payload.
batch_config_connection:
type: uri_file
description: Config file path that contains deployment configurations
code: ../../src
command: >-
python generate_data_preprocess.py
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--max_len_summary ${{inputs.max_len_summary}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--generated_train_payload_path ${{outputs.generated_train_payload_path}}
--generated_validation_payload_path ${{outputs.generated_validation_payload_path}}
--hash_train_data ${{outputs.hash_train_data}}
--hash_validation_data ${{outputs.hash_validation_data}}
--batch_config_connection ${{outputs.batch_config_connection}}

Просмотреть файл

@ -0,0 +1,23 @@
# OSS Distillation Generate Data Path Selection
## Description
This component selects the path for data generation based on the task type. It supports various data generation tasks such as Natural Language Inference (NLI), Conversation, Natural Language Understanding for Question Answering (NLU_QA), Math, and Summarization.
## Environment
The component uses the following environment:
- `azureml://registries/azureml/environments/model-evaluation/labels/latest`
## Inputs
The component accepts the following inputs:
- `data_generation_task_type` (string): Specifies the type of data generation task. Supported values are:
- `NLI`: Generate Natural Language Inference data
- `CONVERSATION`: Generate conversational data (multi/single turn)
- `NLU_QA`: Generate Natural Language Understanding data for Question Answering data
- `MATH`: Generate Math data for numerical responses
- `SUMMARIZATION`: Generate Key Summary for an Article
## Outputs
The component produces the following output:
- `output` (boolean): A control output indicating the success or failure of the data path selection.

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation","Batch or Sequence Scoring Selector"]

Просмотреть файл

@ -0,0 +1,41 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_batch_scoring_selector
version: 0.0.1
type: command
is_deterministic: True
display_name: OSS Distillation Batch Scoring Selector Component
description: Component to select the Batch Scoring Selector based on the task type
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
inputs:
# Inputs
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
- MATH
- SUMMARIZATION
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
outputs:
output:
type: boolean
is_control: true
code: ../../src
command: >-
mldesigner execute --source generate_data_batch_scoring_selection.py --name validate
--inputs data_generation_task_type=${{inputs.data_generation_task_type}}
--outputs output='${{outputs.output}}'

Просмотреть файл

@ -0,0 +1,23 @@
# OSS Distillation Generate Data Path Selection
## Description
This component selects the path for data generation based on the task type. It supports various data generation tasks such as Natural Language Inference (NLI), Conversation, Natural Language Understanding for Question Answering (NLU_QA), Math, and Summarization.
## Environment
The component uses the following environment:
- `azureml://registries/azureml/environments/model-evaluation/labels/latest`
## Inputs
The component accepts the following inputs:
- `data_generation_task_type` (string): Specifies the type of data generation task. Supported values are:
- `NLI`: Generate Natural Language Inference data
- `CONVERSATION`: Generate conversational data (multi/single turn)
- `NLU_QA`: Generate Natural Language Understanding data for Question Answering data
- `MATH`: Generate Math data for numerical responses
- `SUMMARIZATION`: Generate Key Summary for an Article
## Outputs
The component produces the following output:
- `output` (boolean): A control output indicating the success or failure of the data path selection.

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation"]

Просмотреть файл

@ -0,0 +1,54 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_file_selector
version: 0.0.1
type: command
is_deterministic: True
tags:
codegenBy: dsl.condition_output
display_name: OSS Distillation Fine-Tuning Input File Selector
description: Component to select the Batch Scoring Selector based on the task type
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
inputs:
generated_batch_train_file_path:
type: uri_folder
optional: true
mode: ro_mount
generated_batch_validation_file_path:
type: uri_folder
optional: true
mode: ro_mount
generated_train_file_path:
type: uri_folder
optional: true
mode: ro_mount
generated_validation_file_path:
type: uri_folder
optional: true
mode: ro_mount
condition:
type: boolean
default: false
optional: false
outputs:
ft_input_train_file_path:
type: uri_file
mode: rw_mount
ft_input_validation_file_path:
type: uri_file
mode: rw_mount
code: ../../src
command: >-
python dsl_condition_output.py
$[[--generated_batch_train_file_path ${{inputs.generated_batch_train_file_path}}]]
$[[--generated_batch_validation_file_path ${{inputs.generated_batch_validation_file_path}}]]
$[[--generated_train_file_path ${{inputs.generated_train_file_path}}]]
$[[--generated_validation_file_path ${{inputs.generated_validation_file_path}}]]
--condition ${{inputs.condition}}
--ft_input_train_file_path ${{outputs.ft_input_train_file_path}}
--ft_input_validation_file_path ${{outputs.ft_input_validation_file_path}}

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation"]

Просмотреть файл

@ -0,0 +1,241 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_seq_scoring_pipeline
version: 0.0.1
type: pipeline
display_name: OSS Distillation Sequence Scoring Pipeline
description: Component to generate data from teacher model enpoint(sequentially) and finetune student model on generated dataset
inputs:
# Compute parameters
instance_type_pipeline_validation:
type: string
optional: True
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
instance_type_data_generation:
type: string
optional: true
default: Standard_D4as_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
instance_type_data_import:
type: string
optional: true
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for data_import component in case of virtual cluster compute, eg. Singularity.D8_v3. The parameter compute_data_import must be set to 'serverless' for instance_type to be used
instance_type_finetune:
type: string
optional: true
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
compute_pipeline_validation:
type: string
optional: True
default: 'serverless'
description: compute to be used for validation component
compute_data_generation:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for model_import eg. provide 'FT-Cluster' if
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
compute_data_import:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for model_import eg. provide 'FT-Cluster' if
your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
compute_finetune:
type: string
optional: true
default: 'serverless'
description: >-
compute to be used for finetune eg. provide 'FT-Cluster' if your
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
# ########################### Data Generator Component ########################### #
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_info:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount
teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name
teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL
teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key
teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter
teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter
teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter
teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter
teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter
teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter
request_batch_size:
type: integer
default: 10
description: No of data records to hit teacher model endpoint in one go
min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
optional: true
default: "false"
description: Enable Chain of thought for data generation
enable_chain_of_density:
type: string
optional: true
default: "false"
description: Enable Chain of density for text summarization
max_len_summary:
type: integer
optional: true
default: 80
description: Maximum Length Summary for text summarization
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
- MATH
- SUMMARIZATION
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
# Training parameters
num_train_epochs:
type: integer
default: 1
optional: true
description: training epochs
per_device_train_batch_size:
type: integer
default: 1
optional: true
description: Train batch size
learning_rate:
type: number
default: 3e-04
optional: true
description: Start learning rate.
# Output of validation component.
validation_output:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount
outputs:
generated_train_file_path:
type: uri_file
description: Generated train data
mode: rw_mount
generated_validation_file_path:
type: uri_file
description: Generated validation data
mode: rw_mount
jobs:
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.8
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
identity:
type: user_identity
inputs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.inputs.validation_output}}'
outputs:
generated_train_file_path: '${{parent.outputs.generated_train_file_path}}'
generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}'

Просмотреть файл

@ -0,0 +1,23 @@
# OSS Distillation Data Generation Batch Scoring Selector
This component is designed to check if the validation file is present or not in the OSS Distillation pipeline. It ensures that the provided validation file path is valid and accessible.
## Description
The OSS Distillation Data Generation Batch Scoring Selector is a command component that verifies the presence of a validation file. It supports various data formats including `jsonl`, `json`, `csv`, `tsv`, and `parquet`.
## Environment
- **Environment**: `azureml://registries/azureml/environments/model-evaluation/labels/latest`
## Inputs
| Name | Type | Optional | Description |
|-----------------------|----------|----------|-------------------------------------------------------------------------------------------------------------|
| validation_file_path | uri_file | Yes | Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv`, and `parquet`. |
## Outputs
| Name | Type | Description |
|--------|---------|------------------------------------|
| output | boolean | Indicates if the validation file is present or not. |

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation","Validation File Checker"]

Просмотреть файл

@ -0,0 +1,31 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_data_generation_validation_file_checker
version: 0.0.1
type: command
is_deterministic: True
display_name: OSS Distillation Validation File Checker Component
description: Component to Check if the validation file is present or not
environment: azureml://registries/azureml/environments/model-evaluation/labels/latest
inputs:
# Inputs
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
outputs:
output:
type: boolean
is_control: true
code: ../../src
command: >-
mldesigner execute --source generate_data_validation_file_check.py --name validate
--inputs $[[validation_file_path=${{inputs.validation_file_path}}]]
--outputs output='${{outputs.output}}'

Просмотреть файл

@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_pipeline name: oss_distillation_pipeline
version: 0.0.9 version: 0.0.10
type: pipeline type: pipeline
@ -171,6 +171,58 @@ inputs:
4. MATH: Generate Math data for numerical responses 4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article 5. SUMMARIZATION: Generate Key Summary for an Article
# ########################### Batch Score Component ########################### #
authentication_type:
type: string
optional: False
description: Authentication type for endpoint. Either `azureml_workspace_connection` or `managed_identity`.
default: azureml_workspace_connection
enum:
- azureml_workspace_connection
- managed_identity
additional_headers:
type: string
optional: True
description: JSON serialized string expressing additional headers to be added to each request.
debug_mode:
type: boolean
optional: False
default: False
description: Enable debug mode to print all the debug logs in the score step.
ensure_ascii:
type: boolean
optional: False
default: False
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
max_retry_time_interval:
type: integer
optional: True
description: The maximum time (in seconds) spent retrying a payload. If unspecified, payloads are retried for unlimited time.
initial_worker_count:
type: integer
optional: False
default: 5
description: The initial number of workers to use for scoring.
max_worker_count:
type: integer
optional: False
default: 200
description: Overrides `initial_worker_count` if necessary.
instance_count:
type: integer
default: 1
description: Number of nodes in a compute cluster we will run the batch score step on.
max_concurrency_per_instance:
type: integer
default: 1
description: Number of processes that will be run concurrently on any given node. This number should not be larger than 1/2 of the number of cores in an individual node in the specified cluster.
mini_batch_size:
type: string
optional: true
default: 100KB
description: The mini batch size for parallel run.
# ########################### Finetuning Component ########################### # # ########################### Finetuning Component ########################### #
number_of_gpu_to_use_finetuning: number_of_gpu_to_use_finetuning:
@ -230,6 +282,12 @@ inputs:
optional: true optional: true
description: Name of the registered model description: Name of the registered model
validation_info:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount
outputs: outputs:
output_model: output_model:
type: uri_folder type: uri_folder
@ -270,31 +328,103 @@ jobs:
type: uri_file type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
oss_distillation_generate_data: data_generation_batch_scoring_selector:
type: command type: command
component: azureml:oss_distillation_generate_data:0.0.8 component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1
compute: '${{parent.inputs.compute_data_generation}}' compute: '${{parent.inputs.compute_pipeline_validation}}'
resources: resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}' instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity: identity:
type: user_identity type: user_identity
inputs: inputs:
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
validation_succeeded:
type: if_else
condition: ${{parent.jobs.data_generation_batch_scoring_selector.outputs.output}}
true_block: ${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline}}
false_block: ${{parent.jobs.oss_distillation_seq_scoring_pipeline}}
oss_distillation_batchscoring_datagen_pipeline:
type: pipeline
component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1
inputs:
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
instance_type_data_import: '${{parent.inputs.instance_type_data_import}}'
instance_type_finetune: '${{parent.inputs.instance_type_finetune}}'
compute_pipeline_validation: '${{parent.inputs.compute_pipeline_validation}}'
compute_data_generation: '${{parent.inputs.compute_data_generation}}'
compute_data_import: '${{parent.inputs.compute_data_import}}'
compute_finetune: '${{parent.inputs.compute_finetune}}'
train_file_path: '${{parent.inputs.train_file_path}}' train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}' validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}' teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}' teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}' teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}' teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}' teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}' teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}' teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
teacher_model_stop: '${{parent.inputs.teacher_model_stop}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
authentication_type: '${{parent.inputs.authentication_type}}'
additional_headers: '${{parent.inputs.additional_headers}}'
debug_mode: '${{parent.inputs.debug_mode}}'
ensure_ascii: '${{parent.inputs.ensure_ascii}}'
max_retry_time_interval: '${{parent.inputs.max_retry_time_interval}}'
initial_worker_count: '${{parent.inputs.initial_worker_count}}'
max_worker_count: '${{parent.inputs.max_worker_count}}'
instance_count: '${{parent.inputs.instance_count}}'
max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}'
mini_batch_size: '${{parent.inputs.mini_batch_size}}'
validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
outputs:
generated_batch_train_file_path:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
generated_batch_validation_file_path:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
oss_distillation_seq_scoring_pipeline:
type: pipeline
component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1
inputs:
instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}'
instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}'
instance_type_data_import: '${{parent.inputs.instance_type_data_import}}'
instance_type_finetune: '${{parent.inputs.instance_type_finetune}}'
compute_pipeline_validation: '${{parent.inputs.compute_pipeline_validation}}'
compute_data_generation: '${{parent.inputs.compute_data_generation}}'
compute_data_import: '${{parent.inputs.compute_data_import}}'
compute_finetune: '${{parent.inputs.compute_finetune}}'
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
teacher_model_stop: '${{parent.inputs.teacher_model_stop}}'
request_batch_size: '${{parent.inputs.request_batch_size}}' request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}' validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
outputs: outputs:
generated_train_file_path: generated_train_file_path:
@ -304,6 +434,30 @@ jobs:
type: uri_file type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
oss_distillation_train_data_generation_file_selector:
type: command
component: azureml:oss_distillation_data_generation_file_selector:0.0.1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity:
type: user_identity
inputs:
generated_batch_train_file_path: '${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline.outputs.generated_batch_train_file_path}}'
generated_batch_validation_file_path: '${{parent.jobs.oss_distillation_batchscoring_datagen_pipeline.outputs.generated_batch_validation_file_path}}'
generated_train_file_path: '${{parent.jobs.oss_distillation_seq_scoring_pipeline.outputs.generated_train_file_path}}'
generated_validation_file_path: '${{parent.jobs.oss_distillation_seq_scoring_pipeline.outputs.generated_validation_file_path}}'
condition: '${{parent.jobs.data_generation_batch_scoring_selector.outputs.output}}'
outputs:
ft_input_train_file_path:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
ft_input_validation_file_path:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl
oss_text_generation_data_import: oss_text_generation_data_import:
type: command type: command
component: azureml:oss_text_generation_data_import:0.0.24 component: azureml:oss_text_generation_data_import:0.0.24
@ -318,8 +472,8 @@ jobs:
environment_variables: environment_variables:
_AZUREML_CR_ENABLE_ITP_CAP: "false" _AZUREML_CR_ENABLE_ITP_CAP: "false"
inputs: inputs:
train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}' train_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_train_file_path}}'
validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}' validation_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_validation_file_path}}'
system_properties: '${{parent.inputs.system_properties}}' system_properties: '${{parent.inputs.system_properties}}'
oss_chat_completion_finetune: oss_chat_completion_finetune:

Просмотреть файл

@ -74,6 +74,9 @@ DEFAULT_TEMPERATURE = 0.2
# TEXT SUMMARIZATION DEFAULT OUTPUT WORD COUNT # TEXT SUMMARIZATION DEFAULT OUTPUT WORD COUNT
DEFAULT_MAX_LEN_SUMMARY = 80 DEFAULT_MAX_LEN_SUMMARY = 80
STATUS_SUCCESS = "SUCCESS"
FINISH_REASON_STOP = "stop"
class InferenceMode: class InferenceMode:
"""Supported inference modes.""" """Supported inference modes."""
@ -112,6 +115,10 @@ class TelemetryConstants:
INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint" INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint"
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data" BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data" BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
PRE_PROCESS_TRAINING_DATA = "pre_process_training_data"
PRE_PROCESS_VALIDATION_DATA = "pre_process_validation_data"
POST_PROCESS_TRAINING_DATA = "post_process_training_data"
POST_PROCESS_VALIDATION_DATA = "post_process_validation_data"
PROCESS_DATASET_RECORD = "process_dataset_record" PROCESS_DATASET_RECORD = "process_dataset_record"
VALIDATOR = "validator" VALIDATOR = "validator"
@ -123,6 +130,28 @@ class TelemetryConstants:
VALIDATE_TRAINING_DATA = "validate_training_data" VALIDATE_TRAINING_DATA = "validate_training_data"
VALIDATE_VALIDATION_DATA = "validate_validation_data" VALIDATE_VALIDATION_DATA = "validate_validation_data"
VALIDATE_MODEL_INFERENCE = "validate_model_inference" VALIDATE_MODEL_INFERENCE = "validate_model_inference"
VERSION_SELECTION = "version_selection"
class PayloadField:
"""Payload fields."""
# Payload fields that will be sent to the model.
MESSAGES = "messages"
ROLE = "role"
CONTENT = "content"
SYSTEM = "system"
USER = "user"
# Payload fields that will be received from the model.
RESPONSE = "response"
REQUEST = "request"
class HashField:
"""Hash fields."""
HASH = "hash"
class BackoffConstants: class BackoffConstants:
@ -160,13 +189,16 @@ class SystemPrompt:
@classmethod @classmethod
def default_cot_prompt(cls): def default_cot_prompt(cls):
"""Get the default chain of thought prompt.""" """Get the default chain of thought prompt."""
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(keys=cls.DEFAULT_KEYS, additional_instructions="") return cls.DEFAULT_COT_SYSTEM_PROMPT.format(
keys=cls.DEFAULT_KEYS, additional_instructions=""
)
@classmethod @classmethod
def math_cot_prompt(cls): def math_cot_prompt(cls):
"""Get the math chain of thought prompt for datasets expecting numeric answers.""" """Get the math chain of thought prompt for datasets expecting numeric answers."""
return cls.DEFAULT_COT_SYSTEM_PROMPT.format(keys=cls.MATH_NUMERICAL_KEYS, return cls.DEFAULT_COT_SYSTEM_PROMPT.format(
additional_instructions=cls.MATH_ADDITIONAL_INSTRUCTIONS keys=cls.MATH_NUMERICAL_KEYS,
additional_instructions=cls.MATH_ADDITIONAL_INSTRUCTIONS,
) )
@classmethod @classmethod

Просмотреть файл

@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Contains helper functions for input/output operations."""
from typing import Any, List, Dict
import json
import os
from azureml.acft.common_components.utils.error_handling.exceptions import (
ACFTValidationException,
)
from azureml.acft.common_components.utils.error_handling.error_definitions import (
ACFTUserError,
)
from azureml._common._error_definition.azureml_error import AzureMLError
def _filter_files_with_given_extension(
file_paths: List[str], extension: str
) -> List[str]:
"""Filter and return the list of files with given extension."""
return [file_path for file_path in file_paths if file_path.endswith(extension)]
def _get_file_paths_from_dir(dir: str) -> List[str]:
"""
Get sorted file paths from directory.
Args:
dir (str): Directory path.
Returns:
file_paths (List[str]): List of sorted file paths.
"""
file_paths = []
for root, _, files in os.walk(dir):
for file in files:
file_paths.append(os.path.join(root, file))
file_paths.sort()
return file_paths
def _resolve_io_path(path: str) -> List[str]:
"""Resolve input/output path as a list of file paths.
It can handle the following cases for `path` argument:
- `uri_file`: `path` points to a single file.
- `uri_folder`: `path` points to a directory containing multiple files.
Args:
path (str): Path to the file or folder.
Returns:
paths (List[str]): List of file paths.
"""
if not os.path.isfile(path):
return _get_file_paths_from_dir(path)
return [path]
def read_jsonl_files(path: str) -> List[Dict[str, Any]]:
"""
Read jsonl file/files and return a list of dictionaries.
If `path` points to a file without extension, try to read it as \
a jsonl file. This is done to support `uri_file` without extension scenario.
Args:
path (str): Path to a jsonl file or a directory containing jsonl files.
Returns:
data (List[Dict[str, Any]]): List of dictionaries.
"""
file_paths: List[str] = _resolve_io_path(path)
if len(file_paths) > 1:
file_paths = _filter_files_with_given_extension(file_paths, ".jsonl")
data_dicts = []
for file_path in file_paths:
with open(file_path, "r") as file:
for i, line in enumerate(file):
try:
data_dicts.append(json.loads(line))
except json.JSONDecodeError:
mssg = f"Invalid JSON format in line {i + 1} of file '{file_path}'."
raise ACFTValidationException._with_error(
AzureMLError.create(ACFTUserError, pii_safe_message=mssg)
)
if not data_dicts:
mssg = f"No data found in {file_paths}."
raise ACFTValidationException._with_error(
AzureMLError.create(ACFTUserError, pii_safe_message=mssg)
)
return data_dicts
def write_jsonl_file(file_path: str, data: List[Dict[str, Any]]) -> None:
"""Write data to a `.jsonl` file.
Args:
file_path (str): Path to the file.
data (List[Dict[str, Any]]): Data to be written to the file.
"""
with open(file_path, "w") as file:
for line in data:
file.write(json.dumps(line) + "\n")

Просмотреть файл

@ -5,32 +5,43 @@
import os import os
import time import time
from typing import List, Tuple, Union, Optional, Callable from typing import List, Tuple, Union, Optional, Callable, Any, Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from azure.ai.ml import MLClient from azure.ai.ml import MLClient
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, ServerlessEndpoint from azure.ai.ml.entities import (
ManagedOnlineEndpoint,
ManagedOnlineDeployment,
ServerlessEndpoint,
)
from azure.identity import AzureCliCredential, ManagedIdentityCredential from azure.identity import AzureCliCredential, ManagedIdentityCredential
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException from azureml.acft.common_components.utils.error_handling.exceptions import (
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError ACFTValidationException,
)
from azureml.acft.common_components.utils.error_handling.error_definitions import (
ACFTUserError,
)
from azureml._common._error_definition.azureml_error import AzureMLError from azureml._common._error_definition.azureml_error import AzureMLError
from azureml.acft.common_components import get_logger_app from azureml.acft.common_components import get_logger_app
from azureml.core import Run, Workspace from azureml.core import Run, Workspace
from azureml.core.run import _OfflineRun from azureml.core.run import _OfflineRun
import hashlib
import json
from common.constants import ( from common.constants import (
REQUESTS_RETRY_DELAY, REQUESTS_RETRY_DELAY,
REGISTRY_MODEL_PATTERN, REGISTRY_MODEL_PATTERN,
SUPPORTED_STUDENT_MODEL_MAP, SUPPORTED_STUDENT_MODEL_MAP,
SUPPORTED_TEACHER_MODEL_MAP, SUPPORTED_TEACHER_MODEL_MAP,
BackoffConstants BackoffConstants,
) )
logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import") logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
current_run: Run = Run.get_context() current_run: Run = Run.get_context()
@ -57,7 +68,9 @@ def retry(times: int):
time.sleep(REQUESTS_RETRY_DELAY) time.sleep(REQUESTS_RETRY_DELAY)
else: else:
logger.warning( logger.warning(
"Retried {} times when calling {}, now giving up!".format(times, func.__name__) "Retried {} times when calling {}, now giving up!".format(
times, func.__name__
)
) )
raise raise
@ -114,7 +127,7 @@ def get_workspace_mlclient(workspace: Workspace = None) -> MLClient:
credential, credential,
subscription_id=workspace.subscription_id, subscription_id=workspace.subscription_id,
resource_group_name=workspace.resource_group, resource_group_name=workspace.resource_group,
workspace_name=workspace.name workspace_name=workspace.name,
) )
raise Exception("Error creating MLClient. No credentials or workspace found") raise Exception("Error creating MLClient. No credentials or workspace found")
@ -153,9 +166,13 @@ class ServerlessEndpointDetails(EndpointDetails):
""" """
self._mlclient: MLClient = mlclient_ws self._mlclient: MLClient = mlclient_ws
try: try:
self._endpoint: ServerlessEndpoint = self._mlclient.serverless_endpoints.get(endpoint_name) self._endpoint: ServerlessEndpoint = (
self._mlclient.serverless_endpoints.get(endpoint_name)
)
except Exception as e: except Exception as e:
raise Exception(f"Serverless endpoint fetch details failed with exception: {e}") raise Exception(
f"Serverless endpoint fetch details failed with exception: {e}"
)
# ensure endpoint is healthy # ensure endpoint is healthy
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}") logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
@ -172,9 +189,13 @@ class ServerlessEndpointDetails(EndpointDetails):
str: endpoint primary key for serverless deployment. str: endpoint primary key for serverless deployment.
""" """
try: try:
return self._mlclient.serverless_endpoints.get_keys(self._endpoint.name).primary_key return self._mlclient.serverless_endpoints.get_keys(
self._endpoint.name
).primary_key
except Exception as e: except Exception as e:
raise Exception(f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}") raise Exception(
f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}"
)
def get_endpoint_url(self) -> str: def get_endpoint_url(self) -> str:
"""Get URL for managed online endpoint.""" """Get URL for managed online endpoint."""
@ -202,7 +223,9 @@ class OnlineEndpointDetails(EndpointDetails):
""" """
self._mlclient: MLClient = mlclient_ws self._mlclient: MLClient = mlclient_ws
try: try:
self._endpoint: ManagedOnlineEndpoint = self._mlclient.online_endpoints.get(endpoint_name) self._endpoint: ManagedOnlineEndpoint = self._mlclient.online_endpoints.get(
endpoint_name
)
except Exception as e: except Exception as e:
raise Exception(f"Online endpoint fetch details failed with exception: {e}") raise Exception(f"Online endpoint fetch details failed with exception: {e}")
@ -210,9 +233,12 @@ class OnlineEndpointDetails(EndpointDetails):
# fetch deployment with 100% traffic # fetch deployment with 100% traffic
deployments = [ deployments = [
deployment deployment
for deployment in all_deployments if deployment.name in [ for deployment in all_deployments
if deployment.name
in [
deployment_name deployment_name
for deployment_name, traffic_percent in self._endpoint.traffic.items() if traffic_percent == 100 for deployment_name, traffic_percent in self._endpoint.traffic.items()
if traffic_percent == 100
] ]
] ]
@ -226,10 +252,16 @@ class OnlineEndpointDetails(EndpointDetails):
self._deployment = deployments[0] self._deployment = deployments[0]
# ensure endpoint and deployment is healthy # ensure endpoint and deployment is healthy
logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}") logger.info(f"Endpoint provisioning state: {self._endpoint.provisioning_state}")
logger.info(f"Deployment provisioning state: {self._deployment.provisioning_state}") logger.info(
if not (self._endpoint.provisioning_state.lower() == "succeeded" f"Deployment provisioning state: {self._deployment.provisioning_state}"
and self._deployment.provisioning_state.lower() == "succeeded"): )
raise Exception(f"Endpoint {self._endpoint.name} or deployment {self._deployment.name} is unhealthy.") if not (
self._endpoint.provisioning_state.lower() == "succeeded"
and self._deployment.provisioning_state.lower() == "succeeded"
):
raise Exception(
f"Endpoint {self._endpoint.name} or deployment {self._deployment.name} is unhealthy."
)
def get_endpoint_key(self): def get_endpoint_key(self):
"""Get endpoint primary key for managed online deployment. """Get endpoint primary key for managed online deployment.
@ -241,9 +273,13 @@ class OnlineEndpointDetails(EndpointDetails):
str: endpoint primary key for managed online deployment. str: endpoint primary key for managed online deployment.
""" """
try: try:
return self._mlclient.online_endpoints.get_keys(self._endpoint.name).primary_key return self._mlclient.online_endpoints.get_keys(
self._endpoint.name
).primary_key
except Exception as e: except Exception as e:
raise Exception(f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}") raise Exception(
f"Failed to get endpoint keys for endpoint: {self._endpoint.name}. Exception: {e}"
)
def get_endpoint_url(self) -> str: def get_endpoint_url(self) -> str:
"""Get URL for managed online endpoint.""" """Get URL for managed online endpoint."""
@ -260,11 +296,14 @@ class OnlineEndpointDetails(EndpointDetails):
List[ManagedOnlineDeployment]: List of deployments List[ManagedOnlineDeployment]: List of deployments
""" """
try: try:
self._deployments: List[ManagedOnlineDeployment] = self._mlclient.online_deployments.list( self._deployments: List[ManagedOnlineDeployment] = (
self._endpoint.name) self._mlclient.online_deployments.list(self._endpoint.name)
)
return self._deployments return self._deployments
except Exception as e: except Exception as e:
logger.error(f"Could not fetch deployments for endpoint: {self._endpoint.name}. Exception => {e}") logger.error(
f"Could not fetch deployments for endpoint: {self._endpoint.name}. Exception => {e}"
)
return None return None
@ -305,7 +344,11 @@ def _get_model_id_from_run_details():
def _get_model_details(model_asset_id, supported_model_map) -> Tuple[str, str, str]: def _get_model_details(model_asset_id, supported_model_map) -> Tuple[str, str, str]:
# try matching registry model pattern # try matching registry model pattern
if match := REGISTRY_MODEL_PATTERN.match(model_asset_id): if match := REGISTRY_MODEL_PATTERN.match(model_asset_id):
registry, model_name, model_version = match.group("registry"), match.group("model"), match.group("version") registry, model_name, model_version = (
match.group("registry"),
match.group("model"),
match.group("version"),
)
# check if model_name exists in supported list # check if model_name exists in supported list
if model_name not in supported_model_map: if model_name not in supported_model_map:
raise Exception( raise Exception(
@ -412,7 +455,7 @@ def exponential_backoff(
ACFTUserError, ACFTUserError,
pii_safe_message=( pii_safe_message=(
f"Encountered unknown status code: {status_code}. ", f"Encountered unknown status code: {status_code}. ",
) ),
) )
) )
@ -435,10 +478,25 @@ def exponential_backoff(
ACFTUserError, ACFTUserError,
pii_safe_message=( pii_safe_message=(
f"Request failed after multiple tries with status code: {status_code}. ", f"Request failed after multiple tries with status code: {status_code}. ",
) ),
) )
) )
return wrapper return wrapper
return decorator return decorator
def get_hash_value(data: Union[Dict[str, Any], str]) -> str:
"""
Get hash value for the data.
Args:
data (Dict[str, Any]): Data for which hash value needs to be computed.
Returns:
hash_value (str): Hash value.
"""
if isinstance(data, str):
return hashlib.sha256(data.encode()).hexdigest()
return hashlib.sha256(json.dumps(data).encode()).hexdigest()

Просмотреть файл

@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module link output on run conditionally.
If condition is `true`, link `output` to `generated_batch_train_file_path` and `generated_batch_validation_file_path`.
If condition is `false`, link `output` to `generated_train_file_path` and `generated_validation_file_path`.
"""
import argparse
def copy_file_contents(input_src1, ft_input_train_file_path):
"""
Copy the contents of one file to another.
Parameters:
input_src1 (str): The path to the source file.
ft_input_train_file_path (str): The path to the destination file.
Returns:
None
"""
# Read the contents of input_src1
with open(input_src1, "r") as src_file:
contents = src_file.read()
# Write the contents to ft_input_train_file_path
with open(ft_input_train_file_path, "w") as dest_file:
dest_file.write(contents)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--generated_batch_train_file_path", type=str)
parser.add_argument("--generated_batch_validation_file_path", type=str)
parser.add_argument("--generated_train_file_path", type=str)
parser.add_argument("--generated_validation_file_path", type=str)
parser.add_argument("--condition", type=str)
parser.add_argument("--ft_input_train_file_path", type=str)
parser.add_argument("--ft_input_validation_file_path", type=str)
args, _ = parser.parse_known_args()
print(f"Condition output component received args: {args}.")
if (
args.generated_batch_train_file_path is None
and args.generated_train_file_path is None
):
raise Exception(
"Got 'generated_batch_train_file_path' and 'generated_train_file_path' both be None."
)
condition = args.condition.lower() == "true"
input_src1 = args.generated_train_file_path
input_src2 = args.generated_validation_file_path
ft_input_train_file_path = args.ft_input_train_file_path
ft_input_validation_file_path = args.ft_input_validation_file_path
if condition:
input_src1 = args.generated_batch_train_file_path
input_src2 = args.generated_batch_validation_file_path
copy_file_contents(input_src1, ft_input_train_file_path)
copy_file_contents(input_src2, ft_input_validation_file_path)

Просмотреть файл

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""File containing function for FTaaS data import component."""
from azureml.acft.common_components import (
get_logger_app,
)
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from azureml.telemetry.activity import log_activity
from azure.ai.ml import Input
from mldesigner import Output, command_component
from common.constants import (
DataGenerationTaskType,
TelemetryConstants,
)
logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
@command_component
@swallow_all_exceptions(logger)
def validate(
data_generation_task_type: Input(type="string", optional=False), # noqa: F821
) -> Output(type="boolean", is_control=True): # noqa: F821
"""Entry function of model validation script."""
with log_activity(
logger,
TelemetryConstants.VERSION_SELECTION,
{"data_generation_task_type": data_generation_task_type},
):
logger.info("Validating arguments: " + repr(data_generation_task_type))
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
return False
else:
return True

Просмотреть файл

@ -0,0 +1,383 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""File containing function for FTaaS data import component."""
import json
import logging
import argparse
from argparse import Namespace
from pathlib import Path
from typing import List, Dict, Any
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
from azureml.acft.contrib.hf.nlp.constants.constants import (
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
)
from azureml.acft.common_components import (
get_logger_app,
set_logging_parameters,
LoggingLiterals,
)
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from azureml.telemetry.activity import log_activity
from common.io import read_jsonl_files
from common.constants import (
COMPONENT_NAME,
DEFAULT_SUCCESS_RATIO,
DataGenerationTaskType,
HashField,
TelemetryConstants,
SystemPrompt,
PayloadField,
STATUS_SUCCESS,
FINISH_REASON_STOP,
)
from common.utils import (
get_hash_value,
get_workspace_mlclient,
)
logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
def get_parser():
"""
Add arguments and returns the parser. Here we add all the arguments for all the tasks.
Those arguments that are not relevant for the input task should be ignored.
"""
parser = argparse.ArgumentParser(
description="Model selector for hugging face models", allow_abbrev=False
)
# File I/O
parser.add_argument(
"--batch_score_train_result",
type=str,
help="Path to the directory containing jsonl file(s) that have the result for each payload.",
)
parser.add_argument(
"--batch_score_validation_result",
default=None,
type=str,
help="Path to the directory containing jsonl file(s) that have the result for each payload.",
)
parser.add_argument(
"--hash_train_data",
type=str,
required=True,
help="Path tho the jsonl file containing the hash for each payload.",
)
parser.add_argument(
"--hash_validation_data",
type=str,
default=None,
help="Path tho the jsonl file containing the hash for each payload.",
)
parser.add_argument(
"--generated_batch_train_file_path",
type=Path,
default=None,
help="file to save the generated training data",
)
parser.add_argument(
"--generated_batch_validation_file_path",
type=Path,
default=None,
help="file to save the generated validation data",
)
# File I/O
parser.add_argument(
"--train_file_path",
type=str,
help="Input train file path",
)
parser.add_argument(
"--validation_file_path",
default=None,
type=str,
help="Input validation file path",
)
parser.add_argument(
"--enable_chain_of_thought",
type=str,
required=False,
default="false",
help="This enables Chain of Thought",
)
parser.add_argument(
"--enable_chain_of_density",
type=str,
required=False,
default="false",
help="This enables Chain of Density for Summarization",
)
parser.add_argument(
"--data_generation_task_type",
type=str,
required=True,
help="""Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Text Summary for Article
""",
choices=[v.value for v in DataGenerationTaskType],
)
parser.add_argument(
"--min_endpoint_success_ratio",
type=float,
required=False,
default=DEFAULT_SUCCESS_RATIO,
help=(
f"The minimum value of "
"(successful_requests / total_requests) required for classifying inference as successful. "
"If (successful_requests / total_requests) < min_endpoint_success_ratio, "
"the experiment will be marked as failed. "
f"By default it is {DEFAULT_SUCCESS_RATIO}. "
"(0 means all requests are allowed to fail while 1 means no request should fail.)"
),
)
parser.add_argument(
"--connection_config_file",
type=str,
required=False,
default=None,
help="A config file path that contains deployment configurations.",
)
return parser
def delete_connection(config_file: str):
"""Delete the connection configuration file.
Args:
config_file (str): The path to the connection configuration file.
"""
if config_file:
try:
config_from_file = {}
with open(config_file) as file:
config_from_file = json.load(file)
batch_connection_name = config_from_file.get("connection_name", None)
if batch_connection_name:
mlclient_ws = get_workspace_mlclient()
if not mlclient_ws:
raise Exception("Could not create MLClient for current workspace")
mlclient_ws.connections.delete(batch_connection_name)
except Exception as e:
msg = f"Error deleting connection: {e}"
logger.error(msg)
raise Exception(msg)
def postprocess_data(
batch_score_res_path: str,
input_file_path: str,
enable_cot: bool,
enable_cod: bool,
data_generation_task_type: str,
min_endpoint_success_ratio: float,
output_file_path: str,
hash_data: str,
):
"""Generate and save synthentic data under output_dataset.
Args:
batch_score_res_path (str): Path containing jsonl file(s) that have the result for each payload.
input_file_path (str): Input JSONL file path.
enable_cot (bool): Enable Chain of Thought
enable_cod (bool): Enable Chain of Density
data_generation_task_type (str): Data generation task type
min_endpoint_success_ratio (float): Minimum success ratio below which run will be considered a failure
output_file_path (str): Output JSONL file path.
hash_data (str): Path to the jsonl file containing the hash for each payload.
"""
error_count = 0
output_data = []
if input_file_path is None:
logger.info(
f"No input file path provided. Skipping data postprocessing for {input_file_path}."
)
return
hash_data_list: List[Dict[str, Any]] = read_jsonl_files(hash_data)
# Recreate hash_data_list respecting the order of batch_score_res_list
hash_data_dict = {
hash_data[HashField.HASH]: hash_data for hash_data in hash_data_list
}
total_rows = len(hash_data_list)
batch_score_res_list: List[Dict[str, Any]] = read_jsonl_files(batch_score_res_path)
try:
for idx, batch_score_dict in enumerate(batch_score_res_list):
status = batch_score_dict.get("status", "")
if status == STATUS_SUCCESS:
request_dict = batch_score_dict.get("request", {})
if request_dict:
synthetic_responses = []
messages = request_dict.pop("messages", [])
for message in messages:
role = message["role"]
if role == PayloadField.USER:
hash_val = get_hash_value(message)
system_message = hash_data_dict[hash_val].get(
PayloadField.SYSTEM, {}
)
synthetic_responses.append(system_message)
synthetic_responses.append(message)
break
response_data = batch_score_dict.get("response", {})
finish_reason = response_data["choices"][0]["finish_reason"]
if finish_reason == FINISH_REASON_STOP:
prediction_result = response_data["choices"][0]["message"][
"content"
].strip()
# For CoT prompts, need to remove the reasoning and only use the answer
if (
enable_cot
and data_generation_task_type
!= DataGenerationTaskType.CONVERSATION
):
key = SystemPrompt.get_response_key(
data_generation_task_type
)
prediction_result = json.loads(prediction_result)[key]
if (
enable_cod
and data_generation_task_type
== DataGenerationTaskType.SUMMARIZATION
):
result = json.loads(prediction_result)
prediction_result = result[-1]["Denser_Summary"]
synthetic_responses.append(
{"role": "assistant", "content": str(prediction_result)}
)
output_data.append({"messages": synthetic_responses})
else:
error_count += 1
else:
error_count += 1
except Exception as e:
logger.error(f"Error in postprocessing {idx} data: {e}")
raise e
success_ratio = float(total_rows - error_count) / total_rows
print(success_ratio)
if success_ratio < min_endpoint_success_ratio:
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)
with open(output_file_path, "w") as f:
for record in output_data:
f.write(json.dumps(record) + "\n")
def data_import(args: Namespace):
"""Copy the user data to output dir."""
train_file_path = args.train_file_path
validation_file_path = args.validation_file_path
batch_score_train_path = args.batch_score_train_result
batch_score_validation_path = args.batch_score_validation_result
generated_batch_train_file_path = args.generated_batch_train_file_path
generated_batch_validation_file_path = args.generated_batch_validation_file_path
enable_cot_str = args.enable_chain_of_thought
enable_cod_str = args.enable_chain_of_density
min_endpoint_success_ratio = args.min_endpoint_success_ratio
data_generation_task_type = args.data_generation_task_type
hash_train_data = args.hash_train_data
hash_validation_data = args.hash_validation_data
connection_config_file = args.connection_config_file
enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False
with log_activity(
logger=logger, activity_name=TelemetryConstants.POST_PROCESS_TRAINING_DATA
):
logger.info(
"Deleting batch configuration connection used for teacher model invocation."
)
delete_connection(connection_config_file)
logger.info(
"Running data postprocessing for train file path: %s", train_file_path
)
postprocess_data(
batch_score_res_path=batch_score_train_path,
input_file_path=train_file_path,
enable_cot=enable_cot,
enable_cod=enable_cod,
data_generation_task_type=data_generation_task_type,
min_endpoint_success_ratio=min_endpoint_success_ratio,
output_file_path=generated_batch_train_file_path,
hash_data=hash_train_data,
)
if validation_file_path:
with log_activity(
logger=logger,
activity_name=TelemetryConstants.POST_PROCESS_VALIDATION_DATA,
):
logger.info(
"Running data postprocessing for validation file path: %s",
validation_file_path,
)
postprocess_data(
batch_score_res_path=batch_score_validation_path,
input_file_path=validation_file_path,
enable_cot=enable_cot,
enable_cod=enable_cod,
data_generation_task_type=data_generation_task_type,
min_endpoint_success_ratio=min_endpoint_success_ratio,
output_file_path=generated_batch_validation_file_path,
hash_data=hash_validation_data,
)
else:
Path(generated_batch_validation_file_path.parent).mkdir(
exist_ok=True, parents=True
)
# create an empty file if validation file is not provided
open(generated_batch_validation_file_path, "w").close()
@swallow_all_exceptions(time_delay=5)
def main():
"""Parse args and import model."""
parser = get_parser()
args, _ = parser.parse_known_args()
set_logging_parameters(
task_type="ChatCompletion",
acft_custom_dimensions={
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
},
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
log_level=logging.INFO,
)
data_import(args)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,540 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""File containing function for FTaaS data import component."""
import json
import logging
import argparse
from argparse import Namespace
from pathlib import Path
import os
import uuid
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
from azureml.acft.contrib.hf.nlp.constants.constants import (
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
)
from azureml.acft.common_components import (
get_logger_app,
set_logging_parameters,
LoggingLiterals,
)
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from azureml.telemetry.activity import log_activity
from azure.ai.ml.entities import ServerlessConnection
from common.io import read_jsonl_files
from mltable import from_json_lines_files
from common.constants import (
COMPONENT_NAME,
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_SUMMARY_MAX_NEW_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
FREQUENCY_PENALTY,
PRESENCE_PENALTY,
MAX_NEW_TOKENS,
TEMPERATURE,
TOP_P,
STOP_TOKEN,
VLLM_CHAT_SCORE_PATH,
DataGenerationTaskType,
HashField,
TelemetryConstants,
SystemPrompt,
PayloadField,
DEFAULT_MAX_LEN_SUMMARY,
)
from common.utils import (
get_workspace_mlclient,
get_endpoint_details,
get_hash_value,
validate_teacher_model_details,
)
from common.validation import validate_file_paths_with_supported_formats
logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
def get_parser():
"""
Add arguments and returns the parser. Here we add all the arguments for all the tasks.
Those arguments that are not relevant for the input task should be ignored.
"""
parser = argparse.ArgumentParser(
description="Model selector for hugging face models", allow_abbrev=False
)
# File I/O
parser.add_argument(
"--train_file_path",
type=str,
help="Input train file path",
)
parser.add_argument(
"--validation_file_path",
default=None,
type=str,
help="Input validation file path",
)
parser.add_argument(
"--generated_train_file_path",
type=Path,
default=None,
help="file to save the generated training data",
)
parser.add_argument(
"--generated_validation_file_path",
type=Path,
default=None,
help="file to save the generated validation data",
)
# add optional data-generator params
parser.add_argument(
"--teacher_model_endpoint_name",
type=str,
required=False,
help="Teacher model endpoint name",
)
parser.add_argument(
"--teacher_model_endpoint_key",
type=str,
required=False,
help="Teacher model endpoint key",
)
parser.add_argument(
"--teacher_model_endpoint_url",
type=str,
required=True,
help="Teacher model endpoint URL",
)
parser.add_argument(
"--teacher_model_max_new_tokens",
type=int,
required=False,
default=DEFAULT_MAX_NEW_TOKENS,
help="Teacher model max_tokens parameter",
)
parser.add_argument(
"--teacher_model_temperature",
type=float,
required=False,
default=DEFAULT_TEMPERATURE,
help="Teacher model temperature parameter",
)
parser.add_argument(
"--teacher_model_top_p",
type=float,
required=False,
default=DEFAULT_TOP_P,
help="Teacher model top-p parameter",
)
parser.add_argument(
"--teacher_model_frequency_penalty",
type=float,
required=False,
help="Teacher model frequency parameter",
)
parser.add_argument(
"--teacher_model_presence_penalty",
type=float,
required=False,
help="Teacher model presense penalty",
)
parser.add_argument(
"--teacher_model_stop", type=str, required=False, help="Teacher model stop "
)
parser.add_argument(
"--enable_chain_of_thought",
type=str,
required=False,
default="false",
help="This enables Chain of Thought",
)
parser.add_argument(
"--enable_chain_of_density",
type=str,
required=False,
default="false",
help="This enables Chain of Density for Summarization",
)
parser.add_argument(
"--max_len_summary",
type=int,
required=False,
default=DEFAULT_MAX_LEN_SUMMARY,
help="Maximum word count for text summarization ",
)
parser.add_argument(
"--data_generation_task_type",
type=str,
required=True,
help="""Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Text Summary for Article
""",
choices=[v.value for v in DataGenerationTaskType],
)
parser.add_argument(
"--generated_train_payload_path",
type=str,
help="file to save the generated training payload data",
)
parser.add_argument(
"--generated_validation_payload_path",
type=str,
help="file to save the generated validation payload data",
)
parser.add_argument(
"--hash_train_data",
type=str,
required=True,
help="Path tho the jsonl file where the hash for each payload will be dumped.",
)
parser.add_argument(
"--hash_validation_data",
type=str,
required=True,
help="Path tho the jsonl file where the hash for each payload will be dumped.",
)
parser.add_argument(
"--batch_config_connection",
type=str,
required=True,
help="",
)
return parser
def preprocess_data(
inference_params: dict,
enable_cot: bool,
enable_cod: bool,
max_len_summary: int,
data_generation_task_type: str,
generated_train_payload_path: str,
generated_validation_payload_path: str,
hash_train_data: str,
hash_validation_data: str,
train_file_path: Path,
validation_file_path: Path = None,
):
"""Generate and save synthentic data under output_dataset.
Args:
inference_params (dict): Inference params to hit endpoint with
enable_cot (bool): Enable Chain of Thought processing
enable_cod (bool): Enable Chain of Density processing for text summarization task
max_len_summary (int): Maximum word count for text summarization
data_generation_task_type (str): Data generation task type
generated_train_payload_path (str): Path to save the generated training payload data
generated_validation_payload_path (str): Path to save the generated validation payload data
hash_train_data (str): Path to the jsonl file where the hash for each payload will be dumped.
hash_validation_data (str): Path to the jsonl file where the hash for each payload will be dumped.
train_file_path (Path): Train JSONL file path
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
"""
def process_system_prompt(message: dict) -> dict:
"""Update the system prompt depending on the task type and the flag enable_cot.
The original message unchanged if enable_cot is False or task type is conversation.
Args:
message (dict): System message
Returns:
message (dict): System message with updated content
"""
if (
enable_cot
and data_generation_task_type != DataGenerationTaskType.CONVERSATION
):
cot_prompt = SystemPrompt.get_cot_prompt(data_generation_task_type)
cot_system_message = {"role": "system", "content": cot_prompt}
return cot_system_message
elif (
enable_cod
and data_generation_task_type == DataGenerationTaskType.SUMMARIZATION
):
cod_prompt = SystemPrompt.get_cod_prompt(max_len_summary)
cod_system_message = {"role": "system", "content": cod_prompt}
return cod_system_message
else:
return message
def pre_process_data(
input_file_path: Path,
output_file_path: Path,
hash_file_path: str,
) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Args:
input_file_path (Path): Input data file path
output_file_path (Path): Path to output directory
hash_file_path (str): Path to the jsonl file where the hash for each payload will be dumped.
Raises:
Exception: if success ratio is less than min_endpoint_success_ratio
"""
input_data = read_jsonl_files(input_file_path)
output_data = []
output_hash_data = []
try:
for idx, record in enumerate(input_data):
# Basic validation for the input data
messages = record.pop("messages", [])
if not messages: # empty messages
logger.error(f"Failed with exception:{idx} Empty messages")
return
first_message = messages[0]
if first_message["role"] != PayloadField.SYSTEM:
logger.error(
f"row {idx} failed with exception: First message should be system, "
f"but got {first_message['role']}"
)
for message in messages[1:]:
role = message["role"]
if role not in ("assistant", "user"):
logger.error(
f"row {idx} failed with exception: role should be system or user, but got {role}"
)
inference_data = []
system_message = {}
for message in messages:
role = message["role"]
if role == PayloadField.SYSTEM:
system_message[PayloadField.SYSTEM] = message
inference_data.append(process_system_prompt(message))
elif role == PayloadField.USER:
inference_data.append(message)
hash_data = {
HashField.HASH: get_hash_value(message),
**system_message,
}
output_data.append(
{
"messages": inference_data,
**inference_params,
}
)
output_hash_data.append(hash_data)
except Exception as e:
logger.error(f"idx: {idx}. exception: {e}")
payload_jsonl_path = os.path.join(output_file_path, "payload.jsonl")
logger.info("payload_jsonl_path: %s", payload_jsonl_path)
with open(payload_jsonl_path, "w") as payload_file:
for entry in output_data:
payload_file.write(json.dumps(entry) + "\n")
logger.info("hash_file_path: %s", hash_file_path)
with open(hash_file_path, "w") as hash_file:
for entry in output_hash_data:
hash_file.write(json.dumps(entry) + "\n")
output_file_path = str(output_file_path)
mltable = from_json_lines_files(paths=[{"file": payload_jsonl_path}])
logger.info("output_file_path type before saving: %s", output_file_path)
mltable.save(output_file_path)
with log_activity(
logger=logger, activity_name=TelemetryConstants.PRE_PROCESS_TRAINING_DATA
):
logger.info("PreProcessing train file")
pre_process_data(train_file_path, generated_train_payload_path, hash_train_data)
logger.info("Data generated and saved for train file")
if validation_file_path:
with log_activity(
logger=logger,
activity_name=TelemetryConstants.PRE_PROCESS_VALIDATION_DATA,
):
logger.info("PreProcessing validation file")
pre_process_data(
validation_file_path,
generated_validation_payload_path,
hash_validation_data,
)
logger.info("Data generated and saved for validation file")
else:
hash_validation_data = Path(hash_validation_data)
Path(hash_validation_data.parent).mkdir(exist_ok=True, parents=True)
# create an empty file if validation file is not provided
open(hash_validation_data, "w").close()
def data_import(args: Namespace):
"""Copy the user data to output dir."""
train_file_path = args.train_file_path
validation_file_path = args.validation_file_path
generated_train_payload_path = args.generated_train_payload_path
generated_validation_payload_path = args.generated_validation_payload_path
teacher_model_endpoint_name = args.teacher_model_endpoint_name
teacher_model_endpoint_url = args.teacher_model_endpoint_url
teacher_model_endpoint_key = args.teacher_model_endpoint_key
# add optional data-generator params
teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
teacher_model_temperature = args.teacher_model_temperature
teacher_model_top_p = args.teacher_model_top_p
teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
teacher_model_presence_penalty = args.teacher_model_presence_penalty
teacher_model_stop = args.teacher_model_stop
enable_cot_str = args.enable_chain_of_thought
enable_cod_str = args.enable_chain_of_density
max_len_summary = args.max_len_summary
data_generation_task_type = args.data_generation_task_type
hash_train_data = args.hash_train_data
hash_validation_data = args.hash_validation_data
batch_config_connection = args.batch_config_connection
# validate file formats
validate_file_paths_with_supported_formats(
[args.train_file_path, args.validation_file_path]
)
logger.info("File format validation successful.")
enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False
mlclient_ws = get_workspace_mlclient()
if not mlclient_ws:
raise Exception("Could not create MLClient for current workspace")
if teacher_model_endpoint_name:
endpoint_details = get_endpoint_details(
mlclient_ws, teacher_model_endpoint_name
)
teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
teacher_model_asset_id = endpoint_details.get_deployed_model_id()
validate_teacher_model_details(teacher_model_asset_id)
if not teacher_model_endpoint_url:
raise Exception("Endpoint URL is a requried parameter for data generation")
if not teacher_model_endpoint_key:
raise Exception("Endpoint key is a requried parameter for data generation")
if teacher_model_top_p < 0 or teacher_model_top_p > 1:
raise Exception(
f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
)
if teacher_model_temperature < 0 or teacher_model_temperature > 1:
raise Exception(
f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
)
inference_params = {
MAX_NEW_TOKENS: (
DEFAULT_SUMMARY_MAX_NEW_TOKENS
if data_generation_task_type == "SUMMARIZATION"
and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
else teacher_model_max_new_tokens
),
TEMPERATURE: teacher_model_temperature,
TOP_P: teacher_model_top_p,
}
if teacher_model_frequency_penalty:
inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty
if teacher_model_presence_penalty:
inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty
if teacher_model_stop:
inference_params[STOP_TOKEN] = teacher_model_stop
if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH
logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")
try:
guid = uuid.uuid4()
short_guid = str(guid)[:8]
connection_name = f"distillation-ws-connection-{short_guid}"
mlclient_ws.connections.create_or_update(
ServerlessConnection(
name=connection_name,
endpoint=teacher_model_endpoint_url,
api_key=teacher_model_endpoint_key,
)
)
logger.info(f"Connection created with name: {connection_name}")
config = {}
config["scoring_url"] = teacher_model_endpoint_url
config["connection_name"] = connection_name
with open(batch_config_connection, "w") as f:
json.dump(config, f)
except Exception as e:
logger.error(
f"Failed to create connection for teacher model batch score invocation : {e}"
)
raise Exception(
"Failed to create workspace connection for teacher model batch score invocation "
)
logger.info("Running data preprocessing")
preprocess_data(
inference_params=inference_params,
enable_cot=enable_cot,
enable_cod=enable_cod,
max_len_summary=max_len_summary,
generated_train_payload_path=generated_train_payload_path,
generated_validation_payload_path=generated_validation_payload_path,
train_file_path=train_file_path,
data_generation_task_type=data_generation_task_type,
validation_file_path=validation_file_path,
hash_train_data=hash_train_data,
hash_validation_data=hash_validation_data,
)
@swallow_all_exceptions(time_delay=5)
def main():
"""Parse args and import model."""
parser = get_parser()
args, _ = parser.parse_known_args()
set_logging_parameters(
task_type="ChatCompletion",
acft_custom_dimensions={
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
},
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
log_level=logging.INFO,
)
data_import(args)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""File containing function for FTaaS data import component."""
from azureml.acft.common_components import (
get_logger_app,
)
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from azureml.telemetry.activity import log_activity
from azure.ai.ml import Input
from mldesigner import Output, command_component
from common.constants import (
TelemetryConstants,
)
logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
@command_component
@swallow_all_exceptions(logger)
def validate(
validation_file_path: Input(type="string", optional=True), # noqa: F821
) -> Output(type="boolean", is_control=True): # noqa: F821
"""Entry function of model validation script."""
with log_activity(
logger,
TelemetryConstants.VERSION_SELECTION,
{"validation_file_path": validation_file_path},
):
logger.info("Validating arguments: " + repr(validation_file_path))
if validation_file_path:
return True
else:
return False