Adding telemetry for datagen
This commit is contained in:
Родитель
a9aeff7100
Коммит
77284f3523
|
@ -15,6 +15,7 @@ from azureml.acft.common_components import get_logger_app
|
|||
from azureml.core import Run, Workspace
|
||||
from azureml.core.run import _OfflineRun
|
||||
from typing import List, Tuple, Union
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from common.constants import (
|
||||
REQUESTS_RETRY_DELAY,
|
||||
|
@ -348,3 +349,45 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
|
|||
Tuple[str, str, str]: Tuple containing registry name, model name and model version
|
||||
"""
|
||||
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
|
||||
|
||||
|
||||
def format_time(sec):
|
||||
if sec < 1e-6:
|
||||
return '%8.2f ns' % (sec * 1e9)
|
||||
elif sec < 1e-3:
|
||||
return '%8.2f mks' % (sec * 1e6)
|
||||
elif sec < 1:
|
||||
return '%8.2f ms' % (sec * 1e3)
|
||||
else:
|
||||
return '%8.2f s' % sec
|
||||
|
||||
|
||||
time_formatters = {
|
||||
'auto': format_time,
|
||||
'ns': lambda sec: '%8.2f ns' % (sec * 1e9),
|
||||
'mks': lambda sec: '%8.2f mks' % (sec * 1e6),
|
||||
'ms': lambda sec: '%8.2f ms' % (sec * 1e3),
|
||||
's': lambda sec: '%8.2f s' % sec,
|
||||
}
|
||||
|
||||
|
||||
class log_durations(object):
|
||||
"""Times each function call or block execution."""
|
||||
|
||||
def __init__(self, print_func, label=None, unit='auto', threshold=-1, repr_len=25):
|
||||
self.print_func = print_func
|
||||
self.label = label
|
||||
if unit not in time_formatters:
|
||||
raise ValueError('Unknown time unit: %s. It should be ns, mks, ms, s or auto.' % unit)
|
||||
self.format_time = time_formatters[unit]
|
||||
self.threshold = threshold
|
||||
|
||||
def __enter__(self):
|
||||
self.start = timer()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
duration = timer() - self.start
|
||||
if duration >= self.threshold:
|
||||
duration_str = self.format_time(duration)
|
||||
self.print_func("%s : %s" % (self.label, duration_str) if self.label else duration_str)
|
||||
|
|
|
@ -51,6 +51,7 @@ from common.utils import (
|
|||
get_endpoint_details,
|
||||
validate_teacher_model_details,
|
||||
retry,
|
||||
log_durations
|
||||
)
|
||||
|
||||
|
||||
|
@ -197,7 +198,7 @@ def get_parser():
|
|||
|
||||
|
||||
@retry(3)
|
||||
def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
||||
def _invoke_endpoint(url: str, key: str, data: dict, log_metadata: dict = None) -> Response:
|
||||
"""Invoke endpoint with payload data.
|
||||
|
||||
Args:
|
||||
|
@ -212,8 +213,10 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
|||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
response = requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
return response
|
||||
log_metadata = log_metadata or {}
|
||||
with log_durations(logger, f"POST request to teacher model endpoint: {log_metadata}"):
|
||||
response = requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||
return response
|
||||
|
||||
|
||||
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
||||
|
@ -337,7 +340,7 @@ def generate_synthetic_data(
|
|||
}
|
||||
messages = normalize_messages(messages)
|
||||
synthetic_responses = []
|
||||
for message in messages:
|
||||
for turn_id, message in enumerate(messages):
|
||||
role = message['role']
|
||||
if role == 'system':
|
||||
synthetic_responses.append(process_system_prompt(message))
|
||||
|
@ -349,7 +352,8 @@ def generate_synthetic_data(
|
|||
data_with_inference_parameters[key] = value
|
||||
# replace the assistant content from the model
|
||||
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
|
||||
data=data_with_inference_parameters)
|
||||
data=data_with_inference_parameters,
|
||||
log_metadata={"idx": idx, 'turn': turn_id})
|
||||
if response.status_code != 200:
|
||||
break
|
||||
response_data = response.json()
|
||||
|
|
Загрузка…
Ссылка в новой задаче