retry onnx integration test
This commit is contained in:
Родитель
0b48ecbf5a
Коммит
2725459433
|
@ -52,3 +52,5 @@ jobs:
|
|||
- name: Diff result with reference
|
||||
run: diff tests/reference_result_nni_based_torch.txt tests/test_result_nni_based_torch.txt
|
||||
|
||||
- name: clean env
|
||||
run: rm tests/test_result_nni_based_torch.txt
|
||||
|
|
|
@ -60,7 +60,7 @@ def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224,
|
|||
'inception_v3': 'models.inception_v3()',
|
||||
'googlenet': 'models.googlenet()',
|
||||
'shufflenet_v2': 'models.shufflenet_v2_x1_0()',
|
||||
'mobilenet_v2': 'models.mobilenet_v2()', # noqa: F841
|
||||
'mobilenet_v2': 'models.mobilenet_v2()',
|
||||
'resnext50_32x4d': 'models.resnext50_32x4d()',
|
||||
'wide_resnet50_2': 'models.wide_resnet50_2()',
|
||||
'mnasnet': 'models.mnasnet1_0()',
|
||||
|
@ -69,7 +69,7 @@ def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224,
|
|||
model = eval(torchvision_zoo_dict[filename])
|
||||
else:
|
||||
suppost_list = ", ".join([k for k in torchvision_zoo_dict])
|
||||
raise ValueError(f"Unsupported model name in torchvision. Supporting list: {suppost_list}")
|
||||
raise ValueError(f"Unsupported model name: {filename} in torchvision. Supporting list: {suppost_list}")
|
||||
return torch_model_to_graph(model, input_shape, apply_nni)
|
||||
|
||||
else:
|
||||
|
|
|
@ -37,6 +37,9 @@ def integration_test_onnx_based_torch(model_type, model_list, output_name = "tes
|
|||
if not os.path.isfile(output_name):
|
||||
with open(output_name,"w") as f:
|
||||
f.write('model_name, model_type, predictor, predictor_version, latency\n')
|
||||
else:
|
||||
print(f"Found exist file {output_name}")
|
||||
os.system(f'cat {output_name}')
|
||||
|
||||
# start testing
|
||||
for pred_name, pred_version in get_predictors():
|
||||
|
@ -51,6 +54,7 @@ def integration_test_onnx_based_torch(model_type, model_list, output_name = "tes
|
|||
print('Complete os.system run')
|
||||
runtime = time.time() - since
|
||||
print(runtime)
|
||||
os.system(f'cat {output_name}')
|
||||
except NotImplementedError:
|
||||
logging.error(f"Meets ERROR when checking --torchvision {model_list} --predictor {pred_name} --predictor-version {pred_version}")
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче