feat: introduce validation component for distillation pipeline (#3284)

This commit is contained in:
ShobithNandakumar 2024-08-26 10:16:30 +05:30 коммит произвёл GitHub
Родитель 326b05d104
Коммит 7c566f9458
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
12 изменённых файлов: 986 добавлений и 32 удалений

Просмотреть файл

@ -0,0 +1,45 @@
## Data Generation Component
### Name
oss_distillation_generate_data
### Version
0.0.5
### Type
command
### Description
Component to generate data from teacher model enpoint
## Inputs
| Name | Description | Type | Optional |
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
| teacher_model_stop | Teacher model stop inference parameter. | string | True
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
| validation_output | Validation status from validation component. | uri_file | True
## Outputs
| Name | Description | Type |
| -------------------- | -------------------------------------------------------- | ------------ |
| generated_train_file_path | Generated training data. | uri_file |
| generated_validation_file_path | Generated validation data. | uri_file |

Просмотреть файл

@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.4
version: 0.0.5
type: command
is_deterministic: True
@ -8,7 +8,7 @@ is_deterministic: True
display_name: OSS Distillation Generate Data
description: Component to generate data from teacher model enpoint
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66
inputs:
# Inputs
@ -97,7 +97,14 @@ inputs:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
# Output of validation component.
validation_output:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount
outputs:
generated_train_file_path:
type: uri_file
@ -108,7 +115,7 @@ outputs:
description: Generated validation data
mode: rw_mount
code: src/
code: ../../src
command: >-
python generate_data.py
--train_file_path ${{inputs.train_file_path}}

Просмотреть файл

@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_pipeline
version: 0.0.5
version: 0.0.6
type: pipeline
@ -9,6 +9,10 @@ description: Component to generate data from teacher model enpoint and finetune
inputs:
# Compute parameters
instance_type_pipeline_validation:
type: string
optional: True
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
instance_type_data_generation:
type: string
optional: true
@ -25,6 +29,12 @@ inputs:
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
compute_pipeline_validation:
type: string
optional: True
default: 'serverless'
description: compute to be used for validation component
compute_data_generation:
type: string
optional: true
@ -50,8 +60,8 @@ inputs:
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
# ########################### Data Generator Component ########################### #
## OSS Data generator Input Parameters
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
@ -138,7 +148,8 @@ inputs:
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
## OSS Finetune Input Parameters
# ########################### Finetuning Component ########################### #
number_of_gpu_to_use_finetuning:
type: integer
default: 1
@ -203,6 +214,37 @@ outputs:
mode: rw_mount
jobs:
oss_distillation_validate_pipeline:
type: command
component: azureml:oss_distillation_validate_pipeline:0.0.1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity:
type: user_identity
inputs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
outputs:
validation_info:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.4
@ -226,6 +268,7 @@ jobs:
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
outputs:
generated_train_file_path:
type: uri_file

Просмотреть файл

@ -0,0 +1,46 @@
## Pipeline Validation Component
### Name
oss_distillation_validate_pipeline
### Version
0.0.1
### Type
command
### Description
Component to validate all inputs to the distillation pipeline.
## Inputs
| Name | Description | Type | Optional |
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
| teacher_model_stop | Teacher model stop inference parameter. | string | True
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
| num_train_epochs | Number of training epochs. | string | True
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
| per_device_train_batch_size | Train batch size. | integer | True
| learning_rate | Start learning rate. | number | True
## Outputs
| Name | Description | Type |
| -------------------- | -------------------------------------------------------- | ------------ |
| validation_info | Validation status file. | uri_file |

Просмотреть файл

@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation"]

Просмотреть файл

@ -0,0 +1,147 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_validate_pipeline
version: 0.0.1
type: command
is_deterministic: true
display_name: OSS Distillation Validate Pipeline
description: Component to validate inputs to the distillation pipeline
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66
code: ../../src
inputs:
# Inputs
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount
teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name
teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL
teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key
teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter
teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter
teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter
teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter
teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter
teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter
request_batch_size:
type: integer
default: 10
description: No of data records to hit teacher model endpoint in one go
min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
default: "true"
description: Enable Chain of thought for data generation
data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
num_train_epochs:
type: integer
default: 1
optional: true
description: training epochs
per_device_train_batch_size:
type: integer
default: 1
optional: true
description: Train batch size
learning_rate:
type: number
default: 3e-04
optional: true
description: Start learning rate.
outputs:
validation_info:
type: uri_file
description: Validation status.
mode: rw_mount
command: >-
python validate_pipeline.py
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
--request_batch_size ${{inputs.request_batch_size}}
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}
--data_generation_task_type ${{inputs.data_generation_task_type}}
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
--validation_info ${{outputs.validation_info}}

