diff --git a/assets/training/distillation/components/data_generation/src/common/utils.py b/assets/training/distillation/components/data_generation/src/common/utils.py index 1ee1cbd748..c157d2bc93 100644 --- a/assets/training/distillation/components/data_generation/src/common/utils.py +++ b/assets/training/distillation/components/data_generation/src/common/utils.py @@ -15,6 +15,7 @@ from azureml.acft.common_components import get_logger_app from azureml.core import Run, Workspace from azureml.core.run import _OfflineRun from typing import List, Tuple, Union +from timeit import default_timer as timer from common.constants import ( REQUESTS_RETRY_DELAY, @@ -348,3 +349,45 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]: Tuple[str, str, str]: Tuple containing registry name, model name and model version """ return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP) + + +def format_time(sec): + if sec < 1e-6: + return '%8.2f ns' % (sec * 1e9) + elif sec < 1e-3: + return '%8.2f mks' % (sec * 1e6) + elif sec < 1: + return '%8.2f ms' % (sec * 1e3) + else: + return '%8.2f s' % sec + + +time_formatters = { + 'auto': format_time, + 'ns': lambda sec: '%8.2f ns' % (sec * 1e9), + 'mks': lambda sec: '%8.2f mks' % (sec * 1e6), + 'ms': lambda sec: '%8.2f ms' % (sec * 1e3), + 's': lambda sec: '%8.2f s' % sec, +} + + +class log_durations(object): + """Times each function call or block execution.""" + + def __init__(self, print_func, label=None, unit='auto', threshold=-1, repr_len=25): + self.print_func = print_func + self.label = label + if unit not in time_formatters: + raise ValueError('Unknown time unit: %s. It should be ns, mks, ms, s or auto.' % unit) + self.format_time = time_formatters[unit] + self.threshold = threshold + + def __enter__(self): + self.start = timer() + return self + + def __exit__(self, *exc): + duration = timer() - self.start + if duration >= self.threshold: + duration_str = self.format_time(duration) + self.print_func("%s : %s" % (self.label, duration_str) if self.label else duration_str) diff --git a/assets/training/distillation/components/data_generation/src/generate_data.py b/assets/training/distillation/components/data_generation/src/generate_data.py index 4c61ddfbee..ef4f73355b 100644 --- a/assets/training/distillation/components/data_generation/src/generate_data.py +++ b/assets/training/distillation/components/data_generation/src/generate_data.py @@ -51,6 +51,7 @@ from common.utils import ( get_endpoint_details, validate_teacher_model_details, retry, + log_durations ) @@ -197,7 +198,7 @@ def get_parser(): @retry(3) -def _invoke_endpoint(url: str, key: str, data: dict) -> Response: +def _invoke_endpoint(url: str, key: str, data: dict, log_metadata: dict = None) -> Response: """Invoke endpoint with payload data. Args: @@ -212,8 +213,10 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response: "Content-Type": "application/json", "Authorization": f"Bearer {key}" } - response = requests.post(url, headers=request_headers, data=json.dumps(data)) - return response + log_metadata = log_metadata or {} + with log_durations(logger, f"POST request to teacher model endpoint: {log_metadata}"): + response = requests.post(url, headers=request_headers, data=json.dumps(data)) + return response def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]): @@ -337,7 +340,7 @@ def generate_synthetic_data( } messages = normalize_messages(messages) synthetic_responses = [] - for message in messages: + for turn_id, message in enumerate(messages): role = message['role'] if role == 'system': synthetic_responses.append(process_system_prompt(message)) @@ -349,7 +352,8 @@ def generate_synthetic_data( data_with_inference_parameters[key] = value # replace the assistant content from the model response: Response = _invoke_endpoint(url=url, key=endpoint_key, - data=data_with_inference_parameters) + data=data_with_inference_parameters, + log_metadata={"idx": idx, 'turn': turn_id}) if response.status_code != 200: break response_data = response.json()