This commit is contained in:
jiahangxu 2022-04-07 23:07:03 -04:00
Родитель 5721afccad
Коммит db6dc89af2
2 изменённых файлов: 3 добавлений и 2 удалений

Просмотреть файл

@ -53,8 +53,7 @@ class BaseTestCase:
for op, model in testcase.items():
model_path = os.path.join(self.workspace_path, self.name + '_' + op)
# import pdb; pdb.set_trace()
save_model(model, model_path, self.implement)
model_path = save_model(model, model_path, self.implement)
testcase[op]['model'] = model_path
return testcase

Просмотреть файл

@ -142,6 +142,7 @@ def save_model(model, model_path, implement):
from nn_meter.builder.nn_generator.tf_networks.utils import get_tensor_by_shapes
model['model'](get_tensor_by_shapes(model['shapes']))
keras.models.save_model(model['model'], model_path)
return model_path
elif implement == 'torch':
import torch
from nn_meter.builder.nn_generator.torch_networks.utils import get_inputs_by_shapes
@ -156,6 +157,7 @@ def save_model(model, model_path, implement):
opset_version=12,
do_constant_folding=True,
)
return model_path + '.onnx'
else:
import pdb; pdb.set_trace()