This commit is contained in:
babu-namburi 2024-08-13 11:28:29 +05:30
Родитель a9aeff7100
Коммит 77284f3523
2 изменённых файлов: 52 добавлений и 5 удалений

Просмотреть файл

@ -15,6 +15,7 @@ from azureml.acft.common_components import get_logger_app
from azureml.core import Run, Workspace
from azureml.core.run import _OfflineRun
from typing import List, Tuple, Union
from timeit import default_timer as timer
from common.constants import (
REQUESTS_RETRY_DELAY,
@ -348,3 +349,45 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: Tuple containing registry name, model name and model version
"""
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
def format_time(sec):
if sec < 1e-6:
return '%8.2f ns' % (sec * 1e9)
elif sec < 1e-3:
return '%8.2f mks' % (sec * 1e6)
elif sec < 1:
return '%8.2f ms' % (sec * 1e3)
else:
return '%8.2f s' % sec
time_formatters = {
'auto': format_time,
'ns': lambda sec: '%8.2f ns' % (sec * 1e9),
'mks': lambda sec: '%8.2f mks' % (sec * 1e6),
'ms': lambda sec: '%8.2f ms' % (sec * 1e3),
's': lambda sec: '%8.2f s' % sec,
}
class log_durations(object):
"""Times each function call or block execution."""
def __init__(self, print_func, label=None, unit='auto', threshold=-1, repr_len=25):
self.print_func = print_func
self.label = label
if unit not in time_formatters:
raise ValueError('Unknown time unit: %s. It should be ns, mks, ms, s or auto.' % unit)
self.format_time = time_formatters[unit]
self.threshold = threshold
def __enter__(self):
self.start = timer()
return self
def __exit__(self, *exc):
duration = timer() - self.start
if duration >= self.threshold:
duration_str = self.format_time(duration)
self.print_func("%s : %s" % (self.label, duration_str) if self.label else duration_str)

Просмотреть файл

@ -51,6 +51,7 @@ from common.utils import (
get_endpoint_details,
validate_teacher_model_details,
retry,
log_durations
)
@ -197,7 +198,7 @@ def get_parser():
@retry(3)
def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
def _invoke_endpoint(url: str, key: str, data: dict, log_metadata: dict = None) -> Response:
"""Invoke endpoint with payload data.
Args:
@ -212,6 +213,8 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
log_metadata = log_metadata or {}
with log_durations(logger, f"POST request to teacher model endpoint: {log_metadata}"):
response = requests.post(url, headers=request_headers, data=json.dumps(data))
return response
@ -337,7 +340,7 @@ def generate_synthetic_data(
}
messages = normalize_messages(messages)
synthetic_responses = []
for message in messages:
for turn_id, message in enumerate(messages):
role = message['role']
if role == 'system':
synthetic_responses.append(process_system_prompt(message))
@ -349,7 +352,8 @@ def generate_synthetic_data(
data_with_inference_parameters[key] = value
# replace the assistant content from the model
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
data=data_with_inference_parameters)
data=data_with_inference_parameters,
log_metadata={"idx": idx, 'turn': turn_id})
if response.status_code != 200:
break
response_data = response.json()