Refactor the function to remove redundant code. (#3245)
* Refactor the function to remove redundant code. * Refactored the code and removed logging user data * Fix linting issues * Minor bug fix when input data is not matching with data gen task type * Fix linting issues * Updated documentation strings. * chore: introduce minor telemetry to distillation component (#3268) * chore: introduce minor telemetry * conditionally log network calls * revert improper use of lambda * add more context to the logger in decorator * consume azureml telemetry loggers * nit: EOF line * Fix linting issues * Fix documentation strings --------- Co-authored-by: ShobithNandakumar <53436371+ShobithNandakumar@users.noreply.github.com>
This commit is contained in:
Родитель
be338ec53a
Коммит
598dbefa0b
|
@ -105,3 +105,12 @@ class DataGenerationTaskType(str, Enum, metaclass=MetaEnum):
|
|||
NLI = "NLI"
|
||||
CONVERSATION = "CONVERSATION"
|
||||
NLU_QUESTION_ANSWERING = "NLU_QA"
|
||||
|
||||
|
||||
class TelemetryConstants:
|
||||
"""Telemetry constants that describe various activities performed by the distillation components."""
|
||||
|
||||
INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint"
|
||||
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
|
||||
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
|
||||
PROCESS_DATASET_RECORD = "process_dataset_record"
|
||||
|
|
|
@ -23,6 +23,7 @@ from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_
|
|||
swallow_all_exceptions,
|
||||
)
|
||||
from azureml._common._error_definition.azureml_error import AzureMLError
|
||||
from azureml.telemetry.activity import log_activity, monitor_with_activity
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
@ -43,7 +44,8 @@ from common.constants import (
|
|||
STOP_TOKEN,
|
||||
SUPPORTED_FILE_FORMATS,
|
||||
VLLM_CHAT_SCORE_PATH,
|
||||
DataGenerationTaskType
|
||||
DataGenerationTaskType,
|
||||
TelemetryConstants
|
||||
)
|
||||
|
||||
from common.utils import (
|
||||
|
@ -197,13 +199,14 @@ def get_parser():
|
|||
|
||||
|
||||
@retry(3)
|
||||
def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
||||
def _invoke_endpoint(url: str, key: str, data: dict, log_entry: dict = None) -> Response:
|
||||
"""Invoke endpoint with payload data.
|
||||
|
||||
Args:
|
||||
url (str): Endpoint URL
|
||||
key (str): Endpoint key
|
||||
data dict): Payload dictionary
|
||||
data (dict): Payload dictionary
|
||||
log_entry (dict): Metadata used for logging
|
||||
|
||||
Returns:
|
||||
Response: Response from invocation
|
||||
|
@ -212,8 +215,18 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
|||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
response = requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
return response
|
||||
|
||||
log_entry = log_entry or {}
|
||||
idx = log_entry.get("idx", -1)
|
||||
turn = log_entry.get("turn", -1)
|
||||
|
||||
# We don't want to log every request. Conditionally log some, to avoid overwhelming logs.
|
||||
if idx % 10 == 0 and turn % 2 == 0:
|
||||
custom_logger_activity_name = f"{TelemetryConstants.INVOKE_MODEL_ENDPOINT}_idx({idx})_turn({turn})"
|
||||
with log_activity(logger=logger,
|
||||
activity_name=custom_logger_activity_name):
|
||||
return requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
return requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
|
||||
|
||||
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
||||
|
@ -260,52 +273,42 @@ def generate_synthetic_data(
|
|||
train_file_path (Path): Train JSONL file path
|
||||
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
|
||||
"""
|
||||
def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict:
|
||||
"""Process a single request.
|
||||
|
||||
def process_system_prompt(message: dict) -> dict:
|
||||
"""Update the system prompt depending on the task type and the flag enable_cot.
|
||||
|
||||
The original message unchanged if enable_cot is False or task type is conversation.
|
||||
|
||||
Args:
|
||||
idx (str): Row index in Input data.
|
||||
enable_cot (bool): If CoT is enabled
|
||||
data (dict): Payload dict
|
||||
url (str): Endpoint URL
|
||||
endpoint_key (str): key to authenticate endpoint request
|
||||
message (dict): System message
|
||||
|
||||
Returns:
|
||||
dict: result dictionary
|
||||
message (dict): System message with updated content
|
||||
"""
|
||||
try:
|
||||
response: Response = _invoke_endpoint(url=url, key=endpoint_key, data=data)
|
||||
response_data = response.json()
|
||||
if enable_cot and data_generation_task_type != DataGenerationTaskType.CONVERSATION:
|
||||
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
|
||||
return cot_system_message
|
||||
else:
|
||||
return message
|
||||
|
||||
# use jsonpath or regex to capture prediction result
|
||||
prediction_result = (
|
||||
None if response.status_code != 200
|
||||
# response content should be structured as below for a successful vllm response
|
||||
else response_data['choices'][0]["message"]["content"].strip()
|
||||
)
|
||||
def normalize_messages(messages: List[dict]) -> List[dict]:
|
||||
"""Add dummy assistant turn if not present in the messages list.
|
||||
|
||||
if enable_cot:
|
||||
# Try loading JSON answer and filter 'answer_choice'
|
||||
# if JSON loading fails, exception will be caught
|
||||
# And this specific row would not be part of generated data
|
||||
prediction_result = json.loads(prediction_result)['answer_choice']
|
||||
This will help in normalizing the input data before generating synthetic data.
|
||||
|
||||
return {
|
||||
"idx": idx,
|
||||
"status_code": response.status_code,
|
||||
"text": prediction_result,
|
||||
"exception": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"idx: {idx}. exception: {e}")
|
||||
return {
|
||||
"idx": idx,
|
||||
"status_code": None,
|
||||
"text": None,
|
||||
"exception": e,
|
||||
}
|
||||
Args:
|
||||
messages (list[dict]): List of conversation turns
|
||||
|
||||
def process_conversational_request(idx: str, data: dict, url: str, endpoint_key: str):
|
||||
Returns:
|
||||
messages (list[dict]): List of conversation turns with dummy assistant turn added
|
||||
"""
|
||||
if data_generation_task_type != DataGenerationTaskType.CONVERSATION:
|
||||
if messages[-1]['role'] != 'assistant':
|
||||
messages.append({'role': 'assistant', 'content': ''})
|
||||
return messages
|
||||
|
||||
@monitor_with_activity(logger=logger, activity_name=TelemetryConstants.PROCESS_DATASET_RECORD)
|
||||
def process_request(idx: str, data: dict, url: str, endpoint_key: str):
|
||||
"""Process a single conversational request.
|
||||
|
||||
Args:
|
||||
|
@ -345,34 +348,39 @@ def generate_synthetic_data(
|
|||
"messages": [],
|
||||
"exception": f"Incorrect format.\nRole should be assistant or user, but got {role}"
|
||||
}
|
||||
|
||||
messages = normalize_messages(messages)
|
||||
last_status_code = None
|
||||
synthetic_responses = []
|
||||
for message in messages:
|
||||
for turn_id, message in enumerate(messages):
|
||||
role = message['role']
|
||||
if role in ('system', 'user'):
|
||||
if role == 'system':
|
||||
synthetic_responses.append(process_system_prompt(message))
|
||||
elif role == 'user':
|
||||
synthetic_responses.append(message)
|
||||
else:
|
||||
data_with_inference_parameters = {"messages": synthetic_responses}
|
||||
for key, value in data.items():
|
||||
data_with_inference_parameters[key] = value
|
||||
# replace the assistant content from the model
|
||||
log_entry = {"idx": idx, "turn": turn_id}
|
||||
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
|
||||
data=data_with_inference_parameters)
|
||||
if response.status_code != 200:
|
||||
data=data_with_inference_parameters,
|
||||
log_entry=log_entry)
|
||||
last_status_code = response.status_code
|
||||
if last_status_code != 200:
|
||||
break
|
||||
response_data = response.json()
|
||||
prediction_result = (
|
||||
None if response.status_code != 200
|
||||
# response content should be structured as below for a successful vllm response
|
||||
else response_data['choices'][0]["message"]["content"].strip()
|
||||
)
|
||||
# response content should be structured as below for a successful vllm response
|
||||
prediction_result = response_data['choices'][0]["message"]["content"].strip()
|
||||
synthetic_responses.append({'role': 'assistant', 'content': prediction_result})
|
||||
is_success = (last_status_code == 200)
|
||||
logger.info(f"Processing idx: {idx} - {is_success}")
|
||||
return {
|
||||
"idx": idx,
|
||||
"status_code": response.status_code,
|
||||
"status_code": last_status_code,
|
||||
"messages": synthetic_responses,
|
||||
"exception": (f"Not able to generate synthetic response for all turns for idx: {idx}"
|
||||
if response.status_code != 200
|
||||
if not is_success
|
||||
else
|
||||
None),
|
||||
}
|
||||
|
@ -385,12 +393,7 @@ def generate_synthetic_data(
|
|||
"exception": e,
|
||||
}
|
||||
|
||||
def replace_cot_system_message(messages: List[dict]) -> List[dict]:
|
||||
# Replace the system message without changing the original messages list
|
||||
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
|
||||
return [(cot_system_message if message['role'] == 'system' else message) for message in messages]
|
||||
|
||||
def batch_process_conversation_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
|
||||
def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
|
||||
"""Batch process data and do a bulk request to teacher model endpoint.
|
||||
|
||||
Args:
|
||||
|
@ -421,7 +424,7 @@ def generate_synthetic_data(
|
|||
}
|
||||
futures.append(
|
||||
executor.submit(
|
||||
process_conversational_request,
|
||||
process_request,
|
||||
idx,
|
||||
request_data,
|
||||
teacher_model_endpoint_url,
|
||||
|
@ -466,107 +469,16 @@ def generate_synthetic_data(
|
|||
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
|
||||
raise Exception(msg)
|
||||
|
||||
def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
|
||||
"""Batch process data and do a bulk request to teacher model endpoint.
|
||||
|
||||
Args:
|
||||
input_file_path (Path): Input data file path
|
||||
output_file_path (Path): Path to output directory
|
||||
batch_size (int): Input batch size for processing rows in train and validation dataset
|
||||
|
||||
Raises:
|
||||
Exception: if success ratio is less than min_endpoint_success_ratio
|
||||
"""
|
||||
train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size)
|
||||
total_rows = 0
|
||||
error_count = 0
|
||||
output_data = []
|
||||
error_map = {}
|
||||
ERROR = "error"
|
||||
|
||||
for batch in train_df:
|
||||
total_rows += len(batch)
|
||||
futures = []
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for idx, row in batch.iterrows():
|
||||
messages = row.iloc[0]
|
||||
messages = replace_cot_system_message(messages) if enable_cot else messages
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
**inference_params,
|
||||
}
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
process_request,
|
||||
idx,
|
||||
enable_cot,
|
||||
request_data,
|
||||
teacher_model_endpoint_url,
|
||||
teacher_model_endpoint_key
|
||||
)
|
||||
)
|
||||
|
||||
# wait for results to complete
|
||||
future_results = {
|
||||
result["idx"]: result
|
||||
for result in [future.result() for future in as_completed(futures)]
|
||||
}
|
||||
|
||||
idx = 0
|
||||
for idx, row in batch.iterrows():
|
||||
future_result = future_results.get(idx)
|
||||
if future_result['exception']:
|
||||
logger.error(f"row {idx} failed with exception: {future_result['exception']}")
|
||||
error_map[ERROR] = error_map.get(ERROR, 0) + 1
|
||||
elif future_result['status_code'] != 200:
|
||||
logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200")
|
||||
error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1
|
||||
else:
|
||||
new_row = row.copy().iloc[0]
|
||||
answer = future_result['text']
|
||||
|
||||
new_row.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": answer,
|
||||
}
|
||||
)
|
||||
output_data.append({"messages": new_row})
|
||||
|
||||
Path(output_file_path.parent).mkdir(exist_ok=True, parents=True)
|
||||
with open(output_file_path, 'w') as f:
|
||||
for entry in output_data:
|
||||
f.write(json.dumps(entry) + '\n')
|
||||
|
||||
if error_map:
|
||||
logger.info("Error summary. With key donating non-200 status code or some other error.")
|
||||
for k, v in error_map.items():
|
||||
error_count += v
|
||||
logger.warning(f"{k} => {v}")
|
||||
|
||||
success_ratio = float(total_rows - error_count) / total_rows
|
||||
logger.info(f"Success rate was {success_ratio} for {input_file_path}")
|
||||
if success_ratio < min_endpoint_success_ratio:
|
||||
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
|
||||
raise Exception(msg)
|
||||
logger.info("Processing train file")
|
||||
|
||||
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
|
||||
batch_process_conversation_data(train_file_path, generated_train_file_path, request_batch_size)
|
||||
else:
|
||||
with log_activity(logger=logger, activity_name=TelemetryConstants.BATCH_PROCESS_TRAINING_DATA):
|
||||
logger.info("Processing train file")
|
||||
batch_process_data(train_file_path, generated_train_file_path, request_batch_size)
|
||||
|
||||
logger.info("Data generated and saved for train file")
|
||||
logger.info("Data generated and saved for train file")
|
||||
|
||||
if validation_file_path:
|
||||
logger.info("Processing validation file")
|
||||
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
|
||||
batch_process_conversation_data(validation_file_path, generated_validation_file_path, request_batch_size)
|
||||
else:
|
||||
with log_activity(logger=logger, activity_name=TelemetryConstants.BATCH_PROCESS_VALIDATION_DATA):
|
||||
logger.info("Processing validation file")
|
||||
batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size)
|
||||
logger.info("Data generated and saved for validation file")
|
||||
logger.info("Data generated and saved for validation file")
|
||||
|
||||
|
||||
def data_import(args: Namespace):
|
||||
|
|
Загрузка…
Ссылка в новой задаче