diff --git a/nn_meter/builder/backends/interface.py b/nn_meter/builder/backends/interface.py index 6c5a5d5..fd19c07 100644 --- a/nn_meter/builder/backends/interface.py +++ b/nn_meter/builder/backends/interface.py @@ -70,7 +70,7 @@ class BaseBackend: self.parser_kwargs = {} self.profiler_kwargs = {} - def convert_model(self, model_path, save_path, input_shape=None): + def convert_model(self, model_path, save_path, input_shape = None): """ convert the Keras model instance to the type required by the backend inference. @params: @@ -86,7 +86,7 @@ class BaseBackend: converted_model = model_path return converted_model - def profile(self, converted_model, metrics = ['latency'], input_shape = None, **kwargs): + def profile(self, converted_model, metrics = ['latency'], **kwargs): """ run the model on the backend, return required metrics of the running results. nn-Meter only support latency for metric by now. Users may provide other metrics in their customized backend. @@ -112,7 +112,7 @@ class BaseBackend: generated and used """ converted_model = self.convert_model(model_path, save_path, input_shape) - res = self.profile(converted_model, metrics, input_shape, **kwargs) + res = self.profile(converted_model, metrics, input_shape=input_shape, **kwargs) return res def test_connection(self): diff --git a/nn_meter/builder/backends/openvino/openvino_profiler.py b/nn_meter/builder/backends/openvino/openvino_profiler.py index 5c6c1dd..167c059 100644 --- a/nn_meter/builder/backends/openvino/openvino_profiler.py +++ b/nn_meter/builder/backends/openvino/openvino_profiler.py @@ -28,7 +28,7 @@ class OpenVINOProfiler(BaseProfiler): self._graph_path = graph_path self._dst_graph_path = dst_graph_path - def profile(self, shapes, retry=2): + def profile(self, shapes, retry = 2, **kwargs): interpreter_path = os.path.join(self._venv, 'bin/python') pyver = get_pyver(interpreter_path) diff --git a/nn_meter/builder/backends/tflite/tflite_profiler.py b/nn_meter/builder/backends/tflite/tflite_profiler.py index 9539ea8..ac240de 100644 --- a/nn_meter/builder/backends/tflite/tflite_profiler.py +++ b/nn_meter/builder/backends/tflite/tflite_profiler.py @@ -24,7 +24,7 @@ class TFLiteProfiler(BaseProfiler): self._num_runs = num_runs self._warm_ups = warm_ups - def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False): + def profile(self, graph_path, preserve = False, clean = True, taskset = '70', close_xnnpack = False, **kwargs): """ @params: preserve: tflite file exists in remote dir. No need to push it again.