feat: introduce validation component for distillation pipeline (#3284)
This commit is contained in:
Родитель
326b05d104
Коммит
7c566f9458
|
@ -0,0 +1,45 @@
|
|||
## Data Generation Component
|
||||
|
||||
### Name
|
||||
|
||||
oss_distillation_generate_data
|
||||
|
||||
### Version
|
||||
|
||||
0.0.5
|
||||
|
||||
### Type
|
||||
|
||||
command
|
||||
|
||||
### Description
|
||||
|
||||
Component to generate data from teacher model enpoint
|
||||
|
||||
## Inputs
|
||||
|
||||
| Name | Description | Type | Optional |
|
||||
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
|
||||
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
|
||||
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
|
||||
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
|
||||
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
|
||||
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
|
||||
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
|
||||
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
|
||||
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
|
||||
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
|
||||
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
|
||||
| teacher_model_stop | Teacher model stop inference parameter. | string | True
|
||||
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
|
||||
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
|
||||
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
|
||||
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
|
||||
| validation_output | Validation status from validation component. | uri_file | True
|
||||
|
||||
## Outputs
|
||||
|
||||
| Name | Description | Type |
|
||||
| -------------------- | -------------------------------------------------------- | ------------ |
|
||||
| generated_train_file_path | Generated training data. | uri_file |
|
||||
| generated_validation_file_path | Generated validation data. | uri_file |
|
|
@ -1,6 +1,6 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||
name: oss_distillation_generate_data
|
||||
version: 0.0.4
|
||||
version: 0.0.5
|
||||
type: command
|
||||
|
||||
is_deterministic: True
|
||||
|
@ -8,7 +8,7 @@ is_deterministic: True
|
|||
display_name: OSS Distillation Generate Data
|
||||
description: Component to generate data from teacher model enpoint
|
||||
|
||||
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63
|
||||
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66
|
||||
|
||||
inputs:
|
||||
# Inputs
|
||||
|
@ -97,7 +97,14 @@ inputs:
|
|||
1. NLI: Generate Natural Language Inference data
|
||||
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||
|
||||
|
||||
# Output of validation component.
|
||||
validation_output:
|
||||
type: uri_file
|
||||
optional: true
|
||||
description: Validation status.
|
||||
mode: rw_mount
|
||||
|
||||
outputs:
|
||||
generated_train_file_path:
|
||||
type: uri_file
|
||||
|
@ -108,7 +115,7 @@ outputs:
|
|||
description: Generated validation data
|
||||
mode: rw_mount
|
||||
|
||||
code: src/
|
||||
code: ../../src
|
||||
command: >-
|
||||
python generate_data.py
|
||||
--train_file_path ${{inputs.train_file_path}}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
|
||||
name: oss_distillation_pipeline
|
||||
version: 0.0.5
|
||||
version: 0.0.6
|
||||
type: pipeline
|
||||
|
||||
|
||||
|
@ -9,6 +9,10 @@ description: Component to generate data from teacher model enpoint and finetune
|
|||
|
||||
inputs:
|
||||
# Compute parameters
|
||||
instance_type_pipeline_validation:
|
||||
type: string
|
||||
optional: True
|
||||
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
|
||||
instance_type_data_generation:
|
||||
type: string
|
||||
optional: true
|
||||
|
@ -25,6 +29,12 @@ inputs:
|
|||
default: Singularity.ND96amrs_A100_v4
|
||||
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used
|
||||
|
||||
compute_pipeline_validation:
|
||||
type: string
|
||||
optional: True
|
||||
default: 'serverless'
|
||||
description: compute to be used for validation component
|
||||
|
||||
compute_data_generation:
|
||||
type: string
|
||||
optional: true
|
||||
|
@ -50,8 +60,8 @@ inputs:
|
|||
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
|
||||
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
|
||||
|
||||
# ########################### Data Generator Component ########################### #
|
||||
|
||||
## OSS Data generator Input Parameters
|
||||
train_file_path:
|
||||
type: uri_file
|
||||
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||
|
@ -138,7 +148,8 @@ inputs:
|
|||
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||
|
||||
## OSS Finetune Input Parameters
|
||||
# ########################### Finetuning Component ########################### #
|
||||
|
||||
number_of_gpu_to_use_finetuning:
|
||||
type: integer
|
||||
default: 1
|
||||
|
@ -203,6 +214,37 @@ outputs:
|
|||
mode: rw_mount
|
||||
|
||||
jobs:
|
||||
oss_distillation_validate_pipeline:
|
||||
type: command
|
||||
component: azureml:oss_distillation_validate_pipeline:0.0.1
|
||||
compute: '${{parent.inputs.compute_pipeline_validation}}'
|
||||
resources:
|
||||
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
|
||||
identity:
|
||||
type: user_identity
|
||||
inputs:
|
||||
train_file_path: '${{parent.inputs.train_file_path}}'
|
||||
validation_file_path: '${{parent.inputs.validation_file_path}}'
|
||||
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
|
||||
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
|
||||
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
|
||||
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
|
||||
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
|
||||
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
|
||||
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
|
||||
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
|
||||
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
|
||||
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||
request_batch_size: '${{parent.inputs.request_batch_size}}'
|
||||
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
|
||||
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
|
||||
learning_rate: '${{parent.inputs.learning_rate}}'
|
||||
outputs:
|
||||
validation_info:
|
||||
type: uri_file
|
||||
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json
|
||||
|
||||
oss_distillation_generate_data:
|
||||
type: command
|
||||
component: azureml:oss_distillation_generate_data:0.0.4
|
||||
|
@ -226,6 +268,7 @@ jobs:
|
|||
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
|
||||
request_batch_size: '${{parent.inputs.request_batch_size}}'
|
||||
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
|
||||
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
|
||||
outputs:
|
||||
generated_train_file_path:
|
||||
type: uri_file
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
## Pipeline Validation Component
|
||||
|
||||
### Name
|
||||
|
||||
oss_distillation_validate_pipeline
|
||||
|
||||
### Version
|
||||
|
||||
0.0.1
|
||||
|
||||
### Type
|
||||
|
||||
command
|
||||
|
||||
### Description
|
||||
|
||||
Component to validate all inputs to the distillation pipeline.
|
||||
|
||||
## Inputs
|
||||
|
||||
| Name | Description | Type | Optional |
|
||||
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
|
||||
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
|
||||
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
|
||||
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
|
||||
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
|
||||
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
|
||||
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
|
||||
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
|
||||
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
|
||||
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
|
||||
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
|
||||
| teacher_model_stop | Teacher model stop inference parameter. | string | True
|
||||
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
|
||||
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
|
||||
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
|
||||
| num_train_epochs | Number of training epochs. | string | True
|
||||
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
|
||||
| per_device_train_batch_size | Train batch size. | integer | True
|
||||
| learning_rate | Start learning rate. | number | True
|
||||
|
||||
## Outputs
|
||||
|
||||
| Name | Description | Type |
|
||||
| -------------------- | -------------------------------------------------------- | ------------ |
|
||||
| validation_info | Validation status file. | uri_file |
|
|
@ -0,0 +1,3 @@
|
|||
type: component
|
||||
spec: spec.yaml
|
||||
categories: ["Foundational Models", "Finetune", "Distillation"]
|
|
@ -0,0 +1,147 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||
name: oss_distillation_validate_pipeline
|
||||
version: 0.0.1
|
||||
type: command
|
||||
|
||||
is_deterministic: true
|
||||
|
||||
display_name: OSS Distillation Validate Pipeline
|
||||
description: Component to validate inputs to the distillation pipeline
|
||||
|
||||
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66
|
||||
|
||||
code: ../../src
|
||||
|
||||
inputs:
|
||||
# Inputs
|
||||
train_file_path:
|
||||
type: uri_file
|
||||
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||
mode: rw_mount
|
||||
|
||||
validation_file_path:
|
||||
type: uri_file
|
||||
optional: true
|
||||
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
|
||||
mode: rw_mount
|
||||
|
||||
teacher_model_endpoint_name:
|
||||
type: string
|
||||
optional: true
|
||||
description: Teacher model endpoint name
|
||||
|
||||
teacher_model_endpoint_url:
|
||||
type: string
|
||||
optional: true
|
||||
description: Teacher model endpoint URL
|
||||
|
||||
teacher_model_endpoint_key:
|
||||
type: string
|
||||
optional: true
|
||||
description: Teacher model endpoint key
|
||||
|
||||
teacher_model_max_new_tokens:
|
||||
type: integer
|
||||
default: 128
|
||||
description: Teacher model max_new_tokens inference parameter
|
||||
|
||||
teacher_model_temperature:
|
||||
type: number
|
||||
default: 0.2
|
||||
description: Teacher model temperature inference parameter
|
||||
|
||||
teacher_model_top_p:
|
||||
type: number
|
||||
default: 0.1
|
||||
description: Teacher model top_p inference parameter
|
||||
|
||||
teacher_model_frequency_penalty:
|
||||
type: number
|
||||
default: 0.0
|
||||
description: Teacher model frequency penalty inference parameter
|
||||
|
||||
teacher_model_presence_penalty:
|
||||
type: number
|
||||
default: 0.0
|
||||
description: Teacher model presence penalty inference parameter
|
||||
|
||||
teacher_model_stop:
|
||||
type: string
|
||||
optional: true
|
||||
description: Teacher model stop inference parameter
|
||||
|
||||
request_batch_size:
|
||||
type: integer
|
||||
default: 10
|
||||
description: No of data records to hit teacher model endpoint in one go
|
||||
|
||||
min_endpoint_success_ratio:
|
||||
type: number
|
||||
default: 0.7
|
||||
description: >
|
||||
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
|
||||
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
|
||||
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
|
||||
|
||||
enable_chain_of_thought:
|
||||
type: string
|
||||
default: "true"
|
||||
description: Enable Chain of thought for data generation
|
||||
|
||||
data_generation_task_type:
|
||||
type: string
|
||||
enum:
|
||||
- NLI
|
||||
- CONVERSATION
|
||||
- NLU_QA
|
||||
description: >
|
||||
Data generation task type. Supported values are:
|
||||
1. NLI: Generate Natural Language Inference data
|
||||
2. CONVERSATION: Generate conversational data (multi/single turn)
|
||||
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
|
||||
|
||||
num_train_epochs:
|
||||
type: integer
|
||||
default: 1
|
||||
optional: true
|
||||
description: training epochs
|
||||
|
||||
per_device_train_batch_size:
|
||||
type: integer
|
||||
default: 1
|
||||
optional: true
|
||||
description: Train batch size
|
||||
|
||||
learning_rate:
|
||||
type: number
|
||||
default: 3e-04
|
||||
optional: true
|
||||
description: Start learning rate.
|
||||
|
||||
outputs:
|
||||
validation_info:
|
||||
type: uri_file
|
||||
description: Validation status.
|
||||
mode: rw_mount
|
||||
|
||||
command: >-
|
||||
python validate_pipeline.py
|
||||
--train_file_path ${{inputs.train_file_path}}
|
||||
$[[--validation_file_path ${{inputs.validation_file_path}}]]
|
||||
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
|
||||
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
|
||||
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
|
||||
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
|
||||
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
|
||||
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
|
||||
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
|
||||
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
|
||||
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
|
||||
--request_batch_size ${{inputs.request_batch_size}}
|
||||
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
|
||||
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}
|
||||
--data_generation_task_type ${{inputs.data_generation_task_type}}
|
||||
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
|
||||
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
|
||||
$[[--learning_rate ${{inputs.learning_rate}}]]
|
||||
--validation_info ${{outputs.validation_info}}
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import re
|
||||
from enum import EnumMeta, Enum
|
||||
|
||||
# COMPONENT META
|
||||
COMPONENT_NAME = "oss_distillation_generate_data"
|
||||
|
||||
|
@ -52,6 +53,7 @@ HFTV2_TEXT_GEN_SCORE_PATH = "/score"
|
|||
DEFAULT_SUCCESS_RATIO = 0.7
|
||||
DEFAULT_REQUEST_BATCH_SIZE = 10
|
||||
MAX_BATCH_SIZE = 100
|
||||
MIN_RECORDS_FOR_FT = 65
|
||||
|
||||
# VLLM INFERENCE KEYS
|
||||
TOP_P = "top_p"
|
||||
|
@ -114,3 +116,24 @@ class TelemetryConstants:
|
|||
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
|
||||
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
|
||||
PROCESS_DATASET_RECORD = "process_dataset_record"
|
||||
|
||||
VALIDATOR = "validator"
|
||||
ML_CLIENT_INITIALISATION = "ml_client_initialisation"
|
||||
VALIDATE_DATA_GENERATION_INPUTS = "validate_data_generation_inputs"
|
||||
VALIDATE_FILE_PATH = "validate_file_path"
|
||||
VALIDATE_TEACHER_MODEL_ENDPOINT = "validate_teacher_model_endpoint"
|
||||
VALIDATE_INFERENCE_PARAMETERS = "validate_inference_parameters"
|
||||
VALIDATE_TRAINING_DATA = "validate_training_data"
|
||||
VALIDATE_VALIDATION_DATA = "validate_validation_data"
|
||||
VALIDATE_MODEL_INFERENCE = "validate_model_inference"
|
||||
|
||||
|
||||
class BackoffConstants:
|
||||
"""Defaults for retry with exponential backoff."""
|
||||
|
||||
MAX_RETRIES = 3
|
||||
BASE_DELAY = 10
|
||||
MAX_DELAY = 600
|
||||
BACKOFF_FACTOR = 2
|
||||
MAX_TIMEOUT_SEC = 180
|
||||
RETRYABLE_STATUS_CODES = {413, 429, 500, 502, 503, 504, None}
|
|
@ -5,22 +5,28 @@
|
|||
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple, Union, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from azure.ai.ml import MLClient
|
||||
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
|
||||
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, ServerlessEndpoint
|
||||
from azure.identity import AzureCliCredential, ManagedIdentityCredential
|
||||
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException
|
||||
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError
|
||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||
from azureml.acft.common_components import get_logger_app
|
||||
from azureml.core import Run, Workspace
|
||||
from azureml.core.run import _OfflineRun
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
|
||||
from common.constants import (
|
||||
REQUESTS_RETRY_DELAY,
|
||||
REGISTRY_MODEL_PATTERN,
|
||||
SUPPORTED_STUDENT_MODEL_MAP,
|
||||
SUPPORTED_TEACHER_MODEL_MAP,
|
||||
BackoffConstants
|
||||
)
|
||||
|
||||
|
||||
|
@ -348,3 +354,91 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
|
|||
Tuple[str, str, str]: Tuple containing registry name, model name and model version
|
||||
"""
|
||||
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
|
||||
|
||||
|
||||
def get_base_url(url: str) -> str:
|
||||
"""Get base url."""
|
||||
if not url:
|
||||
return url
|
||||
|
||||
parse_result = urlparse(url)
|
||||
return f"{parse_result.scheme}://{parse_result.netloc}"
|
||||
|
||||
|
||||
def _get_status_code(e: Exception) -> Optional[int]:
|
||||
"""
|
||||
Get the status code from the exception.
|
||||
|
||||
:param e: Exception.
|
||||
:return: Status code.
|
||||
"""
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and getattr(e, "response", None) is not None:
|
||||
status_code = getattr(e.response, "status_code", None)
|
||||
return status_code
|
||||
|
||||
|
||||
def exponential_backoff(
|
||||
max_retries: int = BackoffConstants.MAX_RETRIES,
|
||||
base_delay: int = BackoffConstants.BASE_DELAY,
|
||||
max_delay: int = BackoffConstants.MAX_DELAY,
|
||||
backoff_factor: int = BackoffConstants.BACKOFF_FACTOR,
|
||||
) -> Callable:
|
||||
"""
|
||||
Implement exponential backoff for retrying a function for a HTTP request.
|
||||
|
||||
Use this function as a decorator.
|
||||
|
||||
:param max_retries: Maximum number of retries.
|
||||
:param base_delay: Base delay in seconds before the first retry.
|
||||
:param max_delay: Maximum delay in seconds between retries.
|
||||
:return: Decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
retries = 0
|
||||
delay = base_delay
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
tick = time.time()
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
tock = time.time()
|
||||
status_code = _get_status_code(e)
|
||||
if status_code not in BackoffConstants.RETRYABLE_STATUS_CODES:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
f"Encountered unknown status code: {status_code}. ",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
backoff_delay = min(delay, max_delay)
|
||||
logger.info(
|
||||
(
|
||||
f"Retrying method `{func.__name__}` after {backoff_delay} sec. "
|
||||
f"Retry attempt: {retries}/{max_retries}. "
|
||||
f"Time spent: {round(tock - tick)} sec. "
|
||||
f"Error details: {e}"
|
||||
)
|
||||
)
|
||||
time.sleep(backoff_delay)
|
||||
delay *= backoff_factor
|
||||
else:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
f"Request failed after multiple tries with status code: {status_code}. ",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Component validation utils."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from azureml.acft.common_components.utils.error_handling.exceptions import (
|
||||
ACFTValidationException,
|
||||
)
|
||||
from azureml.acft.common_components.utils.error_handling.error_definitions import (
|
||||
ACFTUserError,
|
||||
)
|
||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||
|
||||
from common.constants import SUPPORTED_FILE_FORMATS, MAX_BATCH_SIZE
|
||||
|
||||
|
||||
def validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
||||
"""Check if the file path is in the list of supported formats."""
|
||||
for file_path in file_paths:
|
||||
if file_path:
|
||||
file_suffix = Path(file_path).suffix.lower()
|
||||
file_ext = file_suffix.split("?")[0]
|
||||
if file_ext and file_ext not in SUPPORTED_FILE_FORMATS:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
f"{file_path} is not in list of supported file formats. "
|
||||
f"Supported file formats: {SUPPORTED_FILE_FORMATS}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_file_exists(file_paths: List[Optional[str]]):
|
||||
"""Check if the file paths exist."""
|
||||
for file_path in file_paths:
|
||||
if file_path:
|
||||
file = Path(file_path)
|
||||
if not file.exists():
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(f"File {file_path} does not exist."),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_model_temperature(temperature: float):
|
||||
"""Validate if model temperature is well within limits."""
|
||||
if temperature < 0 or temperature > 1:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid teacher_model_temperature. ",
|
||||
f"Value should 0<=val<=1, but is {temperature}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_model_top_p(top_p: float):
|
||||
"""Validate if model top_p is well within limits."""
|
||||
if top_p < 0 or top_p > 1:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid teacher_model_top_p. ",
|
||||
f"Value should be 0<=val<=1, but is {top_p}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_model_frequency_penalty(val: float):
|
||||
"""Validate if model frequency penalty is well within limits."""
|
||||
if val and (val < 0 or val > 2):
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid teacher_model_frequency_penalty. ",
|
||||
f"Value should be 0<=val<=2, but is {val}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_model_presence_penalty(val: float):
|
||||
"""Validate if model presence penalty is well within limits."""
|
||||
if val and (val < 0 or val > 2):
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid teacher_model_presence_penalty. ",
|
||||
f"Value should be 0<=val<=2, but is {val}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_request_batch_size(val: int):
|
||||
"""Validate if requested batch size is well within limits."""
|
||||
if val and (val <= 0 or val > MAX_BATCH_SIZE):
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid request_batch_size. ",
|
||||
f"Value should be 0<=val<={MAX_BATCH_SIZE}, but is {val}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_min_endpoint_success_ratio(val: int):
|
||||
"""Validate if requested endpoint success ration is well within limits."""
|
||||
if val and (val < 0 or val > 1):
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Invalid min_endpoint_success_ration. ",
|
||||
f"Value sould be 0<=val<=1, but is {val}",
|
||||
),
|
||||
)
|
||||
)
|
|
@ -12,17 +12,14 @@ import requests
|
|||
from argparse import Namespace
|
||||
from requests import Response
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
|
||||
from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS
|
||||
from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals
|
||||
from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException
|
||||
from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError
|
||||
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||
swallow_all_exceptions,
|
||||
)
|
||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||
from azureml.telemetry.activity import log_activity, monitor_with_activity
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
@ -42,7 +39,6 @@ from common.constants import (
|
|||
TEMPERATURE,
|
||||
TOP_P,
|
||||
STOP_TOKEN,
|
||||
SUPPORTED_FILE_FORMATS,
|
||||
VLLM_CHAT_SCORE_PATH,
|
||||
DataGenerationTaskType,
|
||||
TelemetryConstants
|
||||
|
@ -55,6 +51,9 @@ from common.utils import (
|
|||
retry,
|
||||
)
|
||||
|
||||
from common.validation import (
|
||||
validate_file_paths_with_supported_formats
|
||||
)
|
||||
|
||||
logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import")
|
||||
|
||||
|
@ -229,24 +228,6 @@ def _invoke_endpoint(url: str, key: str, data: dict, log_entry: dict = None) ->
|
|||
return requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
|
||||
|
||||
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
||||
"""Check if the file path is in the list of supported formats."""
|
||||
for file_path in file_paths:
|
||||
if file_path:
|
||||
file_suffix = Path(file_path).suffix.lower()
|
||||
file_ext = file_suffix.split('?')[0]
|
||||
if file_ext and file_ext not in SUPPORTED_FILE_FORMATS:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
f"{file_path} is not in list of supported file formats. "
|
||||
f"Supported file formats: {SUPPORTED_FILE_FORMATS}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_synthetic_data(
|
||||
teacher_model_endpoint_url: str,
|
||||
teacher_model_endpoint_key: str,
|
||||
|
@ -504,7 +485,7 @@ def data_import(args: Namespace):
|
|||
data_generation_task_type = args.data_generation_task_type
|
||||
|
||||
# validate file formats
|
||||
_validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path])
|
||||
validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path])
|
||||
logger.info("File format validation successful.")
|
||||
|
||||
enable_cot = True if enable_cot_str.lower() == "true" else False
|
|
@ -0,0 +1,432 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Script for validating distillation pipeline arguments."""
|
||||
import logging
|
||||
import requests
|
||||
import pandas as pd
|
||||
import json
|
||||
from argparse import Namespace
|
||||
|
||||
from azureml.acft.contrib.hf import VERSION, PROJECT_NAME
|
||||
from azureml.acft.contrib.hf.nlp.constants.constants import (
|
||||
LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||
)
|
||||
from azureml.acft.common_components import (
|
||||
get_logger_app,
|
||||
set_logging_parameters,
|
||||
LoggingLiterals,
|
||||
)
|
||||
from azureml.acft.common_components.utils.error_handling.exceptions import (
|
||||
ACFTValidationException,
|
||||
)
|
||||
from azureml.acft.common_components.utils.error_handling.error_definitions import (
|
||||
ACFTUserError,
|
||||
)
|
||||
from azureml.telemetry.activity import log_activity
|
||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||
from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import (
|
||||
swallow_all_exceptions,
|
||||
)
|
||||
|
||||
from generate_data import get_parser
|
||||
|
||||
from common.constants import (
|
||||
DataGenerationTaskType,
|
||||
TelemetryConstants,
|
||||
MAX_NEW_TOKENS,
|
||||
TEMPERATURE,
|
||||
TOP_P,
|
||||
VLLM_CHAT_SCORE_PATH,
|
||||
MIN_RECORDS_FOR_FT,
|
||||
)
|
||||
|
||||
from common.utils import (
|
||||
get_endpoint_details,
|
||||
get_workspace_mlclient,
|
||||
get_base_url,
|
||||
validate_teacher_model_details,
|
||||
exponential_backoff,
|
||||
)
|
||||
|
||||
from common.validation import (
|
||||
validate_file_paths_with_supported_formats,
|
||||
validate_file_exists,
|
||||
validate_model_temperature,
|
||||
validate_model_top_p,
|
||||
validate_model_frequency_penalty,
|
||||
validate_model_presence_penalty,
|
||||
validate_request_batch_size,
|
||||
validate_min_endpoint_success_ratio,
|
||||
)
|
||||
|
||||
logger = get_logger_app(
|
||||
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
|
||||
)
|
||||
|
||||
COMPONENT_NAME = "oss_distillation_validate_pipeline"
|
||||
|
||||
|
||||
class PipelineInputsValidator:
|
||||
"""Dataclass for validating inputs to distillation pipeline."""
|
||||
|
||||
def __init__(self, args: Namespace) -> None:
|
||||
"""Initialise validator.
|
||||
|
||||
Args:
|
||||
args (Namespace): Inputs flags to validate.
|
||||
"""
|
||||
self._args = args
|
||||
with log_activity(
|
||||
logger=logger, activity_name=TelemetryConstants.ML_CLIENT_INITIALISATION
|
||||
):
|
||||
ws_mlclient = get_workspace_mlclient()
|
||||
if not ws_mlclient:
|
||||
raise Exception("Could not create MLClient for current workspace")
|
||||
self._mlclient = ws_mlclient
|
||||
|
||||
with log_activity(
|
||||
logger=logger,
|
||||
activity_name=TelemetryConstants.VALIDATE_DATA_GENERATION_INPUTS,
|
||||
):
|
||||
self._validate_data_generation_inputs()
|
||||
|
||||
def _get_dataframe(self, file_path: str):
|
||||
return pd.read_json(
|
||||
file_path, lines=True, chunksize=self._args.request_batch_size
|
||||
)
|
||||
|
||||
def _get_inference_request_headers(self) -> dict:
|
||||
key = self._args.teacher_model_endpoint_key
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
|
||||
def _get_cot_status(self) -> bool:
|
||||
cot_enabled = self._args.enable_chain_of_thought
|
||||
return cot_enabled.lower() == "true"
|
||||
|
||||
def _validate_model_endpoint_args(self):
|
||||
endpoint_name = self._args.teacher_model_endpoint_name
|
||||
if endpoint_name:
|
||||
endpoint_details = get_endpoint_details(
|
||||
mlclient_ws=self._mlclient, endpoint_name=endpoint_name
|
||||
)
|
||||
self._args.teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
|
||||
self._args.teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
|
||||
model_asset_id = endpoint_details.get_deployed_model_id()
|
||||
validate_teacher_model_details(model_asset_id)
|
||||
|
||||
if (
|
||||
not self._args.teacher_model_endpoint_url
|
||||
or not self._args.teacher_model_endpoint_key
|
||||
):
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Endpoint URL and key are required fields for data generation."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@exponential_backoff()
|
||||
def _validate_model_endpoint(self):
|
||||
"""Validate model endpoints availability by retrieving its details."""
|
||||
base_url = get_base_url(self._args.teacher_model_endpoint_url)
|
||||
request_headers = self._get_inference_request_headers()
|
||||
|
||||
# https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-info
|
||||
response = requests.get(url=f"{base_url}/info", headers=request_headers)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
model_name = response_data.get("model_name")
|
||||
logger.info(f"Model validated, model name - {model_name}")
|
||||
|
||||
@exponential_backoff()
|
||||
def _validate_model_inference(self):
|
||||
"""Validate a sample inference call.
|
||||
|
||||
Raises:
|
||||
HTTPError: If one occured.
|
||||
"""
|
||||
# Prep data.
|
||||
df = self._get_dataframe(file_path=self._args.train_file_path)
|
||||
batch = next(df)
|
||||
record = batch.iloc[0].to_dict()
|
||||
|
||||
# Build inference payload
|
||||
inference_params = {
|
||||
MAX_NEW_TOKENS: self._args.teacher_model_max_new_tokens,
|
||||
TEMPERATURE: self._args.teacher_model_temperature,
|
||||
TOP_P: self._args.teacher_model_top_p,
|
||||
**record,
|
||||
}
|
||||
|
||||
headers = self._get_inference_request_headers()
|
||||
url = self._args.teacher_model_endpoint_url
|
||||
url = url if VLLM_CHAT_SCORE_PATH in url else f"{url}{VLLM_CHAT_SCORE_PATH}"
|
||||
logger.info(f"Model endpoint: {url}")
|
||||
response = requests.post(
|
||||
url=url, headers=headers, data=json.dumps(inference_params)
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def _validate_inference_parameters(self):
|
||||
"""Validate all body parameters passed as part of inference."""
|
||||
validate_model_temperature(self._args.teacher_model_temperature)
|
||||
validate_model_top_p(self._args.teacher_model_top_p)
|
||||
validate_model_presence_penalty(self._args.teacher_model_presence_penalty)
|
||||
validate_model_frequency_penalty(self._args.teacher_model_frequency_penalty)
|
||||
|
||||
validate_request_batch_size(self._args.request_batch_size)
|
||||
validate_min_endpoint_success_ratio(self._args.min_endpoint_success_ratio)
|
||||
|
||||
def _validate_number_of_records(self, size: int):
|
||||
"""Validate number of records in the dataset."""
|
||||
if size < MIN_RECORDS_FOR_FT:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
"Number of records in the dataset are less than the minimum required for fine-tuning."
|
||||
f" Minimum records required: {MIN_RECORDS_FOR_FT}, but got {size}."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_record_for_type_conversation(self, record: list) -> str:
|
||||
if self._args.data_generation_task_type != DataGenerationTaskType.CONVERSATION:
|
||||
return
|
||||
|
||||
if self._get_cot_status():
|
||||
return f"Chain of thought is not supported for task type {DataGenerationTaskType.CONVERSATION}"
|
||||
|
||||
if len(record) < 3:
|
||||
return f"Dataset is not matching expected schema for task type {DataGenerationTaskType.CONVERSATION}. \
|
||||
Expected format: [system, user, assistant]"
|
||||
|
||||
def _validate_record_for_type_NLI(self, record: list) -> str:
|
||||
if self._args.data_generation_task_type != DataGenerationTaskType.NLI:
|
||||
return
|
||||
|
||||
if len(record) > 2:
|
||||
return f"Chat cannot be of type multi-turn for task type {DataGenerationTaskType.NLI}. \
|
||||
Expected format: [system, user]"
|
||||
|
||||
def _validate_record_for_type_NLU_QA(self, record: list) -> str:
|
||||
if (
|
||||
self._args.data_generation_task_type
|
||||
!= DataGenerationTaskType.NLU_QUESTION_ANSWERING
|
||||
):
|
||||
return
|
||||
|
||||
if len(record) > 2:
|
||||
return f"Chat cannot be of type multi-turn for task type {DataGenerationTaskType.NLU_QUESTION_ANSWERING} \
|
||||
Expected format: [system, user]"
|
||||
|
||||
def _validate_record_by_task(self, record: list) -> dict:
|
||||
"""
|
||||
Validate record in a dataset against the data generation task type.
|
||||
|
||||
Returns a dictionary containing exception if any validation error is found.
|
||||
|
||||
Args:
|
||||
record (list): Sequence of messages
|
||||
"""
|
||||
validation_methods = [
|
||||
self._validate_record_for_type_NLI,
|
||||
self._validate_record_for_type_conversation,
|
||||
self._validate_record_for_type_NLU_QA,
|
||||
]
|
||||
|
||||
for method in validation_methods:
|
||||
err = method(record=record)
|
||||
if err:
|
||||
return {"exception": err}
|
||||
|
||||
def _validate_message(self, id: int, message: dict) -> dict:
|
||||
"""
|
||||
Validate individual message in the dataset.
|
||||
|
||||
Returns dictionary containing exception, if any validation error is found.
|
||||
|
||||
Args:
|
||||
id (int): id of the message in sequence of messages.
|
||||
message (dict): Message object in sequence of messages.
|
||||
"""
|
||||
allowed_roles = ["system", "user", "assistant"]
|
||||
if "role" not in message:
|
||||
return f"Message at index {id} is missing 'role'."
|
||||
|
||||
if message["role"] not in allowed_roles:
|
||||
return f"Invalid 'role' at index {id}."
|
||||
|
||||
if "content" not in message:
|
||||
return f"Message at index {id} is missing 'content'."
|
||||
|
||||
def _validate_record_content(self, record: list) -> dict:
|
||||
"""
|
||||
Validate content of a record and ensures messages are in the expected format.
|
||||
|
||||
Currently functional only for task type `CONVERSATION`, `NLI` & `NLU`.
|
||||
Returns dictionary containing exception, if any validation error is found.
|
||||
|
||||
Args:
|
||||
record (list): Sequence of messages
|
||||
"""
|
||||
try:
|
||||
if record[0].get("role") != "system":
|
||||
role = record[0].get("role")
|
||||
return {
|
||||
"exception": f"First message should be of role 'system' but got {role}."
|
||||
}
|
||||
|
||||
expected_roles = ["user", "assistant"]
|
||||
for id, message in enumerate(record[1:], start=1):
|
||||
if not isinstance(message, dict):
|
||||
return {
|
||||
"exception": f"Message at index {id} should be a dictionary."
|
||||
}
|
||||
|
||||
err = self._validate_message(id=id, message=message)
|
||||
if err:
|
||||
return {"exception": err}
|
||||
|
||||
expected_role = expected_roles[(id - 1) % 2]
|
||||
if message.get("role") != expected_role:
|
||||
return {
|
||||
"exception": f"Role at index {id} should be {expected_role}."
|
||||
}
|
||||
|
||||
task_type = self._args.data_generation_task_type
|
||||
if task_type == DataGenerationTaskType.CONVERSATION and (
|
||||
len(record[1:]) % 2 != 0
|
||||
):
|
||||
return {
|
||||
"exception": "There is an incomplete pair of 'user' and 'assistant' messages."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"exception": e}
|
||||
|
||||
def _validate_dataset_record(self, record: list) -> str:
|
||||
"""Validate a record in the dataset. Returns the validation error if found.
|
||||
|
||||
Args:
|
||||
record (list): Sequence of messages
|
||||
"""
|
||||
if not record:
|
||||
return "Chat cannot be empty."
|
||||
|
||||
err = self._validate_record_by_task(record=record)
|
||||
if err and ("exception" in err):
|
||||
return err["exception"]
|
||||
|
||||
err = self._validate_record_content(record=record)
|
||||
if err and ("exception" in err):
|
||||
return err["exception"]
|
||||
|
||||
def _validate_dataset(self, file_path: str):
|
||||
"""Validate training/validation dataset passed to the data-generation component.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the dataset
|
||||
|
||||
Raises:
|
||||
ACFTUserError: If a known validation error is caught
|
||||
"""
|
||||
df = self._get_dataframe(file_path=file_path)
|
||||
total_rows = 0
|
||||
for batch in df:
|
||||
total_rows += len(batch)
|
||||
for idx, row in batch.iterrows():
|
||||
record = row.iloc[0]
|
||||
err = self._validate_dataset_record(record=record)
|
||||
if err:
|
||||
raise ACFTValidationException._with_error(
|
||||
AzureMLError.create(
|
||||
ACFTUserError,
|
||||
pii_safe_message=(
|
||||
f"Error validating dataset record, context({idx}): {err}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self._validate_number_of_records(size=total_rows)
|
||||
|
||||
def _validate_data_generation_inputs(self):
|
||||
"""Validate all input flags to the data-generation component.
|
||||
|
||||
Sequentially performs a set of validations, each dependent on the previous validation.
|
||||
1. Validate training/validation file paths and ensure files exist.
|
||||
2. Validate teacher model endpoint arguments are passed for inference, and
|
||||
authenticity of the endpoint.
|
||||
3. Validate that the passed inference parameters are within limits.
|
||||
4. Validate integrity of datasets.
|
||||
5. Validate a single inference call to the teacher model.
|
||||
"""
|
||||
with log_activity(
|
||||
logger=logger, activity_name=TelemetryConstants.VALIDATE_FILE_PATH
|
||||
):
|
||||
files = [self._args.train_file_path, self._args.validation_file_path]
|
||||
validate_file_paths_with_supported_formats(file_paths=files)
|
||||
validate_file_exists(file_paths=files)
|
||||
|
||||
with log_activity(
|
||||
logger=logger,
|
||||
activity_name=TelemetryConstants.VALIDATE_TEACHER_MODEL_ENDPOINT,
|
||||
):
|
||||
self._validate_model_endpoint_args()
|
||||
self._validate_model_endpoint()
|
||||
|
||||
with log_activity(
|
||||
logger=logger,
|
||||
activity_name=TelemetryConstants.VALIDATE_INFERENCE_PARAMETERS,
|
||||
):
|
||||
self._validate_inference_parameters()
|
||||
|
||||
with log_activity(
|
||||
logger=logger, activity_name=TelemetryConstants.VALIDATE_TRAINING_DATA
|
||||
):
|
||||
self._validate_dataset(self._args.train_file_path)
|
||||
|
||||
if self._args.validation_file_path:
|
||||
with log_activity(
|
||||
logger=logger, activity_name=TelemetryConstants.VALIDATE_VALIDATION_DATA
|
||||
):
|
||||
self._validate_dataset(self._args.validation_file_path)
|
||||
|
||||
with log_activity(
|
||||
logger=logger, activity_name=TelemetryConstants.VALIDATE_MODEL_INFERENCE
|
||||
):
|
||||
self._validate_model_inference()
|
||||
|
||||
|
||||
@swallow_all_exceptions(time_delay=5)
|
||||
def main():
|
||||
"""Run validation."""
|
||||
# Get data generation component input parameters.
|
||||
parser = get_parser()
|
||||
parser.add_argument("--validation_info", required=True, help="Validation status")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
set_logging_parameters(
|
||||
task_type="DistillationPipelineValidation",
|
||||
acft_custom_dimensions={
|
||||
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
|
||||
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
|
||||
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME,
|
||||
},
|
||||
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
|
||||
log_level=logging.INFO,
|
||||
)
|
||||
|
||||
with log_activity(logger=logger, activity_name=TelemetryConstants.VALIDATOR):
|
||||
PipelineInputsValidator(args=args)
|
||||
|
||||
if args.validation_info:
|
||||
with open(args.validation_info, "w") as f:
|
||||
f.write(json.dumps({"validation_status": "ok"}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Загрузка…
Ссылка в новой задаче