This commit is contained in:
jiahangxu 2021-10-26 14:32:15 +08:00
Родитель 0b48ecbf5a
Коммит 2725459433
3 изменённых файлов: 8 добавлений и 2 удалений

Просмотреть файл

@ -52,3 +52,5 @@ jobs:
- name: Diff result with reference
run: diff tests/reference_result_nni_based_torch.txt tests/test_result_nni_based_torch.txt
- name: clean env
run: rm tests/test_result_nni_based_torch.txt

Просмотреть файл

@ -60,7 +60,7 @@ def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224,
'inception_v3': 'models.inception_v3()',
'googlenet': 'models.googlenet()',
'shufflenet_v2': 'models.shufflenet_v2_x1_0()',
'mobilenet_v2': 'models.mobilenet_v2()', # noqa: F841
'mobilenet_v2': 'models.mobilenet_v2()',
'resnext50_32x4d': 'models.resnext50_32x4d()',
'wide_resnet50_2': 'models.wide_resnet50_2()',
'mnasnet': 'models.mnasnet1_0()',
@ -69,7 +69,7 @@ def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224,
model = eval(torchvision_zoo_dict[filename])
else:
suppost_list = ", ".join([k for k in torchvision_zoo_dict])
raise ValueError(f"Unsupported model name in torchvision. Supporting list: {suppost_list}")
raise ValueError(f"Unsupported model name: {filename} in torchvision. Supporting list: {suppost_list}")
return torch_model_to_graph(model, input_shape, apply_nni)
else:

Просмотреть файл

@ -37,6 +37,9 @@ def integration_test_onnx_based_torch(model_type, model_list, output_name = "tes
if not os.path.isfile(output_name):
with open(output_name,"w") as f:
f.write('model_name, model_type, predictor, predictor_version, latency\n')
else:
print(f"Found exist file {output_name}")
os.system(f'cat {output_name}')
# start testing
for pred_name, pred_version in get_predictors():
@ -51,6 +54,7 @@ def integration_test_onnx_based_torch(model_type, model_list, output_name = "tes
print('Complete os.system run')
runtime = time.time() - since
print(runtime)
os.system(f'cat {output_name}')
except NotImplementedError:
logging.error(f"Meets ERROR when checking --torchvision {model_list} --predictor {pred_name} --predictor-version {pred_version}")