[Fix bug] Modify profiling arguments (#86)

This commit is contained in:
Jiahang Xu 2022-10-25 16:41:35 +08:00 коммит произвёл GitHub
Родитель 29b7f9a073
Коммит 0f00a22a2d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 12 добавлений и 6 удалений

Просмотреть файл

@ -38,10 +38,16 @@ class TFLiteCPULatencyParser(BaseParser):
if re.search(end_regex, line):
flag = False
# import pandas as pd
# df = pd.DataFrame(columns=('node_type', 'avg', 'name'))
# for node in nodes:
# # print({'node_type': node['node_type'], 'avg': node['avg'], 'name': node['name']})
# df.loc[len(df)] = [node['node_type'], node['avg'], node['name']]
return nodes
def _parse_total_latency(self, content):
total_latency_regex = r'Timings \(microseconds\): count=[\d.e-]+ first=[\d.e-]+ curr=[\d.e-]+ min=[\d.e-]+ max=[\d.e-]+ avg=([\d.e-]+) std=([\d.e-]+)'
total_latency_regex = r'Timings \(microseconds\): count=[\d.e-]+ first=[\d.e-]+ curr=[\d.e-]+ min=[\d.e-]+ max=[\d.e-]+ avg=([\d.\+e-]+) std=([\d.\+e-]+)'
total_latency = Latency()
match = re.search(total_latency_regex, content, re.MULTILINE)

Просмотреть файл

@ -63,7 +63,7 @@ class TFLiteProfiler(BaseProfiler):
if clean:
if self._serial:
os.system(f"adb -s {self._serial} shell rm {remote_graph_path}")
os.remove(graph_path)
# os.remove(graph_path)
else:
os.system(f"adb shell rm {remote_graph_path}")

Просмотреть файл

@ -48,7 +48,7 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
continue
try:
model_path = model['model']
converted_model = backend.convert_model(model_path, model_save_path, model['shapes'])
converted_model = backend.convert_model(model_path, model_save_path, input_shape=model['shapes'])
model['converted_model'] = converted_model
count += 1
except Exception as e:
@ -136,7 +136,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
try:
model_path = model['converted_model']
signal.alarm(time_threshold)
profiled_res = backend.profile(model_path, metrics, model['shapes'], **kwargs)
profiled_res = backend.profile(model_path, metrics, input_shape=model['shapes'], **kwargs)
signal.alarm(0)
for metric in metrics:
model[metric] = profiled_res[metric]
@ -148,7 +148,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
try:
model_path = model['model']
signal.alarm(time_threshold)
profiled_res = backend.profile_model_file(model_path, model_save_path, model['shapes'], metrics, **kwargs)
profiled_res = backend.profile_model_file(model_path, model_save_path, input_shape=model['shapes'], metrics=metrics, **kwargs)
signal.alarm(0)
for metric in metrics:
model[metric] = profiled_res[metric]