Handle model profiling timeout (#76)
This commit is contained in:
Родитель
681f4ae86e
Коммит
7ac63167f4
|
@ -3,9 +3,10 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import signal
|
||||
import logging
|
||||
from . import builder_config
|
||||
from .utils import save_profiled_results, merge_info
|
||||
from .utils import save_profiled_results, merge_info, handle_timeout
|
||||
from nn_meter.builder.backends import connect_backend
|
||||
logging = logging.getLogger("nn-Meter")
|
||||
|
||||
|
@ -69,7 +70,7 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
|
|||
|
||||
|
||||
def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], save_name = "profiled_results.json",
|
||||
have_converted = False, log_frequency = 50, broken_point_mode = False, **kwargs):
|
||||
have_converted = False, log_frequency = 50, broken_point_mode = False, time_threshold = 300, **kwargs):
|
||||
""" run models with given backend and return latency of testcase models
|
||||
|
||||
@params:
|
||||
|
@ -91,8 +92,12 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
|
|||
broken_point_mode (boolean): broken_point_mode will check file in `<workspace>/<mode-folder>/results/<save-name>` (if the file exists)
|
||||
and skip all models already have attributes "latency"
|
||||
|
||||
time_threshold (int): the time threshold for profiling one single model. If the total profiling time of a model is longger than the
|
||||
`time_threshold` (second), nn-Meter will log a profiling timeout error for this model and step to profile the next model.
|
||||
|
||||
**kwargs: arguments for profiler, such as `taskset` and `close_xnnpack` in TFLite profiler
|
||||
"""
|
||||
signal.signal(signal.SIGALRM, handle_timeout)
|
||||
if isinstance(models, str):
|
||||
with open(models, 'r') as fp:
|
||||
models = json.load(fp)
|
||||
|
@ -130,7 +135,9 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
|
|||
if have_converted: # the models have been converted for the backend
|
||||
try:
|
||||
model_path = model['converted_model']
|
||||
signal.alarm(time_threshold)
|
||||
profiled_res = backend.profile(model_path, metrics, model['shapes'], **kwargs)
|
||||
signal.alarm(0)
|
||||
for metric in metrics:
|
||||
model[metric] = profiled_res[metric]
|
||||
time.sleep(0.2)
|
||||
|
@ -140,7 +147,9 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
|
|||
else: # the models have not been converted
|
||||
try:
|
||||
model_path = model['model']
|
||||
signal.alarm(time_threshold)
|
||||
profiled_res = backend.profile_model_file(model_path, model_save_path, model['shapes'], metrics, **kwargs)
|
||||
signal.alarm(0)
|
||||
for metric in metrics:
|
||||
model[metric] = profiled_res[metric]
|
||||
time.sleep(0.2)
|
||||
|
|
|
@ -52,3 +52,7 @@ def save_profiled_results(models, save_path, detail, metrics = ["latency"]):
|
|||
from .backend_meta.utils import dump_profiled_results
|
||||
with open(save_path, 'w') as fp:
|
||||
json.dump(dump_profiled_results(new_models, detail=detail, metrics=metrics), fp, indent=4)
|
||||
|
||||
|
||||
def handle_timeout(sig, frame):
|
||||
raise TimeoutError('Model profiling took too long (longer than the time threshold in the funciton `nn_meter.builder.profile_models`, default to be 300s)')
|
||||
|
|
Загрузка…
Ссылка в новой задаче