Adding telemetry for datagen
This commit is contained in:
Родитель
a9aeff7100
Коммит
77284f3523
|
@ -15,6 +15,7 @@ from azureml.acft.common_components import get_logger_app
|
||||||
from azureml.core import Run, Workspace
|
from azureml.core import Run, Workspace
|
||||||
from azureml.core.run import _OfflineRun
|
from azureml.core.run import _OfflineRun
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from common.constants import (
|
from common.constants import (
|
||||||
REQUESTS_RETRY_DELAY,
|
REQUESTS_RETRY_DELAY,
|
||||||
|
@ -348,3 +349,45 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
|
||||||
Tuple[str, str, str]: Tuple containing registry name, model name and model version
|
Tuple[str, str, str]: Tuple containing registry name, model name and model version
|
||||||
"""
|
"""
|
||||||
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
|
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(sec):
|
||||||
|
if sec < 1e-6:
|
||||||
|
return '%8.2f ns' % (sec * 1e9)
|
||||||
|
elif sec < 1e-3:
|
||||||
|
return '%8.2f mks' % (sec * 1e6)
|
||||||
|
elif sec < 1:
|
||||||
|
return '%8.2f ms' % (sec * 1e3)
|
||||||
|
else:
|
||||||
|
return '%8.2f s' % sec
|
||||||
|
|
||||||
|
|
||||||
|
time_formatters = {
|
||||||
|
'auto': format_time,
|
||||||
|
'ns': lambda sec: '%8.2f ns' % (sec * 1e9),
|
||||||
|
'mks': lambda sec: '%8.2f mks' % (sec * 1e6),
|
||||||
|
'ms': lambda sec: '%8.2f ms' % (sec * 1e3),
|
||||||
|
's': lambda sec: '%8.2f s' % sec,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class log_durations(object):
|
||||||
|
"""Times each function call or block execution."""
|
||||||
|
|
||||||
|
def __init__(self, print_func, label=None, unit='auto', threshold=-1, repr_len=25):
|
||||||
|
self.print_func = print_func
|
||||||
|
self.label = label
|
||||||
|
if unit not in time_formatters:
|
||||||
|
raise ValueError('Unknown time unit: %s. It should be ns, mks, ms, s or auto.' % unit)
|
||||||
|
self.format_time = time_formatters[unit]
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = timer()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
duration = timer() - self.start
|
||||||
|
if duration >= self.threshold:
|
||||||
|
duration_str = self.format_time(duration)
|
||||||
|
self.print_func("%s : %s" % (self.label, duration_str) if self.label else duration_str)
|
||||||
|
|
|
@ -51,6 +51,7 @@ from common.utils import (
|
||||||
get_endpoint_details,
|
get_endpoint_details,
|
||||||
validate_teacher_model_details,
|
validate_teacher_model_details,
|
||||||
retry,
|
retry,
|
||||||
|
log_durations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,7 +198,7 @@ def get_parser():
|
||||||
|
|
||||||
|
|
||||||
@retry(3)
|
@retry(3)
|
||||||
def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
def _invoke_endpoint(url: str, key: str, data: dict, log_metadata: dict = None) -> Response:
|
||||||
"""Invoke endpoint with payload data.
|
"""Invoke endpoint with payload data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -212,8 +213,10 @@ def _invoke_endpoint(url: str, key: str, data: dict) -> Response:
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}"
|
"Authorization": f"Bearer {key}"
|
||||||
}
|
}
|
||||||
response = requests.post(url, headers=request_headers, data=json.dumps(data))
|
log_metadata = log_metadata or {}
|
||||||
return response
|
with log_durations(logger, f"POST request to teacher model endpoint: {log_metadata}"):
|
||||||
|
response = requests.post(url, headers=request_headers, data=json.dumps(data))
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]):
|
||||||
|
@ -337,7 +340,7 @@ def generate_synthetic_data(
|
||||||
}
|
}
|
||||||
messages = normalize_messages(messages)
|
messages = normalize_messages(messages)
|
||||||
synthetic_responses = []
|
synthetic_responses = []
|
||||||
for message in messages:
|
for turn_id, message in enumerate(messages):
|
||||||
role = message['role']
|
role = message['role']
|
||||||
if role == 'system':
|
if role == 'system':
|
||||||
synthetic_responses.append(process_system_prompt(message))
|
synthetic_responses.append(process_system_prompt(message))
|
||||||
|
@ -349,7 +352,8 @@ def generate_synthetic_data(
|
||||||
data_with_inference_parameters[key] = value
|
data_with_inference_parameters[key] = value
|
||||||
# replace the assistant content from the model
|
# replace the assistant content from the model
|
||||||
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
|
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
|
||||||
data=data_with_inference_parameters)
|
data=data_with_inference_parameters,
|
||||||
|
log_metadata={"idx": idx, 'turn': turn_id})
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
break
|
break
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче