зеркало из https://github.com/microsoft/MMdnn.git
cntk parser and MXNetEmit
This commit is contained in:
Коммит
b764afb56f
|
@ -25,7 +25,7 @@ class TestKit(object):
|
|||
'resnet' : [(22, 11.756789), (147, 8.5718527), (24, 6.1751032), (88, 4.3121386), (141, 4.1778097)],
|
||||
'resnet_v1_101' : [(21, 14.384739), (23, 14.262486), (144, 14.068737), (94, 12.17205), (134, 12.064575)],
|
||||
'inception_v3' : [(22, 9.4921198), (24, 4.0932288), (25, 3.700398), (23, 3.3715961), (147, 3.3620636)],
|
||||
'mobilenet' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
|
||||
'mobilenet_v1_1.0' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
|
||||
},
|
||||
'keras' : {
|
||||
'vgg16' : [(21, 0.81199354), (562, 0.019326132), (23, 0.018279659), (144, 0.012460723), (22, 0.012429929)],
|
||||
|
@ -81,7 +81,12 @@ class TestKit(object):
|
|||
'resnet_v2_152' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_200' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet152' : lambda path : TestKit.Standard(path, 299),
|
||||
'mobilenet' : lambda path : TestKit.Standard(path, 224)
|
||||
'mobilenet_v1_1.0' : lambda path : TestKit.Standard(path, 224),
|
||||
'mobilenet_v1_0.50' : lambda path : TestKit.Standard(path, 224),
|
||||
'mobilenet_v1_0.25' : lambda path : TestKit.Standard(path, 224),
|
||||
'mobilenet' : lambda path : TestKit.Standard(path, 224),
|
||||
'nasnet-a_large' : lambda path : TestKit.Standard(path, 331),
|
||||
'inception_resnet_v2' : lambda path : TestKit.Standard(path, 299),
|
||||
},
|
||||
|
||||
'keras' : {
|
||||
|
|
|
@ -6,10 +6,14 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.slim.python.slim.nets import vgg
|
||||
from tensorflow.contrib.slim.python.slim.nets import inception
|
||||
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
|
||||
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
|
||||
|
||||
from tensorflow.contrib.slim.nets import vgg
|
||||
from tensorflow.contrib.slim.nets import inception
|
||||
from tensorflow.contrib.slim.nets import resnet_v1
|
||||
from tensorflow.contrib.slim.nets import resnet_v2
|
||||
from mmdnn.conversion.examples.tensorflow.models import inception_resnet_v2
|
||||
from mmdnn.conversion.examples.tensorflow.models import mobilenet_v1
|
||||
from mmdnn.conversion.examples.tensorflow.models import nasnet
|
||||
slim = tf.contrib.slim
|
||||
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
@ -84,6 +88,30 @@ class tensorflow_extractor(base_extractor):
|
|||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'mobilenet_v1_1.0' : {
|
||||
'url' : 'http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz',
|
||||
'filename' : 'mobilenet_v1_1.0_224.ckpt',
|
||||
'builder' : lambda : mobilenet_v1.mobilenet_v1,
|
||||
'arg_scope' : mobilenet_v1.mobilenet_v1_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'inception_resnet_v2' : {
|
||||
'url' : 'http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz',
|
||||
'filename' : 'inception_resnet_v2_2016_08_30.ckpt',
|
||||
'builder' : lambda : inception_resnet_v2.inception_resnet_v2,
|
||||
'arg_scope' : inception_resnet_v2.inception_resnet_v2_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'nasnet-a_large' : {
|
||||
'url' : 'https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz',
|
||||
'filename' : 'model.ckpt',
|
||||
'builder' : lambda : nasnet.build_nasnet_large,
|
||||
'arg_scope' : nasnet.nasnet_large_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 331, 331, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
|
@ -0,0 +1,397 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains the definition of the Inception Resnet V2 architecture.
|
||||
|
||||
As described in http://arxiv.org/abs/1602.07261.
|
||||
|
||||
Inception-v4, Inception-ResNet and the Impact of Residual Connections
|
||||
on Learning
|
||||
Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
|
||||
"""Builds the 35x35 resnet block."""
|
||||
with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
|
||||
with tf.variable_scope('Branch_2'):
|
||||
tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
|
||||
tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
|
||||
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])
|
||||
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
|
||||
activation_fn=None, scope='Conv2d_1x1')
|
||||
scaled_up = up * scale
|
||||
if activation_fn == tf.nn.relu6:
|
||||
# Use clip_by_value to simulate bandpass activation.
|
||||
scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
|
||||
|
||||
net += scaled_up
|
||||
if activation_fn:
|
||||
net = activation_fn(net)
|
||||
return net
|
||||
|
||||
|
||||
def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
|
||||
"""Builds the 17x17 resnet block."""
|
||||
with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
|
||||
scope='Conv2d_0b_1x7')
|
||||
tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
|
||||
scope='Conv2d_0c_7x1')
|
||||
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
|
||||
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
|
||||
activation_fn=None, scope='Conv2d_1x1')
|
||||
|
||||
scaled_up = up * scale
|
||||
if activation_fn == tf.nn.relu6:
|
||||
# Use clip_by_value to simulate bandpass activation.
|
||||
scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
|
||||
|
||||
net += scaled_up
|
||||
if activation_fn:
|
||||
net = activation_fn(net)
|
||||
return net
|
||||
|
||||
|
||||
def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
|
||||
"""Builds the 8x8 resnet block."""
|
||||
with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
|
||||
scope='Conv2d_0b_1x3')
|
||||
tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
|
||||
scope='Conv2d_0c_3x1')
|
||||
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
|
||||
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
|
||||
activation_fn=None, scope='Conv2d_1x1')
|
||||
|
||||
scaled_up = up * scale
|
||||
if activation_fn == tf.nn.relu6:
|
||||
# Use clip_by_value to simulate bandpass activation.
|
||||
scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
|
||||
|
||||
net += scaled_up
|
||||
if activation_fn:
|
||||
net = activation_fn(net)
|
||||
return net
|
||||
|
||||
|
||||
def inception_resnet_v2_base(inputs,
|
||||
final_endpoint='Conv2d_7b_1x1',
|
||||
output_stride=16,
|
||||
align_feature_maps=False,
|
||||
scope=None,
|
||||
activation_fn=tf.nn.relu):
|
||||
"""Inception model from http://arxiv.org/abs/1602.07261.
|
||||
|
||||
Constructs an Inception Resnet v2 network from inputs to the given final
|
||||
endpoint. This method can construct the network up to the final inception
|
||||
block Conv2d_7b_1x1.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of size [batch_size, height, width, channels].
|
||||
final_endpoint: specifies the endpoint to construct the network up to. It
|
||||
can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
|
||||
'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
|
||||
'Mixed_5b', 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
|
||||
output_stride: A scalar that specifies the requested ratio of input to
|
||||
output spatial resolution. Only supports 8 and 16.
|
||||
align_feature_maps: When true, changes all the VALID paddings in the network
|
||||
to SAME padding so that the feature maps are aligned.
|
||||
scope: Optional variable_scope.
|
||||
activation_fn: Activation function for block scopes.
|
||||
|
||||
Returns:
|
||||
tensor_out: output tensor corresponding to the final_endpoint.
|
||||
end_points: a set of activations for external use, for example summaries or
|
||||
losses.
|
||||
|
||||
Raises:
|
||||
ValueError: if final_endpoint is not set to one of the predefined values,
|
||||
or if the output_stride is not 8 or 16, or if the output_stride is 8 and
|
||||
we request an end point after 'PreAuxLogits'.
|
||||
"""
|
||||
if output_stride != 8 and output_stride != 16:
|
||||
raise ValueError('output_stride must be 8 or 16.')
|
||||
|
||||
padding = 'SAME' if align_feature_maps else 'VALID'
|
||||
|
||||
end_points = {}
|
||||
|
||||
def add_and_check_final(name, net):
|
||||
end_points[name] = net
|
||||
return name == final_endpoint
|
||||
|
||||
with tf.variable_scope(scope, 'InceptionResnetV2', [inputs]):
|
||||
with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
|
||||
stride=1, padding='SAME'):
|
||||
# 149 x 149 x 32
|
||||
net = slim.conv2d(inputs, 32, 3, stride=2, padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points
|
||||
|
||||
# 147 x 147 x 32
|
||||
net = slim.conv2d(net, 32, 3, padding=padding,
|
||||
scope='Conv2d_2a_3x3')
|
||||
if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
|
||||
# 147 x 147 x 64
|
||||
net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
|
||||
if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
|
||||
# 73 x 73 x 64
|
||||
net = slim.max_pool2d(net, 3, stride=2, padding=padding,
|
||||
scope='MaxPool_3a_3x3')
|
||||
if add_and_check_final('MaxPool_3a_3x3', net): return net, end_points
|
||||
# 73 x 73 x 80
|
||||
net = slim.conv2d(net, 80, 1, padding=padding,
|
||||
scope='Conv2d_3b_1x1')
|
||||
if add_and_check_final('Conv2d_3b_1x1', net): return net, end_points
|
||||
# 71 x 71 x 192
|
||||
net = slim.conv2d(net, 192, 3, padding=padding,
|
||||
scope='Conv2d_4a_3x3')
|
||||
if add_and_check_final('Conv2d_4a_3x3', net): return net, end_points
|
||||
# 35 x 35 x 192
|
||||
net = slim.max_pool2d(net, 3, stride=2, padding=padding,
|
||||
scope='MaxPool_5a_3x3')
|
||||
if add_and_check_final('MaxPool_5a_3x3', net): return net, end_points
|
||||
|
||||
# 35 x 35 x 320
|
||||
with tf.variable_scope('Mixed_5b'):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
|
||||
scope='Conv2d_0b_5x5')
|
||||
with tf.variable_scope('Branch_2'):
|
||||
tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
|
||||
scope='Conv2d_0b_3x3')
|
||||
tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
|
||||
scope='Conv2d_0c_3x3')
|
||||
with tf.variable_scope('Branch_3'):
|
||||
tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
|
||||
scope='AvgPool_0a_3x3')
|
||||
tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
|
||||
scope='Conv2d_0b_1x1')
|
||||
net = tf.concat(
|
||||
[tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)
|
||||
|
||||
if add_and_check_final('Mixed_5b', net): return net, end_points
|
||||
# TODO(alemi): Register intermediate endpoints
|
||||
net = slim.repeat(net, 10, block35, scale=0.17,
|
||||
activation_fn=activation_fn)
|
||||
|
||||
# 17 x 17 x 1088 if output_stride == 8,
|
||||
# 33 x 33 x 1088 if output_stride == 16
|
||||
use_atrous = output_stride == 8
|
||||
|
||||
with tf.variable_scope('Mixed_6a'):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 384, 3, stride=1 if use_atrous else 2,
|
||||
padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
|
||||
scope='Conv2d_0b_3x3')
|
||||
tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
|
||||
stride=1 if use_atrous else 2,
|
||||
padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
with tf.variable_scope('Branch_2'):
|
||||
tower_pool = slim.max_pool2d(net, 3, stride=1 if use_atrous else 2,
|
||||
padding=padding,
|
||||
scope='MaxPool_1a_3x3')
|
||||
net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
|
||||
|
||||
if add_and_check_final('Mixed_6a', net): return net, end_points
|
||||
|
||||
# TODO(alemi): register intermediate endpoints
|
||||
with slim.arg_scope([slim.conv2d], rate=2 if use_atrous else 1):
|
||||
net = slim.repeat(net, 20, block17, scale=0.10,
|
||||
activation_fn=activation_fn)
|
||||
if add_and_check_final('PreAuxLogits', net): return net, end_points
|
||||
|
||||
if output_stride == 8:
|
||||
# TODO(gpapan): Properly support output_stride for the rest of the net.
|
||||
raise ValueError('output_stride==8 is only supported up to the '
|
||||
'PreAuxlogits end_point for now.')
|
||||
|
||||
# 8 x 8 x 2080
|
||||
with tf.variable_scope('Mixed_7a'):
|
||||
with tf.variable_scope('Branch_0'):
|
||||
tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
|
||||
padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
with tf.variable_scope('Branch_1'):
|
||||
tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
|
||||
padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
with tf.variable_scope('Branch_2'):
|
||||
tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
|
||||
tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
|
||||
scope='Conv2d_0b_3x3')
|
||||
tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
|
||||
padding=padding,
|
||||
scope='Conv2d_1a_3x3')
|
||||
with tf.variable_scope('Branch_3'):
|
||||
tower_pool = slim.max_pool2d(net, 3, stride=2,
|
||||
padding=padding,
|
||||
scope='MaxPool_1a_3x3')
|
||||
net = tf.concat(
|
||||
[tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3)
|
||||
|
||||
if add_and_check_final('Mixed_7a', net): return net, end_points
|
||||
|
||||
# TODO(alemi): register intermediate endpoints
|
||||
net = slim.repeat(net, 9, block8, scale=0.20, activation_fn=activation_fn)
|
||||
net = block8(net, activation_fn=None)
|
||||
|
||||
# 8 x 8 x 1536
|
||||
net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
|
||||
if add_and_check_final('Conv2d_7b_1x1', net): return net, end_points
|
||||
|
||||
raise ValueError('final_endpoint (%s) not recognized', final_endpoint)
|
||||
|
||||
|
||||
def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
|
||||
dropout_keep_prob=0.8,
|
||||
reuse=None,
|
||||
scope='InceptionResnetV2',
|
||||
create_aux_logits=True,
|
||||
activation_fn=tf.nn.relu):
|
||||
"""Creates the Inception Resnet V2 model.
|
||||
|
||||
Args:
|
||||
inputs: a 4-D tensor of size [batch_size, height, width, 3].
|
||||
Dimension batch_size may be undefined. If create_aux_logits is false,
|
||||
also height and width may be undefined.
|
||||
num_classes: number of predicted classes. If 0 or None, the logits layer
|
||||
is omitted and the input features to the logits layer (before dropout)
|
||||
are returned instead.
|
||||
is_training: whether is training or not.
|
||||
dropout_keep_prob: float, the fraction to keep before final layer.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
create_aux_logits: Whether to include the auxilliary logits.
|
||||
activation_fn: Activation function for conv2d.
|
||||
|
||||
Returns:
|
||||
net: the output of the logits layer (if num_classes is a non-zero integer),
|
||||
or the non-dropped-out input to the logits layer (if num_classes is 0 or
|
||||
None).
|
||||
end_points: the set of end_points from the inception model.
|
||||
"""
|
||||
end_points = {}
|
||||
|
||||
with tf.variable_scope(scope, 'InceptionResnetV2', [inputs],
|
||||
reuse=reuse) as scope:
|
||||
with slim.arg_scope([slim.batch_norm, slim.dropout],
|
||||
is_training=is_training):
|
||||
|
||||
net, end_points = inception_resnet_v2_base(inputs, scope=scope,
|
||||
activation_fn=activation_fn)
|
||||
|
||||
if create_aux_logits and num_classes:
|
||||
with tf.variable_scope('AuxLogits'):
|
||||
aux = end_points['PreAuxLogits']
|
||||
aux = slim.avg_pool2d(aux, 5, stride=3, padding='VALID',
|
||||
scope='Conv2d_1a_3x3')
|
||||
aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
|
||||
aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
|
||||
padding='VALID', scope='Conv2d_2a_5x5')
|
||||
aux = slim.flatten(aux)
|
||||
aux = slim.fully_connected(aux, num_classes, activation_fn=None,
|
||||
scope='Logits')
|
||||
end_points['AuxLogits'] = aux
|
||||
|
||||
with tf.variable_scope('Logits'):
|
||||
# TODO(sguada,arnoegw): Consider adding a parameter global_pool which
|
||||
# can be set to False to disable pooling here (as in resnet_*()).
|
||||
kernel_size = net.get_shape()[1:3]
|
||||
if kernel_size.is_fully_defined():
|
||||
net = slim.avg_pool2d(net, kernel_size, padding='VALID',
|
||||
scope='AvgPool_1a_8x8')
|
||||
else:
|
||||
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
|
||||
end_points['global_pool'] = net
|
||||
if not num_classes:
|
||||
return net, end_points
|
||||
net = slim.flatten(net)
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='Dropout')
|
||||
end_points['PreLogitsFlatten'] = net
|
||||
logits = slim.fully_connected(net, num_classes, activation_fn=None,
|
||||
scope='Logits')
|
||||
end_points['Logits'] = logits
|
||||
end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
|
||||
|
||||
return logits, end_points
|
||||
inception_resnet_v2.default_image_size = 299
|
||||
|
||||
|
||||
def inception_resnet_v2_arg_scope(weight_decay=0.00004,
|
||||
batch_norm_decay=0.9997,
|
||||
batch_norm_epsilon=0.001,
|
||||
activation_fn=tf.nn.relu):
|
||||
"""Returns the scope with the default parameters for inception_resnet_v2.
|
||||
|
||||
Args:
|
||||
weight_decay: the weight decay for weights variables.
|
||||
batch_norm_decay: decay for the moving average of batch_norm momentums.
|
||||
batch_norm_epsilon: small float added to variance to avoid dividing by zero.
|
||||
activation_fn: Activation function for conv2d.
|
||||
|
||||
Returns:
|
||||
a arg_scope with the parameters needed for inception_resnet_v2.
|
||||
"""
|
||||
# Set weight_decay for weights in conv2d and fully_connected layers.
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||
weights_regularizer=slim.l2_regularizer(weight_decay),
|
||||
biases_regularizer=slim.l2_regularizer(weight_decay)):
|
||||
|
||||
batch_norm_params = {
|
||||
'decay': batch_norm_decay,
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'fused': None, # Use fused batch norm if possible.
|
||||
}
|
||||
# Set activation_fn and parameters for batch_norm.
|
||||
with slim.arg_scope([slim.conv2d], activation_fn=activation_fn,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params=batch_norm_params) as scope:
|
||||
return scope
|
|
@ -0,0 +1,428 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""MobileNet v1.
|
||||
|
||||
MobileNet is a general architecture and can be used for multiple use cases.
|
||||
Depending on the use case, it can use different input layer size and different
|
||||
head (for example: embeddings, localization and classification).
|
||||
|
||||
As described in https://arxiv.org/abs/1704.04861.
|
||||
|
||||
MobileNets: Efficient Convolutional Neural Networks for
|
||||
Mobile Vision Applications
|
||||
Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang,
|
||||
Tobias Weyand, Marco Andreetto, Hartwig Adam
|
||||
|
||||
100% Mobilenet V1 (base) with input size 224x224:
|
||||
|
||||
See mobilenet_v1()
|
||||
|
||||
Layer params macs
|
||||
--------------------------------------------------------------------------------
|
||||
MobilenetV1/Conv2d_0/Conv2D: 864 10,838,016
|
||||
MobilenetV1/Conv2d_1_depthwise/depthwise: 288 3,612,672
|
||||
MobilenetV1/Conv2d_1_pointwise/Conv2D: 2,048 25,690,112
|
||||
MobilenetV1/Conv2d_2_depthwise/depthwise: 576 1,806,336
|
||||
MobilenetV1/Conv2d_2_pointwise/Conv2D: 8,192 25,690,112
|
||||
MobilenetV1/Conv2d_3_depthwise/depthwise: 1,152 3,612,672
|
||||
MobilenetV1/Conv2d_3_pointwise/Conv2D: 16,384 51,380,224
|
||||
MobilenetV1/Conv2d_4_depthwise/depthwise: 1,152 903,168
|
||||
MobilenetV1/Conv2d_4_pointwise/Conv2D: 32,768 25,690,112
|
||||
MobilenetV1/Conv2d_5_depthwise/depthwise: 2,304 1,806,336
|
||||
MobilenetV1/Conv2d_5_pointwise/Conv2D: 65,536 51,380,224
|
||||
MobilenetV1/Conv2d_6_depthwise/depthwise: 2,304 451,584
|
||||
MobilenetV1/Conv2d_6_pointwise/Conv2D: 131,072 25,690,112
|
||||
MobilenetV1/Conv2d_7_depthwise/depthwise: 4,608 903,168
|
||||
MobilenetV1/Conv2d_7_pointwise/Conv2D: 262,144 51,380,224
|
||||
MobilenetV1/Conv2d_8_depthwise/depthwise: 4,608 903,168
|
||||
MobilenetV1/Conv2d_8_pointwise/Conv2D: 262,144 51,380,224
|
||||
MobilenetV1/Conv2d_9_depthwise/depthwise: 4,608 903,168
|
||||
MobilenetV1/Conv2d_9_pointwise/Conv2D: 262,144 51,380,224
|
||||
MobilenetV1/Conv2d_10_depthwise/depthwise: 4,608 903,168
|
||||
MobilenetV1/Conv2d_10_pointwise/Conv2D: 262,144 51,380,224
|
||||
MobilenetV1/Conv2d_11_depthwise/depthwise: 4,608 903,168
|
||||
MobilenetV1/Conv2d_11_pointwise/Conv2D: 262,144 51,380,224
|
||||
MobilenetV1/Conv2d_12_depthwise/depthwise: 4,608 225,792
|
||||
MobilenetV1/Conv2d_12_pointwise/Conv2D: 524,288 25,690,112
|
||||
MobilenetV1/Conv2d_13_depthwise/depthwise: 9,216 451,584
|
||||
MobilenetV1/Conv2d_13_pointwise/Conv2D: 1,048,576 51,380,224
|
||||
--------------------------------------------------------------------------------
|
||||
Total: 3,185,088 567,716,352
|
||||
|
||||
|
||||
75% Mobilenet V1 (base) with input size 128x128:
|
||||
|
||||
See mobilenet_v1_075()
|
||||
|
||||
Layer params macs
|
||||
--------------------------------------------------------------------------------
|
||||
MobilenetV1/Conv2d_0/Conv2D: 648 2,654,208
|
||||
MobilenetV1/Conv2d_1_depthwise/depthwise: 216 884,736
|
||||
MobilenetV1/Conv2d_1_pointwise/Conv2D: 1,152 4,718,592
|
||||
MobilenetV1/Conv2d_2_depthwise/depthwise: 432 442,368
|
||||
MobilenetV1/Conv2d_2_pointwise/Conv2D: 4,608 4,718,592
|
||||
MobilenetV1/Conv2d_3_depthwise/depthwise: 864 884,736
|
||||
MobilenetV1/Conv2d_3_pointwise/Conv2D: 9,216 9,437,184
|
||||
MobilenetV1/Conv2d_4_depthwise/depthwise: 864 221,184
|
||||
MobilenetV1/Conv2d_4_pointwise/Conv2D: 18,432 4,718,592
|
||||
MobilenetV1/Conv2d_5_depthwise/depthwise: 1,728 442,368
|
||||
MobilenetV1/Conv2d_5_pointwise/Conv2D: 36,864 9,437,184
|
||||
MobilenetV1/Conv2d_6_depthwise/depthwise: 1,728 110,592
|
||||
MobilenetV1/Conv2d_6_pointwise/Conv2D: 73,728 4,718,592
|
||||
MobilenetV1/Conv2d_7_depthwise/depthwise: 3,456 221,184
|
||||
MobilenetV1/Conv2d_7_pointwise/Conv2D: 147,456 9,437,184
|
||||
MobilenetV1/Conv2d_8_depthwise/depthwise: 3,456 221,184
|
||||
MobilenetV1/Conv2d_8_pointwise/Conv2D: 147,456 9,437,184
|
||||
MobilenetV1/Conv2d_9_depthwise/depthwise: 3,456 221,184
|
||||
MobilenetV1/Conv2d_9_pointwise/Conv2D: 147,456 9,437,184
|
||||
MobilenetV1/Conv2d_10_depthwise/depthwise: 3,456 221,184
|
||||
MobilenetV1/Conv2d_10_pointwise/Conv2D: 147,456 9,437,184
|
||||
MobilenetV1/Conv2d_11_depthwise/depthwise: 3,456 221,184
|
||||
MobilenetV1/Conv2d_11_pointwise/Conv2D: 147,456 9,437,184
|
||||
MobilenetV1/Conv2d_12_depthwise/depthwise: 3,456 55,296
|
||||
MobilenetV1/Conv2d_12_pointwise/Conv2D: 294,912 4,718,592
|
||||
MobilenetV1/Conv2d_13_depthwise/depthwise: 6,912 110,592
|
||||
MobilenetV1/Conv2d_13_pointwise/Conv2D: 589,824 9,437,184
|
||||
--------------------------------------------------------------------------------
|
||||
Total: 1,800,144 106,002,432
|
||||
|
||||
"""
|
||||
|
||||
# Tensorflow mandates these.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
# Conv and DepthSepConv namedtuple define layers of the MobileNet architecture
|
||||
# Conv defines 3x3 convolution layers
|
||||
# DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution.
|
||||
# stride is the stride of the convolution
|
||||
# depth is the number of channels or filters in a layer
|
||||
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
||||
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
||||
|
||||
# _CONV_DEFS specifies the MobileNet body
|
||||
_CONV_DEFS = [
|
||||
Conv(kernel=[3, 3], stride=2, depth=32),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=64),
|
||||
DepthSepConv(kernel=[3, 3], stride=2, depth=128),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
||||
DepthSepConv(kernel=[3, 3], stride=2, depth=256),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=256),
|
||||
DepthSepConv(kernel=[3, 3], stride=2, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=512),
|
||||
DepthSepConv(kernel=[3, 3], stride=2, depth=1024),
|
||||
DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
|
||||
]
|
||||
|
||||
|
||||
def mobilenet_v1_base(inputs,
|
||||
final_endpoint='Conv2d_13_pointwise',
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
conv_defs=None,
|
||||
output_stride=None,
|
||||
scope=None):
|
||||
"""Mobilenet v1.
|
||||
|
||||
Constructs a Mobilenet v1 network from inputs to the given final endpoint.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
final_endpoint: specifies the endpoint to construct the network up to. It
|
||||
can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
|
||||
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise,
|
||||
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
|
||||
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
|
||||
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'].
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
output_stride: An integer that specifies the requested ratio of input to
|
||||
output spatial resolution. If not None, then we invoke atrous convolution
|
||||
if necessary to prevent the network from reducing the spatial resolution
|
||||
of the activation maps. Allowed values are 8 (accurate fully convolutional
|
||||
mode), 16 (fast fully convolutional mode), 32 (classification mode).
|
||||
scope: Optional variable_scope.
|
||||
|
||||
Returns:
|
||||
tensor_out: output tensor corresponding to the final_endpoint.
|
||||
end_points: a set of activations for external use, for example summaries or
|
||||
losses.
|
||||
|
||||
Raises:
|
||||
ValueError: if final_endpoint is not set to one of the predefined values,
|
||||
or depth_multiplier <= 0, or the target output_stride is not
|
||||
allowed.
|
||||
"""
|
||||
depth = lambda d: max(int(d * depth_multiplier), min_depth)
|
||||
end_points = {}
|
||||
|
||||
# Used to find thinned depths for each layer.
|
||||
if depth_multiplier <= 0:
|
||||
raise ValueError('depth_multiplier is not greater than zero.')
|
||||
|
||||
if conv_defs is None:
|
||||
conv_defs = _CONV_DEFS
|
||||
|
||||
if output_stride is not None and output_stride not in [8, 16, 32]:
|
||||
raise ValueError('Only allowed output_stride values are 8, 16, 32.')
|
||||
|
||||
with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
|
||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'):
|
||||
# The current_stride variable keeps track of the output stride of the
|
||||
# activations, i.e., the running product of convolution strides up to the
|
||||
# current network layer. This allows us to invoke atrous convolution
|
||||
# whenever applying the next convolution would result in the activations
|
||||
# having output stride larger than the target output_stride.
|
||||
current_stride = 1
|
||||
|
||||
# The atrous convolution rate parameter.
|
||||
rate = 1
|
||||
|
||||
net = inputs
|
||||
for i, conv_def in enumerate(conv_defs):
|
||||
end_point_base = 'Conv2d_%d' % i
|
||||
|
||||
if output_stride is not None and current_stride == output_stride:
|
||||
# If we have reached the target output_stride, then we need to employ
|
||||
# atrous convolution with stride=1 and multiply the atrous rate by the
|
||||
# current unit's stride for use in subsequent layers.
|
||||
layer_stride = 1
|
||||
layer_rate = rate
|
||||
rate *= conv_def.stride
|
||||
else:
|
||||
layer_stride = conv_def.stride
|
||||
layer_rate = 1
|
||||
current_stride *= conv_def.stride
|
||||
|
||||
if isinstance(conv_def, Conv):
|
||||
end_point = end_point_base
|
||||
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
|
||||
stride=conv_def.stride,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
scope=end_point)
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
elif isinstance(conv_def, DepthSepConv):
|
||||
end_point = end_point_base + '_depthwise'
|
||||
|
||||
# By passing filters=None
|
||||
# separable_conv2d produces only a depthwise convolution layer
|
||||
net = slim.separable_conv2d(net, None, conv_def.kernel,
|
||||
depth_multiplier=1,
|
||||
stride=layer_stride,
|
||||
rate=layer_rate,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
end_point = end_point_base + '_pointwise'
|
||||
|
||||
net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
|
||||
stride=1,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
else:
|
||||
raise ValueError('Unknown convolution type %s for layer %d'
|
||||
% (conv_def.ltype, i))
|
||||
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
||||
|
||||
|
||||
def mobilenet_v1(inputs,
|
||||
num_classes=1000,
|
||||
dropout_keep_prob=0.999,
|
||||
is_training=True,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
conv_defs=None,
|
||||
prediction_fn=tf.contrib.layers.softmax,
|
||||
spatial_squeeze=True,
|
||||
reuse=None,
|
||||
scope='MobilenetV1',
|
||||
global_pool=False):
|
||||
"""Mobilenet v1 model for classification.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes. If 0 or None, the logits layer
|
||||
is omitted and the input features to the logits layer (before dropout)
|
||||
are returned instead.
|
||||
dropout_keep_prob: the percentage of activation values that are retained.
|
||||
is_training: whether is training or not.
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
prediction_fn: a function to get predictions out of logits.
|
||||
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
|
||||
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
global_pool: Optional boolean flag to control the avgpooling before the
|
||||
logits layer. If false or unset, pooling is done with a fixed window
|
||||
that reduces default-sized inputs to 1x1, while larger inputs lead to
|
||||
larger outputs. If true, any input size is pooled down to 1x1.
|
||||
|
||||
Returns:
|
||||
net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
|
||||
is a non-zero integer, or the non-dropped-out input to the logits layer
|
||||
if num_classes is 0 or None.
|
||||
end_points: a dictionary from components of the network to the corresponding
|
||||
activation.
|
||||
|
||||
Raises:
|
||||
ValueError: Input rank is invalid.
|
||||
"""
|
||||
input_shape = inputs.get_shape().as_list()
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
|
||||
len(input_shape))
|
||||
|
||||
with tf.variable_scope(scope, 'MobilenetV1', [inputs], reuse=reuse) as scope:
|
||||
with slim.arg_scope([slim.batch_norm, slim.dropout],
|
||||
is_training=is_training):
|
||||
net, end_points = mobilenet_v1_base(inputs, scope=scope,
|
||||
min_depth=min_depth,
|
||||
depth_multiplier=depth_multiplier,
|
||||
conv_defs=conv_defs)
|
||||
with tf.variable_scope('Logits'):
|
||||
if global_pool:
|
||||
# Global average pooling.
|
||||
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
|
||||
end_points['global_pool'] = net
|
||||
else:
|
||||
# Pooling with a fixed kernel size.
|
||||
kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
|
||||
net = slim.avg_pool2d(net, kernel_size, padding='VALID',
|
||||
scope='AvgPool_1a')
|
||||
end_points['AvgPool_1a'] = net
|
||||
if not num_classes:
|
||||
return net, end_points
|
||||
# 1 x 1 x 1024
|
||||
net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
|
||||
logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope='Conv2d_1c_1x1')
|
||||
if spatial_squeeze:
|
||||
logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
|
||||
end_points['Logits'] = logits
|
||||
if prediction_fn:
|
||||
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
||||
return logits, end_points
|
||||
|
||||
mobilenet_v1.default_image_size = 224
|
||||
|
||||
|
||||
def wrapped_partial(func, *args, **kwargs):
|
||||
partial_func = functools.partial(func, *args, **kwargs)
|
||||
functools.update_wrapper(partial_func, func)
|
||||
return partial_func
|
||||
|
||||
|
||||
mobilenet_v1_075 = wrapped_partial(mobilenet_v1, depth_multiplier=0.75)
|
||||
mobilenet_v1_050 = wrapped_partial(mobilenet_v1, depth_multiplier=0.50)
|
||||
mobilenet_v1_025 = wrapped_partial(mobilenet_v1, depth_multiplier=0.25)
|
||||
|
||||
|
||||
def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
||||
"""Define kernel size which is automatically reduced for small input.
|
||||
|
||||
If the shape of the input images is unknown at graph construction time this
|
||||
function assumes that the input images are large enough.
|
||||
|
||||
Args:
|
||||
input_tensor: input tensor of size [batch_size, height, width, channels].
|
||||
kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
|
||||
|
||||
Returns:
|
||||
a tensor with the kernel size.
|
||||
"""
|
||||
shape = input_tensor.get_shape().as_list()
|
||||
if shape[1] is None or shape[2] is None:
|
||||
kernel_size_out = kernel_size
|
||||
else:
|
||||
kernel_size_out = [min(shape[1], kernel_size[0]),
|
||||
min(shape[2], kernel_size[1])]
|
||||
return kernel_size_out
|
||||
|
||||
|
||||
def mobilenet_v1_arg_scope(is_training=True,
|
||||
weight_decay=0.00004,
|
||||
stddev=0.09,
|
||||
regularize_depthwise=False):
|
||||
"""Defines the default MobilenetV1 arg scope.
|
||||
|
||||
Args:
|
||||
is_training: Whether or not we're training the model.
|
||||
weight_decay: The weight decay to use for regularizing the model.
|
||||
stddev: The standard deviation of the trunctated normal weight initializer.
|
||||
regularize_depthwise: Whether or not apply regularization on depthwise.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the mobilenet v1 model.
|
||||
"""
|
||||
batch_norm_params = {
|
||||
'is_training': is_training,
|
||||
'center': True,
|
||||
'scale': True,
|
||||
'decay': 0.9997,
|
||||
'epsilon': 0.001,
|
||||
}
|
||||
|
||||
# Set weight_decay for weights in Conv and DepthSepConv layers.
|
||||
weights_init = tf.truncated_normal_initializer(stddev=stddev)
|
||||
regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
|
||||
if regularize_depthwise:
|
||||
depthwise_regularizer = regularizer
|
||||
else:
|
||||
depthwise_regularizer = None
|
||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
|
||||
weights_initializer=weights_init,
|
||||
activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm):
|
||||
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
|
||||
with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer):
|
||||
with slim.arg_scope([slim.separable_conv2d],
|
||||
weights_regularizer=depthwise_regularizer) as sc:
|
||||
return sc
|
|
@ -0,0 +1,513 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains the definition for the NASNet classification networks.
|
||||
|
||||
Paper: https://arxiv.org/abs/1707.07012
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from . import nasnet_utils
|
||||
|
||||
arg_scope = tf.contrib.framework.arg_scope
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
# Notes for training NASNet Cifar Model
|
||||
# -------------------------------------
|
||||
# batch_size: 32
|
||||
# learning rate: 0.025
|
||||
# cosine (single period) learning rate decay
|
||||
# auxiliary head loss weighting: 0.4
|
||||
# clip global norm of all gradients by 5
|
||||
def _cifar_config(is_training=True):
|
||||
drop_path_keep_prob = 1.0 if not is_training else 0.6
|
||||
return tf.contrib.training.HParams(
|
||||
stem_multiplier=3.0,
|
||||
drop_path_keep_prob=drop_path_keep_prob,
|
||||
num_cells=18,
|
||||
use_aux_head=1,
|
||||
num_conv_filters=32,
|
||||
dense_dropout_keep_prob=1.0,
|
||||
filter_scaling_rate=2.0,
|
||||
num_reduction_layers=2,
|
||||
data_format='NHWC',
|
||||
skip_reduction_layer_input=0,
|
||||
# 600 epochs with a batch size of 32
|
||||
# This is used for the drop path probabilities since it needs to increase
|
||||
# the drop out probability over the course of training.
|
||||
total_training_steps=937500,
|
||||
)
|
||||
|
||||
|
||||
# Notes for training large NASNet model on ImageNet
|
||||
# -------------------------------------
|
||||
# batch size (per replica): 16
|
||||
# learning rate: 0.015 * 100
|
||||
# learning rate decay factor: 0.97
|
||||
# num epochs per decay: 2.4
|
||||
# sync sgd with 100 replicas
|
||||
# auxiliary head loss weighting: 0.4
|
||||
# label smoothing: 0.1
|
||||
# clip global norm of all gradients by 10
|
||||
def _large_imagenet_config(is_training=True):
|
||||
drop_path_keep_prob = 1.0 if not is_training else 0.7
|
||||
return tf.contrib.training.HParams(
|
||||
stem_multiplier=3.0,
|
||||
dense_dropout_keep_prob=0.5,
|
||||
num_cells=18,
|
||||
filter_scaling_rate=2.0,
|
||||
num_conv_filters=168,
|
||||
drop_path_keep_prob=drop_path_keep_prob,
|
||||
use_aux_head=1,
|
||||
num_reduction_layers=2,
|
||||
data_format='NHWC',
|
||||
skip_reduction_layer_input=1,
|
||||
total_training_steps=250000,
|
||||
)
|
||||
|
||||
|
||||
# Notes for training the mobile NASNet ImageNet model
|
||||
# -------------------------------------
|
||||
# batch size (per replica): 32
|
||||
# learning rate: 0.04 * 50
|
||||
# learning rate scaling factor: 0.97
|
||||
# num epochs per decay: 2.4
|
||||
# sync sgd with 50 replicas
|
||||
# auxiliary head weighting: 0.4
|
||||
# label smoothing: 0.1
|
||||
# clip global norm of all gradients by 10
|
||||
def _mobile_imagenet_config():
|
||||
return tf.contrib.training.HParams(
|
||||
stem_multiplier=1.0,
|
||||
dense_dropout_keep_prob=0.5,
|
||||
num_cells=12,
|
||||
filter_scaling_rate=2.0,
|
||||
drop_path_keep_prob=1.0,
|
||||
num_conv_filters=44,
|
||||
use_aux_head=1,
|
||||
num_reduction_layers=2,
|
||||
data_format='NHWC',
|
||||
skip_reduction_layer_input=0,
|
||||
total_training_steps=250000,
|
||||
)
|
||||
|
||||
|
||||
def nasnet_cifar_arg_scope(weight_decay=5e-4,
|
||||
batch_norm_decay=0.9,
|
||||
batch_norm_epsilon=1e-5):
|
||||
"""Defines the default arg scope for the NASNet-A Cifar model.
|
||||
|
||||
Args:
|
||||
weight_decay: The weight decay to use for regularizing the model.
|
||||
batch_norm_decay: Decay for batch norm moving average.
|
||||
batch_norm_epsilon: Small float added to variance to avoid dividing by zero
|
||||
in batch norm.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the NASNet Cifar Model.
|
||||
"""
|
||||
batch_norm_params = {
|
||||
# Decay for the moving averages.
|
||||
'decay': batch_norm_decay,
|
||||
# epsilon to prevent 0s in variance.
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'scale': True,
|
||||
'fused': True,
|
||||
}
|
||||
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
|
||||
weights_initializer = tf.contrib.layers.variance_scaling_initializer(
|
||||
mode='FAN_OUT')
|
||||
with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
|
||||
weights_regularizer=weights_regularizer,
|
||||
weights_initializer=weights_initializer):
|
||||
with arg_scope([slim.fully_connected],
|
||||
activation_fn=None, scope='FC'):
|
||||
with arg_scope([slim.conv2d, slim.separable_conv2d],
|
||||
activation_fn=None, biases_initializer=None):
|
||||
with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
|
||||
return sc
|
||||
|
||||
|
||||
def nasnet_mobile_arg_scope(weight_decay=4e-5,
|
||||
batch_norm_decay=0.9997,
|
||||
batch_norm_epsilon=1e-3):
|
||||
"""Defines the default arg scope for the NASNet-A Mobile ImageNet model.
|
||||
|
||||
Args:
|
||||
weight_decay: The weight decay to use for regularizing the model.
|
||||
batch_norm_decay: Decay for batch norm moving average.
|
||||
batch_norm_epsilon: Small float added to variance to avoid dividing by zero
|
||||
in batch norm.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the NASNet Mobile Model.
|
||||
"""
|
||||
batch_norm_params = {
|
||||
# Decay for the moving averages.
|
||||
'decay': batch_norm_decay,
|
||||
# epsilon to prevent 0s in variance.
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'scale': True,
|
||||
'fused': True,
|
||||
}
|
||||
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
|
||||
weights_initializer = tf.contrib.layers.variance_scaling_initializer(
|
||||
mode='FAN_OUT')
|
||||
with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
|
||||
weights_regularizer=weights_regularizer,
|
||||
weights_initializer=weights_initializer):
|
||||
with arg_scope([slim.fully_connected],
|
||||
activation_fn=None, scope='FC'):
|
||||
with arg_scope([slim.conv2d, slim.separable_conv2d],
|
||||
activation_fn=None, biases_initializer=None):
|
||||
with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
|
||||
return sc
|
||||
|
||||
|
||||
def nasnet_large_arg_scope(weight_decay=5e-5,
|
||||
batch_norm_decay=0.9997,
|
||||
batch_norm_epsilon=1e-3):
|
||||
"""Defines the default arg scope for the NASNet-A Large ImageNet model.
|
||||
|
||||
Args:
|
||||
weight_decay: The weight decay to use for regularizing the model.
|
||||
batch_norm_decay: Decay for batch norm moving average.
|
||||
batch_norm_epsilon: Small float added to variance to avoid dividing by zero
|
||||
in batch norm.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the NASNet Large Model.
|
||||
"""
|
||||
batch_norm_params = {
|
||||
# Decay for the moving averages.
|
||||
'decay': batch_norm_decay,
|
||||
# epsilon to prevent 0s in variance.
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'scale': True,
|
||||
'fused': True,
|
||||
}
|
||||
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
|
||||
weights_initializer = tf.contrib.layers.variance_scaling_initializer(
|
||||
mode='FAN_OUT')
|
||||
with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
|
||||
weights_regularizer=weights_regularizer,
|
||||
weights_initializer=weights_initializer):
|
||||
with arg_scope([slim.fully_connected],
|
||||
activation_fn=None, scope='FC'):
|
||||
with arg_scope([slim.conv2d, slim.separable_conv2d],
|
||||
activation_fn=None, biases_initializer=None):
|
||||
with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
|
||||
return sc
|
||||
|
||||
|
||||
def _build_aux_head(net, end_points, num_classes, hparams, scope):
|
||||
"""Auxiliary head used for all models across all datasets."""
|
||||
with tf.variable_scope(scope):
|
||||
aux_logits = tf.identity(net)
|
||||
with tf.variable_scope('aux_logits'):
|
||||
aux_logits = slim.avg_pool2d(
|
||||
aux_logits, [5, 5], stride=3, padding='VALID')
|
||||
aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
|
||||
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
|
||||
aux_logits = tf.nn.relu(aux_logits)
|
||||
# Shape of feature map before the final layer.
|
||||
shape = aux_logits.shape
|
||||
if hparams.data_format == 'NHWC':
|
||||
shape = shape[1:3]
|
||||
else:
|
||||
shape = shape[2:4]
|
||||
aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
|
||||
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
|
||||
aux_logits = tf.nn.relu(aux_logits)
|
||||
aux_logits = tf.contrib.layers.flatten(aux_logits)
|
||||
aux_logits = slim.fully_connected(aux_logits, num_classes)
|
||||
end_points['AuxLogits'] = aux_logits
|
||||
|
||||
|
||||
def _imagenet_stem(inputs, hparams, stem_cell):
|
||||
"""Stem used for models trained on ImageNet."""
|
||||
num_stem_cells = 2
|
||||
|
||||
# 149 x 149 x 32
|
||||
num_stem_filters = int(32 * hparams.stem_multiplier)
|
||||
net = slim.conv2d(
|
||||
inputs, num_stem_filters, [3, 3], stride=2, scope='conv0',
|
||||
padding='VALID')
|
||||
net = slim.batch_norm(net, scope='conv0_bn')
|
||||
|
||||
# Run the reduction cells
|
||||
cell_outputs = [None, net]
|
||||
filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells)
|
||||
for cell_num in range(num_stem_cells):
|
||||
net = stem_cell(
|
||||
net,
|
||||
scope='cell_stem_{}'.format(cell_num),
|
||||
filter_scaling=filter_scaling,
|
||||
stride=2,
|
||||
prev_layer=cell_outputs[-2],
|
||||
cell_num=cell_num)
|
||||
cell_outputs.append(net)
|
||||
filter_scaling *= hparams.filter_scaling_rate
|
||||
return net, cell_outputs
|
||||
|
||||
|
||||
def _cifar_stem(inputs, hparams):
|
||||
"""Stem used for models trained on Cifar."""
|
||||
num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier)
|
||||
net = slim.conv2d(
|
||||
inputs,
|
||||
num_stem_filters,
|
||||
3,
|
||||
scope='l1_stem_3x3')
|
||||
net = slim.batch_norm(net, scope='l1_stem_bn')
|
||||
return net, [None, net]
|
||||
|
||||
|
||||
def build_nasnet_cifar(
|
||||
images, num_classes, is_training=True):
|
||||
"""Build NASNet model for the Cifar Dataset."""
|
||||
hparams = _cifar_config(is_training=is_training)
|
||||
|
||||
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
|
||||
tf.logging.info('A GPU is available on the machine, consider using NCHW '
|
||||
'data format for increased speed on GPU.')
|
||||
|
||||
if hparams.data_format == 'NCHW':
|
||||
images = tf.transpose(images, [0, 3, 1, 2])
|
||||
|
||||
# Calculate the total number of cells in the network
|
||||
# Add 2 for the reduction cells
|
||||
total_num_cells = hparams.num_cells + 2
|
||||
|
||||
normal_cell = nasnet_utils.NasNetANormalCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
reduction_cell = nasnet_utils.NasNetAReductionCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
|
||||
is_training=is_training):
|
||||
with arg_scope([slim.avg_pool2d,
|
||||
slim.max_pool2d,
|
||||
slim.conv2d,
|
||||
slim.batch_norm,
|
||||
slim.separable_conv2d,
|
||||
nasnet_utils.factorized_reduction,
|
||||
nasnet_utils.global_avg_pool,
|
||||
nasnet_utils.get_channel_index,
|
||||
nasnet_utils.get_channel_dim],
|
||||
data_format=hparams.data_format):
|
||||
return _build_nasnet_base(images,
|
||||
normal_cell=normal_cell,
|
||||
reduction_cell=reduction_cell,
|
||||
num_classes=num_classes,
|
||||
hparams=hparams,
|
||||
is_training=is_training,
|
||||
stem_type='cifar')
|
||||
build_nasnet_cifar.default_image_size = 32
|
||||
|
||||
|
||||
def build_nasnet_mobile(images, num_classes,
|
||||
is_training=True,
|
||||
final_endpoint=None):
|
||||
"""Build NASNet Mobile model for the ImageNet Dataset."""
|
||||
hparams = _mobile_imagenet_config()
|
||||
|
||||
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
|
||||
tf.logging.info('A GPU is available on the machine, consider using NCHW '
|
||||
'data format for increased speed on GPU.')
|
||||
|
||||
if hparams.data_format == 'NCHW':
|
||||
images = tf.transpose(images, [0, 3, 1, 2])
|
||||
|
||||
# Calculate the total number of cells in the network
|
||||
# Add 2 for the reduction cells
|
||||
total_num_cells = hparams.num_cells + 2
|
||||
# If ImageNet, then add an additional two for the stem cells
|
||||
total_num_cells += 2
|
||||
|
||||
normal_cell = nasnet_utils.NasNetANormalCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
reduction_cell = nasnet_utils.NasNetAReductionCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
|
||||
is_training=is_training):
|
||||
with arg_scope([slim.avg_pool2d,
|
||||
slim.max_pool2d,
|
||||
slim.conv2d,
|
||||
slim.batch_norm,
|
||||
slim.separable_conv2d,
|
||||
nasnet_utils.factorized_reduction,
|
||||
nasnet_utils.global_avg_pool,
|
||||
nasnet_utils.get_channel_index,
|
||||
nasnet_utils.get_channel_dim],
|
||||
data_format=hparams.data_format):
|
||||
return _build_nasnet_base(images,
|
||||
normal_cell=normal_cell,
|
||||
reduction_cell=reduction_cell,
|
||||
num_classes=num_classes,
|
||||
hparams=hparams,
|
||||
is_training=is_training,
|
||||
stem_type='imagenet',
|
||||
final_endpoint=final_endpoint)
|
||||
build_nasnet_mobile.default_image_size = 224
|
||||
|
||||
|
||||
def build_nasnet_large(images, num_classes,
|
||||
is_training=True,
|
||||
final_endpoint=None):
|
||||
"""Build NASNet Large model for the ImageNet Dataset."""
|
||||
hparams = _large_imagenet_config(is_training=is_training)
|
||||
|
||||
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
|
||||
tf.logging.info('A GPU is available on the machine, consider using NCHW '
|
||||
'data format for increased speed on GPU.')
|
||||
|
||||
if hparams.data_format == 'NCHW':
|
||||
images = tf.transpose(images, [0, 3, 1, 2])
|
||||
|
||||
# Calculate the total number of cells in the network
|
||||
# Add 2 for the reduction cells
|
||||
total_num_cells = hparams.num_cells + 2
|
||||
# If ImageNet, then add an additional two for the stem cells
|
||||
total_num_cells += 2
|
||||
|
||||
normal_cell = nasnet_utils.NasNetANormalCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
reduction_cell = nasnet_utils.NasNetAReductionCell(
|
||||
hparams.num_conv_filters, hparams.drop_path_keep_prob,
|
||||
total_num_cells, hparams.total_training_steps)
|
||||
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
|
||||
is_training=is_training):
|
||||
with arg_scope([slim.avg_pool2d,
|
||||
slim.max_pool2d,
|
||||
slim.conv2d,
|
||||
slim.batch_norm,
|
||||
slim.separable_conv2d,
|
||||
nasnet_utils.factorized_reduction,
|
||||
nasnet_utils.global_avg_pool,
|
||||
nasnet_utils.get_channel_index,
|
||||
nasnet_utils.get_channel_dim],
|
||||
data_format=hparams.data_format):
|
||||
return _build_nasnet_base(images,
|
||||
normal_cell=normal_cell,
|
||||
reduction_cell=reduction_cell,
|
||||
num_classes=num_classes,
|
||||
hparams=hparams,
|
||||
is_training=is_training,
|
||||
stem_type='imagenet',
|
||||
final_endpoint=final_endpoint)
|
||||
build_nasnet_large.default_image_size = 331
|
||||
|
||||
|
||||
def _build_nasnet_base(images,
|
||||
normal_cell,
|
||||
reduction_cell,
|
||||
num_classes,
|
||||
hparams,
|
||||
is_training,
|
||||
stem_type,
|
||||
final_endpoint=None):
|
||||
"""Constructs a NASNet image model."""
|
||||
|
||||
end_points = {}
|
||||
def add_and_check_endpoint(endpoint_name, net):
|
||||
end_points[endpoint_name] = net
|
||||
return final_endpoint and (endpoint_name == final_endpoint)
|
||||
|
||||
# Find where to place the reduction cells or stride normal cells
|
||||
reduction_indices = nasnet_utils.calc_reduction_layers(
|
||||
hparams.num_cells, hparams.num_reduction_layers)
|
||||
stem_cell = reduction_cell
|
||||
|
||||
if stem_type == 'imagenet':
|
||||
stem = lambda: _imagenet_stem(images, hparams, stem_cell)
|
||||
elif stem_type == 'cifar':
|
||||
stem = lambda: _cifar_stem(images, hparams)
|
||||
else:
|
||||
raise ValueError('Unknown stem_type: ', stem_type)
|
||||
net, cell_outputs = stem()
|
||||
if add_and_check_endpoint('Stem', net): return net, end_points
|
||||
|
||||
# Setup for building in the auxiliary head.
|
||||
aux_head_cell_idxes = []
|
||||
if len(reduction_indices) >= 2:
|
||||
aux_head_cell_idxes.append(reduction_indices[1] - 1)
|
||||
|
||||
# Run the cells
|
||||
filter_scaling = 1.0
|
||||
# true_cell_num accounts for the stem cells
|
||||
true_cell_num = 2 if stem_type == 'imagenet' else 0
|
||||
for cell_num in range(hparams.num_cells):
|
||||
stride = 1
|
||||
if hparams.skip_reduction_layer_input:
|
||||
prev_layer = cell_outputs[-2]
|
||||
if cell_num in reduction_indices:
|
||||
filter_scaling *= hparams.filter_scaling_rate
|
||||
net = reduction_cell(
|
||||
net,
|
||||
scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)),
|
||||
filter_scaling=filter_scaling,
|
||||
stride=2,
|
||||
prev_layer=cell_outputs[-2],
|
||||
cell_num=true_cell_num)
|
||||
if add_and_check_endpoint(
|
||||
'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
|
||||
return net, end_points
|
||||
true_cell_num += 1
|
||||
cell_outputs.append(net)
|
||||
if not hparams.skip_reduction_layer_input:
|
||||
prev_layer = cell_outputs[-2]
|
||||
net = normal_cell(
|
||||
net,
|
||||
scope='cell_{}'.format(cell_num),
|
||||
filter_scaling=filter_scaling,
|
||||
stride=stride,
|
||||
prev_layer=prev_layer,
|
||||
cell_num=true_cell_num)
|
||||
|
||||
if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
|
||||
return net, end_points
|
||||
true_cell_num += 1
|
||||
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
|
||||
num_classes and is_training):
|
||||
aux_net = tf.nn.relu(net)
|
||||
_build_aux_head(aux_net, end_points, num_classes, hparams,
|
||||
scope='aux_{}'.format(cell_num))
|
||||
cell_outputs.append(net)
|
||||
|
||||
# Final softmax layer
|
||||
with tf.variable_scope('final_layer'):
|
||||
net = tf.nn.relu(net)
|
||||
net = nasnet_utils.global_avg_pool(net)
|
||||
if add_and_check_endpoint('global_pool', net) or num_classes is None:
|
||||
return net, end_points
|
||||
net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
|
||||
logits = slim.fully_connected(net, num_classes)
|
||||
|
||||
if add_and_check_endpoint('Logits', logits):
|
||||
return net, end_points
|
||||
|
||||
predictions = tf.nn.softmax(logits, name='predictions')
|
||||
if add_and_check_endpoint('Predictions', predictions):
|
||||
return net, end_points
|
||||
return logits, end_points
|
|
@ -0,0 +1,477 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A custom module for some common operations used by NASNet.
|
||||
|
||||
Functions exposed in this file:
|
||||
- calc_reduction_layers
|
||||
- get_channel_index
|
||||
- get_channel_dim
|
||||
- global_avg_pool
|
||||
- factorized_reduction
|
||||
- drop_path
|
||||
|
||||
Classes exposed in this file:
|
||||
- NasNetABaseCell
|
||||
- NasNetANormalCell
|
||||
- NasNetAReductionCell
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
arg_scope = tf.contrib.framework.arg_scope
|
||||
slim = tf.contrib.slim
|
||||
|
||||
DATA_FORMAT_NCHW = 'NCHW'
|
||||
DATA_FORMAT_NHWC = 'NHWC'
|
||||
INVALID = 'null'
|
||||
|
||||
|
||||
def calc_reduction_layers(num_cells, num_reduction_layers):
|
||||
"""Figure out what layers should have reductions."""
|
||||
reduction_layers = []
|
||||
for pool_num in range(1, num_reduction_layers + 1):
|
||||
layer_num = (float(pool_num) / (num_reduction_layers + 1)) * num_cells
|
||||
layer_num = int(layer_num)
|
||||
reduction_layers.append(layer_num)
|
||||
return reduction_layers
|
||||
|
||||
|
||||
@tf.contrib.framework.add_arg_scope
|
||||
def get_channel_index(data_format=INVALID):
|
||||
assert data_format != INVALID
|
||||
axis = 3 if data_format == 'NHWC' else 1
|
||||
return axis
|
||||
|
||||
|
||||
@tf.contrib.framework.add_arg_scope
|
||||
def get_channel_dim(shape, data_format=INVALID):
|
||||
assert data_format != INVALID
|
||||
assert len(shape) == 4
|
||||
if data_format == 'NHWC':
|
||||
return int(shape[3])
|
||||
elif data_format == 'NCHW':
|
||||
return int(shape[1])
|
||||
else:
|
||||
raise ValueError('Not a valid data_format', data_format)
|
||||
|
||||
|
||||
@tf.contrib.framework.add_arg_scope
|
||||
def global_avg_pool(x, data_format=INVALID):
|
||||
"""Average pool away the height and width spatial dimensions of x."""
|
||||
assert data_format != INVALID
|
||||
assert data_format in ['NHWC', 'NCHW']
|
||||
assert x.shape.ndims == 4
|
||||
if data_format == 'NHWC':
|
||||
return tf.reduce_mean(x, [1, 2])
|
||||
else:
|
||||
return tf.reduce_mean(x, [2, 3])
|
||||
|
||||
|
||||
@tf.contrib.framework.add_arg_scope
|
||||
def factorized_reduction(net, output_filters, stride, data_format=INVALID):
|
||||
"""Reduces the shape of net without information loss due to striding."""
|
||||
assert output_filters % 2 == 0, (
|
||||
'Need even number of filters when using this factorized reduction.')
|
||||
assert data_format != INVALID
|
||||
if stride == 1:
|
||||
net = slim.conv2d(net, output_filters, 1, scope='path_conv')
|
||||
net = slim.batch_norm(net, scope='path_bn')
|
||||
return net
|
||||
if data_format == 'NHWC':
|
||||
stride_spec = [1, stride, stride, 1]
|
||||
else:
|
||||
stride_spec = [1, 1, stride, stride]
|
||||
|
||||
# Skip path 1
|
||||
path1 = tf.nn.avg_pool(
|
||||
net, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
|
||||
path1 = slim.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
|
||||
|
||||
# Skip path 2
|
||||
# First pad with 0's on the right and bottom, then shift the filter to
|
||||
# include those 0's that were added.
|
||||
if data_format == 'NHWC':
|
||||
pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
|
||||
path2 = tf.pad(net, pad_arr)[:, 1:, 1:, :]
|
||||
concat_axis = 3
|
||||
else:
|
||||
pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
|
||||
path2 = tf.pad(net, pad_arr)[:, :, 1:, 1:]
|
||||
concat_axis = 1
|
||||
|
||||
path2 = tf.nn.avg_pool(
|
||||
path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
|
||||
path2 = slim.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
|
||||
|
||||
# Concat and apply BN
|
||||
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
|
||||
final_path = slim.batch_norm(final_path, scope='final_path_bn')
|
||||
return final_path
|
||||
|
||||
|
||||
@tf.contrib.framework.add_arg_scope
|
||||
def drop_path(net, keep_prob, is_training=True):
|
||||
"""Drops out a whole example hiddenstate with the specified probability."""
|
||||
if is_training:
|
||||
batch_size = tf.shape(net)[0]
|
||||
noise_shape = [batch_size, 1, 1, 1]
|
||||
random_tensor = keep_prob
|
||||
random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
|
||||
binary_tensor = tf.floor(random_tensor)
|
||||
net = tf.div(net, keep_prob) * binary_tensor
|
||||
return net
|
||||
|
||||
|
||||
def _operation_to_filter_shape(operation):
|
||||
splitted_operation = operation.split('x')
|
||||
filter_shape = int(splitted_operation[0][-1])
|
||||
assert filter_shape == int(
|
||||
splitted_operation[1][0]), 'Rectangular filters not supported.'
|
||||
return filter_shape
|
||||
|
||||
|
||||
def _operation_to_num_layers(operation):
|
||||
splitted_operation = operation.split('_')
|
||||
if 'x' in splitted_operation[-1]:
|
||||
return 1
|
||||
return int(splitted_operation[-1])
|
||||
|
||||
|
||||
def _operation_to_info(operation):
|
||||
"""Takes in operation name and returns meta information.
|
||||
|
||||
An example would be 'separable_3x3_4' -> (3, 4).
|
||||
|
||||
Args:
|
||||
operation: String that corresponds to convolution operation.
|
||||
|
||||
Returns:
|
||||
Tuple of (filter shape, num layers).
|
||||
"""
|
||||
num_layers = _operation_to_num_layers(operation)
|
||||
filter_shape = _operation_to_filter_shape(operation)
|
||||
return num_layers, filter_shape
|
||||
|
||||
|
||||
def _stacked_separable_conv(net, stride, operation, filter_size):
|
||||
"""Takes in an operations and parses it to the correct sep operation."""
|
||||
num_layers, kernel_size = _operation_to_info(operation)
|
||||
for layer_num in range(num_layers - 1):
|
||||
net = tf.nn.relu(net)
|
||||
net = slim.separable_conv2d(
|
||||
net,
|
||||
filter_size,
|
||||
kernel_size,
|
||||
depth_multiplier=1,
|
||||
scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
|
||||
stride=stride)
|
||||
net = slim.batch_norm(
|
||||
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
|
||||
stride = 1
|
||||
net = tf.nn.relu(net)
|
||||
net = slim.separable_conv2d(
|
||||
net,
|
||||
filter_size,
|
||||
kernel_size,
|
||||
depth_multiplier=1,
|
||||
scope='separable_{0}x{0}_{1}'.format(kernel_size, num_layers),
|
||||
stride=stride)
|
||||
net = slim.batch_norm(
|
||||
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, num_layers))
|
||||
return net
|
||||
|
||||
|
||||
def _operation_to_pooling_type(operation):
|
||||
"""Takes in the operation string and returns the pooling type."""
|
||||
splitted_operation = operation.split('_')
|
||||
return splitted_operation[0]
|
||||
|
||||
|
||||
def _operation_to_pooling_shape(operation):
|
||||
"""Takes in the operation string and returns the pooling kernel shape."""
|
||||
splitted_operation = operation.split('_')
|
||||
shape = splitted_operation[-1]
|
||||
assert 'x' in shape
|
||||
filter_height, filter_width = shape.split('x')
|
||||
assert filter_height == filter_width
|
||||
return int(filter_height)
|
||||
|
||||
|
||||
def _operation_to_pooling_info(operation):
|
||||
"""Parses the pooling operation string to return its type and shape."""
|
||||
pooling_type = _operation_to_pooling_type(operation)
|
||||
pooling_shape = _operation_to_pooling_shape(operation)
|
||||
return pooling_type, pooling_shape
|
||||
|
||||
|
||||
def _pooling(net, stride, operation):
|
||||
"""Parses operation and performs the correct pooling operation on net."""
|
||||
padding = 'SAME'
|
||||
pooling_type, pooling_shape = _operation_to_pooling_info(operation)
|
||||
if pooling_type == 'avg':
|
||||
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
|
||||
elif pooling_type == 'max':
|
||||
net = slim.max_pool2d(net, pooling_shape, stride=stride, padding=padding)
|
||||
else:
|
||||
raise NotImplementedError('Unimplemented pooling type: ', pooling_type)
|
||||
return net
|
||||
|
||||
|
||||
class NasNetABaseCell(object):
|
||||
"""NASNet Cell class that is used as a 'layer' in image architectures.
|
||||
|
||||
Args:
|
||||
num_conv_filters: The number of filters for each convolution operation.
|
||||
operations: List of operations that are performed in the NASNet Cell in
|
||||
order.
|
||||
used_hiddenstates: Binary array that signals if the hiddenstate was used
|
||||
within the cell. This is used to determine what outputs of the cell
|
||||
should be concatenated together.
|
||||
hiddenstate_indices: Determines what hiddenstates should be combined
|
||||
together with the specified operations to create the NASNet cell.
|
||||
"""
|
||||
|
||||
def __init__(self, num_conv_filters, operations, used_hiddenstates,
|
||||
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
|
||||
total_training_steps):
|
||||
self._num_conv_filters = num_conv_filters
|
||||
self._operations = operations
|
||||
self._used_hiddenstates = used_hiddenstates
|
||||
self._hiddenstate_indices = hiddenstate_indices
|
||||
self._drop_path_keep_prob = drop_path_keep_prob
|
||||
self._total_num_cells = total_num_cells
|
||||
self._total_training_steps = total_training_steps
|
||||
|
||||
def _reduce_prev_layer(self, prev_layer, curr_layer):
|
||||
"""Matches dimension of prev_layer to the curr_layer."""
|
||||
# Set the prev layer to the current layer if it is none
|
||||
if prev_layer is None:
|
||||
return curr_layer
|
||||
curr_num_filters = self._filter_size
|
||||
prev_num_filters = get_channel_dim(prev_layer.shape)
|
||||
curr_filter_shape = int(curr_layer.shape[2])
|
||||
prev_filter_shape = int(prev_layer.shape[2])
|
||||
if curr_filter_shape != prev_filter_shape:
|
||||
prev_layer = tf.nn.relu(prev_layer)
|
||||
prev_layer = factorized_reduction(
|
||||
prev_layer, curr_num_filters, stride=2)
|
||||
elif curr_num_filters != prev_num_filters:
|
||||
prev_layer = tf.nn.relu(prev_layer)
|
||||
prev_layer = slim.conv2d(
|
||||
prev_layer, curr_num_filters, 1, scope='prev_1x1')
|
||||
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
|
||||
return prev_layer
|
||||
|
||||
def _cell_base(self, net, prev_layer):
|
||||
"""Runs the beginning of the conv cell before the predicted ops are run."""
|
||||
num_filters = self._filter_size
|
||||
|
||||
# Check to be sure prev layer stuff is setup correctly
|
||||
prev_layer = self._reduce_prev_layer(prev_layer, net)
|
||||
|
||||
net = tf.nn.relu(net)
|
||||
net = slim.conv2d(net, num_filters, 1, scope='1x1')
|
||||
net = slim.batch_norm(net, scope='beginning_bn')
|
||||
split_axis = get_channel_index()
|
||||
net = tf.split(
|
||||
axis=split_axis, num_or_size_splits=1, value=net)
|
||||
for split in net:
|
||||
assert int(split.shape[split_axis] == int(self._num_conv_filters *
|
||||
self._filter_scaling))
|
||||
net.append(prev_layer)
|
||||
return net
|
||||
|
||||
def __call__(self, net, scope=None, filter_scaling=1, stride=1,
|
||||
prev_layer=None, cell_num=-1):
|
||||
"""Runs the conv cell."""
|
||||
self._cell_num = cell_num
|
||||
self._filter_scaling = filter_scaling
|
||||
self._filter_size = int(self._num_conv_filters * filter_scaling)
|
||||
|
||||
i = 0
|
||||
with tf.variable_scope(scope):
|
||||
net = self._cell_base(net, prev_layer)
|
||||
for iteration in range(5):
|
||||
with tf.variable_scope('comb_iter_{}'.format(iteration)):
|
||||
left_hiddenstate_idx, right_hiddenstate_idx = (
|
||||
self._hiddenstate_indices[i],
|
||||
self._hiddenstate_indices[i + 1])
|
||||
original_input_left = left_hiddenstate_idx < 2
|
||||
original_input_right = right_hiddenstate_idx < 2
|
||||
h1 = net[left_hiddenstate_idx]
|
||||
h2 = net[right_hiddenstate_idx]
|
||||
|
||||
operation_left = self._operations[i]
|
||||
operation_right = self._operations[i+1]
|
||||
i += 2
|
||||
# Apply conv operations
|
||||
with tf.variable_scope('left'):
|
||||
h1 = self._apply_conv_operation(h1, operation_left,
|
||||
stride, original_input_left)
|
||||
with tf.variable_scope('right'):
|
||||
h2 = self._apply_conv_operation(h2, operation_right,
|
||||
stride, original_input_right)
|
||||
|
||||
# Combine hidden states using 'add'.
|
||||
with tf.variable_scope('combine'):
|
||||
h = h1 + h2
|
||||
|
||||
# Add hiddenstate to the list of hiddenstates we can choose from
|
||||
net.append(h)
|
||||
|
||||
with tf.variable_scope('cell_output'):
|
||||
net = self._combine_unused_states(net)
|
||||
|
||||
return net
|
||||
|
||||
def _apply_conv_operation(self, net, operation,
|
||||
stride, is_from_original_input):
|
||||
"""Applies the predicted conv operation to net."""
|
||||
# Dont stride if this is not one of the original hiddenstates
|
||||
if stride > 1 and not is_from_original_input:
|
||||
stride = 1
|
||||
input_filters = get_channel_dim(net.shape)
|
||||
filter_size = self._filter_size
|
||||
if 'separable' in operation:
|
||||
net = _stacked_separable_conv(net, stride, operation, filter_size)
|
||||
elif operation in ['none']:
|
||||
# Check if a stride is needed, then use a strided 1x1 here
|
||||
if stride > 1 or (input_filters != filter_size):
|
||||
net = tf.nn.relu(net)
|
||||
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
|
||||
net = slim.batch_norm(net, scope='bn_1')
|
||||
elif 'pool' in operation:
|
||||
net = _pooling(net, stride, operation)
|
||||
if input_filters != filter_size:
|
||||
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
|
||||
net = slim.batch_norm(net, scope='bn_1')
|
||||
else:
|
||||
raise ValueError('Unimplemented operation', operation)
|
||||
|
||||
if operation != 'none':
|
||||
net = self._apply_drop_path(net)
|
||||
return net
|
||||
|
||||
def _combine_unused_states(self, net):
|
||||
"""Concatenate the unused hidden states of the cell."""
|
||||
used_hiddenstates = self._used_hiddenstates
|
||||
|
||||
final_height = int(net[-1].shape[2])
|
||||
final_num_filters = get_channel_dim(net[-1].shape)
|
||||
assert len(used_hiddenstates) == len(net)
|
||||
for idx, used_h in enumerate(used_hiddenstates):
|
||||
curr_height = int(net[idx].shape[2])
|
||||
curr_num_filters = get_channel_dim(net[idx].shape)
|
||||
|
||||
# Determine if a reduction should be applied to make the number of
|
||||
# filters match.
|
||||
should_reduce = final_num_filters != curr_num_filters
|
||||
should_reduce = (final_height != curr_height) or should_reduce
|
||||
should_reduce = should_reduce and not used_h
|
||||
if should_reduce:
|
||||
stride = 2 if final_height != curr_height else 1
|
||||
with tf.variable_scope('reduction_{}'.format(idx)):
|
||||
net[idx] = factorized_reduction(
|
||||
net[idx], final_num_filters, stride)
|
||||
|
||||
states_to_combine = (
|
||||
[h for h, is_used in zip(net, used_hiddenstates) if not is_used])
|
||||
|
||||
# Return the concat of all the states
|
||||
concat_axis = get_channel_index()
|
||||
net = tf.concat(values=states_to_combine, axis=concat_axis)
|
||||
return net
|
||||
|
||||
def _apply_drop_path(self, net):
|
||||
"""Apply drop_path regularization to net."""
|
||||
drop_path_keep_prob = self._drop_path_keep_prob
|
||||
if drop_path_keep_prob < 1.0:
|
||||
# Scale keep prob by layer number
|
||||
assert self._cell_num != -1
|
||||
# The added 2 is for the reduction cells
|
||||
num_cells = self._total_num_cells
|
||||
layer_ratio = (self._cell_num + 1)/float(num_cells)
|
||||
with tf.device('/cpu:0'):
|
||||
tf.summary.scalar('layer_ratio', layer_ratio)
|
||||
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
|
||||
# Decrease the keep probability over time
|
||||
current_step = tf.cast(tf.train.get_or_create_global_step(),
|
||||
tf.float32)
|
||||
drop_path_burn_in_steps = self._total_training_steps
|
||||
current_ratio = (
|
||||
current_step / drop_path_burn_in_steps)
|
||||
current_ratio = tf.minimum(1.0, current_ratio)
|
||||
with tf.device('/cpu:0'):
|
||||
tf.summary.scalar('current_ratio', current_ratio)
|
||||
drop_path_keep_prob = (
|
||||
1 - current_ratio * (1 - drop_path_keep_prob))
|
||||
with tf.device('/cpu:0'):
|
||||
tf.summary.scalar('drop_path_keep_prob', drop_path_keep_prob)
|
||||
net = drop_path(net, drop_path_keep_prob)
|
||||
return net
|
||||
|
||||
|
||||
class NasNetANormalCell(NasNetABaseCell):
|
||||
"""NASNetA Normal Cell."""
|
||||
|
||||
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
|
||||
total_training_steps):
|
||||
operations = ['separable_5x5_2',
|
||||
'separable_3x3_2',
|
||||
'separable_5x5_2',
|
||||
'separable_3x3_2',
|
||||
'avg_pool_3x3',
|
||||
'none',
|
||||
'avg_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'separable_3x3_2',
|
||||
'none']
|
||||
used_hiddenstates = [1, 0, 0, 0, 0, 0, 0]
|
||||
hiddenstate_indices = [0, 1, 1, 1, 0, 1, 1, 1, 0, 0]
|
||||
super(NasNetANormalCell, self).__init__(num_conv_filters, operations,
|
||||
used_hiddenstates,
|
||||
hiddenstate_indices,
|
||||
drop_path_keep_prob,
|
||||
total_num_cells,
|
||||
total_training_steps)
|
||||
|
||||
|
||||
class NasNetAReductionCell(NasNetABaseCell):
|
||||
"""NASNetA Reduction Cell."""
|
||||
|
||||
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
|
||||
total_training_steps):
|
||||
operations = ['separable_5x5_2',
|
||||
'separable_7x7_2',
|
||||
'max_pool_3x3',
|
||||
'separable_7x7_2',
|
||||
'avg_pool_3x3',
|
||||
'separable_5x5_2',
|
||||
'none',
|
||||
'avg_pool_3x3',
|
||||
'separable_3x3_2',
|
||||
'max_pool_3x3']
|
||||
used_hiddenstates = [1, 1, 1, 0, 0, 0, 0]
|
||||
hiddenstate_indices = [0, 1, 0, 1, 0, 1, 3, 2, 2, 0]
|
||||
super(NasNetAReductionCell, self).__init__(num_conv_filters, operations,
|
||||
used_hiddenstates,
|
||||
hiddenstate_indices,
|
||||
drop_path_keep_prob,
|
||||
total_num_cells,
|
||||
total_training_steps)
|
|
@ -383,12 +383,13 @@ def predict(model, labels, url):
|
|||
else:
|
||||
num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1]
|
||||
|
||||
no_bias = not IR_node.IR_layer.attr["use_bias"].b
|
||||
if not no_bias and self.weight_loaded:
|
||||
use_bias = IR_node.get_attr('use_bias', False)
|
||||
if use_bias and self.weight_loaded:
|
||||
self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']
|
||||
|
||||
if pattern == "DepthwiseConv":
|
||||
num_group = num_filter
|
||||
num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
|
||||
num_filter = num_filter * num_group
|
||||
pattern = "Convolution"
|
||||
else:
|
||||
num_group = IR_node.get_attr('group', 1)
|
||||
|
@ -404,6 +405,8 @@ def predict(model, labels, url):
|
|||
if self.weight_loaded:
|
||||
# if layout not in MXNetEmitter.channels_last:
|
||||
weights = MXNetEmitter.transpose(weights, dim)
|
||||
if num_group > 1:
|
||||
weights = np.swapaxes(weights, 0, 1)
|
||||
self.output_weights[IR_node.name + "_weight"] = weights
|
||||
|
||||
code = ""
|
||||
|
@ -418,7 +421,7 @@ def predict(model, labels, url):
|
|||
tuple(pad),
|
||||
num_filter,
|
||||
num_group,
|
||||
no_bias,
|
||||
not use_bias,
|
||||
layout,
|
||||
IR_node.name)
|
||||
else:
|
||||
|
@ -432,7 +435,7 @@ def predict(model, labels, url):
|
|||
dilate,
|
||||
num_filter,
|
||||
num_group,
|
||||
no_bias,
|
||||
not use_bias,
|
||||
layout,
|
||||
IR_node.name)
|
||||
|
||||
|
|
|
@ -440,23 +440,18 @@ class TensorflowParser(Parser):
|
|||
if self.weight_loaded:
|
||||
self.set_weight(source_node.name, 'weights', self.ckpt_data[W.name])
|
||||
|
||||
if source_node.out_edges:
|
||||
add_node = self.tf_graph.get_node(source_node.out_edges[0])
|
||||
if add_node.type == 'Add':
|
||||
add_node.covered = True
|
||||
add_node.real_name = source_node.real_name
|
||||
# FullyConnected Layer
|
||||
# name, op
|
||||
TensorflowParser._copy_and_reop(source_node, IR_node, 'FullyConnected')
|
||||
if source_node.out_edges and self.tf_graph.get_node(source_node.out_edges[0]).type == 'Add':
|
||||
add_node.covered = True
|
||||
add_node.real_name = source_node.real_name
|
||||
# FullyConnected Layer
|
||||
# name, op
|
||||
TensorflowParser._copy_and_reop(source_node, IR_node, 'FullyConnected')
|
||||
|
||||
# get Bias
|
||||
B = self.tf_graph.get_node(self.tf_graph.get_node(source_node.out_edges[0]).in_edges[1]).in_edges[0]
|
||||
if self.weight_loaded:
|
||||
self.set_weight(source_node.name, 'bias', self.ckpt_data[B])
|
||||
IR_node.attr['use_bias'].b = True
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet. Please submit a issue in github and provide your models for reproduce.")
|
||||
# get Bias
|
||||
B = self.tf_graph.get_node(self.tf_graph.get_node(source_node.out_edges[0]).in_edges[1]).in_edges[0]
|
||||
if self.weight_loaded:
|
||||
self.set_weight(source_node.name, 'bias', self.ckpt_data[B])
|
||||
IR_node.attr['use_bias'].b = True
|
||||
|
||||
else:
|
||||
# Matmul Layer
|
||||
|
|
|
@ -3,7 +3,6 @@ import sys
|
|||
import six
|
||||
import unittest
|
||||
import numpy as np
|
||||
from six.moves import reload_module
|
||||
import tensorflow as tf
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
||||
|
@ -383,7 +382,10 @@ class TestModels(CorrectnessTest):
|
|||
'resnet_v1_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v2_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v2_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
},
|
||||
'mobilenet_v1_1.0' : [TensorflowEmit, KerasEmit],
|
||||
# 'inception_resnet_v2' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], # TODO
|
||||
# 'nasnet-a_large' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -400,8 +402,8 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
IR_file = TestModels.tmpdir + original_framework + '_' + network_name + "_converted"
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
# print('Testing conversion {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True)
|
||||
print('Testing conversion {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]))
|
||||
# print('Testing {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True)
|
||||
print('Testing {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]))
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
|
|
Загрузка…
Ссылка в новой задаче