Просмотреть файл

@ -5,6 +5,7 @@
import re
from enum import EnumMeta, Enum
# COMPONENT META
COMPONENT_NAME = "oss_distillation_generate_data"
@ -52,6 +53,7 @@ HFTV2_TEXT_GEN_SCORE_PATH = "/score"
DEFAULT_SUCCESS_RATIO = 0.7
DEFAULT_REQUEST_BATCH_SIZE = 10
MAX_BATCH_SIZE = 100
MIN_RECORDS_FOR_FT = 65
# VLLM INFERENCE KEYS
TOP_P = "top_p"
@ -114,3 +116,24 @@ class TelemetryConstants:
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
PROCESS_DATASET_RECORD = "process_dataset_record"
VALIDATOR = "validator"
ML_CLIENT_INITIALISATION = "ml_client_initialisation"
VALIDATE_DATA_GENERATION_INPUTS = "validate_data_generation_inputs"
VALIDATE_FILE_PATH = "validate_file_path"
VALIDATE_TEACHER_MODEL_ENDPOINT = "validate_teacher_model_endpoint"
VALIDATE_INFERENCE_PARAMETERS = "validate_inference_parameters"
VALIDATE_TRAINING_DATA = "validate_training_data"
VALIDATE_VALIDATION_DATA = "validate_validation_data"
VALIDATE_MODEL_INFERENCE = "validate_model_inference"
class BackoffConstants:
"""Defaults for retry with exponential backoff."""
MAX_RETRIES = 3
BASE_DELAY = 10
MAX_DELAY = 600
BACKOFF_FACTOR = 2
MAX_TIMEOUT_SEC = 180
RETRYABLE_STATUS_CODES = {413, 429, 500, 502, 503, 504, None}

Просмотреть файл

@ -5,22 +5,28 @@
import os
import time
from typing import List, Tuple, Union, Optional, Callable
from urllib.parse import urlparse
from abc import ABC, abstractmethod
from azure.ai.ml import MLClient
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, ServerlessEndpoint
from azure.identity import AzureCliCredential, ManagedIdentityCredential
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError
from azureml._common._error_definition.azureml_error import AzureMLError
from azureml.acft.common_components import get_logger_app
from azureml.core import Run, Workspace
from azureml.core.run import _OfflineRun
from typing import List, Tuple, Union
from common.constants import (
REQUESTS_RETRY_DELAY,
REGISTRY_MODEL_PATTERN,
SUPPORTED_STUDENT_MODEL_MAP,
SUPPORTED_TEACHER_MODEL_MAP,
BackoffConstants
)
@ -348,3 +354,91 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: Tuple containing registry name, model name and model version
"""
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
def get_base_url(url: str) -> str:
"""Get base url."""
if not url:
return url
parse_result = urlparse(url)
return f"{parse_result.scheme}://{parse_result.netloc}"
def _get_status_code(e: Exception) -> Optional[int]:
"""
Get the status code from the exception.
:param e: Exception.
:return: Status code.
"""
status_code = getattr(e, "status_code", None)
if status_code is None and getattr(e, "response", None) is not None:
status_code = getattr(e.response, "status_code", None)
return status_code
def exponential_backoff(
max_retries: int = BackoffConstants.MAX_RETRIES,
base_delay: int = BackoffConstants.BASE_DELAY,
max_delay: int = BackoffConstants.MAX_DELAY,
backoff_factor: int = BackoffConstants.BACKOFF_FACTOR,
) -> Callable:
"""
Implement exponential backoff for retrying a function for a HTTP request.
Use this function as a decorator.
:param max_retries: Maximum number of retries.
:param base_delay: Base delay in seconds before the first retry.
:param max_delay: Maximum delay in seconds between retries.
:return: Decorated function.
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
retries = 0
delay = base_delay
while retries <= max_retries:
try:
tick = time.time()
return func(*args, **kwargs)
except Exception as e:
tock = time.time()
status_code = _get_status_code(e)
if status_code not in BackoffConstants.RETRYABLE_STATUS_CODES:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"Encountered unknown status code: {status_code}. ",
)
)
)
retries += 1
if retries <= max_retries:
backoff_delay = min(delay, max_delay)
logger.info(
(
f"Retrying method `{func.__name__}` after {backoff_delay} sec. "
f"Retry attempt: {retries}/{max_retries}. "
f"Time spent: {round(tock - tick)} sec. "
f"Error details: {e}"
)
)
time.sleep(backoff_delay)
delay *= backoff_factor
else:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"Request failed after multiple tries with status code: {status_code}. ",
)
)
)
return wrapper
return decorator

