fix model path
This commit is contained in:
Родитель
5721afccad
Коммит
db6dc89af2
|
@ -53,8 +53,7 @@ class BaseTestCase:
|
|||
|
||||
for op, model in testcase.items():
|
||||
model_path = os.path.join(self.workspace_path, self.name + '_' + op)
|
||||
# import pdb; pdb.set_trace()
|
||||
save_model(model, model_path, self.implement)
|
||||
model_path = save_model(model, model_path, self.implement)
|
||||
testcase[op]['model'] = model_path
|
||||
|
||||
return testcase
|
||||
|
|
|
@ -142,6 +142,7 @@ def save_model(model, model_path, implement):
|
|||
from nn_meter.builder.nn_generator.tf_networks.utils import get_tensor_by_shapes
|
||||
model['model'](get_tensor_by_shapes(model['shapes']))
|
||||
keras.models.save_model(model['model'], model_path)
|
||||
return model_path
|
||||
elif implement == 'torch':
|
||||
import torch
|
||||
from nn_meter.builder.nn_generator.torch_networks.utils import get_inputs_by_shapes
|
||||
|
@ -156,6 +157,7 @@ def save_model(model, model_path, implement):
|
|||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
return model_path + '.onnx'
|
||||
|
||||
else:
|
||||
import pdb; pdb.set_trace()
|
||||
|
|
Загрузка…
Ссылка в новой задаче