Refactor the function to remove redundant code. (#3245)

* Refactor the function to remove redundant code.

* Refactored the code and removed logging user data

* Fix linting issues

* Minor bug fix when input data is  not matching with data gen task type

* Fix linting issues

* Updated documentation strings.

* chore: introduce minor telemetry to distillation component (#3268)

* chore: introduce minor telemetry

* conditionally log network calls

* revert improper use of lambda

* add more context to the logger in decorator

* consume azureml telemetry loggers

* nit: EOF line

* Fix linting issues

* Fix documentation strings

---------

Co-authored-by: ShobithNandakumar <53436371+ShobithNandakumar@users.noreply.github.com>
This commit is contained in:
HarshaVardhanBabu 2024-08-20 15:30:46 +05:30 коммит произвёл GitHub
Родитель be338ec53a
Коммит 598dbefa0b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 79 добавлений и 158 удалений

Просмотреть файл

@ -105,3 +105,12 @@ class DataGenerationTaskType(str, Enum, metaclass=MetaEnum):
NLI = "NLI"
CONVERSATION = "CONVERSATION"
NLU_QUESTION_ANSWERING = "NLU_QA"
class TelemetryConstants:
"""Telemetry constants that describe various activities performed by the distillation components."""
INVOKE_MODEL_ENDPOINT = "invoke_model_endpoint"
BATCH_PROCESS_TRAINING_DATA = "batch_process_training_data"
BATCH_PROCESS_VALIDATION_DATA = "batch_process_validation_data"
PROCESS_DATASET_RECORD = "process_dataset_record"

Просмотреть файл

@ -23,6 +23,7 @@ from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_
swallow_all_exceptions,
)
from azureml._common._error_definition.azureml_error import AzureMLError
from azureml.telemetry.activity import log_activity, monitor_with_activity
from concurrent.futures import ThreadPoolExecutor, as_completed
@ -43,7 +44,8 @@ from common.constants import (
STOP_TOKEN,
SUPPORTED_FILE_FORMATS,
VLLM_CHAT_SCORE_PATH,
DataGenerationTaskType
DataGenerationTaskType,
TelemetryConstants
)
from common.utils import (
@ -197,13 +199,14 @@ def get_parser():
@retry(3)
def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
def _invoke_endpoint(url: str, key: str, data: dict, log_entry: dict = None) -> Response:
"""Invoke endpoint with payload data.
Args:
url (str): Endpoint URL
key (str): Endpoint key
data dict): Payload dictionary
data (dict): Payload dictionary
log_entry (dict): Metadata used for logging
Returns:
Response: Response from invocation
@ -212,8 +215,18 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
response = requests.post(url, headers=request_headers, data=json.dumps(data))
return response
log_entry = log_entry or {}
idx = log_entry.get("idx", -1)
turn = log_entry.get("turn", -1)
# We don't want to log every request. Conditionally log some, to avoid overwhelming logs.
if idx % 10 == 0 and turn % 2 == 0:
custom_logger_activity_name = f"{TelemetryConstants.INVOKE_MODEL_ENDPOINT}_idx({idx})_turn({turn})"
with log_activity(logger=logger,
activity_name=custom_logger_activity_name):
return requests.post(url, headers=request_headers, data=json.dumps(data))
return requests.post(url, headers=request_headers, data=json.dumps(data))
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
@ -260,52 +273,42 @@ def generate_synthetic_data(
train_file_path (Path): Train JSONL file path
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
"""
def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict:
"""Process a single request.
def process_system_prompt(message: dict) -> dict:
"""Update the system prompt depending on the task type and the flag enable_cot.
The original message unchanged if enable_cot is False or task type is conversation.
Args:
idx (str): Row index in Input data.
enable_cot (bool): If CoT is enabled
data (dict): Payload dict
url (str): Endpoint URL
endpoint_key (str): key to authenticate endpoint request
message (dict): System message
Returns:
dict: result dictionary
message (dict): System message with updated content
"""
try:
response: Response = _invoke_endpoint(url=url, key=endpoint_key, data=data)
response_data = response.json()
if enable_cot and data_generation_task_type != DataGenerationTaskType.CONVERSATION:
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
return cot_system_message
else:
return message
# use jsonpath or regex to capture prediction result
prediction_result = (
None if response.status_code != 200
# response content should be structured as below for a successful vllm response
else response_data['choices'][0]["message"]["content"].strip()
)
def normalize_messages(messages: List[dict]) -> List[dict]:
"""Add dummy assistant turn if not present in the messages list.
if enable_cot:
# Try loading JSON answer and filter 'answer_choice'
# if JSON loading fails, exception will be caught
# And this specific row would not be part of generated data
prediction_result = json.loads(prediction_result)['answer_choice']
This will help in normalizing the input data before generating synthetic data.
return {
"idx": idx,
"status_code": response.status_code,
"text": prediction_result,
"exception": None,
}
except Exception as e:
logger.error(f"idx: {idx}. exception: {e}")
return {
"idx": idx,
"status_code": None,
"text": None,
"exception": e,
}
Args:
messages (list[dict]): List of conversation turns
def process_conversational_request(idx: str, data: dict, url: str, endpoint_key: str):
Returns:
messages (list[dict]): List of conversation turns with dummy assistant turn added
"""
if data_generation_task_type != DataGenerationTaskType.CONVERSATION:
if messages[-1]['role'] != 'assistant':
messages.append({'role': 'assistant', 'content': ''})
return messages
@monitor_with_activity(logger=logger, activity_name=TelemetryConstants.PROCESS_DATASET_RECORD)
def process_request(idx: str, data: dict, url: str, endpoint_key: str):
"""Process a single conversational request.
Args:
@ -345,34 +348,39 @@ def generate_synthetic_data(
"messages": [],
"exception": f"Incorrect format.\nRole should be assistant or user, but got {role}"
}
messages = normalize_messages(messages)
last_status_code = None
synthetic_responses = []
for message in messages:
for turn_id, message in enumerate(messages):
role = message['role']
if role in ('system', 'user'):
if role == 'system':
synthetic_responses.append(process_system_prompt(message))
elif role == 'user':
synthetic_responses.append(message)
else:
data_with_inference_parameters = {"messages": synthetic_responses}
for key, value in data.items():
data_with_inference_parameters[key] = value
# replace the assistant content from the model
log_entry = {"idx": idx, "turn": turn_id}
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
data=data_with_inference_parameters)
if response.status_code != 200:
data=data_with_inference_parameters,
log_entry=log_entry)
last_status_code = response.status_code
if last_status_code != 200:
break
response_data = response.json()
prediction_result = (
None if response.status_code != 200
# response content should be structured as below for a successful vllm response
else response_data['choices'][0]["message"]["content"].strip()
)
# response content should be structured as below for a successful vllm response
prediction_result = response_data['choices'][0]["message"]["content"].strip()
synthetic_responses.append({'role': 'assistant', 'content': prediction_result})
is_success = (last_status_code == 200)
logger.info(f"Processing idx: {idx} - {is_success}")
return {
"idx": idx,
"status_code": response.status_code,
"status_code": last_status_code,
"messages": synthetic_responses,
"exception": (f"Not able to generate synthetic response for all turns for idx: {idx}"
if response.status_code != 200
if not is_success
else
None),
}
@ -385,12 +393,7 @@ def generate_synthetic_data(
"exception": e,
}
def replace_cot_system_message(messages: List[dict]) -> List[dict]:
# Replace the system message without changing the original messages list
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
return [(cot_system_message if message['role'] == 'system' else message) for message in messages]
def batch_process_conversation_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Args:
@ -421,7 +424,7 @@ def generate_synthetic_data(
}
futures.append(
executor.submit(
process_conversational_request,
process_request,
idx,
request_data,
teacher_model_endpoint_url,
@ -466,107 +469,16 @@ def generate_synthetic_data(
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)
def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Args:
input_file_path (Path): Input data file path
output_file_path (Path): Path to output directory
batch_size (int): Input batch size for processing rows in train and validation dataset
Raises:
Exception: if success ratio is less than min_endpoint_success_ratio
"""
train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size)
total_rows = 0
error_count = 0
output_data = []
error_map = {}
ERROR = "error"
for batch in train_df:
total_rows += len(batch)
futures = []
with ThreadPoolExecutor() as executor:
for idx, row in batch.iterrows():
messages = row.iloc[0]
messages = replace_cot_system_message(messages) if enable_cot else messages
request_data = {
"messages": messages,
**inference_params,
}
futures.append(
executor.submit(
process_request,
idx,
enable_cot,
request_data,
teacher_model_endpoint_url,
teacher_model_endpoint_key
)
)
# wait for results to complete
future_results = {
result["idx"]: result
for result in [future.result() for future in as_completed(futures)]
}
idx = 0
for idx, row in batch.iterrows():
future_result = future_results.get(idx)
if future_result['exception']:
logger.error(f"row {idx} failed with exception: {future_result['exception']}")
error_map[ERROR] = error_map.get(ERROR, 0) + 1
elif future_result['status_code'] != 200:
logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200")
error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1
else:
new_row = row.copy().iloc[0]
answer = future_result['text']
new_row.append(
{
"role": "assistant",
"content": answer,
}
)
output_data.append({"messages": new_row})
Path(output_file_path.parent).mkdir(exist_ok=True, parents=True)
with open(output_file_path, 'w') as f:
for entry in output_data:
f.write(json.dumps(entry) + '\n')
if error_map:
logger.info("Error summary. With key donating non-200 status code or some other error.")
for k, v in error_map.items():
error_count += v
logger.warning(f"{k} => {v}")
success_ratio = float(total_rows - error_count) / total_rows
logger.info(f"Success rate was {success_ratio} for {input_file_path}")
if success_ratio < min_endpoint_success_ratio:
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)
logger.info("Processing train file")
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
batch_process_conversation_data(train_file_path, generated_train_file_path, request_batch_size)
else:
with log_activity(logger=logger, activity_name=TelemetryConstants.BATCH_PROCESS_TRAINING_DATA):
logger.info("Processing train file")
batch_process_data(train_file_path, generated_train_file_path, request_batch_size)
logger.info("Data generated and saved for train file")
logger.info("Data generated and saved for train file")
if validation_file_path:
logger.info("Processing validation file")
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
batch_process_conversation_data(validation_file_path, generated_validation_file_path, request_batch_size)
else:
with log_activity(logger=logger, activity_name=TelemetryConstants.BATCH_PROCESS_VALIDATION_DATA):
logger.info("Processing validation file")
batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size)
logger.info("Data generated and saved for validation file")
logger.info("Data generated and saved for validation file")
def data_import(args: Namespace):