[Fix bug] Modify profiling arguments (#86)
This commit is contained in:
Родитель
29b7f9a073
Коммит
0f00a22a2d
|
@ -38,10 +38,16 @@ class TFLiteCPULatencyParser(BaseParser):
|
|||
if re.search(end_regex, line):
|
||||
flag = False
|
||||
|
||||
# import pandas as pd
|
||||
# df = pd.DataFrame(columns=('node_type', 'avg', 'name'))
|
||||
# for node in nodes:
|
||||
# # print({'node_type': node['node_type'], 'avg': node['avg'], 'name': node['name']})
|
||||
# df.loc[len(df)] = [node['node_type'], node['avg'], node['name']]
|
||||
|
||||
return nodes
|
||||
|
||||
def _parse_total_latency(self, content):
|
||||
total_latency_regex = r'Timings \(microseconds\): count=[\d.e-]+ first=[\d.e-]+ curr=[\d.e-]+ min=[\d.e-]+ max=[\d.e-]+ avg=([\d.e-]+) std=([\d.e-]+)'
|
||||
total_latency_regex = r'Timings \(microseconds\): count=[\d.e-]+ first=[\d.e-]+ curr=[\d.e-]+ min=[\d.e-]+ max=[\d.e-]+ avg=([\d.\+e-]+) std=([\d.\+e-]+)'
|
||||
|
||||
total_latency = Latency()
|
||||
match = re.search(total_latency_regex, content, re.MULTILINE)
|
||||
|
|
|
@ -63,7 +63,7 @@ class TFLiteProfiler(BaseProfiler):
|
|||
if clean:
|
||||
if self._serial:
|
||||
os.system(f"adb -s {self._serial} shell rm {remote_graph_path}")
|
||||
os.remove(graph_path)
|
||||
# os.remove(graph_path)
|
||||
else:
|
||||
os.system(f"adb shell rm {remote_graph_path}")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
|
|||
continue
|
||||
try:
|
||||
model_path = model['model']
|
||||
converted_model = backend.convert_model(model_path, model_save_path, model['shapes'])
|
||||
converted_model = backend.convert_model(model_path, model_save_path, input_shape=model['shapes'])
|
||||
model['converted_model'] = converted_model
|
||||
count += 1
|
||||
except Exception as e:
|
||||
|
@ -136,7 +136,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
|
|||
try:
|
||||
model_path = model['converted_model']
|
||||
signal.alarm(time_threshold)
|
||||
profiled_res = backend.profile(model_path, metrics, model['shapes'], **kwargs)
|
||||
profiled_res = backend.profile(model_path, metrics, input_shape=model['shapes'], **kwargs)
|
||||
signal.alarm(0)
|
||||
for metric in metrics:
|
||||
model[metric] = profiled_res[metric]
|
||||
|
@ -148,7 +148,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
|
|||
try:
|
||||
model_path = model['model']
|
||||
signal.alarm(time_threshold)
|
||||
profiled_res = backend.profile_model_file(model_path, model_save_path, model['shapes'], metrics, **kwargs)
|
||||
profiled_res = backend.profile_model_file(model_path, model_save_path, input_shape=model['shapes'], metrics=metrics, **kwargs)
|
||||
signal.alarm(0)
|
||||
for metric in metrics:
|
||||
model[metric] = profiled_res[metric]
|
||||
|
|
Загрузка…
Ссылка в новой задаче