change the backend profile argument (#78)
This commit is contained in:
Родитель
28af7704b5
Коммит
681f4ae86e
|
@ -70,7 +70,7 @@ class BaseBackend:
|
|||
self.parser_kwargs = {}
|
||||
self.profiler_kwargs = {}
|
||||
|
||||
def convert_model(self, model_path, save_path, input_shape=None):
|
||||
def convert_model(self, model_path, save_path, input_shape = None):
|
||||
""" convert the Keras model instance to the type required by the backend inference.
|
||||
|
||||
@params:
|
||||
|
@ -86,7 +86,7 @@ class BaseBackend:
|
|||
converted_model = model_path
|
||||
return converted_model
|
||||
|
||||
def profile(self, converted_model, metrics = ['latency'], input_shape = None, **kwargs):
|
||||
def profile(self, converted_model, metrics = ['latency'], **kwargs):
|
||||
"""
|
||||
run the model on the backend, return required metrics of the running results. nn-Meter only support latency
|
||||
for metric by now. Users may provide other metrics in their customized backend.
|
||||
|
@ -112,7 +112,7 @@ class BaseBackend:
|
|||
generated and used
|
||||
"""
|
||||
converted_model = self.convert_model(model_path, save_path, input_shape)
|
||||
res = self.profile(converted_model, metrics, input_shape, **kwargs)
|
||||
res = self.profile(converted_model, metrics, input_shape=input_shape, **kwargs)
|
||||
return res
|
||||
|
||||
def test_connection(self):
|
||||
|
|
|
@ -28,7 +28,7 @@ class OpenVINOProfiler(BaseProfiler):
|
|||
self._graph_path = graph_path
|
||||
self._dst_graph_path = dst_graph_path
|
||||
|
||||
def profile(self, shapes, retry=2):
|
||||
def profile(self, shapes, retry = 2, **kwargs):
|
||||
interpreter_path = os.path.join(self._venv, 'bin/python')
|
||||
pyver = get_pyver(interpreter_path)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class TFLiteProfiler(BaseProfiler):
|
|||
self._num_runs = num_runs
|
||||
self._warm_ups = warm_ups
|
||||
|
||||
def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False):
|
||||
def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False, **kwargs):
|
||||
"""
|
||||
@params:
|
||||
preserve: tflite file exists in remote dir. No need to push it again.
|
||||
|
|
Загрузка…
Ссылка в новой задаче