change the backend profile argument (#78)

This commit is contained in:
Jiahang Xu 2022-09-09 18:08:53 +08:00 коммит произвёл GitHub
Родитель 28af7704b5
Коммит 681f4ae86e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 5 добавлений и 5 удалений

Просмотреть файл

@ -70,7 +70,7 @@ class BaseBackend:
self.parser_kwargs = {}
self.profiler_kwargs = {}
def convert_model(self, model_path, save_path, input_shape=None):
def convert_model(self, model_path, save_path, input_shape = None):
""" convert the Keras model instance to the type required by the backend inference.
@params:
@ -86,7 +86,7 @@ class BaseBackend:
converted_model = model_path
return converted_model
def profile(self, converted_model, metrics = ['latency'], input_shape = None, **kwargs):
def profile(self, converted_model, metrics = ['latency'], **kwargs):
"""
run the model on the backend, return required metrics of the running results. nn-Meter only support latency
for metric by now. Users may provide other metrics in their customized backend.
@ -112,7 +112,7 @@ class BaseBackend:
generated and used
"""
converted_model = self.convert_model(model_path, save_path, input_shape)
res = self.profile(converted_model, metrics, input_shape, **kwargs)
res = self.profile(converted_model, metrics, input_shape=input_shape, **kwargs)
return res
def test_connection(self):

Просмотреть файл

@ -28,7 +28,7 @@ class OpenVINOProfiler(BaseProfiler):
self._graph_path = graph_path
self._dst_graph_path = dst_graph_path
def profile(self, shapes, retry=2):
def profile(self, shapes, retry = 2, **kwargs):
interpreter_path = os.path.join(self._venv, 'bin/python')
pyver = get_pyver(interpreter_path)

Просмотреть файл

@ -24,7 +24,7 @@ class TFLiteProfiler(BaseProfiler):
self._num_runs = num_runs
self._warm_ups = warm_ups
def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False):
def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False, **kwargs):
"""
@params:
preserve: tflite file exists in remote dir. No need to push it again.