fix get_attr and new tensorflow saver

This commit is contained in:
namizzz 2018-05-23 19:04:00 +08:00
Родитель 67cb6e0dd0
Коммит dec36a9ae0
5 изменённых файлов: 72 добавлений и 25 удалений

Просмотреть файл

@ -60,8 +60,13 @@ class IRGraphNode(GraphNode):
attr = self.layer.attr[name]
field = attr.WhichOneof('value')
val = getattr(attr, field) if field else default_value
if not val:
return val
if isinstance(val, AttrValue.ListValue):
return list(val.ListFields()[0][1])
if val.ListFields():
return list(val.ListFields()[0][1])
else:
return val.ListFields()
else:
return val.decode('utf-8') if isinstance(val, bytes) else val
else:

Просмотреть файл

@ -174,9 +174,6 @@ class tensorflow_extractor(base_extractor):
init = tf.global_variables_initializer()
with tf.Session() as sess:
# tf.train.export_meta_graph("kit.meta", as_text=True)
# writer = tf.summary.FileWriter('./graphs', sess.graph)
# writer.close()
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess, path + cls.architecture_map[architecture]['filename'])

Просмотреть файл

@ -59,10 +59,31 @@ class TestTF(TestKit):
def dump(self, path = None):
if path is None: path = self.args.dump
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver()
save_path = saver.save(sess, self.args.dump)
sess.run(tf.global_variables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(path)
tensor_info_input = tf.saved_model.utils.build_tensor_info(self.input)
tensor_info_output = tf.saved_model.utils.build_tensor_info(self.model)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tensor_info_input},
outputs={'output': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.TRAINING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
}
)
save_path = builder.save()
print ('Tensorflow file is saved as [{}], generated by [{}.py] and [{}].'.format(
save_path, self.args.n, self.args.w))

Просмотреть файл

@ -4,9 +4,30 @@ import tensorflow as tf
def save_model(MainModel, network_filepath, weight_filepath, dump_filepath):
input, model = MainModel.KitModel(weight_filepath)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver()
save_path = saver.save(sess, dump_filepath)
sess.run(tf.global_variables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(path)
tensor_info_input = tf.saved_model.utils.build_tensor_info(input)
tensor_info_output = tf.saved_model.utils.build_tensor_info(model)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tensor_info_input},
outputs={'output': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.TRAINING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
}
)
save_path = builder.save()
print('Tensorflow file is saved as [{}], generated by [{}.py] and [{}].'.format(
save_path, network_filepath, weight_filepath))

Просмотреть файл

@ -689,24 +689,27 @@ class TestModels(CorrectnessTest):
os.remove(converted_file + '.npy')
# In case of odd number add the extra padding at the end for SAME_UPPER(eg. pads:[0, 2, 2, 0, 0, 3, 3, 0]) and at the beginning for SAME_LOWER(eg. pads:[0, 3, 3, 0, 0, 2, 2, 0])
exception_tabel = {
'cntk_Keras_resnet18', # Cntk Padding is SAME_UPPER, but Keras Padding is SAME_LOWER, in first convolution layer.
'cntk_Keras_resnet152', # Cntk Padding is SAME_UPPER, but Keras Padding is SAME_LOWER, in first convolution layer.
'cntk_Tensorflow_resnet18', # Cntk Padding is SAME_UPPER, but Keras Padding is SAME_LOWER, in first convolution layer.
'cntk_Tensorflow_resnet152', # Cntk Padding is SAME_UPPER, but Keras Padding is SAME_LOWER, in first convolution layer.
'tensorflow_Cntk_inception_v1', # TODO
'tensorflow_Cntk_resnet_v1_50', # TODO
'tensorflow_Cntk_resnet_v2_50', # TODO
'tensorflow_Cntk_resnet_v1_152', # TODO
'tensorflow_Cntk_resnet_v2_152', # TODO
'tensorflow_Cntk_mobilenet_v1_1.0', # TODO
'tensorflow_frozen_MXNet_inception_v1', # TODO
'cntk_Keras_resnet18', # Cntk Padding is SAME_LOWER, but Keras Padding is SAME_UPPER, in first convolution layer.
'cntk_Keras_resnet152', # Cntk Padding is SAME_LOWER, but Keras Padding is SAME_UPPER, in first convolution layer.
'cntk_Tensorflow_resnet18', # Cntk Padding is SAME_LOWER, but Keras Padding is SAME_UPPER, in first convolution layer.
'cntk_Tensorflow_resnet152', # Cntk Padding is SAME_LOWER, but Keras Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_inception_v1', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_resnet_v1_50', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_resnet_v2_50', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_resnet_v1_152', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_resnet_v2_152', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_Cntk_mobilenet_v1_1.0', # Cntk Padding is SAME_LOWER, but Tensorflow Padding is SAME_UPPER, in first convolution layer.
'tensorflow_frozen_MXNet_inception_v1', # different after AvgPool. AVG POOL padding difference between these two framework. MXNet AVGPooling Padding is SAME_LOWER, Tensorflow AVGPooling Padding is SAME_UPPER
'tensorflow_MXNet_inception_v3', # different after "InceptionV3/InceptionV3/Mixed_5b/Branch_3/AvgPool_0a_3x3/AvgPool". AVG POOL padding difference between these two framework.
'caffe_Pytorch_inception_v1', # TODO
'caffe_Pytorch_alexnet', # TODO
'caffe_Pytorch_inception_v4', # TODO, same with caffe_Cntk_inception_v4
'darknet_Keras_yolov2', # TODO,
'darknet_Keras_yolov3', # TODO,
'darknet_Keras_yolov2', # accumulation of small difference
'darknet_Keras_yolov3', # accumulation of small difference
}