Просмотреть файл

@ -0,0 +1,133 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Component validation utils."""
from pathlib import Path
from typing import List, Optional
from azureml.acft.common_components.utils.error_handling.exceptions import (
ACFTValidationException,
)
from azureml.acft.common_components.utils.error_handling.error_definitions import (
ACFTUserError,
)
from azureml._common._error_definition.azureml_error import AzureMLError
from common.constants import SUPPORTED_FILE_FORMATS, MAX_BATCH_SIZE
def validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
"""Check if the file path is in the list of supported formats."""
for file_path in file_paths:
if file_path:
file_suffix = Path(file_path).suffix.lower()
file_ext = file_suffix.split("?")[0]
if file_ext and file_ext not in SUPPORTED_FILE_FORMATS:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"{file_path} is not in list of supported file formats. "
f"Supported file formats: {SUPPORTED_FILE_FORMATS}"
),
)
)
def validate_file_exists(file_paths: List[Optional[str]]):
"""Check if the file paths exist."""
for file_path in file_paths:
if file_path:
file = Path(file_path)
if not file.exists():
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(f"File {file_path} does not exist."),
)
)
def validate_model_temperature(temperature: float):
"""Validate if model temperature is well within limits."""
if temperature < 0 or temperature > 1:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid teacher_model_temperature. ",
f"Value should 0<=val<=1, but is {temperature}",
),
)
)
def validate_model_top_p(top_p: float):
"""Validate if model top_p is well within limits."""
if top_p < 0 or top_p > 1:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid teacher_model_top_p. ",
f"Value should be 0<=val<=1, but is {top_p}",
),
)
)
def validate_model_frequency_penalty(val: float):
"""Validate if model frequency penalty is well within limits."""
if val and (val < 0 or val > 2):
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid teacher_model_frequency_penalty. ",
f"Value should be 0<=val<=2, but is {val}",
),
)
)
def validate_model_presence_penalty(val: float):
"""Validate if model presence penalty is well within limits."""
if val and (val < 0 or val > 2):
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid teacher_model_presence_penalty. ",
f"Value should be 0<=val<=2, but is {val}",
),
)
)
def validate_request_batch_size(val: int):
"""Validate if requested batch size is well within limits."""
if val and (val <= 0 or val > MAX_BATCH_SIZE):
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid request_batch_size. ",
f"Value should be 0<=val<={MAX_BATCH_SIZE}, but is {val}",
),
)
)
def validate_min_endpoint_success_ratio(val: int):
"""Validate if requested endpoint success ration is well within limits."""
if val and (val < 0 or val > 1):
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Invalid min_endpoint_success_ration. ",
f"Value sould be 0<=val<=1, but is {val}",
),
)
)

Просмотреть файл

@ -12,17 +12,14 @@ import requests
from argparse import Namespace
from requests import Response
from pathlib import Path
from typing import List, Optional
from typing import List
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS
from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from azureml._common._error_definition.azureml_error import AzureMLError
from azureml.telemetry.activity import log_activity, monitor_with_activity
from concurrent.futures import ThreadPoolExecutor, as_completed
@ -42,7 +39,6 @@ from common.constants import (
TEMPERATURE,
TOP_P,
STOP_TOKEN,
SUPPORTED_FILE_FORMATS,
VLLM_CHAT_SCORE_PATH,
DataGenerationTaskType,
TelemetryConstants
@ -55,6 +51,9 @@ from common.utils import (
retry,
)
from common.validation import (
validate_file_paths_with_supported_formats
)
logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import")
@ -229,24 +228,6 @@ def _invoke_endpoint(url: str, key: str, data: dict, log_entry: dict = None) ->
return requests.post(url, headers=request_headers, data=json.dumps(data))
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
"""Check if the file path is in the list of supported formats."""
for file_path in file_paths:
if file_path:
file_suffix = Path(file_path).suffix.lower()
file_ext = file_suffix.split('?')[0]
if file_ext and file_ext not in SUPPORTED_FILE_FORMATS:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"{file_path} is not in list of supported file formats. "
f"Supported file formats: {SUPPORTED_FILE_FORMATS}"
)
)
)
def generate_synthetic_data(
teacher_model_endpoint_url: str,
teacher_model_endpoint_key: str,
@ -504,7 +485,7 @@ def data_import(args: Namespace):
data_generation_task_type = args.data_generation_task_type
# validate file formats
_validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path])
validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path])
logger.info("File format validation successful.")
enable_cot = True if enable_cot_str.lower() == "true" else False

Просмотреть файл

@ -0,0 +1,432 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Script for validating distillation pipeline arguments."""
import logging
import requests
import pandas as pd
import json
from argparse import Namespace
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
from azureml.acft.contrib.hf.nlp.constants.constants import (
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
)
from azureml.acft.common_components import (
get_logger_app,
set_logging_parameters,
LoggingLiterals,
)
from azureml.acft.common_components.utils.error_handling.exceptions import (
ACFTValidationException,
)
from azureml.acft.common_components.utils.error_handling.error_definitions import (
ACFTUserError,
)
from azureml.telemetry.activity import log_activity
from azureml._common._error_definition.azureml_error import AzureMLError
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
swallow_all_exceptions,
)
from generate_data import get_parser
from common.constants import (
DataGenerationTaskType,
TelemetryConstants,
MAX_NEW_TOKENS,
TEMPERATURE,
TOP_P,
VLLM_CHAT_SCORE_PATH,
MIN_RECORDS_FOR_FT,
)
from common.utils import (
get_endpoint_details,
get_workspace_mlclient,
get_base_url,
validate_teacher_model_details,
exponential_backoff,
)
from common.validation import (
validate_file_paths_with_supported_formats,
validate_file_exists,
validate_model_temperature,
validate_model_top_p,
validate_model_frequency_penalty,
validate_model_presence_penalty,
validate_request_batch_size,
validate_min_endpoint_success_ratio,
)
logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
)
COMPONENT_NAME = "oss_distillation_validate_pipeline"
class PipelineInputsValidator:
"""Dataclass for validating inputs to distillation pipeline."""
def __init__(self, args: Namespace) -> None:
"""Initialise validator.
Args:
args (Namespace): Inputs flags to validate.
"""
self._args = args
with log_activity(
logger=logger, activity_name=TelemetryConstants.ML_CLIENT_INITIALISATION
):
ws_mlclient = get_workspace_mlclient()
if not ws_mlclient:
raise Exception("Could not create MLClient for current workspace")
self._mlclient = ws_mlclient
with log_activity(
logger=logger,
activity_name=TelemetryConstants.VALIDATE_DATA_GENERATION_INPUTS,
):
self._validate_data_generation_inputs()
def _get_dataframe(self, file_path: str):
return pd.read_json(
file_path, lines=True, chunksize=self._args.request_batch_size
)
def _get_inference_request_headers(self) -> dict:
key = self._args.teacher_model_endpoint_key
return {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
def _get_cot_status(self) -> bool:
cot_enabled = self._args.enable_chain_of_thought
return cot_enabled.lower() == "true"
def _validate_model_endpoint_args(self):
endpoint_name = self._args.teacher_model_endpoint_name
if endpoint_name:
endpoint_details = get_endpoint_details(
mlclient_ws=self._mlclient, endpoint_name=endpoint_name
)
self._args.teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
self._args.teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
model_asset_id = endpoint_details.get_deployed_model_id()
validate_teacher_model_details(model_asset_id)
if (
not self._args.teacher_model_endpoint_url
or not self._args.teacher_model_endpoint_key
):
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Endpoint URL and key are required fields for data generation."
),
)
)
@exponential_backoff()
def _validate_model_endpoint(self):
"""Validate model endpoints availability by retrieving its details."""
base_url = get_base_url(self._args.teacher_model_endpoint_url)
request_headers = self._get_inference_request_headers()
# https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-info
response = requests.get(url=f"{base_url}/info", headers=request_headers)
response.raise_for_status()
response_data = response.json()
model_name = response_data.get("model_name")
logger.info(f"Model validated, model name - {model_name}")
@exponential_backoff()
def _validate_model_inference(self):
"""Validate a sample inference call.
Raises:
HTTPError: If one occured.
"""
# Prep data.
df = self._get_dataframe(file_path=self._args.train_file_path)
batch = next(df)
record = batch.iloc[0].to_dict()
# Build inference payload
inference_params = {
MAX_NEW_TOKENS: self._args.teacher_model_max_new_tokens,
TEMPERATURE: self._args.teacher_model_temperature,
TOP_P: self._args.teacher_model_top_p,
**record,
}
headers = self._get_inference_request_headers()
url = self._args.teacher_model_endpoint_url
url = url if VLLM_CHAT_SCORE_PATH in url else f"{url}{VLLM_CHAT_SCORE_PATH}"
logger.info(f"Model endpoint: {url}")
response = requests.post(
url=url, headers=headers, data=json.dumps(inference_params)
)
response.raise_for_status()
def _validate_inference_parameters(self):
"""Validate all body parameters passed as part of inference."""
validate_model_temperature(self._args.teacher_model_temperature)
validate_model_top_p(self._args.teacher_model_top_p)
validate_model_presence_penalty(self._args.teacher_model_presence_penalty)
validate_model_frequency_penalty(self._args.teacher_model_frequency_penalty)
validate_request_batch_size(self._args.request_batch_size)
validate_min_endpoint_success_ratio(self._args.min_endpoint_success_ratio)
def _validate_number_of_records(self, size: int):
"""Validate number of records in the dataset."""
if size < MIN_RECORDS_FOR_FT:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Number of records in the dataset are less than the minimum required for fine-tuning."
f" Minimum records required: {MIN_RECORDS_FOR_FT}, but got {size}."
),
)
)
def _validate_record_for_type_conversation(self, record: list) -> str:
if self._args.data_generation_task_type != DataGenerationTaskType.CONVERSATION:
return
if self._get_cot_status():
return f"Chain of thought is not supported for task type {DataGenerationTaskType.CONVERSATION}"
if len(record) < 3:
return f"Dataset is not matching expected schema for task type {DataGenerationTaskType.CONVERSATION}. \
Expected format: [system, user, assistant]"
def _validate_record_for_type_NLI(self, record: list) -> str:
if self._args.data_generation_task_type != DataGenerationTaskType.NLI:
return
if len(record) > 2:
return f"Chat cannot be of type multi-turn for task type {DataGenerationTaskType.NLI}. \
Expected format: [system, user]"
def _validate_record_for_type_NLU_QA(self, record: list) -> str:
if (
self._args.data_generation_task_type
!= DataGenerationTaskType.NLU_QUESTION_ANSWERING
):
return
if len(record) > 2:
return f"Chat cannot be of type multi-turn for task type {DataGenerationTaskType.NLU_QUESTION_ANSWERING} \
Expected format: [system, user]"
def _validate_record_by_task(self, record: list) -> dict:
"""
Validate record in a dataset against the data generation task type.
Returns a dictionary containing exception if any validation error is found.
Args:
record (list): Sequence of messages
"""
validation_methods = [
self._validate_record_for_type_NLI,
self._validate_record_for_type_conversation,
self._validate_record_for_type_NLU_QA,
]
for method in validation_methods:
err = method(record=record)
if err:
return {"exception": err}
def _validate_message(self, id: int, message: dict) -> dict:
"""
Validate individual message in the dataset.
Returns dictionary containing exception, if any validation error is found.
Args:
id (int): id of the message in sequence of messages.
message (dict): Message object in sequence of messages.
"""
allowed_roles = ["system", "user", "assistant"]
if "role" not in message:
return f"Message at index {id} is missing 'role'."
if message["role"] not in allowed_roles:
return f"Invalid 'role' at index {id}."
if "content" not in message:
return f"Message at index {id} is missing 'content'."
def _validate_record_content(self, record: list) -> dict:
"""
Validate content of a record and ensures messages are in the expected format.
Currently functional only for task type `CONVERSATION`, `NLI` & `NLU`.
Returns dictionary containing exception, if any validation error is found.
Args:
record (list): Sequence of messages
"""
try:
if record[0].get("role") != "system":
role = record[0].get("role")
return {
"exception": f"First message should be of role 'system' but got {role}."
}
expected_roles = ["user", "assistant"]
for id, message in enumerate(record[1:], start=1):
if not isinstance(message, dict):
return {
"exception": f"Message at index {id} should be a dictionary."
}
err = self._validate_message(id=id, message=message)
if err:
return {"exception": err}
expected_role = expected_roles[(id - 1) % 2]
if message.get("role") != expected_role:
return {
"exception": f"Role at index {id} should be {expected_role}."
}
task_type = self._args.data_generation_task_type
if task_type == DataGenerationTaskType.CONVERSATION and (
len(record[1:]) % 2 != 0
):
return {
"exception": "There is an incomplete pair of 'user' and 'assistant' messages."
}
except Exception as e:
return {"exception": e}
def _validate_dataset_record(self, record: list) -> str:
"""Validate a record in the dataset. Returns the validation error if found.
Args:
record (list): Sequence of messages
"""
if not record:
return "Chat cannot be empty."
err = self._validate_record_by_task(record=record)
if err and ("exception" in err):
return err["exception"]
err = self._validate_record_content(record=record)
if err and ("exception" in err):
return err["exception"]
def _validate_dataset(self, file_path: str):
"""Validate training/validation dataset passed to the data-generation component.
Args:
file_path (str): Path to the dataset
Raises:
ACFTUserError: If a known validation error is caught
"""
df = self._get_dataframe(file_path=file_path)
total_rows = 0
for batch in df:
total_rows += len(batch)
for idx, row in batch.iterrows():
record = row.iloc[0]
err = self._validate_dataset_record(record=record)
if err:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"Error validating dataset record, context({idx}): {err}"
),
)
)
self._validate_number_of_records(size=total_rows)
def _validate_data_generation_inputs(self):
"""Validate all input flags to the data-generation component.
Sequentially performs a set of validations, each dependent on the previous validation.
1. Validate training/validation file paths and ensure files exist.
2. Validate teacher model endpoint arguments are passed for inference, and
authenticity of the endpoint.
3. Validate that the passed inference parameters are within limits.
4. Validate integrity of datasets.
5. Validate a single inference call to the teacher model.
"""
with log_activity(
logger=logger, activity_name=TelemetryConstants.VALIDATE_FILE_PATH
):
files = [self._args.train_file_path, self._args.validation_file_path]
validate_file_paths_with_supported_formats(file_paths=files)
validate_file_exists(file_paths=files)
with log_activity(
logger=logger,
activity_name=TelemetryConstants.VALIDATE_TEACHER_MODEL_ENDPOINT,
):
self._validate_model_endpoint_args()
self._validate_model_endpoint()
with log_activity(
logger=logger,
activity_name=TelemetryConstants.VALIDATE_INFERENCE_PARAMETERS,
):
self._validate_inference_parameters()
with log_activity(
logger=logger, activity_name=TelemetryConstants.VALIDATE_TRAINING_DATA
):
self._validate_dataset(self._args.train_file_path)
if self._args.validation_file_path:
with log_activity(
logger=logger, activity_name=TelemetryConstants.VALIDATE_VALIDATION_DATA
):
self._validate_dataset(self._args.validation_file_path)
with log_activity(
logger=logger, activity_name=TelemetryConstants.VALIDATE_MODEL_INFERENCE
):
self._validate_model_inference()
@swallow_all_exceptions(time_delay=5)
def main():
"""Run validation."""
# Get data generation component input parameters.
parser = get_parser()
parser.add_argument("--validation_info", required=True, help="Validation status")
args, _ = parser.parse_known_args()
set_logging_parameters(
task_type="DistillationPipelineValidation",
acft_custom_dimensions={
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
},
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
log_level=logging.INFO,
)
with log_activity(logger=logger, activity_name=TelemetryConstants.VALIDATOR):
PipelineInputsValidator(args=args)
if args.validation_info:
with open(args.validation_info, "w") as f:
f.write(json.dumps({"validation_status": "ok"}))
if __name__ == "__main__":
main()