Added draft of TF training info
This commit is contained in:
Родитель
6048f01240
Коммит
29714fdf9d
18
LICENSE
18
LICENSE
|
@ -19,3 +19,21 @@
|
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
|
||||
==========================================================================
|
||||
|
||||
Copyright 2017 Microsoft Corporation. All Rights Reserved.
|
||||
Copyright 2016 The Tensorflow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==========================================================================
|
|
@ -706,7 +706,9 @@
|
|||
"<a name=\"tf\"></a>\n",
|
||||
"### TensorFlow\n",
|
||||
"\n",
|
||||
"We made use of the [`tf-slim` API](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) for Tensorflow, which provides pre-trained ResNet models and helpful scripts for retraining and scoring. Below, we convert our raw PNG images to the [TFRecords](https://www.tensorflow.org/how_tos/reading_data/#file_formats) files that those scripts expect as input. (Our evaluation images will be scored on Spark without conversion to TFRecord format.)"
|
||||
"We made use of the [`tf-slim` API](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) for Tensorflow, which provides pre-trained ResNet models and helpful scripts for retraining and scoring. Below, we convert our raw PNG images to the [TFRecords](https://www.tensorflow.org/how_tos/reading_data/#file_formats) files that those scripts expect as input. (Our evaluation images will be scored on Spark without conversion to TFRecord format.)\n",
|
||||
"\n",
|
||||
"The following code was modified from the [Tensorflow models repo's slim subdirectory](https://github.com/tensorflow/models/tree/master/slim). To run it, you will need to clone that repo and copy the slim folder to the local directory."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -717,33 +719,106 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Original Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n",
|
||||
"# Modified 2017 by Microsoft Corporation.\n",
|
||||
"#\n",
|
||||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"# ==============================================================================\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"import pandas as pd\n",
|
||||
"from slim.datasets import dataset_utils\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"image_dir = 'E:\\\\combined\\\\train_subsampled'\n",
|
||||
"filenames = []\n",
|
||||
"for folder in os.listdir(image_dir):\n",
|
||||
" folder_path = os.path.join(image_dir, folder)\n",
|
||||
" if not os.path.isdir(folder_path):\n",
|
||||
" continue\n",
|
||||
" filenames.extend([os.path.join(folder_path, i) for i in os.listdir(folder_path)])\n",
|
||||
"filenames = np.random.permutation(filenames)\n",
|
||||
"np.random.seed(5318)\n",
|
||||
"\n",
|
||||
"class ImageReader(object):\n",
|
||||
" def __init__(self):\n",
|
||||
" # Initializes function that decodes RGB JPEG data.\n",
|
||||
" self._decode_png_data = tf.placeholder(dtype=tf.string)\n",
|
||||
" self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)\n",
|
||||
"\n",
|
||||
" def read_image_dims(self, sess, image_data):\n",
|
||||
" image = self.decode_png(sess, image_data)\n",
|
||||
" return image.shape[0], image.shape[1]\n",
|
||||
"\n",
|
||||
" def decode_png(self, sess, image_data):\n",
|
||||
" image = sess.run(self._decode_png,\n",
|
||||
" feed_dict={self._decode_png_data: image_data})\n",
|
||||
" assert len(image.shape) == 3\n",
|
||||
" assert image.shape[2] == 3\n",
|
||||
" return image\n",
|
||||
"\n",
|
||||
"def find_and_split_images(image_dir, validation_fraction=0.2):\n",
|
||||
" class_names = []\n",
|
||||
" training_filenames = []\n",
|
||||
" validation_filenames = []\n",
|
||||
" \n",
|
||||
"n = int(np.ceil(len(filenames) / 100))\n",
|
||||
"with tf.Graph().as_default():\n",
|
||||
" image_reader = ImageReader()\n",
|
||||
" with tf.Session('') as sess:\n",
|
||||
" for i in range(100):\n",
|
||||
" tfrecord_filename = os.path.join(image_dir, 'aerial_train_{:03d}.tfrecord'.format(i))\n",
|
||||
" with tf.python_io.TFRecordWriter(tfrecord_filename) as tfrecord_writer:\n",
|
||||
" for j in range(n*i, min(n*(i+1), len(filenames))):\n",
|
||||
" image_data = tf.gfile.FastGFile(filenames[j], 'r').read()\n",
|
||||
" height, width = image_reader.read_image_dims(sess, image_data)\n",
|
||||
" label = int(os.path.basename(os.path.dirname(filenames[j])))\n",
|
||||
" example = dataset_utils.image_to_tfexample(image_data, b'png', height, width, label)\n",
|
||||
" tfrecord_writer.write(example.SerializeToString())"
|
||||
" for folder in os.listdir(image_dir):\n",
|
||||
" folder_path = os.path.join(image_dir, folder)\n",
|
||||
" if not os.path.isdir(folder_path):\n",
|
||||
" continue\n",
|
||||
" ''' This is a new directory/label -- consider all images inside it '''\n",
|
||||
" class_names.append(folder)\n",
|
||||
" my_filenames = []\n",
|
||||
" for filename in os.listdir(folder_path):\n",
|
||||
" my_filenames.append(os.path.join(folder_path, filename))\n",
|
||||
" my_filenames = np.random.permutation(my_filenames)\n",
|
||||
" n_validation = int(np.ceil(validation_fraction * len(my_filenames)))\n",
|
||||
" validation_filenames.extend(my_filenames[:n_validation])\n",
|
||||
" training_filenames.extend(my_filenames[n_validation:])\n",
|
||||
" print('Found {} training and {} validation images'.format(len(training_filenames),\n",
|
||||
" len(validation_filenames)))\n",
|
||||
" return(sorted(class_names), training_filenames, validation_filenames)\n",
|
||||
" \n",
|
||||
"def write_dataset(dataset_name, split_name, my_filenames, class_names_to_ids, image_dir, n_shards=5):\n",
|
||||
" num_per_shard = int(np.ceil(len(my_filenames) / n_shards))\n",
|
||||
" records = []\n",
|
||||
" with tf.Graph().as_default():\n",
|
||||
" image_reader = ImageReader()\n",
|
||||
" with tf.Session('') as sess:\n",
|
||||
" for shard_idx in range(n_shards):\n",
|
||||
" shard_filename = os.path.join(image_dir,\n",
|
||||
" '{}_{}_{:05d}-of-{:05d}.tfrecord'.format(dataset_name,\n",
|
||||
" split_name,\n",
|
||||
" shard_idx,\n",
|
||||
" n_shards))\n",
|
||||
" with tf.python_io.TFRecordWriter(shard_filename) as tfrecord_writer:\n",
|
||||
" for image_idx in range(num_per_shard * shard_idx,\n",
|
||||
" min(num_per_shard * (shard_idx+1), len(my_filenames))):\n",
|
||||
" #print('>> Converting image {}/{} shard {}'.format(image_idx+1, len(my_filenames), shard_idx))\n",
|
||||
" image_data = tf.gfile.FastGFile(my_filenames[image_idx], 'r').read()\n",
|
||||
" height, width = image_reader.read_image_dims(sess, image_data)\n",
|
||||
" class_name = os.path.basename(os.path.dirname(my_filenames[image_idx]))\n",
|
||||
" class_id = class_names_to_ids[class_name]\n",
|
||||
" example = dataset_utils.image_to_tfexample(image_data, b'png', height, width, class_id)\n",
|
||||
" tfrecord_writer.write(example.SerializeToString())\n",
|
||||
" records.append([dataset_name, split_name, my_filenames[image_idx], shard_idx,\n",
|
||||
" image_idx, class_name, class_id])\n",
|
||||
" df = pd.DataFrame(records, columns=['dataset_name', 'split_name', 'filename', 'shard_idx', 'image_idx',\n",
|
||||
" 'class_name', 'class_id'])\n",
|
||||
" return(df)\n",
|
||||
" \n",
|
||||
"image_dir = 'E:\\\\combined\\\\train_subsample'\n",
|
||||
"class_names, training_filenames, validation_filenames = find_and_split_images(image_dir, 0.0)\n",
|
||||
"class_names_to_ids = dict(zip(class_names, list(range(len(class_names)))))\n",
|
||||
"df = write_dataset('aerial', 'train', training_filenames, class_names_to_ids, image_dir, n_shards=50)\n",
|
||||
"df.to_csv(os.path.join(image_dir, 'dataset_split_info.csv'), index=False)\n",
|
||||
"\n",
|
||||
"with open(os.path.join(image_dir, 'labels.txt')) as f:\n",
|
||||
" for i in range(len(class_names)):\n",
|
||||
" f.write('{0}:{0}\\n'.format(i))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Training CNTK and TensorFlow models for image classification\n",
|
||||
"\n",
|
||||
"## Outline\n",
|
||||
"- [Provision an Azure N-Series GPU Deep Learning VM](#provision)\n",
|
||||
"- [Partition the image set for training and evaluation](#partition)\n",
|
||||
"- [Microsoft Cognitive Toolkit](#cntk)\n",
|
||||
" - [Training an 18-layer ResNet model](#cntktrain)\n",
|
||||
"- [TensorFlow](#tensorflow)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"anaconda-cloud": {},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:python35]",
|
||||
"language": "python",
|
||||
"name": "conda-env-python35-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Training CNTK and TensorFlow models for image classification\n",
|
||||
"\n",
|
||||
"## Outline\n",
|
||||
"- [Provision an Azure N-Series GPU Deep Learning VM](#provision)\n",
|
||||
"- [Microsoft Cognitive Toolkit](#cntk)\n",
|
||||
"- [TensorFlow](#tensorflow)\n",
|
||||
" - [Training script](#tfscript)\n",
|
||||
" - [Model](#tfmodel)\n",
|
||||
" - [Running the training script](#tfrun)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a name=\"provision\"></a>\n",
|
||||
"## Provision an Azure N-Series GPU Deep Learning VM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Deploy a \"Deep Learning toolkit for the DSVM\" resource in a region that offers GPU VMs, such as East US. As of this writing (1/19), the DSVM deploys with CNTK 2.0.\n",
|
||||
"\n",
|
||||
"### Connecting to the VM by remote desktop\n",
|
||||
"\n",
|
||||
"To use remote desktop, click \"Connect\" on the VM's main pane to download an RDP file. When accessing, make sure that you specify the \"domain\" (VM name) as well as your username, e.g. \"mawahgpudsvm\\mawah\", so that the connection doesn't attempt to use your Microsoft domain.\n",
|
||||
"\n",
|
||||
"### Clone/download the contents of this repo\n",
|
||||
"\n",
|
||||
"Download the contents of this repo and copy the contents of the `tf` and `cntk` subfolders to appropriate locations. We have used locations on the temporary drive, e.g. `D:\\tf` and `D:\\cntk`.\n",
|
||||
"\n",
|
||||
"### Downloading the training and evaluation set locally\n",
|
||||
"\n",
|
||||
"During image set preparation, a training image set and descriptive files were created for use with CNTK and TensorFlow. Transfer these files to the GPU VM and store in an appropriate location. (We have used the `D:\\combined\\train_subsample` folder.) If you did not generate a larger training set earlier, you can use the small training set included in this git repo. You may need to regenerate the CNTK map file if image paths have been changed.\n",
|
||||
"\n",
|
||||
"### (Optional) Access the VM remotely via Jupyter Notebook\n",
|
||||
"\n",
|
||||
"Follow these steps if you wish to be able to access the notebook server remotely:\n",
|
||||
"1. In the [Azure Portal](https://portal.azure.com), navigate to the deployed VM's pane and determine its IP address.\n",
|
||||
"1. In the [Azure Portal](https://portal.azure.com), navigate to the deployed VM's Network Security Group's pane and add inbound/outbound rules permitting traffic on port 9999.\n",
|
||||
"1. While connected to the VM via remote desktop, launch a command prompt (Windows key + R) and type the following commands:\n",
|
||||
"\n",
|
||||
" ```\n",
|
||||
" cd C:\\dsvm\\tools\\setup\n",
|
||||
" JupyterSetPasswordAndStart.cmd\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
" Follow the prompts to set your remote access password.\n",
|
||||
" \n",
|
||||
"1. Connect to your VM remotely via Jupyter Notebooks using the IP address you determined earlier and port 9999, e.g. `https://[__.__.__.__]:9999`. The default directory on login will be `C:\\dsvm\\notebooks`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a name=\"tensorflow\"></a>\n",
|
||||
"## Tensorflow\n",
|
||||
"\n",
|
||||
"<a name=\"tfscript\"></a>\n",
|
||||
"### Training script\n",
|
||||
"\n",
|
||||
"We made use of the [`tf-slim` API](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) for Tensorflow, which provides pre-trained ResNet models and helpful scripts for retraining and scoring. During training set preparation, we converted raw PNG images to the [TFRecords](https://www.tensorflow.org/how_tos/reading_data/#file_formats) files that those scripts expect as input. (Our evaluation set images will be scored on Spark without conversion to TFRecord format.)\n",
|
||||
"\n",
|
||||
"Our training script is a modified version of `train_image_classifier.py` from the [Tensorflow models repo's slim subdirectory](https://github.com/tensorflow/models/tree/master/slim). Changes have also been made to some of that script's dependencies. We recommend that you clone this repo and transfer the `tf` subfolder, including dependencies, to a suitable location, e.g."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_dir = 'D:\\\\tf'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a name=\"tfmodel\"></a>\n",
|
||||
"### Model\n",
|
||||
"\n",
|
||||
"We will retrain the logits of a 152-layer ResNet pretrained on ImageNet. This model is highlighted in the [Tensorflow models repo's slim subdirectory](https://github.com/tensorflow/models/tree/master/slim). The pretrained model can be obtained and unpacked with the code snippet below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import tarfile\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"urllib.request.urlretrieve('http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz',\n",
|
||||
" os.path.join(repo_dir, 'resnet_v1_152_2016_08_28.tar.gz'))\n",
|
||||
"with tarfile.open(os.path.join(repo_dir, 'resnet_v1_152_2016_08_28.tar.gz'), 'r:gz') as f:\n",
|
||||
" f.extractall(path=repo_dir)\n",
|
||||
"os.remove(os.path.join(repo_dir, 'resnet_v1_152_2016_08_28.tar.gz'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a name=\"tfrun\"></a>\n",
|
||||
"### Running the training script\n",
|
||||
"\n",
|
||||
"We recommend that you run the training script from an Anaconda prompt. The code cell below will help you generate the appropriate command based on your file locations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# repo_dir was defined above\n",
|
||||
"\n",
|
||||
"# path where retrained model and logs will be saved during training\n",
|
||||
"train_dir = os.path.join(repo_dir, 'models')\n",
|
||||
"if not os.path.exists(train_dir):\n",
|
||||
" os.makedirs(train_dir)\n",
|
||||
" \n",
|
||||
"# location of the unpacked pretrained model\n",
|
||||
"checkpoint_path = os.path.join(repo_dir, 'resnet_v1_152.ckpt')\n",
|
||||
"\n",
|
||||
"# Location of the TFRecords and other files generated during image set preparation\n",
|
||||
"image_dir = 'D:\\\\combined\\\\train_subsample'\n",
|
||||
"\n",
|
||||
"command = '''activate py35\n",
|
||||
"python {0} --train_dir={1} --dataset_name=aerial --dataset_split_name=train --dataset_dir={2} --checkpoint_path={3}\n",
|
||||
"'''.format(os.path.join(repo_dir, 'retrain.py'),\n",
|
||||
" train_dir,\n",
|
||||
" dataset_dir,\n",
|
||||
" checkpoint_path)\n",
|
||||
"\n",
|
||||
"print(command)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"anaconda-cloud": {},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:python35]",
|
||||
"language": "python",
|
||||
"name": "conda-env-python35-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,678 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Deploy Slim models across multiple clones and replicas.
|
||||
|
||||
# TODO(sguada) docstring paragraph by (a) motivating the need for the file and
|
||||
# (b) defining clones.
|
||||
|
||||
# TODO(sguada) describe the high-level components of model deployment.
|
||||
# E.g. "each model deployment is composed of several parts: a DeploymentConfig,
|
||||
# which captures A, B and C, an input_fn which loads data.. etc
|
||||
|
||||
To easily train a model on multiple GPUs or across multiple machines this
|
||||
module provides a set of helper functions: `create_clones`,
|
||||
`optimize_clones` and `deploy`.
|
||||
|
||||
Usage:
|
||||
|
||||
g = tf.Graph()
|
||||
|
||||
# Set up DeploymentConfig
|
||||
config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
|
||||
|
||||
# Create the global step on the device storing the variables.
|
||||
with tf.device(config.variables_device()):
|
||||
global_step = slim.create_global_step()
|
||||
|
||||
# Define the inputs
|
||||
with tf.device(config.inputs_device()):
|
||||
images, labels = LoadData(...)
|
||||
inputs_queue = slim.data.prefetch_queue((images, labels))
|
||||
|
||||
# Define the optimizer.
|
||||
with tf.device(config.optimizer_device()):
|
||||
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
|
||||
|
||||
# Define the model including the loss.
|
||||
def model_fn(inputs_queue):
|
||||
images, labels = inputs_queue.dequeue()
|
||||
predictions = CreateNetwork(images)
|
||||
slim.losses.log_loss(predictions, labels)
|
||||
|
||||
model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
|
||||
optimizer=optimizer)
|
||||
|
||||
# Run training.
|
||||
slim.learning.train(model_dp.train_op, my_log_dir,
|
||||
summary_op=model_dp.summary_op)
|
||||
|
||||
The Clone namedtuple holds together the values associated with each call to
|
||||
model_fn:
|
||||
* outputs: The return values of the calls to `model_fn()`.
|
||||
* scope: The scope used to create the clone.
|
||||
* device: The device used to create the clone.
|
||||
|
||||
DeployedModel namedtuple, holds together the values needed to train multiple
|
||||
clones:
|
||||
* train_op: An operation that run the optimizer training op and include
|
||||
all the update ops created by `model_fn`. Present only if an optimizer
|
||||
was specified.
|
||||
* summary_op: An operation that run the summaries created by `model_fn`
|
||||
and process_gradients.
|
||||
* total_loss: A `Tensor` that contains the sum of all losses created by
|
||||
`model_fn` plus the regularization losses.
|
||||
* clones: List of `Clone` tuples returned by `create_clones()`.
|
||||
|
||||
DeploymentConfig parameters:
|
||||
* num_clones: Number of model clones to deploy in each replica.
|
||||
* clone_on_cpu: True if clones should be placed on CPU.
|
||||
* replica_id: Integer. Index of the replica for which the model is
|
||||
deployed. Usually 0 for the chief replica.
|
||||
* num_replicas: Number of replicas to use.
|
||||
* num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
|
||||
* worker_job_name: A name for the worker job.
|
||||
* ps_job_name: A name for the parameter server job.
|
||||
|
||||
TODO(sguada):
|
||||
- describe side effect to the graph.
|
||||
- what happens to summaries and update_ops.
|
||||
- which graph collections are altered.
|
||||
- write a tutorial on how to use this.
|
||||
- analyze the possibility of calling deploy more than once.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
__all__ = ['create_clones',
|
||||
'deploy',
|
||||
'optimize_clones',
|
||||
'DeployedModel',
|
||||
'DeploymentConfig',
|
||||
'Clone',
|
||||
]
|
||||
|
||||
|
||||
# Namedtuple used to represent a clone during deployment.
|
||||
Clone = collections.namedtuple('Clone',
|
||||
['outputs', # Whatever model_fn() returned.
|
||||
'scope', # The scope used to create it.
|
||||
'device', # The device used to create.
|
||||
])
|
||||
|
||||
# Namedtuple used to represent a DeployedModel, returned by deploy().
|
||||
DeployedModel = collections.namedtuple('DeployedModel',
|
||||
['train_op', # The `train_op`
|
||||
'summary_op', # The `summary_op`
|
||||
'total_loss', # The loss `Tensor`
|
||||
'clones', # A list of `Clones` tuples.
|
||||
])
|
||||
|
||||
# Default parameters for DeploymentConfig
|
||||
_deployment_params = {'num_clones': 1,
|
||||
'clone_on_cpu': False,
|
||||
'replica_id': 0,
|
||||
'num_replicas': 1,
|
||||
'num_ps_tasks': 0,
|
||||
'worker_job_name': 'worker',
|
||||
'ps_job_name': 'ps'}
|
||||
|
||||
|
||||
def create_clones(config, model_fn, args=None, kwargs=None):
|
||||
"""Creates multiple clones according to config using a `model_fn`.
|
||||
|
||||
The returned values of `model_fn(*args, **kwargs)` are collected along with
|
||||
the scope and device used to created it in a namedtuple
|
||||
`Clone(outputs, scope, device)`
|
||||
|
||||
Note: it is assumed that any loss created by `model_fn` is collected at
|
||||
the tf.GraphKeys.LOSSES collection.
|
||||
|
||||
To recover the losses, summaries or update_ops created by the clone use:
|
||||
```python
|
||||
losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
|
||||
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
|
||||
```
|
||||
|
||||
The deployment options are specified by the config object and support
|
||||
deploying one or several clones on different GPUs and one or several replicas
|
||||
of such clones.
|
||||
|
||||
The argument `model_fn` is called `config.num_clones` times to create the
|
||||
model clones as `model_fn(*args, **kwargs)`.
|
||||
|
||||
If `config` specifies deployment on multiple replicas then the default
|
||||
tensorflow device is set appropriatly for each call to `model_fn` and for the
|
||||
slim variable creation functions: model and global variables will be created
|
||||
on the `ps` device, the clone operations will be on the `worker` device.
|
||||
|
||||
Args:
|
||||
config: A DeploymentConfig object.
|
||||
model_fn: A callable. Called as `model_fn(*args, **kwargs)`
|
||||
args: Optional list of arguments to pass to `model_fn`.
|
||||
kwargs: Optional list of keyword arguments to pass to `model_fn`.
|
||||
|
||||
Returns:
|
||||
A list of namedtuples `Clone`.
|
||||
"""
|
||||
clones = []
|
||||
args = args or []
|
||||
kwargs = kwargs or {}
|
||||
with slim.arg_scope([slim.model_variable, slim.variable],
|
||||
device=config.variables_device()):
|
||||
# Create clones.
|
||||
for i in range(0, config.num_clones):
|
||||
with tf.name_scope(config.clone_scope(i)) as clone_scope:
|
||||
clone_device = config.clone_device(i)
|
||||
with tf.device(clone_device):
|
||||
with tf.variable_scope(tf.get_variable_scope(),
|
||||
reuse=True if i > 0 else None):
|
||||
outputs = model_fn(*args, **kwargs)
|
||||
clones.append(Clone(outputs, clone_scope, clone_device))
|
||||
return clones
|
||||
|
||||
|
||||
def _gather_clone_loss(clone, num_clones, regularization_losses):
|
||||
"""Gather the loss for a single clone.
|
||||
|
||||
Args:
|
||||
clone: A Clone namedtuple.
|
||||
num_clones: The number of clones being deployed.
|
||||
regularization_losses: Possibly empty list of regularization_losses
|
||||
to add to the clone losses.
|
||||
|
||||
Returns:
|
||||
A tensor for the total loss for the clone. Can be None.
|
||||
"""
|
||||
# The return value.
|
||||
sum_loss = None
|
||||
# Individual components of the loss that will need summaries.
|
||||
clone_loss = None
|
||||
regularization_loss = None
|
||||
# Compute and aggregate losses on the clone device.
|
||||
with tf.device(clone.device):
|
||||
all_losses = []
|
||||
clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
|
||||
if clone_losses:
|
||||
clone_loss = tf.add_n(clone_losses, name='clone_loss')
|
||||
if num_clones > 1:
|
||||
clone_loss = tf.div(clone_loss, 1.0 * num_clones,
|
||||
name='scaled_clone_loss')
|
||||
all_losses.append(clone_loss)
|
||||
if regularization_losses:
|
||||
regularization_loss = tf.add_n(regularization_losses,
|
||||
name='regularization_loss')
|
||||
all_losses.append(regularization_loss)
|
||||
if all_losses:
|
||||
sum_loss = tf.add_n(all_losses)
|
||||
# Add the summaries out of the clone device block.
|
||||
if clone_loss is not None:
|
||||
tf.summary.scalar(clone.scope + '/clone_loss', clone_loss)
|
||||
if regularization_loss is not None:
|
||||
tf.summary.scalar('regularization_loss', regularization_loss)
|
||||
return sum_loss
|
||||
|
||||
|
||||
def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
|
||||
**kwargs):
|
||||
"""Compute losses and gradients for a single clone.
|
||||
|
||||
Args:
|
||||
optimizer: A tf.Optimizer object.
|
||||
clone: A Clone namedtuple.
|
||||
num_clones: The number of clones being deployed.
|
||||
regularization_losses: Possibly empty list of regularization_losses
|
||||
to add to the clone losses.
|
||||
**kwargs: Dict of kwarg to pass to compute_gradients().
|
||||
|
||||
Returns:
|
||||
A tuple (clone_loss, clone_grads_and_vars).
|
||||
- clone_loss: A tensor for the total loss for the clone. Can be None.
|
||||
- clone_grads_and_vars: List of (gradient, variable) for the clone.
|
||||
Can be empty.
|
||||
"""
|
||||
sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
|
||||
clone_grad = None
|
||||
if sum_loss is not None:
|
||||
with tf.device(clone.device):
|
||||
clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
|
||||
return sum_loss, clone_grad
|
||||
|
||||
|
||||
def optimize_clones(clones, optimizer,
|
||||
regularization_losses=None,
|
||||
**kwargs):
|
||||
"""Compute clone losses and gradients for the given list of `Clones`.
|
||||
|
||||
Note: The regularization_losses are added to the first clone losses.
|
||||
|
||||
Args:
|
||||
clones: List of `Clones` created by `create_clones()`.
|
||||
optimizer: An `Optimizer` object.
|
||||
regularization_losses: Optional list of regularization losses. If None it
|
||||
will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
|
||||
exclude them.
|
||||
**kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
|
||||
|
||||
Returns:
|
||||
A tuple (total_loss, grads_and_vars).
|
||||
- total_loss: A Tensor containing the average of the clone losses including
|
||||
the regularization loss.
|
||||
- grads_and_vars: A List of tuples (gradient, variable) containing the sum
|
||||
of the gradients for each variable.
|
||||
|
||||
"""
|
||||
grads_and_vars = []
|
||||
clones_losses = []
|
||||
num_clones = len(clones)
|
||||
if regularization_losses is None:
|
||||
regularization_losses = tf.get_collection(
|
||||
tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
for clone in clones:
|
||||
with tf.name_scope(clone.scope):
|
||||
clone_loss, clone_grad = _optimize_clone(
|
||||
optimizer, clone, num_clones, regularization_losses, **kwargs)
|
||||
if clone_loss is not None:
|
||||
clones_losses.append(clone_loss)
|
||||
grads_and_vars.append(clone_grad)
|
||||
# Only use regularization_losses for the first clone
|
||||
regularization_losses = None
|
||||
# Compute the total_loss summing all the clones_losses.
|
||||
total_loss = tf.add_n(clones_losses, name='total_loss')
|
||||
# Sum the gradients accross clones.
|
||||
grads_and_vars = _sum_clones_gradients(grads_and_vars)
|
||||
return total_loss, grads_and_vars
|
||||
|
||||
|
||||
def deploy(config,
|
||||
model_fn,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
optimizer=None,
|
||||
summarize_gradients=False):
|
||||
"""Deploys a Slim-constructed model across multiple clones.
|
||||
|
||||
The deployment options are specified by the config object and support
|
||||
deploying one or several clones on different GPUs and one or several replicas
|
||||
of such clones.
|
||||
|
||||
The argument `model_fn` is called `config.num_clones` times to create the
|
||||
model clones as `model_fn(*args, **kwargs)`.
|
||||
|
||||
The optional argument `optimizer` is an `Optimizer` object. If not `None`,
|
||||
the deployed model is configured for training with that optimizer.
|
||||
|
||||
If `config` specifies deployment on multiple replicas then the default
|
||||
tensorflow device is set appropriatly for each call to `model_fn` and for the
|
||||
slim variable creation functions: model and global variables will be created
|
||||
on the `ps` device, the clone operations will be on the `worker` device.
|
||||
|
||||
Args:
|
||||
config: A `DeploymentConfig` object.
|
||||
model_fn: A callable. Called as `model_fn(*args, **kwargs)`
|
||||
args: Optional list of arguments to pass to `model_fn`.
|
||||
kwargs: Optional list of keyword arguments to pass to `model_fn`.
|
||||
optimizer: Optional `Optimizer` object. If passed the model is deployed
|
||||
for training with that optimizer.
|
||||
summarize_gradients: Whether or not add summaries to the gradients.
|
||||
|
||||
Returns:
|
||||
A `DeployedModel` namedtuple.
|
||||
|
||||
"""
|
||||
# Gather initial summaries.
|
||||
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||
|
||||
# Create Clones.
|
||||
clones = create_clones(config, model_fn, args, kwargs)
|
||||
first_clone = clones[0]
|
||||
|
||||
# Gather update_ops from the first clone. These contain, for example,
|
||||
# the updates for the batch_norm variables created by model_fn.
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope)
|
||||
|
||||
train_op = None
|
||||
total_loss = None
|
||||
with tf.device(config.optimizer_device()):
|
||||
if optimizer:
|
||||
# Place the global step on the device storing the variables.
|
||||
with tf.device(config.variables_device()):
|
||||
global_step = slim.get_or_create_global_step()
|
||||
|
||||
# Compute the gradients for the clones.
|
||||
total_loss, clones_gradients = optimize_clones(clones, optimizer)
|
||||
|
||||
if clones_gradients:
|
||||
if summarize_gradients:
|
||||
# Add summaries to the gradients.
|
||||
summaries |= set(_add_gradients_summaries(clones_gradients))
|
||||
|
||||
# Create gradient updates.
|
||||
grad_updates = optimizer.apply_gradients(clones_gradients,
|
||||
global_step=global_step)
|
||||
update_ops.append(grad_updates)
|
||||
|
||||
update_op = tf.group(*update_ops)
|
||||
train_op = control_flow_ops.with_dependencies([update_op], total_loss,
|
||||
name='train_op')
|
||||
else:
|
||||
clones_losses = []
|
||||
regularization_losses = tf.get_collection(
|
||||
tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
for clone in clones:
|
||||
with tf.name_scope(clone.scope):
|
||||
clone_loss = _gather_clone_loss(clone, len(clones),
|
||||
regularization_losses)
|
||||
if clone_loss is not None:
|
||||
clones_losses.append(clone_loss)
|
||||
# Only use regularization_losses for the first clone
|
||||
regularization_losses = None
|
||||
if clones_losses:
|
||||
total_loss = tf.add_n(clones_losses, name='total_loss')
|
||||
|
||||
# Add the summaries from the first clone. These contain the summaries
|
||||
# created by model_fn and either optimize_clones() or _gather_clone_loss().
|
||||
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
first_clone.scope))
|
||||
|
||||
if total_loss is not None:
|
||||
# Add total_loss to summary.
|
||||
summaries.add(tf.summary.scalar('total_loss', total_loss))
|
||||
|
||||
if summaries:
|
||||
# Merge all summaries together.
|
||||
summary_op = tf.summary.merge(list(summaries), name='summary_op')
|
||||
else:
|
||||
summary_op = None
|
||||
|
||||
return DeployedModel(train_op, summary_op, total_loss, clones)
|
||||
|
||||
|
||||
def _sum_clones_gradients(clone_grads):
|
||||
"""Calculate the sum gradient for each shared variable across all clones.
|
||||
|
||||
This function assumes that the clone_grads has been scaled appropriately by
|
||||
1 / num_clones.
|
||||
|
||||
Args:
|
||||
clone_grads: A List of List of tuples (gradient, variable), one list per
|
||||
`Clone`.
|
||||
|
||||
Returns:
|
||||
List of tuples of (gradient, variable) where the gradient has been summed
|
||||
across all clones.
|
||||
"""
|
||||
sum_grads = []
|
||||
for grad_and_vars in zip(*clone_grads):
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN))
|
||||
grads = []
|
||||
var = grad_and_vars[0][1]
|
||||
for g, v in grad_and_vars:
|
||||
assert v == var
|
||||
if g is not None:
|
||||
grads.append(g)
|
||||
if grads:
|
||||
if len(grads) > 1:
|
||||
sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads')
|
||||
else:
|
||||
sum_grad = grads[0]
|
||||
sum_grads.append((sum_grad, var))
|
||||
return sum_grads
|
||||
|
||||
|
||||
def _add_gradients_summaries(grads_and_vars):
|
||||
"""Add histogram summaries to gradients.
|
||||
|
||||
Note: The summaries are also added to the SUMMARIES collection.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
|
||||
Returns:
|
||||
The _list_ of the added summaries for grads_and_vars.
|
||||
"""
|
||||
summaries = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is not None:
|
||||
if isinstance(grad, tf.IndexedSlices):
|
||||
grad_values = grad.values
|
||||
else:
|
||||
grad_values = grad
|
||||
summaries.append(tf.histogram_summary(var.op.name + ':gradient',
|
||||
grad_values))
|
||||
summaries.append(tf.histogram_summary(var.op.name + ':gradient_norm',
|
||||
tf.global_norm([grad_values])))
|
||||
else:
|
||||
tf.logging.info('Var %s has no gradient', var.op.name)
|
||||
return summaries
|
||||
|
||||
|
||||
class DeploymentConfig(object):
|
||||
"""Configuration for deploying a model with `deploy()`.
|
||||
|
||||
You can pass an instance of this class to `deploy()` to specify exactly
|
||||
how to deploy the model to build. If you do not pass one, an instance built
|
||||
from the default deployment_hparams will be used.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_clones=1,
|
||||
clone_on_cpu=False,
|
||||
replica_id=0,
|
||||
num_replicas=1,
|
||||
num_ps_tasks=0,
|
||||
worker_job_name='worker',
|
||||
ps_job_name='ps'):
|
||||
"""Create a DeploymentConfig.
|
||||
|
||||
The config describes how to deploy a model across multiple clones and
|
||||
replicas. The model will be replicated `num_clones` times in each replica.
|
||||
If `clone_on_cpu` is True, each clone will placed on CPU.
|
||||
|
||||
If `num_replicas` is 1, the model is deployed via a single process. In that
|
||||
case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored.
|
||||
|
||||
If `num_replicas` is greater than 1, then `worker_device` and `ps_device`
|
||||
must specify TensorFlow devices for the `worker` and `ps` jobs and
|
||||
`num_ps_tasks` must be positive.
|
||||
|
||||
Args:
|
||||
num_clones: Number of model clones to deploy in each replica.
|
||||
clone_on_cpu: If True clones would be placed on CPU.
|
||||
replica_id: Integer. Index of the replica for which the model is
|
||||
deployed. Usually 0 for the chief replica.
|
||||
num_replicas: Number of replicas to use.
|
||||
num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
|
||||
worker_job_name: A name for the worker job.
|
||||
ps_job_name: A name for the parameter server job.
|
||||
|
||||
Raises:
|
||||
ValueError: If the arguments are invalid.
|
||||
"""
|
||||
if num_replicas > 1:
|
||||
if num_ps_tasks < 1:
|
||||
raise ValueError('When using replicas num_ps_tasks must be positive')
|
||||
if num_replicas > 1 or num_ps_tasks > 0:
|
||||
if not worker_job_name:
|
||||
raise ValueError('Must specify worker_job_name when using replicas')
|
||||
if not ps_job_name:
|
||||
raise ValueError('Must specify ps_job_name when using parameter server')
|
||||
if replica_id >= num_replicas:
|
||||
raise ValueError('replica_id must be less than num_replicas')
|
||||
self._num_clones = num_clones
|
||||
self._clone_on_cpu = clone_on_cpu
|
||||
self._replica_id = replica_id
|
||||
self._num_replicas = num_replicas
|
||||
self._num_ps_tasks = num_ps_tasks
|
||||
self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else ''
|
||||
self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else ''
|
||||
|
||||
@property
|
||||
def num_clones(self):
|
||||
return self._num_clones
|
||||
|
||||
@property
|
||||
def clone_on_cpu(self):
|
||||
return self._clone_on_cpu
|
||||
|
||||
@property
|
||||
def replica_id(self):
|
||||
return self._replica_id
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
return self._num_replicas
|
||||
|
||||
@property
|
||||
def num_ps_tasks(self):
|
||||
return self._num_ps_tasks
|
||||
|
||||
@property
|
||||
def ps_device(self):
|
||||
return self._ps_device
|
||||
|
||||
@property
|
||||
def worker_device(self):
|
||||
return self._worker_device
|
||||
|
||||
def caching_device(self):
|
||||
"""Returns the device to use for caching variables.
|
||||
|
||||
Variables are cached on the worker CPU when using replicas.
|
||||
|
||||
Returns:
|
||||
A device string or None if the variables do not need to be cached.
|
||||
"""
|
||||
if self._num_ps_tasks > 0:
|
||||
return lambda op: op.device
|
||||
else:
|
||||
return None
|
||||
|
||||
def clone_device(self, clone_index):
|
||||
"""Device used to create the clone and all the ops inside the clone.
|
||||
|
||||
Args:
|
||||
clone_index: Int, representing the clone_index.
|
||||
|
||||
Returns:
|
||||
A value suitable for `tf.device()`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `clone_index` is greater or equal to the number of clones".
|
||||
"""
|
||||
if clone_index >= self._num_clones:
|
||||
raise ValueError('clone_index must be less than num_clones')
|
||||
device = ''
|
||||
if self._num_ps_tasks > 0:
|
||||
device += self._worker_device
|
||||
if self._clone_on_cpu:
|
||||
device += '/device:CPU:0'
|
||||
else:
|
||||
if self._num_clones > 1:
|
||||
device += '/device:GPU:%d' % clone_index
|
||||
return device
|
||||
|
||||
def clone_scope(self, clone_index):
|
||||
"""Name scope to create the clone.
|
||||
|
||||
Args:
|
||||
clone_index: Int, representing the clone_index.
|
||||
|
||||
Returns:
|
||||
A name_scope suitable for `tf.name_scope()`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `clone_index` is greater or equal to the number of clones".
|
||||
"""
|
||||
if clone_index >= self._num_clones:
|
||||
raise ValueError('clone_index must be less than num_clones')
|
||||
scope = ''
|
||||
if self._num_clones > 1:
|
||||
scope = 'clone_%d' % clone_index
|
||||
return scope
|
||||
|
||||
def optimizer_device(self):
|
||||
"""Device to use with the optimizer.
|
||||
|
||||
Returns:
|
||||
A value suitable for `tf.device()`.
|
||||
"""
|
||||
if self._num_ps_tasks > 0 or self._num_clones > 0:
|
||||
return self._worker_device + '/device:CPU:0'
|
||||
else:
|
||||
return ''
|
||||
|
||||
def inputs_device(self):
|
||||
"""Device to use to build the inputs.
|
||||
|
||||
Returns:
|
||||
A value suitable for `tf.device()`.
|
||||
"""
|
||||
device = ''
|
||||
if self._num_ps_tasks > 0:
|
||||
device += self._worker_device
|
||||
device += '/device:CPU:0'
|
||||
return device
|
||||
|
||||
def variables_device(self):
|
||||
"""Returns the device to use for variables created inside the clone.
|
||||
|
||||
Returns:
|
||||
A value suitable for `tf.device()`.
|
||||
"""
|
||||
device = ''
|
||||
if self._num_ps_tasks > 0:
|
||||
device += self._ps_device
|
||||
device += '/device:CPU:0'
|
||||
|
||||
class _PSDeviceChooser(object):
|
||||
"""Slim device chooser for variables when using PS."""
|
||||
|
||||
def __init__(self, device, tasks):
|
||||
self._device = device
|
||||
self._tasks = tasks
|
||||
self._task = 0
|
||||
|
||||
def choose(self, op):
|
||||
if op.device:
|
||||
return op.device
|
||||
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
||||
if node_def.op == 'Variable':
|
||||
t = self._task
|
||||
self._task = (self._task + 1) % self._tasks
|
||||
d = '%s/task:%d' % (self._device, t)
|
||||
return d
|
||||
else:
|
||||
return op.device
|
||||
|
||||
if not self._num_ps_tasks:
|
||||
return device
|
||||
else:
|
||||
chooser = _PSDeviceChooser(device, self._num_ps_tasks)
|
||||
return chooser.choose
|
|
@ -0,0 +1,565 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for model_deploy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from deployment import model_deploy
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
class DeploymentConfigTest(tf.test.TestCase):
|
||||
|
||||
def testDefaults(self):
|
||||
deploy_config = model_deploy.DeploymentConfig()
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
self.assertEqual(deploy_config.caching_device(), None)
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0), '')
|
||||
self.assertEqual(deploy_config.clone_scope(0), '')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
|
||||
|
||||
def testCPUonly(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(clone_on_cpu=True)
|
||||
|
||||
self.assertEqual(deploy_config.caching_device(), None)
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0), 'CPU:0')
|
||||
self.assertEqual(deploy_config.clone_scope(0), '')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
|
||||
|
||||
def testMultiGPU(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2)
|
||||
|
||||
self.assertEqual(deploy_config.caching_device(), None)
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0')
|
||||
self.assertDeviceEqual(deploy_config.clone_device(1), 'GPU:1')
|
||||
self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
|
||||
self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
|
||||
|
||||
def testPS(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1)
|
||||
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0),
|
||||
'/job:worker')
|
||||
self.assertEqual(deploy_config.clone_scope(0), '')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
with tf.device(deploy_config.variables_device()):
|
||||
a = tf.Variable(0)
|
||||
b = tf.Variable(0)
|
||||
c = tf.no_op()
|
||||
d = slim.variable('a', [],
|
||||
caching_device=deploy_config.caching_device())
|
||||
self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
|
||||
self.assertDeviceEqual(a.device, a.value().device)
|
||||
self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0')
|
||||
self.assertDeviceEqual(b.device, b.value().device)
|
||||
self.assertDeviceEqual(c.device, '')
|
||||
self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
|
||||
self.assertDeviceEqual(d.value().device, '')
|
||||
|
||||
def testMultiGPUPS(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2, num_ps_tasks=1)
|
||||
|
||||
self.assertEqual(deploy_config.caching_device()(tf.no_op()), '')
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0),
|
||||
'/job:worker/device:GPU:0')
|
||||
self.assertDeviceEqual(deploy_config.clone_device(1),
|
||||
'/job:worker/device:GPU:1')
|
||||
self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
|
||||
self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
|
||||
def testReplicasPS(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
|
||||
num_ps_tasks=2)
|
||||
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0),
|
||||
'/job:worker')
|
||||
self.assertEqual(deploy_config.clone_scope(0), '')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
|
||||
def testReplicasMultiGPUPS(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
|
||||
num_clones=2,
|
||||
num_ps_tasks=2)
|
||||
self.assertDeviceEqual(deploy_config.clone_device(0),
|
||||
'/job:worker/device:GPU:0')
|
||||
self.assertDeviceEqual(deploy_config.clone_device(1),
|
||||
'/job:worker/device:GPU:1')
|
||||
self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
|
||||
self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
|
||||
self.assertDeviceEqual(deploy_config.optimizer_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
self.assertDeviceEqual(deploy_config.inputs_device(),
|
||||
'/job:worker/device:CPU:0')
|
||||
|
||||
def testVariablesPS(self):
|
||||
deploy_config = model_deploy.DeploymentConfig(num_ps_tasks=2)
|
||||
|
||||
with tf.device(deploy_config.variables_device()):
|
||||
a = tf.Variable(0)
|
||||
b = tf.Variable(0)
|
||||
c = tf.no_op()
|
||||
d = slim.variable('a', [],
|
||||
caching_device=deploy_config.caching_device())
|
||||
|
||||
self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
|
||||
self.assertDeviceEqual(a.device, a.value().device)
|
||||
self.assertDeviceEqual(b.device, '/job:ps/task:1/device:CPU:0')
|
||||
self.assertDeviceEqual(b.device, b.value().device)
|
||||
self.assertDeviceEqual(c.device, '')
|
||||
self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
|
||||
self.assertDeviceEqual(d.value().device, '')
|
||||
|
||||
|
||||
def LogisticClassifier(inputs, labels, scope=None, reuse=None):
|
||||
with tf.variable_scope(scope, 'LogisticClassifier', [inputs, labels],
|
||||
reuse=reuse):
|
||||
predictions = slim.fully_connected(inputs, 1, activation_fn=tf.sigmoid,
|
||||
scope='fully_connected')
|
||||
slim.losses.log_loss(predictions, labels)
|
||||
return predictions
|
||||
|
||||
|
||||
def BatchNormClassifier(inputs, labels, scope=None, reuse=None):
|
||||
with tf.variable_scope(scope, 'BatchNormClassifier', [inputs, labels],
|
||||
reuse=reuse):
|
||||
inputs = slim.batch_norm(inputs, decay=0.1)
|
||||
predictions = slim.fully_connected(inputs, 1,
|
||||
activation_fn=tf.sigmoid,
|
||||
scope='fully_connected')
|
||||
slim.losses.log_loss(predictions, labels)
|
||||
return predictions
|
||||
|
||||
|
||||
class CreatecloneTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = self.get_temp_dir()
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testCreateLogisticClassifier(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = LogisticClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
clone = clones[0]
|
||||
self.assertEqual(len(slim.get_variables()), 2)
|
||||
for v in slim.get_variables():
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
self.assertDeviceEqual(v.value().device, 'CPU:0')
|
||||
self.assertEqual(clone.outputs.op.name,
|
||||
'LogisticClassifier/fully_connected/Sigmoid')
|
||||
self.assertEqual(clone.scope, '')
|
||||
self.assertDeviceEqual(clone.device, '')
|
||||
self.assertEqual(len(slim.losses.get_losses()), 1)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(update_ops, [])
|
||||
|
||||
def testCreateSingleclone(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
clone = clones[0]
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
for v in slim.get_variables():
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
self.assertDeviceEqual(v.value().device, 'CPU:0')
|
||||
self.assertEqual(clone.outputs.op.name,
|
||||
'BatchNormClassifier/fully_connected/Sigmoid')
|
||||
self.assertEqual(clone.scope, '')
|
||||
self.assertDeviceEqual(clone.device, '')
|
||||
self.assertEqual(len(slim.losses.get_losses()), 1)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), 2)
|
||||
|
||||
def testCreateMulticlone(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
num_clones = 4
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
for v in slim.get_variables():
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
self.assertDeviceEqual(v.value().device, 'CPU:0')
|
||||
self.assertEqual(len(clones), num_clones)
|
||||
for i, clone in enumerate(clones):
|
||||
self.assertEqual(
|
||||
clone.outputs.op.name,
|
||||
'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
|
||||
self.assertEqual(len(update_ops), 2)
|
||||
self.assertEqual(clone.scope, 'clone_%d/' % i)
|
||||
self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
|
||||
|
||||
def testCreateOnecloneWithPS(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1,
|
||||
num_ps_tasks=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(clones), 1)
|
||||
clone = clones[0]
|
||||
self.assertEqual(clone.outputs.op.name,
|
||||
'BatchNormClassifier/fully_connected/Sigmoid')
|
||||
self.assertDeviceEqual(clone.device, '/job:worker')
|
||||
self.assertEqual(clone.scope, '')
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
for v in slim.get_variables():
|
||||
self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
|
||||
self.assertDeviceEqual(v.device, v.value().device)
|
||||
|
||||
def testCreateMulticloneWithPS(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2,
|
||||
num_ps_tasks=2)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
for i, v in enumerate(slim.get_variables()):
|
||||
t = i % 2
|
||||
self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t)
|
||||
self.assertDeviceEqual(v.device, v.value().device)
|
||||
self.assertEqual(len(clones), 2)
|
||||
for i, clone in enumerate(clones):
|
||||
self.assertEqual(
|
||||
clone.outputs.op.name,
|
||||
'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
|
||||
self.assertEqual(clone.scope, 'clone_%d/' % i)
|
||||
self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
|
||||
|
||||
|
||||
class OptimizeclonesTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = self.get_temp_dir()
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testCreateLogisticClassifier(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = LogisticClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(slim.get_variables()), 2)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(update_ops, [])
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
|
||||
optimizer)
|
||||
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
|
||||
self.assertEqual(total_loss.op.name, 'total_loss')
|
||||
for g, v in grads_and_vars:
|
||||
self.assertDeviceEqual(g.device, '')
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
|
||||
def testCreateSingleclone(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), 2)
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
|
||||
optimizer)
|
||||
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
|
||||
self.assertEqual(total_loss.op.name, 'total_loss')
|
||||
for g, v in grads_and_vars:
|
||||
self.assertDeviceEqual(g.device, '')
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
|
||||
def testCreateMulticlone(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
clone_args = (tf_inputs, tf_labels)
|
||||
num_clones = 4
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), num_clones * 2)
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
|
||||
optimizer)
|
||||
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
|
||||
self.assertEqual(total_loss.op.name, 'total_loss')
|
||||
for g, v in grads_and_vars:
|
||||
self.assertDeviceEqual(g.device, '')
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
|
||||
def testCreateMulticloneCPU(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
model_args = (tf_inputs, tf_labels)
|
||||
num_clones = 4
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones,
|
||||
clone_on_cpu=True)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), num_clones * 2)
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
|
||||
optimizer)
|
||||
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
|
||||
self.assertEqual(total_loss.op.name, 'total_loss')
|
||||
for g, v in grads_and_vars:
|
||||
self.assertDeviceEqual(g.device, '')
|
||||
self.assertDeviceEqual(v.device, 'CPU:0')
|
||||
|
||||
def testCreateOnecloneWithPS(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
model_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=1,
|
||||
num_ps_tasks=1)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
|
||||
self.assertEqual(len(slim.get_variables()), 5)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), 2)
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
|
||||
optimizer)
|
||||
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
|
||||
self.assertEqual(total_loss.op.name, 'total_loss')
|
||||
for g, v in grads_and_vars:
|
||||
self.assertDeviceEqual(g.device, '/job:worker')
|
||||
self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
|
||||
|
||||
|
||||
class DeployTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = self.get_temp_dir()
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testLocalTrainOp(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
model_fn = BatchNormClassifier
|
||||
model_args = (tf_inputs, tf_labels)
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2,
|
||||
clone_on_cpu=True)
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
self.assertEqual(slim.get_variables(), [])
|
||||
model = model_deploy.deploy(deploy_config, model_fn, model_args,
|
||||
optimizer=optimizer)
|
||||
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
self.assertEqual(len(update_ops), 4)
|
||||
self.assertEqual(len(model.clones), 2)
|
||||
self.assertEqual(model.total_loss.op.name, 'total_loss')
|
||||
self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op')
|
||||
self.assertEqual(model.train_op.op.name, 'train_op')
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
initial_loss = sess.run(model.total_loss)
|
||||
initial_mean, initial_variance = sess.run([moving_mean,
|
||||
moving_variance])
|
||||
self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0])
|
||||
self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0])
|
||||
for _ in range(10):
|
||||
sess.run(model.train_op)
|
||||
final_loss = sess.run(model.total_loss)
|
||||
self.assertLess(final_loss, initial_loss / 10.0)
|
||||
|
||||
final_mean, final_variance = sess.run([moving_mean,
|
||||
moving_variance])
|
||||
self.assertAllClose(final_mean, [0.125, 0.25, 0.375, 0.25])
|
||||
self.assertAllClose(final_variance, [0.109375, 0.1875,
|
||||
0.234375, 0.1875])
|
||||
|
||||
def testNoSummariesOnGPU(self):
|
||||
with tf.Graph().as_default():
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2)
|
||||
|
||||
# clone function creates a fully_connected layer with a regularizer loss.
|
||||
def ModelFn():
|
||||
inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
|
||||
reg = tf.contrib.layers.l2_regularizer(0.001)
|
||||
tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
|
||||
|
||||
model = model_deploy.deploy(
|
||||
deploy_config, ModelFn,
|
||||
optimizer=tf.train.GradientDescentOptimizer(1.0))
|
||||
# The model summary op should have a few summary inputs and all of them
|
||||
# should be on the CPU.
|
||||
self.assertTrue(model.summary_op.op.inputs)
|
||||
for inp in model.summary_op.op.inputs:
|
||||
self.assertEqual('/device:CPU:0', inp.device)
|
||||
|
||||
def testNoSummariesOnGPUForEvals(self):
|
||||
with tf.Graph().as_default():
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=2)
|
||||
|
||||
# clone function creates a fully_connected layer with a regularizer loss.
|
||||
def ModelFn():
|
||||
inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
|
||||
reg = tf.contrib.layers.l2_regularizer(0.001)
|
||||
tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
|
||||
|
||||
# No optimizer here, it's an eval.
|
||||
model = model_deploy.deploy(deploy_config, ModelFn)
|
||||
# The model summary op should have a few summary inputs and all of them
|
||||
# should be on the CPU.
|
||||
self.assertTrue(model.summary_op.op.inputs)
|
||||
for inp in model.summary_op.op.inputs:
|
||||
self.assertEqual('/device:CPU:0', inp.device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains a factory for building various models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import functools
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import alexnet
|
||||
from nets import cifarnet
|
||||
from nets import inception
|
||||
from nets import lenet
|
||||
from nets import overfeat
|
||||
from nets import resnet_v1
|
||||
from nets import resnet_v2
|
||||
from nets import vgg
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
networks_map = {'alexnet_v2': alexnet.alexnet_v2,
|
||||
'cifarnet': cifarnet.cifarnet,
|
||||
'overfeat': overfeat.overfeat,
|
||||
'vgg_a': vgg.vgg_a,
|
||||
'vgg_16': vgg.vgg_16,
|
||||
'vgg_19': vgg.vgg_19,
|
||||
'inception_v1': inception.inception_v1,
|
||||
'inception_v2': inception.inception_v2,
|
||||
'inception_v3': inception.inception_v3,
|
||||
'inception_v4': inception.inception_v4,
|
||||
'inception_resnet_v2': inception.inception_resnet_v2,
|
||||
'lenet': lenet.lenet,
|
||||
'resnet_v1_50': resnet_v1.resnet_v1_50,
|
||||
'resnet_v1_101': resnet_v1.resnet_v1_101,
|
||||
'resnet_v1_152': resnet_v1.resnet_v1_152,
|
||||
'resnet_v1_200': resnet_v1.resnet_v1_200,
|
||||
'resnet_v2_50': resnet_v2.resnet_v2_50,
|
||||
'resnet_v2_101': resnet_v2.resnet_v2_101,
|
||||
'resnet_v2_152': resnet_v2.resnet_v2_152,
|
||||
'resnet_v2_200': resnet_v2.resnet_v2_200,
|
||||
}
|
||||
|
||||
arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
|
||||
'cifarnet': cifarnet.cifarnet_arg_scope,
|
||||
'overfeat': overfeat.overfeat_arg_scope,
|
||||
'vgg_a': vgg.vgg_arg_scope,
|
||||
'vgg_16': vgg.vgg_arg_scope,
|
||||
'vgg_19': vgg.vgg_arg_scope,
|
||||
'inception_v1': inception.inception_v3_arg_scope,
|
||||
'inception_v2': inception.inception_v3_arg_scope,
|
||||
'inception_v3': inception.inception_v3_arg_scope,
|
||||
'inception_v4': inception.inception_v4_arg_scope,
|
||||
'inception_resnet_v2':
|
||||
inception.inception_resnet_v2_arg_scope,
|
||||
'lenet': lenet.lenet_arg_scope,
|
||||
'resnet_v1_50': resnet_v1.resnet_arg_scope,
|
||||
'resnet_v1_101': resnet_v1.resnet_arg_scope,
|
||||
'resnet_v1_152': resnet_v1.resnet_arg_scope,
|
||||
'resnet_v1_200': resnet_v1.resnet_arg_scope,
|
||||
'resnet_v2_50': resnet_v2.resnet_arg_scope,
|
||||
'resnet_v2_101': resnet_v2.resnet_arg_scope,
|
||||
'resnet_v2_152': resnet_v2.resnet_arg_scope,
|
||||
'resnet_v2_200': resnet_v2.resnet_arg_scope,
|
||||
}
|
||||
|
||||
|
||||
def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
|
||||
"""Returns a network_fn such as `logits, end_points = network_fn(images)`.
|
||||
|
||||
Args:
|
||||
name: The name of the network.
|
||||
num_classes: The number of classes to use for classification.
|
||||
weight_decay: The l2 coefficient for the model weights.
|
||||
is_training: `True` if the model is being used for training and `False`
|
||||
otherwise.
|
||||
|
||||
Returns:
|
||||
network_fn: A function that applies the model to a batch of images. It has
|
||||
the following signature:
|
||||
logits, end_points = network_fn(images)
|
||||
Raises:
|
||||
ValueError: If network `name` is not recognized.
|
||||
"""
|
||||
if name not in networks_map:
|
||||
raise ValueError('Name of network unknown %s' % name)
|
||||
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
|
||||
func = networks_map[name]
|
||||
@functools.wraps(func)
|
||||
def network_fn(images):
|
||||
with slim.arg_scope(arg_scope):
|
||||
return func(images, num_classes, is_training=is_training)
|
||||
if hasattr(func, 'default_image_size'):
|
||||
network_fn.default_image_size = func.default_image_size
|
||||
|
||||
return network_fn
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for slim.inception."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import nets_factory
|
||||
|
||||
|
||||
class NetworksTest(tf.test.TestCase):
|
||||
|
||||
def testGetNetworkFn(self):
|
||||
batch_size = 5
|
||||
num_classes = 1000
|
||||
for net in nets_factory.networks_map:
|
||||
with self.test_session():
|
||||
net_fn = nets_factory.get_network_fn(net, num_classes)
|
||||
# Most networks use 224 as their default_image_size
|
||||
image_size = getattr(net_fn, 'default_image_size', 224)
|
||||
inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
|
||||
logits, end_points = net_fn(inputs)
|
||||
self.assertTrue(isinstance(logits, tf.Tensor))
|
||||
self.assertTrue(isinstance(end_points, dict))
|
||||
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
|
||||
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,254 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains building blocks for various versions of Residual Networks.
|
||||
|
||||
Residual networks (ResNets) were proposed in:
|
||||
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
|
||||
|
||||
More variants were introduced in:
|
||||
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
|
||||
|
||||
We can obtain different ResNet variants by changing the network depth, width,
|
||||
and form of residual unit. This module implements the infrastructure for
|
||||
building them. Concrete ResNet units and full ResNet networks are implemented in
|
||||
the accompanying resnet_v1.py and resnet_v2.py modules.
|
||||
|
||||
Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
|
||||
implementation we subsample the output activations in the last residual unit of
|
||||
each block, instead of subsampling the input activations in the first residual
|
||||
unit of each block. The two implementations give identical results but our
|
||||
implementation is more memory efficient.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import tensorflow as tf
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
|
||||
"""A named tuple describing a ResNet block.
|
||||
|
||||
Its parts are:
|
||||
scope: The scope of the `Block`.
|
||||
unit_fn: The ResNet unit function which takes as input a `Tensor` and
|
||||
returns another `Tensor` with the output of the ResNet unit.
|
||||
args: A list of length equal to the number of units in the `Block`. The list
|
||||
contains one (depth, depth_bottleneck, stride) tuple for each unit in the
|
||||
block to serve as argument to unit_fn.
|
||||
"""
|
||||
|
||||
|
||||
def subsample(inputs, factor, scope=None):
|
||||
"""Subsamples the input along the spatial dimensions.
|
||||
|
||||
Args:
|
||||
inputs: A `Tensor` of size [batch, height_in, width_in, channels].
|
||||
factor: The subsampling factor.
|
||||
scope: Optional variable_scope.
|
||||
|
||||
Returns:
|
||||
output: A `Tensor` of size [batch, height_out, width_out, channels] with the
|
||||
input, either intact (if factor == 1) or subsampled (if factor > 1).
|
||||
"""
|
||||
if factor == 1:
|
||||
return inputs
|
||||
else:
|
||||
return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
|
||||
|
||||
|
||||
def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
|
||||
"""Strided 2-D convolution with 'SAME' padding.
|
||||
|
||||
When stride > 1, then we do explicit zero-padding, followed by conv2d with
|
||||
'VALID' padding.
|
||||
|
||||
Note that
|
||||
|
||||
net = conv2d_same(inputs, num_outputs, 3, stride=stride)
|
||||
|
||||
is equivalent to
|
||||
|
||||
net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
|
||||
net = subsample(net, factor=stride)
|
||||
|
||||
whereas
|
||||
|
||||
net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
|
||||
|
||||
is different when the input's height or width is even, which is why we add the
|
||||
current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
|
||||
|
||||
Args:
|
||||
inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
|
||||
num_outputs: An integer, the number of output filters.
|
||||
kernel_size: An int with the kernel_size of the filters.
|
||||
stride: An integer, the output stride.
|
||||
rate: An integer, rate for atrous convolution.
|
||||
scope: Scope.
|
||||
|
||||
Returns:
|
||||
output: A 4-D tensor of size [batch, height_out, width_out, channels] with
|
||||
the convolution output.
|
||||
"""
|
||||
if stride == 1:
|
||||
return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
|
||||
padding='SAME', scope=scope)
|
||||
else:
|
||||
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
||||
pad_total = kernel_size_effective - 1
|
||||
pad_beg = pad_total // 2
|
||||
pad_end = pad_total - pad_beg
|
||||
inputs = tf.pad(inputs,
|
||||
[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
|
||||
return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
|
||||
rate=rate, padding='VALID', scope=scope)
|
||||
|
||||
|
||||
@slim.add_arg_scope
|
||||
def stack_blocks_dense(net, blocks, output_stride=None,
|
||||
outputs_collections=None):
|
||||
"""Stacks ResNet `Blocks` and controls output feature density.
|
||||
|
||||
First, this function creates scopes for the ResNet in the form of
|
||||
'block_name/unit_1', 'block_name/unit_2', etc.
|
||||
|
||||
Second, this function allows the user to explicitly control the ResNet
|
||||
output_stride, which is the ratio of the input to output spatial resolution.
|
||||
This is useful for dense prediction tasks such as semantic segmentation or
|
||||
object detection.
|
||||
|
||||
Most ResNets consist of 4 ResNet blocks and subsample the activations by a
|
||||
factor of 2 when transitioning between consecutive ResNet blocks. This results
|
||||
to a nominal ResNet output_stride equal to 8. If we set the output_stride to
|
||||
half the nominal network stride (e.g., output_stride=4), then we compute
|
||||
responses twice.
|
||||
|
||||
Control of the output feature density is implemented by atrous convolution.
|
||||
|
||||
Args:
|
||||
net: A `Tensor` of size [batch, height, width, channels].
|
||||
blocks: A list of length equal to the number of ResNet `Blocks`. Each
|
||||
element is a ResNet `Block` object describing the units in the `Block`.
|
||||
output_stride: If `None`, then the output will be computed at the nominal
|
||||
network stride. If output_stride is not `None`, it specifies the requested
|
||||
ratio of input to output spatial resolution, which needs to be equal to
|
||||
the product of unit strides from the start up to some level of the ResNet.
|
||||
For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
|
||||
then valid values for the output_stride are 1, 2, 6, 24 or None (which
|
||||
is equivalent to output_stride=24).
|
||||
outputs_collections: Collection to add the ResNet block outputs.
|
||||
|
||||
Returns:
|
||||
net: Output tensor with stride equal to the specified output_stride.
|
||||
|
||||
Raises:
|
||||
ValueError: If the target output_stride is not valid.
|
||||
"""
|
||||
# The current_stride variable keeps track of the effective stride of the
|
||||
# activations. This allows us to invoke atrous convolution whenever applying
|
||||
# the next residual unit would result in the activations having stride larger
|
||||
# than the target output_stride.
|
||||
current_stride = 1
|
||||
|
||||
# The atrous convolution rate parameter.
|
||||
rate = 1
|
||||
|
||||
for block in blocks:
|
||||
with tf.variable_scope(block.scope, 'block', [net]) as sc:
|
||||
for i, unit in enumerate(block.args):
|
||||
if output_stride is not None and current_stride > output_stride:
|
||||
raise ValueError('The target output_stride cannot be reached.')
|
||||
|
||||
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
|
||||
unit_depth, unit_depth_bottleneck, unit_stride = unit
|
||||
|
||||
# If we have reached the target output_stride, then we need to employ
|
||||
# atrous convolution with stride=1 and multiply the atrous rate by the
|
||||
# current unit's stride for use in subsequent layers.
|
||||
if output_stride is not None and current_stride == output_stride:
|
||||
net = block.unit_fn(net,
|
||||
depth=unit_depth,
|
||||
depth_bottleneck=unit_depth_bottleneck,
|
||||
stride=1,
|
||||
rate=rate)
|
||||
rate *= unit_stride
|
||||
|
||||
else:
|
||||
net = block.unit_fn(net,
|
||||
depth=unit_depth,
|
||||
depth_bottleneck=unit_depth_bottleneck,
|
||||
stride=unit_stride,
|
||||
rate=1)
|
||||
current_stride *= unit_stride
|
||||
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
|
||||
|
||||
if output_stride is not None and current_stride != output_stride:
|
||||
raise ValueError('The target output_stride cannot be reached.')
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def resnet_arg_scope(weight_decay=0.0001,
|
||||
batch_norm_decay=0.997,
|
||||
batch_norm_epsilon=1e-5,
|
||||
batch_norm_scale=True):
|
||||
"""Defines the default ResNet arg scope.
|
||||
|
||||
TODO(gpapan): The batch-normalization related default values above are
|
||||
appropriate for use in conjunction with the reference ResNet models
|
||||
released at https://github.com/KaimingHe/deep-residual-networks. When
|
||||
training ResNets from scratch, they might need to be tuned.
|
||||
|
||||
Args:
|
||||
weight_decay: The weight decay to use for regularizing the model.
|
||||
batch_norm_decay: The moving average decay when estimating layer activation
|
||||
statistics in batch normalization.
|
||||
batch_norm_epsilon: Small constant to prevent division by zero when
|
||||
normalizing activations by their variance in batch normalization.
|
||||
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
|
||||
activations in the batch normalization layer.
|
||||
|
||||
Returns:
|
||||
An `arg_scope` to use for the resnet models.
|
||||
"""
|
||||
batch_norm_params = {
|
||||
'decay': batch_norm_decay,
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'scale': batch_norm_scale,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS,
|
||||
}
|
||||
|
||||
with slim.arg_scope(
|
||||
[slim.conv2d],
|
||||
weights_regularizer=slim.l2_regularizer(weight_decay),
|
||||
weights_initializer=slim.variance_scaling_initializer(),
|
||||
activation_fn=tf.nn.relu,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params=batch_norm_params):
|
||||
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
|
||||
# The following implies padding='SAME' for pool1, which makes feature
|
||||
# alignment easier for dense prediction tasks. This is also used in
|
||||
# https://github.com/facebook/fb.resnet.torch. However the accompanying
|
||||
# code of 'Deep Residual Learning for Image Recognition' uses
|
||||
# padding='VALID' for pool1. You can switch to that choice by setting
|
||||
# slim.arg_scope([slim.max_pool2d], padding='VALID').
|
||||
with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
|
||||
return arg_sc
|
|
@ -0,0 +1,296 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains definitions for the original form of Residual Networks.
|
||||
|
||||
The 'v1' residual networks (ResNets) implemented in this module were proposed
|
||||
by:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
|
||||
Other variants were introduced in:
|
||||
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
|
||||
|
||||
The networks defined in this module utilize the bottleneck building block of
|
||||
[1] with projection shortcuts only for increasing depths. They employ batch
|
||||
normalization *after* every weight layer. This is the architecture used by
|
||||
MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
|
||||
ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
|
||||
architecture and the alternative 'v2' architecture of [2] which uses batch
|
||||
normalization *before* every weight layer in the so-called full pre-activation
|
||||
units.
|
||||
|
||||
Typical use:
|
||||
|
||||
from tensorflow.contrib.slim.nets import resnet_v1
|
||||
|
||||
ResNet-101 for image classification into 1000 classes:
|
||||
|
||||
# inputs has shape [batch, 224, 224, 3]
|
||||
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
|
||||
net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
|
||||
|
||||
ResNet-101 for semantic segmentation into 21 classes:
|
||||
|
||||
# inputs has shape [batch, 513, 513, 3]
|
||||
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
|
||||
net, end_points = resnet_v1.resnet_v1_101(inputs,
|
||||
21,
|
||||
is_training=False,
|
||||
global_pool=False,
|
||||
output_stride=16)
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import resnet_utils
|
||||
|
||||
|
||||
resnet_arg_scope = resnet_utils.resnet_arg_scope
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
@slim.add_arg_scope
|
||||
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
|
||||
outputs_collections=None, scope=None):
|
||||
"""Bottleneck residual unit variant with BN after convolutions.
|
||||
|
||||
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
|
||||
its definition. Note that we use here the bottleneck variant which has an
|
||||
extra bottleneck layer.
|
||||
|
||||
When putting together two consecutive ResNet blocks that use this unit, one
|
||||
should use stride = 2 in the last unit of the first block.
|
||||
|
||||
Args:
|
||||
inputs: A tensor of size [batch, height, width, channels].
|
||||
depth: The depth of the ResNet unit output.
|
||||
depth_bottleneck: The depth of the bottleneck layers.
|
||||
stride: The ResNet unit's stride. Determines the amount of downsampling of
|
||||
the units output compared to its input.
|
||||
rate: An integer, rate for atrous convolution.
|
||||
outputs_collections: Collection to add the ResNet unit output.
|
||||
scope: Optional variable_scope.
|
||||
|
||||
Returns:
|
||||
The ResNet unit's output.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
|
||||
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
|
||||
if depth == depth_in:
|
||||
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
|
||||
else:
|
||||
shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
|
||||
activation_fn=None, scope='shortcut')
|
||||
|
||||
residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
|
||||
scope='conv1')
|
||||
residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
|
||||
rate=rate, scope='conv2')
|
||||
residual = slim.conv2d(residual, depth, [1, 1], stride=1,
|
||||
activation_fn=None, scope='conv3')
|
||||
|
||||
output = tf.nn.relu(shortcut + residual)
|
||||
|
||||
return slim.utils.collect_named_outputs(outputs_collections,
|
||||
sc.original_name_scope,
|
||||
output)
|
||||
|
||||
|
||||
def resnet_v1(inputs,
|
||||
blocks,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
include_root_block=True,
|
||||
reuse=None,
|
||||
scope=None):
|
||||
"""Generator for v1 ResNet models.
|
||||
|
||||
This function generates a family of ResNet v1 models. See the resnet_v1_*()
|
||||
methods for specific model instantiations, obtained by selecting different
|
||||
block instantiations that produce ResNets of various depths.
|
||||
|
||||
Training for image classification on Imagenet is usually done with [224, 224]
|
||||
inputs, resulting in [7, 7] feature maps at the output of the last ResNet
|
||||
block for the ResNets defined in [1] that have nominal stride equal to 32.
|
||||
However, for dense prediction tasks we advise that one uses inputs with
|
||||
spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
|
||||
this case the feature maps at the ResNet output will have spatial shape
|
||||
[(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
|
||||
and corners exactly aligned with the input image corners, which greatly
|
||||
facilitates alignment of the features to the image. Using as input [225, 225]
|
||||
images results in [8, 8] feature maps at the output of the last ResNet block.
|
||||
|
||||
For dense prediction tasks, the ResNet needs to run in fully-convolutional
|
||||
(FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
|
||||
have nominal stride equal to 32 and a good choice in FCN mode is to use
|
||||
output_stride=16 in order to increase the density of the computed features at
|
||||
small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
|
||||
|
||||
Args:
|
||||
inputs: A tensor of size [batch, height_in, width_in, channels].
|
||||
blocks: A list of length equal to the number of ResNet blocks. Each element
|
||||
is a resnet_utils.Block object describing the units in the block.
|
||||
num_classes: Number of predicted classes for classification tasks. If None
|
||||
we return the features before the logit layer.
|
||||
is_training: whether is training or not.
|
||||
global_pool: If True, we perform global average pooling before computing the
|
||||
logits. Set to True for image classification, False for dense prediction.
|
||||
output_stride: If None, then the output will be computed at the nominal
|
||||
network stride. If output_stride is not None, it specifies the requested
|
||||
ratio of input to output spatial resolution.
|
||||
include_root_block: If True, include the initial convolution followed by
|
||||
max-pooling, if False excludes it.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
|
||||
Returns:
|
||||
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
|
||||
If global_pool is False, then height_out and width_out are reduced by a
|
||||
factor of output_stride compared to the respective height_in and width_in,
|
||||
else both height_out and width_out equal one. If num_classes is None, then
|
||||
net is the output of the last ResNet block, potentially after global
|
||||
average pooling. If num_classes is not None, net contains the pre-softmax
|
||||
activations.
|
||||
end_points: A dictionary from components of the network to the corresponding
|
||||
activation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the target output_stride is not valid.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
|
||||
end_points_collection = sc.name + '_end_points'
|
||||
with slim.arg_scope([slim.conv2d, bottleneck,
|
||||
resnet_utils.stack_blocks_dense],
|
||||
outputs_collections=end_points_collection):
|
||||
with slim.arg_scope([slim.batch_norm], is_training=is_training):
|
||||
net = inputs
|
||||
if include_root_block:
|
||||
if output_stride is not None:
|
||||
if output_stride % 4 != 0:
|
||||
raise ValueError('The output_stride needs to be a multiple of 4.')
|
||||
output_stride /= 4
|
||||
net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
|
||||
net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
|
||||
net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
|
||||
if global_pool:
|
||||
# Global average pooling.
|
||||
net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
|
||||
if num_classes is not None:
|
||||
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope='logits')
|
||||
net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
|
||||
# Convert end_points_collection into a dictionary of end_points.
|
||||
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
||||
if num_classes is not None:
|
||||
end_points['predictions'] = slim.softmax(net, scope='predictions')
|
||||
return net, end_points
|
||||
resnet_v1.default_image_size = 224
|
||||
|
||||
|
||||
def resnet_v1_50(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v1_50'):
|
||||
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)
|
||||
]
|
||||
return resnet_v1(inputs, blocks, num_classes, is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v1_101(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v1_101'):
|
||||
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)
|
||||
]
|
||||
return resnet_v1(inputs, blocks, num_classes, is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v1_152(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v1_152'):
|
||||
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v1(inputs, blocks, num_classes, is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v1_200(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v1_200'):
|
||||
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v1(inputs, blocks, num_classes, is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
|
@ -0,0 +1,450 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for slim.nets.resnet_v1."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import resnet_utils
|
||||
from nets import resnet_v1
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
def create_test_input(batch_size, height, width, channels):
|
||||
"""Create test input tensor.
|
||||
|
||||
Args:
|
||||
batch_size: The number of images per batch or `None` if unknown.
|
||||
height: The height of each image or `None` if unknown.
|
||||
width: The width of each image or `None` if unknown.
|
||||
channels: The number of channels per image or `None` if unknown.
|
||||
|
||||
Returns:
|
||||
Either a placeholder `Tensor` of dimension
|
||||
[batch_size, height, width, channels] if any of the inputs are `None` or a
|
||||
constant `Tensor` with the mesh grid values along the spatial dimensions.
|
||||
"""
|
||||
if None in [batch_size, height, width, channels]:
|
||||
return tf.placeholder(tf.float32, (batch_size, height, width, channels))
|
||||
else:
|
||||
return tf.to_float(
|
||||
np.tile(
|
||||
np.reshape(
|
||||
np.reshape(np.arange(height), [height, 1]) +
|
||||
np.reshape(np.arange(width), [1, width]),
|
||||
[1, height, width, 1]),
|
||||
[batch_size, 1, 1, channels]))
|
||||
|
||||
|
||||
class ResnetUtilsTest(tf.test.TestCase):
|
||||
|
||||
def testSubsampleThreeByThree(self):
|
||||
x = tf.reshape(tf.to_float(tf.range(9)), [1, 3, 3, 1])
|
||||
x = resnet_utils.subsample(x, 2)
|
||||
expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
|
||||
with self.test_session():
|
||||
self.assertAllClose(x.eval(), expected.eval())
|
||||
|
||||
def testSubsampleFourByFour(self):
|
||||
x = tf.reshape(tf.to_float(tf.range(16)), [1, 4, 4, 1])
|
||||
x = resnet_utils.subsample(x, 2)
|
||||
expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
|
||||
with self.test_session():
|
||||
self.assertAllClose(x.eval(), expected.eval())
|
||||
|
||||
def testConv2DSameEven(self):
|
||||
n, n2 = 4, 2
|
||||
|
||||
# Input image.
|
||||
x = create_test_input(1, n, n, 1)
|
||||
|
||||
# Convolution kernel.
|
||||
w = create_test_input(1, 3, 3, 1)
|
||||
w = tf.reshape(w, [3, 3, 1, 1])
|
||||
|
||||
tf.get_variable('Conv/weights', initializer=w)
|
||||
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
|
||||
y1_expected = tf.to_float([[14, 28, 43, 26],
|
||||
[28, 48, 66, 37],
|
||||
[43, 66, 84, 46],
|
||||
[26, 37, 46, 22]])
|
||||
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
|
||||
|
||||
y2 = resnet_utils.subsample(y1, 2)
|
||||
y2_expected = tf.to_float([[14, 43],
|
||||
[43, 84]])
|
||||
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
|
||||
|
||||
y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
|
||||
y3_expected = y2_expected
|
||||
|
||||
y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
|
||||
y4_expected = tf.to_float([[48, 37],
|
||||
[37, 22]])
|
||||
y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(y1.eval(), y1_expected.eval())
|
||||
self.assertAllClose(y2.eval(), y2_expected.eval())
|
||||
self.assertAllClose(y3.eval(), y3_expected.eval())
|
||||
self.assertAllClose(y4.eval(), y4_expected.eval())
|
||||
|
||||
def testConv2DSameOdd(self):
|
||||
n, n2 = 5, 3
|
||||
|
||||
# Input image.
|
||||
x = create_test_input(1, n, n, 1)
|
||||
|
||||
# Convolution kernel.
|
||||
w = create_test_input(1, 3, 3, 1)
|
||||
w = tf.reshape(w, [3, 3, 1, 1])
|
||||
|
||||
tf.get_variable('Conv/weights', initializer=w)
|
||||
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
|
||||
y1_expected = tf.to_float([[14, 28, 43, 58, 34],
|
||||
[28, 48, 66, 84, 46],
|
||||
[43, 66, 84, 102, 55],
|
||||
[58, 84, 102, 120, 64],
|
||||
[34, 46, 55, 64, 30]])
|
||||
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
|
||||
|
||||
y2 = resnet_utils.subsample(y1, 2)
|
||||
y2_expected = tf.to_float([[14, 43, 34],
|
||||
[43, 84, 55],
|
||||
[34, 55, 30]])
|
||||
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
|
||||
|
||||
y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
|
||||
y3_expected = y2_expected
|
||||
|
||||
y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
|
||||
y4_expected = y2_expected
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(y1.eval(), y1_expected.eval())
|
||||
self.assertAllClose(y2.eval(), y2_expected.eval())
|
||||
self.assertAllClose(y3.eval(), y3_expected.eval())
|
||||
self.assertAllClose(y4.eval(), y4_expected.eval())
|
||||
|
||||
def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
|
||||
"""A plain ResNet without extra layers before or after the ResNet blocks."""
|
||||
with tf.variable_scope(scope, values=[inputs]):
|
||||
with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
|
||||
net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
|
||||
end_points = dict(tf.get_collection('end_points'))
|
||||
return net, end_points
|
||||
|
||||
def testEndPointsV1(self):
|
||||
"""Test the end points of a tiny v1 bottleneck network."""
|
||||
bottleneck = resnet_v1.bottleneck
|
||||
blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
|
||||
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
|
||||
inputs = create_test_input(2, 32, 16, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
|
||||
expected = [
|
||||
'tiny/block1/unit_1/bottleneck_v1/shortcut',
|
||||
'tiny/block1/unit_1/bottleneck_v1/conv1',
|
||||
'tiny/block1/unit_1/bottleneck_v1/conv2',
|
||||
'tiny/block1/unit_1/bottleneck_v1/conv3',
|
||||
'tiny/block1/unit_2/bottleneck_v1/conv1',
|
||||
'tiny/block1/unit_2/bottleneck_v1/conv2',
|
||||
'tiny/block1/unit_2/bottleneck_v1/conv3',
|
||||
'tiny/block2/unit_1/bottleneck_v1/shortcut',
|
||||
'tiny/block2/unit_1/bottleneck_v1/conv1',
|
||||
'tiny/block2/unit_1/bottleneck_v1/conv2',
|
||||
'tiny/block2/unit_1/bottleneck_v1/conv3',
|
||||
'tiny/block2/unit_2/bottleneck_v1/conv1',
|
||||
'tiny/block2/unit_2/bottleneck_v1/conv2',
|
||||
'tiny/block2/unit_2/bottleneck_v1/conv3']
|
||||
self.assertItemsEqual(expected, end_points)
|
||||
|
||||
def _stack_blocks_nondense(self, net, blocks):
|
||||
"""A simplified ResNet Block stacker without output stride control."""
|
||||
for block in blocks:
|
||||
with tf.variable_scope(block.scope, 'block', [net]):
|
||||
for i, unit in enumerate(block.args):
|
||||
depth, depth_bottleneck, stride = unit
|
||||
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
|
||||
net = block.unit_fn(net,
|
||||
depth=depth,
|
||||
depth_bottleneck=depth_bottleneck,
|
||||
stride=stride,
|
||||
rate=1)
|
||||
return net
|
||||
|
||||
def _atrousValues(self, bottleneck):
|
||||
"""Verify the values of dense feature extraction by atrous convolution.
|
||||
|
||||
Make sure that dense feature extraction by stack_blocks_dense() followed by
|
||||
subsampling gives identical results to feature extraction at the nominal
|
||||
network output stride using the simple self._stack_blocks_nondense() above.
|
||||
|
||||
Args:
|
||||
bottleneck: The bottleneck function.
|
||||
"""
|
||||
blocks = [
|
||||
resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
|
||||
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
|
||||
resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
|
||||
resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
|
||||
]
|
||||
nominal_stride = 8
|
||||
|
||||
# Test both odd and even input dimensions.
|
||||
height = 30
|
||||
width = 31
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
with slim.arg_scope([slim.batch_norm], is_training=False):
|
||||
for output_stride in [1, 2, 4, 8, None]:
|
||||
with tf.Graph().as_default():
|
||||
with self.test_session() as sess:
|
||||
tf.set_random_seed(0)
|
||||
inputs = create_test_input(1, height, width, 3)
|
||||
# Dense feature extraction followed by subsampling.
|
||||
output = resnet_utils.stack_blocks_dense(inputs,
|
||||
blocks,
|
||||
output_stride)
|
||||
if output_stride is None:
|
||||
factor = 1
|
||||
else:
|
||||
factor = nominal_stride // output_stride
|
||||
|
||||
output = resnet_utils.subsample(output, factor)
|
||||
# Make the two networks use the same weights.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
# Feature extraction at the nominal network rate.
|
||||
expected = self._stack_blocks_nondense(inputs, blocks)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output, expected = sess.run([output, expected])
|
||||
self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testAtrousValuesBottleneck(self):
|
||||
self._atrousValues(resnet_v1.bottleneck)
|
||||
|
||||
|
||||
class ResnetCompleteNetworkTest(tf.test.TestCase):
|
||||
"""Tests with complete small ResNet v1 networks."""
|
||||
|
||||
def _resnet_small(self,
|
||||
inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
include_root_block=True,
|
||||
reuse=None,
|
||||
scope='resnet_v1_small'):
|
||||
"""A shallow and thin ResNet v1 for faster tests."""
|
||||
bottleneck = resnet_v1.bottleneck
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(32, 8, 1)] * 2)]
|
||||
return resnet_v1.resnet_v1(inputs, blocks, num_classes,
|
||||
is_training=is_training,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride,
|
||||
include_root_block=include_root_block,
|
||||
reuse=reuse,
|
||||
scope=scope)
|
||||
|
||||
def testClassificationEndPoints(self):
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 224, 224, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
logits, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
self.assertTrue(logits.op.name.startswith('resnet/logits'))
|
||||
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
|
||||
self.assertTrue('predictions' in end_points)
|
||||
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
|
||||
[2, 1, 1, num_classes])
|
||||
|
||||
def testClassificationShapes(self):
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 224, 224, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 28, 28, 4],
|
||||
'resnet/block2': [2, 14, 14, 8],
|
||||
'resnet/block3': [2, 7, 7, 16],
|
||||
'resnet/block4': [2, 7, 7, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 321, 321, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 41, 41, 4],
|
||||
'resnet/block2': [2, 21, 21, 8],
|
||||
'resnet/block3': [2, 11, 11, 16],
|
||||
'resnet/block4': [2, 11, 11, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testRootlessFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 128, 128, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
include_root_block=False,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 64, 64, 4],
|
||||
'resnet/block2': [2, 32, 32, 8],
|
||||
'resnet/block3': [2, 16, 16, 16],
|
||||
'resnet/block4': [2, 16, 16, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testAtrousFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
output_stride = 8
|
||||
inputs = create_test_input(2, 321, 321, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs,
|
||||
num_classes,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 41, 41, 4],
|
||||
'resnet/block2': [2, 41, 41, 8],
|
||||
'resnet/block3': [2, 41, 41, 16],
|
||||
'resnet/block4': [2, 41, 41, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testAtrousFullyConvolutionalValues(self):
|
||||
"""Verify dense feature extraction with atrous convolution."""
|
||||
nominal_stride = 32
|
||||
for output_stride in [4, 8, 16, 32, None]:
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
with tf.Graph().as_default():
|
||||
with self.test_session() as sess:
|
||||
tf.set_random_seed(0)
|
||||
inputs = create_test_input(2, 81, 81, 3)
|
||||
# Dense feature extraction followed by subsampling.
|
||||
output, _ = self._resnet_small(inputs, None, is_training=False,
|
||||
global_pool=False,
|
||||
output_stride=output_stride)
|
||||
if output_stride is None:
|
||||
factor = 1
|
||||
else:
|
||||
factor = nominal_stride // output_stride
|
||||
output = resnet_utils.subsample(output, factor)
|
||||
# Make the two networks use the same weights.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
# Feature extraction at the nominal network rate.
|
||||
expected, _ = self._resnet_small(inputs, None, is_training=False,
|
||||
global_pool=False)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(output.eval(), expected.eval(),
|
||||
atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testUnknownBatchSize(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(None, height, width, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
logits, _ = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
self.assertTrue(logits.op.name.startswith('resnet/logits'))
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[None, 1, 1, num_classes])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(logits, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
|
||||
|
||||
def testFullyConvolutionalUnknownHeightWidth(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = False
|
||||
inputs = create_test_input(batch, None, None, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[batch, None, None, 32])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(output, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 3, 3, 32))
|
||||
|
||||
def testAtrousFullyConvolutionalUnknownHeightWidth(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = False
|
||||
output_stride = 8
|
||||
inputs = create_test_input(batch, None, None, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
output, _ = self._resnet_small(inputs,
|
||||
None,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[batch, None, None, 32])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(output, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 9, 9, 32))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,302 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains definitions for the preactivation form of Residual Networks.
|
||||
|
||||
Residual networks (ResNets) were originally proposed in:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
|
||||
The full preactivation 'v2' ResNet variant implemented in this module was
|
||||
introduced by:
|
||||
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
|
||||
|
||||
The key difference of the full preactivation 'v2' variant compared to the
|
||||
'v1' variant in [1] is the use of batch normalization before every weight layer.
|
||||
Another difference is that 'v2' ResNets do not include an activation function in
|
||||
the main pathway. Also see [2; Fig. 4e].
|
||||
|
||||
Typical use:
|
||||
|
||||
from tensorflow.contrib.slim.nets import resnet_v2
|
||||
|
||||
ResNet-101 for image classification into 1000 classes:
|
||||
|
||||
# inputs has shape [batch, 224, 224, 3]
|
||||
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
|
||||
net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
|
||||
|
||||
ResNet-101 for semantic segmentation into 21 classes:
|
||||
|
||||
# inputs has shape [batch, 513, 513, 3]
|
||||
with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
|
||||
net, end_points = resnet_v2.resnet_v2_101(inputs,
|
||||
21,
|
||||
is_training=False,
|
||||
global_pool=False,
|
||||
output_stride=16)
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import resnet_utils
|
||||
|
||||
slim = tf.contrib.slim
|
||||
resnet_arg_scope = resnet_utils.resnet_arg_scope
|
||||
|
||||
|
||||
@slim.add_arg_scope
|
||||
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
|
||||
outputs_collections=None, scope=None):
|
||||
"""Bottleneck residual unit variant with BN before convolutions.
|
||||
|
||||
This is the full preactivation residual unit variant proposed in [2]. See
|
||||
Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
|
||||
variant which has an extra bottleneck layer.
|
||||
|
||||
When putting together two consecutive ResNet blocks that use this unit, one
|
||||
should use stride = 2 in the last unit of the first block.
|
||||
|
||||
Args:
|
||||
inputs: A tensor of size [batch, height, width, channels].
|
||||
depth: The depth of the ResNet unit output.
|
||||
depth_bottleneck: The depth of the bottleneck layers.
|
||||
stride: The ResNet unit's stride. Determines the amount of downsampling of
|
||||
the units output compared to its input.
|
||||
rate: An integer, rate for atrous convolution.
|
||||
outputs_collections: Collection to add the ResNet unit output.
|
||||
scope: Optional variable_scope.
|
||||
|
||||
Returns:
|
||||
The ResNet unit's output.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
|
||||
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
|
||||
preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
|
||||
if depth == depth_in:
|
||||
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
|
||||
else:
|
||||
shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
|
||||
normalizer_fn=None, activation_fn=None,
|
||||
scope='shortcut')
|
||||
|
||||
residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
|
||||
scope='conv1')
|
||||
residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
|
||||
rate=rate, scope='conv2')
|
||||
residual = slim.conv2d(residual, depth, [1, 1], stride=1,
|
||||
normalizer_fn=None, activation_fn=None,
|
||||
scope='conv3')
|
||||
|
||||
output = shortcut + residual
|
||||
|
||||
return slim.utils.collect_named_outputs(outputs_collections,
|
||||
sc.original_name_scope,
|
||||
output)
|
||||
|
||||
|
||||
def resnet_v2(inputs,
|
||||
blocks,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
include_root_block=True,
|
||||
reuse=None,
|
||||
scope=None):
|
||||
"""Generator for v2 (preactivation) ResNet models.
|
||||
|
||||
This function generates a family of ResNet v2 models. See the resnet_v2_*()
|
||||
methods for specific model instantiations, obtained by selecting different
|
||||
block instantiations that produce ResNets of various depths.
|
||||
|
||||
Training for image classification on Imagenet is usually done with [224, 224]
|
||||
inputs, resulting in [7, 7] feature maps at the output of the last ResNet
|
||||
block for the ResNets defined in [1] that have nominal stride equal to 32.
|
||||
However, for dense prediction tasks we advise that one uses inputs with
|
||||
spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
|
||||
this case the feature maps at the ResNet output will have spatial shape
|
||||
[(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
|
||||
and corners exactly aligned with the input image corners, which greatly
|
||||
facilitates alignment of the features to the image. Using as input [225, 225]
|
||||
images results in [8, 8] feature maps at the output of the last ResNet block.
|
||||
|
||||
For dense prediction tasks, the ResNet needs to run in fully-convolutional
|
||||
(FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
|
||||
have nominal stride equal to 32 and a good choice in FCN mode is to use
|
||||
output_stride=16 in order to increase the density of the computed features at
|
||||
small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
|
||||
|
||||
Args:
|
||||
inputs: A tensor of size [batch, height_in, width_in, channels].
|
||||
blocks: A list of length equal to the number of ResNet blocks. Each element
|
||||
is a resnet_utils.Block object describing the units in the block.
|
||||
num_classes: Number of predicted classes for classification tasks. If None
|
||||
we return the features before the logit layer.
|
||||
is_training: whether is training or not.
|
||||
global_pool: If True, we perform global average pooling before computing the
|
||||
logits. Set to True for image classification, False for dense prediction.
|
||||
output_stride: If None, then the output will be computed at the nominal
|
||||
network stride. If output_stride is not None, it specifies the requested
|
||||
ratio of input to output spatial resolution.
|
||||
include_root_block: If True, include the initial convolution followed by
|
||||
max-pooling, if False excludes it. If excluded, `inputs` should be the
|
||||
results of an activation-less convolution.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
|
||||
|
||||
Returns:
|
||||
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
|
||||
If global_pool is False, then height_out and width_out are reduced by a
|
||||
factor of output_stride compared to the respective height_in and width_in,
|
||||
else both height_out and width_out equal one. If num_classes is None, then
|
||||
net is the output of the last ResNet block, potentially after global
|
||||
average pooling. If num_classes is not None, net contains the pre-softmax
|
||||
activations.
|
||||
end_points: A dictionary from components of the network to the corresponding
|
||||
activation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the target output_stride is not valid.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
|
||||
end_points_collection = sc.name + '_end_points'
|
||||
with slim.arg_scope([slim.conv2d, bottleneck,
|
||||
resnet_utils.stack_blocks_dense],
|
||||
outputs_collections=end_points_collection):
|
||||
with slim.arg_scope([slim.batch_norm], is_training=is_training):
|
||||
net = inputs
|
||||
if include_root_block:
|
||||
if output_stride is not None:
|
||||
if output_stride % 4 != 0:
|
||||
raise ValueError('The output_stride needs to be a multiple of 4.')
|
||||
output_stride /= 4
|
||||
# We do not include batch normalization or activation functions in
|
||||
# conv1 because the first ResNet unit will perform these. Cf.
|
||||
# Appendix of [2].
|
||||
with slim.arg_scope([slim.conv2d],
|
||||
activation_fn=None, normalizer_fn=None):
|
||||
net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
|
||||
net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
|
||||
net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
|
||||
# This is needed because the pre-activation variant does not have batch
|
||||
# normalization or activation functions in the residual unit output. See
|
||||
# Appendix of [2].
|
||||
net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
|
||||
if global_pool:
|
||||
# Global average pooling.
|
||||
net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
|
||||
if num_classes is not None:
|
||||
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope='logits')
|
||||
# Convert end_points_collection into a dictionary of end_points.
|
||||
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
||||
if num_classes is not None:
|
||||
end_points['predictions'] = slim.softmax(net, scope='predictions')
|
||||
return net, end_points
|
||||
resnet_v2.default_image_size = 224
|
||||
|
||||
|
||||
def resnet_v2_50(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v2_50'):
|
||||
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v2_101(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v2_101'):
|
||||
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v2_152(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v2_152'):
|
||||
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
||||
|
||||
|
||||
def resnet_v2_200(inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
reuse=None,
|
||||
scope='resnet_v2_200'):
|
||||
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(2048, 512, 1)] * 3)]
|
||||
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
||||
global_pool=global_pool, output_stride=output_stride,
|
||||
include_root_block=True, reuse=reuse, scope=scope)
|
|
@ -0,0 +1,453 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for slim.nets.resnet_v2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import resnet_utils
|
||||
from nets import resnet_v2
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
def create_test_input(batch_size, height, width, channels):
|
||||
"""Create test input tensor.
|
||||
|
||||
Args:
|
||||
batch_size: The number of images per batch or `None` if unknown.
|
||||
height: The height of each image or `None` if unknown.
|
||||
width: The width of each image or `None` if unknown.
|
||||
channels: The number of channels per image or `None` if unknown.
|
||||
|
||||
Returns:
|
||||
Either a placeholder `Tensor` of dimension
|
||||
[batch_size, height, width, channels] if any of the inputs are `None` or a
|
||||
constant `Tensor` with the mesh grid values along the spatial dimensions.
|
||||
"""
|
||||
if None in [batch_size, height, width, channels]:
|
||||
return tf.placeholder(tf.float32, (batch_size, height, width, channels))
|
||||
else:
|
||||
return tf.to_float(
|
||||
np.tile(
|
||||
np.reshape(
|
||||
np.reshape(np.arange(height), [height, 1]) +
|
||||
np.reshape(np.arange(width), [1, width]),
|
||||
[1, height, width, 1]),
|
||||
[batch_size, 1, 1, channels]))
|
||||
|
||||
|
||||
class ResnetUtilsTest(tf.test.TestCase):
|
||||
|
||||
def testSubsampleThreeByThree(self):
|
||||
x = tf.reshape(tf.to_float(tf.range(9)), [1, 3, 3, 1])
|
||||
x = resnet_utils.subsample(x, 2)
|
||||
expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
|
||||
with self.test_session():
|
||||
self.assertAllClose(x.eval(), expected.eval())
|
||||
|
||||
def testSubsampleFourByFour(self):
|
||||
x = tf.reshape(tf.to_float(tf.range(16)), [1, 4, 4, 1])
|
||||
x = resnet_utils.subsample(x, 2)
|
||||
expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
|
||||
with self.test_session():
|
||||
self.assertAllClose(x.eval(), expected.eval())
|
||||
|
||||
def testConv2DSameEven(self):
|
||||
n, n2 = 4, 2
|
||||
|
||||
# Input image.
|
||||
x = create_test_input(1, n, n, 1)
|
||||
|
||||
# Convolution kernel.
|
||||
w = create_test_input(1, 3, 3, 1)
|
||||
w = tf.reshape(w, [3, 3, 1, 1])
|
||||
|
||||
tf.get_variable('Conv/weights', initializer=w)
|
||||
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
|
||||
y1_expected = tf.to_float([[14, 28, 43, 26],
|
||||
[28, 48, 66, 37],
|
||||
[43, 66, 84, 46],
|
||||
[26, 37, 46, 22]])
|
||||
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
|
||||
|
||||
y2 = resnet_utils.subsample(y1, 2)
|
||||
y2_expected = tf.to_float([[14, 43],
|
||||
[43, 84]])
|
||||
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
|
||||
|
||||
y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
|
||||
y3_expected = y2_expected
|
||||
|
||||
y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
|
||||
y4_expected = tf.to_float([[48, 37],
|
||||
[37, 22]])
|
||||
y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(y1.eval(), y1_expected.eval())
|
||||
self.assertAllClose(y2.eval(), y2_expected.eval())
|
||||
self.assertAllClose(y3.eval(), y3_expected.eval())
|
||||
self.assertAllClose(y4.eval(), y4_expected.eval())
|
||||
|
||||
def testConv2DSameOdd(self):
|
||||
n, n2 = 5, 3
|
||||
|
||||
# Input image.
|
||||
x = create_test_input(1, n, n, 1)
|
||||
|
||||
# Convolution kernel.
|
||||
w = create_test_input(1, 3, 3, 1)
|
||||
w = tf.reshape(w, [3, 3, 1, 1])
|
||||
|
||||
tf.get_variable('Conv/weights', initializer=w)
|
||||
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
|
||||
y1_expected = tf.to_float([[14, 28, 43, 58, 34],
|
||||
[28, 48, 66, 84, 46],
|
||||
[43, 66, 84, 102, 55],
|
||||
[58, 84, 102, 120, 64],
|
||||
[34, 46, 55, 64, 30]])
|
||||
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
|
||||
|
||||
y2 = resnet_utils.subsample(y1, 2)
|
||||
y2_expected = tf.to_float([[14, 43, 34],
|
||||
[43, 84, 55],
|
||||
[34, 55, 30]])
|
||||
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
|
||||
|
||||
y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
|
||||
y3_expected = y2_expected
|
||||
|
||||
y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
|
||||
y4_expected = y2_expected
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(y1.eval(), y1_expected.eval())
|
||||
self.assertAllClose(y2.eval(), y2_expected.eval())
|
||||
self.assertAllClose(y3.eval(), y3_expected.eval())
|
||||
self.assertAllClose(y4.eval(), y4_expected.eval())
|
||||
|
||||
def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
|
||||
"""A plain ResNet without extra layers before or after the ResNet blocks."""
|
||||
with tf.variable_scope(scope, values=[inputs]):
|
||||
with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
|
||||
net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
|
||||
end_points = dict(tf.get_collection('end_points'))
|
||||
return net, end_points
|
||||
|
||||
def testEndPointsV2(self):
|
||||
"""Test the end points of a tiny v2 bottleneck network."""
|
||||
bottleneck = resnet_v2.bottleneck
|
||||
blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
|
||||
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
|
||||
inputs = create_test_input(2, 32, 16, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
|
||||
expected = [
|
||||
'tiny/block1/unit_1/bottleneck_v2/shortcut',
|
||||
'tiny/block1/unit_1/bottleneck_v2/conv1',
|
||||
'tiny/block1/unit_1/bottleneck_v2/conv2',
|
||||
'tiny/block1/unit_1/bottleneck_v2/conv3',
|
||||
'tiny/block1/unit_2/bottleneck_v2/conv1',
|
||||
'tiny/block1/unit_2/bottleneck_v2/conv2',
|
||||
'tiny/block1/unit_2/bottleneck_v2/conv3',
|
||||
'tiny/block2/unit_1/bottleneck_v2/shortcut',
|
||||
'tiny/block2/unit_1/bottleneck_v2/conv1',
|
||||
'tiny/block2/unit_1/bottleneck_v2/conv2',
|
||||
'tiny/block2/unit_1/bottleneck_v2/conv3',
|
||||
'tiny/block2/unit_2/bottleneck_v2/conv1',
|
||||
'tiny/block2/unit_2/bottleneck_v2/conv2',
|
||||
'tiny/block2/unit_2/bottleneck_v2/conv3']
|
||||
self.assertItemsEqual(expected, end_points)
|
||||
|
||||
def _stack_blocks_nondense(self, net, blocks):
|
||||
"""A simplified ResNet Block stacker without output stride control."""
|
||||
for block in blocks:
|
||||
with tf.variable_scope(block.scope, 'block', [net]):
|
||||
for i, unit in enumerate(block.args):
|
||||
depth, depth_bottleneck, stride = unit
|
||||
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
|
||||
net = block.unit_fn(net,
|
||||
depth=depth,
|
||||
depth_bottleneck=depth_bottleneck,
|
||||
stride=stride,
|
||||
rate=1)
|
||||
return net
|
||||
|
||||
def _atrousValues(self, bottleneck):
|
||||
"""Verify the values of dense feature extraction by atrous convolution.
|
||||
|
||||
Make sure that dense feature extraction by stack_blocks_dense() followed by
|
||||
subsampling gives identical results to feature extraction at the nominal
|
||||
network output stride using the simple self._stack_blocks_nondense() above.
|
||||
|
||||
Args:
|
||||
bottleneck: The bottleneck function.
|
||||
"""
|
||||
blocks = [
|
||||
resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
|
||||
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
|
||||
resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
|
||||
resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
|
||||
]
|
||||
nominal_stride = 8
|
||||
|
||||
# Test both odd and even input dimensions.
|
||||
height = 30
|
||||
width = 31
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
with slim.arg_scope([slim.batch_norm], is_training=False):
|
||||
for output_stride in [1, 2, 4, 8, None]:
|
||||
with tf.Graph().as_default():
|
||||
with self.test_session() as sess:
|
||||
tf.set_random_seed(0)
|
||||
inputs = create_test_input(1, height, width, 3)
|
||||
# Dense feature extraction followed by subsampling.
|
||||
output = resnet_utils.stack_blocks_dense(inputs,
|
||||
blocks,
|
||||
output_stride)
|
||||
if output_stride is None:
|
||||
factor = 1
|
||||
else:
|
||||
factor = nominal_stride // output_stride
|
||||
|
||||
output = resnet_utils.subsample(output, factor)
|
||||
# Make the two networks use the same weights.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
# Feature extraction at the nominal network rate.
|
||||
expected = self._stack_blocks_nondense(inputs, blocks)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output, expected = sess.run([output, expected])
|
||||
self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testAtrousValuesBottleneck(self):
|
||||
self._atrousValues(resnet_v2.bottleneck)
|
||||
|
||||
|
||||
class ResnetCompleteNetworkTest(tf.test.TestCase):
|
||||
"""Tests with complete small ResNet v2 networks."""
|
||||
|
||||
def _resnet_small(self,
|
||||
inputs,
|
||||
num_classes=None,
|
||||
is_training=True,
|
||||
global_pool=True,
|
||||
output_stride=None,
|
||||
include_root_block=True,
|
||||
reuse=None,
|
||||
scope='resnet_v2_small'):
|
||||
"""A shallow and thin ResNet v2 for faster tests."""
|
||||
bottleneck = resnet_v2.bottleneck
|
||||
blocks = [
|
||||
resnet_utils.Block(
|
||||
'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
|
||||
resnet_utils.Block(
|
||||
'block4', bottleneck, [(32, 8, 1)] * 2)]
|
||||
return resnet_v2.resnet_v2(inputs, blocks, num_classes,
|
||||
is_training=is_training,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride,
|
||||
include_root_block=include_root_block,
|
||||
reuse=reuse,
|
||||
scope=scope)
|
||||
|
||||
def testClassificationEndPoints(self):
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 224, 224, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
logits, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
self.assertTrue(logits.op.name.startswith('resnet/logits'))
|
||||
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
|
||||
self.assertTrue('predictions' in end_points)
|
||||
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
|
||||
[2, 1, 1, num_classes])
|
||||
|
||||
def testClassificationShapes(self):
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 224, 224, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 28, 28, 4],
|
||||
'resnet/block2': [2, 14, 14, 8],
|
||||
'resnet/block3': [2, 7, 7, 16],
|
||||
'resnet/block4': [2, 7, 7, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 321, 321, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 41, 41, 4],
|
||||
'resnet/block2': [2, 21, 21, 8],
|
||||
'resnet/block3': [2, 11, 11, 16],
|
||||
'resnet/block4': [2, 11, 11, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testRootlessFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
inputs = create_test_input(2, 128, 128, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
include_root_block=False,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 64, 64, 4],
|
||||
'resnet/block2': [2, 32, 32, 8],
|
||||
'resnet/block3': [2, 16, 16, 16],
|
||||
'resnet/block4': [2, 16, 16, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testAtrousFullyConvolutionalEndpointShapes(self):
|
||||
global_pool = False
|
||||
num_classes = 10
|
||||
output_stride = 8
|
||||
inputs = create_test_input(2, 321, 321, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
_, end_points = self._resnet_small(inputs,
|
||||
num_classes,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride,
|
||||
scope='resnet')
|
||||
endpoint_to_shape = {
|
||||
'resnet/block1': [2, 41, 41, 4],
|
||||
'resnet/block2': [2, 41, 41, 8],
|
||||
'resnet/block3': [2, 41, 41, 16],
|
||||
'resnet/block4': [2, 41, 41, 32]}
|
||||
for endpoint in endpoint_to_shape:
|
||||
shape = endpoint_to_shape[endpoint]
|
||||
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
|
||||
|
||||
def testAtrousFullyConvolutionalValues(self):
|
||||
"""Verify dense feature extraction with atrous convolution."""
|
||||
nominal_stride = 32
|
||||
for output_stride in [4, 8, 16, 32, None]:
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
with tf.Graph().as_default():
|
||||
with self.test_session() as sess:
|
||||
tf.set_random_seed(0)
|
||||
inputs = create_test_input(2, 81, 81, 3)
|
||||
# Dense feature extraction followed by subsampling.
|
||||
output, _ = self._resnet_small(inputs, None,
|
||||
is_training=False,
|
||||
global_pool=False,
|
||||
output_stride=output_stride)
|
||||
if output_stride is None:
|
||||
factor = 1
|
||||
else:
|
||||
factor = nominal_stride // output_stride
|
||||
output = resnet_utils.subsample(output, factor)
|
||||
# Make the two networks use the same weights.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
# Feature extraction at the nominal network rate.
|
||||
expected, _ = self._resnet_small(inputs, None,
|
||||
is_training=False,
|
||||
global_pool=False)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllClose(output.eval(), expected.eval(),
|
||||
atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testUnknownBatchSize(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = True
|
||||
num_classes = 10
|
||||
inputs = create_test_input(None, height, width, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
logits, _ = self._resnet_small(inputs, num_classes,
|
||||
global_pool=global_pool,
|
||||
scope='resnet')
|
||||
self.assertTrue(logits.op.name.startswith('resnet/logits'))
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[None, 1, 1, num_classes])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(logits, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
|
||||
|
||||
def testFullyConvolutionalUnknownHeightWidth(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = False
|
||||
inputs = create_test_input(batch, None, None, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
output, _ = self._resnet_small(inputs, None,
|
||||
global_pool=global_pool)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[batch, None, None, 32])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(output, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 3, 3, 32))
|
||||
|
||||
def testAtrousFullyConvolutionalUnknownHeightWidth(self):
|
||||
batch = 2
|
||||
height, width = 65, 65
|
||||
global_pool = False
|
||||
output_stride = 8
|
||||
inputs = create_test_input(batch, None, None, 3)
|
||||
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
|
||||
output, _ = self._resnet_small(inputs,
|
||||
None,
|
||||
global_pool=global_pool,
|
||||
output_stride=output_stride)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[batch, None, None, 32])
|
||||
images = create_test_input(batch, height, width, 3)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(output, {inputs: images.eval()})
|
||||
self.assertEqual(output.shape, (batch, 9, 9, 32))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains model definitions for versions of the Oxford VGG network.
|
||||
|
||||
These model definitions were introduced in the following technical report:
|
||||
|
||||
Very Deep Convolutional Networks For Large-Scale Image Recognition
|
||||
Karen Simonyan and Andrew Zisserman
|
||||
arXiv technical report, 2015
|
||||
PDF: http://arxiv.org/pdf/1409.1556.pdf
|
||||
ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
|
||||
CC-BY-4.0
|
||||
|
||||
More information can be obtained from the VGG website:
|
||||
www.robots.ox.ac.uk/~vgg/research/very_deep/
|
||||
|
||||
Usage:
|
||||
with slim.arg_scope(vgg.vgg_arg_scope()):
|
||||
outputs, end_points = vgg.vgg_a(inputs)
|
||||
|
||||
with slim.arg_scope(vgg.vgg_arg_scope()):
|
||||
outputs, end_points = vgg.vgg_16(inputs)
|
||||
|
||||
@@vgg_a
|
||||
@@vgg_16
|
||||
@@vgg_19
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
def vgg_arg_scope(weight_decay=0.0005):
|
||||
"""Defines the VGG arg scope.
|
||||
|
||||
Args:
|
||||
weight_decay: The l2 regularization coefficient.
|
||||
|
||||
Returns:
|
||||
An arg_scope.
|
||||
"""
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||
activation_fn=tf.nn.relu,
|
||||
weights_regularizer=slim.l2_regularizer(weight_decay),
|
||||
biases_initializer=tf.zeros_initializer):
|
||||
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
|
||||
return arg_sc
|
||||
|
||||
|
||||
def vgg_a(inputs,
|
||||
num_classes=1000,
|
||||
is_training=True,
|
||||
dropout_keep_prob=0.5,
|
||||
spatial_squeeze=True,
|
||||
scope='vgg_a'):
|
||||
"""Oxford Net VGG 11-Layers version A Example.
|
||||
|
||||
Note: All the fully_connected layers have been transformed to conv2d layers.
|
||||
To use in classification mode, resize input to 224x224.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of size [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes.
|
||||
is_training: whether or not the model is being trained.
|
||||
dropout_keep_prob: the probability that activations are kept in the dropout
|
||||
layers during training.
|
||||
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
|
||||
outputs. Useful to remove unnecessary dimensions for classification.
|
||||
scope: Optional scope for the variables.
|
||||
|
||||
Returns:
|
||||
the last op containing the log predictions and end_points dict.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
|
||||
end_points_collection = sc.name + '_end_points'
|
||||
# Collect outputs for conv2d, fully_connected and max_pool2d.
|
||||
with slim.arg_scope([slim.conv2d, slim.max_pool2d],
|
||||
outputs_collections=end_points_collection):
|
||||
net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool1')
|
||||
net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool2')
|
||||
net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool3')
|
||||
net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool4')
|
||||
net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool5')
|
||||
# Use conv2d instead of fully_connected layers.
|
||||
net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout6')
|
||||
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout7')
|
||||
net = slim.conv2d(net, num_classes, [1, 1],
|
||||
activation_fn=None,
|
||||
normalizer_fn=None,
|
||||
scope='fc8')
|
||||
# Convert end_points_collection into a end_point dict.
|
||||
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
||||
if spatial_squeeze:
|
||||
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
|
||||
end_points[sc.name + '/fc8'] = net
|
||||
return net, end_points
|
||||
vgg_a.default_image_size = 224
|
||||
|
||||
|
||||
def vgg_16(inputs,
|
||||
num_classes=1000,
|
||||
is_training=True,
|
||||
dropout_keep_prob=0.5,
|
||||
spatial_squeeze=True,
|
||||
scope='vgg_16'):
|
||||
"""Oxford Net VGG 16-Layers version D Example.
|
||||
|
||||
Note: All the fully_connected layers have been transformed to conv2d layers.
|
||||
To use in classification mode, resize input to 224x224.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of size [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes.
|
||||
is_training: whether or not the model is being trained.
|
||||
dropout_keep_prob: the probability that activations are kept in the dropout
|
||||
layers during training.
|
||||
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
|
||||
outputs. Useful to remove unnecessary dimensions for classification.
|
||||
scope: Optional scope for the variables.
|
||||
|
||||
Returns:
|
||||
the last op containing the log predictions and end_points dict.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
|
||||
end_points_collection = sc.name + '_end_points'
|
||||
# Collect outputs for conv2d, fully_connected and max_pool2d.
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
|
||||
outputs_collections=end_points_collection):
|
||||
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool1')
|
||||
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool2')
|
||||
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool3')
|
||||
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool4')
|
||||
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool5')
|
||||
# Use conv2d instead of fully_connected layers.
|
||||
net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout6')
|
||||
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout7')
|
||||
net = slim.conv2d(net, num_classes, [1, 1],
|
||||
activation_fn=None,
|
||||
normalizer_fn=None,
|
||||
scope='fc8')
|
||||
# Convert end_points_collection into a end_point dict.
|
||||
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
||||
if spatial_squeeze:
|
||||
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
|
||||
end_points[sc.name + '/fc8'] = net
|
||||
return net, end_points
|
||||
vgg_16.default_image_size = 224
|
||||
|
||||
|
||||
def vgg_19(inputs,
|
||||
num_classes=1000,
|
||||
is_training=True,
|
||||
dropout_keep_prob=0.5,
|
||||
spatial_squeeze=True,
|
||||
scope='vgg_19'):
|
||||
"""Oxford Net VGG 19-Layers version E Example.
|
||||
|
||||
Note: All the fully_connected layers have been transformed to conv2d layers.
|
||||
To use in classification mode, resize input to 224x224.
|
||||
|
||||
Args:
|
||||
inputs: a tensor of size [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes.
|
||||
is_training: whether or not the model is being trained.
|
||||
dropout_keep_prob: the probability that activations are kept in the dropout
|
||||
layers during training.
|
||||
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
|
||||
outputs. Useful to remove unnecessary dimensions for classification.
|
||||
scope: Optional scope for the variables.
|
||||
|
||||
Returns:
|
||||
the last op containing the log predictions and end_points dict.
|
||||
"""
|
||||
with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
|
||||
end_points_collection = sc.name + '_end_points'
|
||||
# Collect outputs for conv2d, fully_connected and max_pool2d.
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
|
||||
outputs_collections=end_points_collection):
|
||||
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool1')
|
||||
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool2')
|
||||
net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool3')
|
||||
net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool4')
|
||||
net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
|
||||
net = slim.max_pool2d(net, [2, 2], scope='pool5')
|
||||
# Use conv2d instead of fully_connected layers.
|
||||
net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout6')
|
||||
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
|
||||
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||||
scope='dropout7')
|
||||
net = slim.conv2d(net, num_classes, [1, 1],
|
||||
activation_fn=None,
|
||||
normalizer_fn=None,
|
||||
scope='fc8')
|
||||
# Convert end_points_collection into a end_point dict.
|
||||
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
||||
if spatial_squeeze:
|
||||
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
|
||||
end_points[sc.name + '/fc8'] = net
|
||||
return net, end_points
|
||||
vgg_19.default_image_size = 224
|
||||
|
||||
# Alias
|
||||
vgg_d = vgg_16
|
||||
vgg_e = vgg_19
|
|
@ -0,0 +1,455 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for slim.nets.vgg."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from nets import vgg
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
class VGGATest(tf.test.TestCase):
|
||||
|
||||
def testBuild(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_a(inputs, num_classes)
|
||||
self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
|
||||
def testFullyConvolutional(self):
|
||||
batch_size = 1
|
||||
height, width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False)
|
||||
self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, 2, 2, num_classes])
|
||||
|
||||
def testEndPoints(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
_, end_points = vgg.vgg_a(inputs, num_classes)
|
||||
expected_names = ['vgg_a/conv1/conv1_1',
|
||||
'vgg_a/pool1',
|
||||
'vgg_a/conv2/conv2_1',
|
||||
'vgg_a/pool2',
|
||||
'vgg_a/conv3/conv3_1',
|
||||
'vgg_a/conv3/conv3_2',
|
||||
'vgg_a/pool3',
|
||||
'vgg_a/conv4/conv4_1',
|
||||
'vgg_a/conv4/conv4_2',
|
||||
'vgg_a/pool4',
|
||||
'vgg_a/conv5/conv5_1',
|
||||
'vgg_a/conv5/conv5_2',
|
||||
'vgg_a/pool5',
|
||||
'vgg_a/fc6',
|
||||
'vgg_a/fc7',
|
||||
'vgg_a/fc8'
|
||||
]
|
||||
self.assertSetEqual(set(end_points.keys()), set(expected_names))
|
||||
|
||||
def testModelVariables(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
vgg.vgg_a(inputs, num_classes)
|
||||
expected_names = ['vgg_a/conv1/conv1_1/weights',
|
||||
'vgg_a/conv1/conv1_1/biases',
|
||||
'vgg_a/conv2/conv2_1/weights',
|
||||
'vgg_a/conv2/conv2_1/biases',
|
||||
'vgg_a/conv3/conv3_1/weights',
|
||||
'vgg_a/conv3/conv3_1/biases',
|
||||
'vgg_a/conv3/conv3_2/weights',
|
||||
'vgg_a/conv3/conv3_2/biases',
|
||||
'vgg_a/conv4/conv4_1/weights',
|
||||
'vgg_a/conv4/conv4_1/biases',
|
||||
'vgg_a/conv4/conv4_2/weights',
|
||||
'vgg_a/conv4/conv4_2/biases',
|
||||
'vgg_a/conv5/conv5_1/weights',
|
||||
'vgg_a/conv5/conv5_1/biases',
|
||||
'vgg_a/conv5/conv5_2/weights',
|
||||
'vgg_a/conv5/conv5_2/biases',
|
||||
'vgg_a/fc6/weights',
|
||||
'vgg_a/fc6/biases',
|
||||
'vgg_a/fc7/weights',
|
||||
'vgg_a/fc7/biases',
|
||||
'vgg_a/fc8/weights',
|
||||
'vgg_a/fc8/biases',
|
||||
]
|
||||
model_variables = [v.op.name for v in slim.get_model_variables()]
|
||||
self.assertSetEqual(set(model_variables), set(expected_names))
|
||||
|
||||
def testEvaluation(self):
|
||||
batch_size = 2
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
eval_inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_a(eval_inputs, is_training=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
|
||||
|
||||
def testTrainEvalWithReuse(self):
|
||||
train_batch_size = 2
|
||||
eval_batch_size = 1
|
||||
train_height, train_width = 224, 224
|
||||
eval_height, eval_width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
train_inputs = tf.random_uniform(
|
||||
(train_batch_size, train_height, train_width, 3))
|
||||
logits, _ = vgg.vgg_a(train_inputs)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[train_batch_size, num_classes])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
eval_inputs = tf.random_uniform(
|
||||
(eval_batch_size, eval_height, eval_width, 3))
|
||||
logits, _ = vgg.vgg_a(eval_inputs, is_training=False,
|
||||
spatial_squeeze=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[eval_batch_size, 2, 2, num_classes])
|
||||
logits = tf.reduce_mean(logits, [1, 2])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
|
||||
|
||||
def testForward(self):
|
||||
batch_size = 1
|
||||
height, width = 224, 224
|
||||
with self.test_session() as sess:
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_a(inputs)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(logits)
|
||||
self.assertTrue(output.any())
|
||||
|
||||
|
||||
class VGG16Test(tf.test.TestCase):
|
||||
|
||||
def testBuild(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_16(inputs, num_classes)
|
||||
self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
|
||||
def testFullyConvolutional(self):
|
||||
batch_size = 1
|
||||
height, width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False)
|
||||
self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, 2, 2, num_classes])
|
||||
|
||||
def testEndPoints(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
_, end_points = vgg.vgg_16(inputs, num_classes)
|
||||
expected_names = ['vgg_16/conv1/conv1_1',
|
||||
'vgg_16/conv1/conv1_2',
|
||||
'vgg_16/pool1',
|
||||
'vgg_16/conv2/conv2_1',
|
||||
'vgg_16/conv2/conv2_2',
|
||||
'vgg_16/pool2',
|
||||
'vgg_16/conv3/conv3_1',
|
||||
'vgg_16/conv3/conv3_2',
|
||||
'vgg_16/conv3/conv3_3',
|
||||
'vgg_16/pool3',
|
||||
'vgg_16/conv4/conv4_1',
|
||||
'vgg_16/conv4/conv4_2',
|
||||
'vgg_16/conv4/conv4_3',
|
||||
'vgg_16/pool4',
|
||||
'vgg_16/conv5/conv5_1',
|
||||
'vgg_16/conv5/conv5_2',
|
||||
'vgg_16/conv5/conv5_3',
|
||||
'vgg_16/pool5',
|
||||
'vgg_16/fc6',
|
||||
'vgg_16/fc7',
|
||||
'vgg_16/fc8'
|
||||
]
|
||||
self.assertSetEqual(set(end_points.keys()), set(expected_names))
|
||||
|
||||
def testModelVariables(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
vgg.vgg_16(inputs, num_classes)
|
||||
expected_names = ['vgg_16/conv1/conv1_1/weights',
|
||||
'vgg_16/conv1/conv1_1/biases',
|
||||
'vgg_16/conv1/conv1_2/weights',
|
||||
'vgg_16/conv1/conv1_2/biases',
|
||||
'vgg_16/conv2/conv2_1/weights',
|
||||
'vgg_16/conv2/conv2_1/biases',
|
||||
'vgg_16/conv2/conv2_2/weights',
|
||||
'vgg_16/conv2/conv2_2/biases',
|
||||
'vgg_16/conv3/conv3_1/weights',
|
||||
'vgg_16/conv3/conv3_1/biases',
|
||||
'vgg_16/conv3/conv3_2/weights',
|
||||
'vgg_16/conv3/conv3_2/biases',
|
||||
'vgg_16/conv3/conv3_3/weights',
|
||||
'vgg_16/conv3/conv3_3/biases',
|
||||
'vgg_16/conv4/conv4_1/weights',
|
||||
'vgg_16/conv4/conv4_1/biases',
|
||||
'vgg_16/conv4/conv4_2/weights',
|
||||
'vgg_16/conv4/conv4_2/biases',
|
||||
'vgg_16/conv4/conv4_3/weights',
|
||||
'vgg_16/conv4/conv4_3/biases',
|
||||
'vgg_16/conv5/conv5_1/weights',
|
||||
'vgg_16/conv5/conv5_1/biases',
|
||||
'vgg_16/conv5/conv5_2/weights',
|
||||
'vgg_16/conv5/conv5_2/biases',
|
||||
'vgg_16/conv5/conv5_3/weights',
|
||||
'vgg_16/conv5/conv5_3/biases',
|
||||
'vgg_16/fc6/weights',
|
||||
'vgg_16/fc6/biases',
|
||||
'vgg_16/fc7/weights',
|
||||
'vgg_16/fc7/biases',
|
||||
'vgg_16/fc8/weights',
|
||||
'vgg_16/fc8/biases',
|
||||
]
|
||||
model_variables = [v.op.name for v in slim.get_model_variables()]
|
||||
self.assertSetEqual(set(model_variables), set(expected_names))
|
||||
|
||||
def testEvaluation(self):
|
||||
batch_size = 2
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
eval_inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_16(eval_inputs, is_training=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
|
||||
|
||||
def testTrainEvalWithReuse(self):
|
||||
train_batch_size = 2
|
||||
eval_batch_size = 1
|
||||
train_height, train_width = 224, 224
|
||||
eval_height, eval_width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
train_inputs = tf.random_uniform(
|
||||
(train_batch_size, train_height, train_width, 3))
|
||||
logits, _ = vgg.vgg_16(train_inputs)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[train_batch_size, num_classes])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
eval_inputs = tf.random_uniform(
|
||||
(eval_batch_size, eval_height, eval_width, 3))
|
||||
logits, _ = vgg.vgg_16(eval_inputs, is_training=False,
|
||||
spatial_squeeze=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[eval_batch_size, 2, 2, num_classes])
|
||||
logits = tf.reduce_mean(logits, [1, 2])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
|
||||
|
||||
def testForward(self):
|
||||
batch_size = 1
|
||||
height, width = 224, 224
|
||||
with self.test_session() as sess:
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_16(inputs)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(logits)
|
||||
self.assertTrue(output.any())
|
||||
|
||||
|
||||
class VGG19Test(tf.test.TestCase):
|
||||
|
||||
def testBuild(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_19(inputs, num_classes)
|
||||
self.assertEquals(logits.op.name, 'vgg_19/fc8/squeezed')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
|
||||
def testFullyConvolutional(self):
|
||||
batch_size = 1
|
||||
height, width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_19(inputs, num_classes, spatial_squeeze=False)
|
||||
self.assertEquals(logits.op.name, 'vgg_19/fc8/BiasAdd')
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, 2, 2, num_classes])
|
||||
|
||||
def testEndPoints(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
_, end_points = vgg.vgg_19(inputs, num_classes)
|
||||
expected_names = [
|
||||
'vgg_19/conv1/conv1_1',
|
||||
'vgg_19/conv1/conv1_2',
|
||||
'vgg_19/pool1',
|
||||
'vgg_19/conv2/conv2_1',
|
||||
'vgg_19/conv2/conv2_2',
|
||||
'vgg_19/pool2',
|
||||
'vgg_19/conv3/conv3_1',
|
||||
'vgg_19/conv3/conv3_2',
|
||||
'vgg_19/conv3/conv3_3',
|
||||
'vgg_19/conv3/conv3_4',
|
||||
'vgg_19/pool3',
|
||||
'vgg_19/conv4/conv4_1',
|
||||
'vgg_19/conv4/conv4_2',
|
||||
'vgg_19/conv4/conv4_3',
|
||||
'vgg_19/conv4/conv4_4',
|
||||
'vgg_19/pool4',
|
||||
'vgg_19/conv5/conv5_1',
|
||||
'vgg_19/conv5/conv5_2',
|
||||
'vgg_19/conv5/conv5_3',
|
||||
'vgg_19/conv5/conv5_4',
|
||||
'vgg_19/pool5',
|
||||
'vgg_19/fc6',
|
||||
'vgg_19/fc7',
|
||||
'vgg_19/fc8'
|
||||
]
|
||||
self.assertSetEqual(set(end_points.keys()), set(expected_names))
|
||||
|
||||
def testModelVariables(self):
|
||||
batch_size = 5
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
vgg.vgg_19(inputs, num_classes)
|
||||
expected_names = [
|
||||
'vgg_19/conv1/conv1_1/weights',
|
||||
'vgg_19/conv1/conv1_1/biases',
|
||||
'vgg_19/conv1/conv1_2/weights',
|
||||
'vgg_19/conv1/conv1_2/biases',
|
||||
'vgg_19/conv2/conv2_1/weights',
|
||||
'vgg_19/conv2/conv2_1/biases',
|
||||
'vgg_19/conv2/conv2_2/weights',
|
||||
'vgg_19/conv2/conv2_2/biases',
|
||||
'vgg_19/conv3/conv3_1/weights',
|
||||
'vgg_19/conv3/conv3_1/biases',
|
||||
'vgg_19/conv3/conv3_2/weights',
|
||||
'vgg_19/conv3/conv3_2/biases',
|
||||
'vgg_19/conv3/conv3_3/weights',
|
||||
'vgg_19/conv3/conv3_3/biases',
|
||||
'vgg_19/conv3/conv3_4/weights',
|
||||
'vgg_19/conv3/conv3_4/biases',
|
||||
'vgg_19/conv4/conv4_1/weights',
|
||||
'vgg_19/conv4/conv4_1/biases',
|
||||
'vgg_19/conv4/conv4_2/weights',
|
||||
'vgg_19/conv4/conv4_2/biases',
|
||||
'vgg_19/conv4/conv4_3/weights',
|
||||
'vgg_19/conv4/conv4_3/biases',
|
||||
'vgg_19/conv4/conv4_4/weights',
|
||||
'vgg_19/conv4/conv4_4/biases',
|
||||
'vgg_19/conv5/conv5_1/weights',
|
||||
'vgg_19/conv5/conv5_1/biases',
|
||||
'vgg_19/conv5/conv5_2/weights',
|
||||
'vgg_19/conv5/conv5_2/biases',
|
||||
'vgg_19/conv5/conv5_3/weights',
|
||||
'vgg_19/conv5/conv5_3/biases',
|
||||
'vgg_19/conv5/conv5_4/weights',
|
||||
'vgg_19/conv5/conv5_4/biases',
|
||||
'vgg_19/fc6/weights',
|
||||
'vgg_19/fc6/biases',
|
||||
'vgg_19/fc7/weights',
|
||||
'vgg_19/fc7/biases',
|
||||
'vgg_19/fc8/weights',
|
||||
'vgg_19/fc8/biases',
|
||||
]
|
||||
model_variables = [v.op.name for v in slim.get_model_variables()]
|
||||
self.assertSetEqual(set(model_variables), set(expected_names))
|
||||
|
||||
def testEvaluation(self):
|
||||
batch_size = 2
|
||||
height, width = 224, 224
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
eval_inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_19(eval_inputs, is_training=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[batch_size, num_classes])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
|
||||
|
||||
def testTrainEvalWithReuse(self):
|
||||
train_batch_size = 2
|
||||
eval_batch_size = 1
|
||||
train_height, train_width = 224, 224
|
||||
eval_height, eval_width = 256, 256
|
||||
num_classes = 1000
|
||||
with self.test_session():
|
||||
train_inputs = tf.random_uniform(
|
||||
(train_batch_size, train_height, train_width, 3))
|
||||
logits, _ = vgg.vgg_19(train_inputs)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[train_batch_size, num_classes])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
eval_inputs = tf.random_uniform(
|
||||
(eval_batch_size, eval_height, eval_width, 3))
|
||||
logits, _ = vgg.vgg_19(eval_inputs, is_training=False,
|
||||
spatial_squeeze=False)
|
||||
self.assertListEqual(logits.get_shape().as_list(),
|
||||
[eval_batch_size, 2, 2, num_classes])
|
||||
logits = tf.reduce_mean(logits, [1, 2])
|
||||
predictions = tf.argmax(logits, 1)
|
||||
self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
|
||||
|
||||
def testForward(self):
|
||||
batch_size = 1
|
||||
height, width = 224, 224
|
||||
with self.test_session() as sess:
|
||||
inputs = tf.random_uniform((batch_size, height, width, 3))
|
||||
logits, _ = vgg.vgg_19(inputs)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
output = sess.run(logits)
|
||||
self.assertTrue(output.any())
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,291 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
# Modified 2017 Microsoft Corporation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generic training script that trains a model using a given dataset."""
|
||||
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
import os
|
||||
import functools
|
||||
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from deployment import model_deploy
|
||||
from nets import resnet_v1 # Needed to be modified, see https://github.com/tensorflow/models/issues/533
|
||||
from tensorflow.contrib.training.python.training import evaluation
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
''' Enumerate the flags '''
|
||||
tf.app.flags.DEFINE_string('train_dir',
|
||||
'D:\\tf\\models',
|
||||
'Directory where checkpoints and event logs are written to.')
|
||||
tf.app.flags.DEFINE_string('dataset_name', 'aerial', 'The name of the dataset to load.')
|
||||
tf.app.flags.DEFINE_string('dataset_dir',
|
||||
'D:\\combined\\train_subsample',
|
||||
'The directory where the dataset files are stored.')
|
||||
tf.app.flags.DEFINE_string('checkpoint_path',
|
||||
'D:\\tf\\resnet_v1_152.ckpt',
|
||||
'The path to a checkpoint from which to fine-tune.')
|
||||
|
||||
tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', 'resnet_v1_152/logits',
|
||||
'Comma-separated list of scopes of variables to exclude when restoring '
|
||||
'from a checkpoint.')
|
||||
tf.app.flags.DEFINE_string('trainable_scopes', 'resnet_v1_152/logits',
|
||||
'Comma-separated list of scopes to filter the set of variables to train.'
|
||||
'By default, None would train all the variables.')
|
||||
|
||||
tf.app.flags.DEFINE_integer('num_clones', 1, 'Number of model clones to deploy.')
|
||||
tf.app.flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
|
||||
tf.app.flags.DEFINE_integer('num_readers', 4, 'The number of parallel readers that read data from the dataset.')
|
||||
tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4, 'The number of threads used to create the batches.')
|
||||
tf.app.flags.DEFINE_integer('log_every_n_steps', 10, 'The frequency with which logs are printed.')
|
||||
tf.app.flags.DEFINE_integer('save_summaries_secs', 600, 'The frequency with which summaries are saved, in seconds.')
|
||||
tf.app.flags.DEFINE_integer('save_interval_secs', 600, 'The frequency with which the model is saved, in seconds.')
|
||||
|
||||
tf.app.flags.DEFINE_float('weight_decay', 0.00004, 'The weight decay on the model weights.')
|
||||
tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
|
||||
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
|
||||
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
|
||||
tf.app.flags.DEFINE_float('learning_rate', 0.05, 'Initial learning rate.')
|
||||
tf.app.flags.DEFINE_float('label_smoothing', 0.0, 'The amount of label smoothing.')
|
||||
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.9, 'Learning rate decay factor.')
|
||||
tf.app.flags.DEFINE_float('num_epochs_per_decay', 2.0, 'Number of epochs after which learning rate decays.')
|
||||
tf.app.flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of gradients to collect before updating params.')
|
||||
tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of samples in each batch.')
|
||||
tf.app.flags.DEFINE_integer('max_number_of_steps', 4000, 'The maximum number of training steps.')
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
def get_image_and_class_count(dataset_dir, split_name):
|
||||
df = pd.read_csv(os.path.join(dataset_dir, 'dataset_split_info.csv'))
|
||||
image_count = len(df.loc[df['split_name'] == split_name].index)
|
||||
class_count = len(df['class_name'].unique())
|
||||
return(image_count, class_count)
|
||||
|
||||
def read_label_file(dataset_dir, filename='labels.txt'):
|
||||
labels_filename = os.path.join(dataset_dir, filename)
|
||||
with tf.gfile.Open(labels_filename, 'r') as f:
|
||||
lines = f.read().decode()
|
||||
lines = lines.split('\n')
|
||||
lines = filter(None, lines)
|
||||
|
||||
labels_to_class_names = {}
|
||||
for line in lines:
|
||||
index = line.index(':')
|
||||
labels_to_class_names[int(line[:index])] = line[index+1:]
|
||||
return(labels_to_class_names)
|
||||
|
||||
def mean_image_subtraction(image, means):
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError('Input must be of size [height, width, C>0]')
|
||||
num_channels = image.get_shape().as_list()[-1]
|
||||
if len(means) != num_channels:
|
||||
raise ValueError('len(means) must match the number of channels')
|
||||
|
||||
channels = tf.split(2, num_channels, image)
|
||||
for i in range(num_channels):
|
||||
channels[i] -= means[i]
|
||||
return(tf.concat(2, channels))
|
||||
|
||||
def get_preprocessing():
|
||||
def preprocessing_fn(image, output_height=224, output_width=224):
|
||||
''' Resize the image and subtract "mean" RGB values '''
|
||||
_R_MEAN = 123.68
|
||||
_G_MEAN = 116.78
|
||||
_B_MEAN = 103.94
|
||||
image = tf.expand_dims(image, 0)
|
||||
resized_image = tf.image.resize_bilinear(image, [output_height, output_width], align_corners=False)
|
||||
resized_image = tf.squeeze(resized_image)
|
||||
resized_image.set_shape([output_height, output_width, 3])
|
||||
image = tf.to_float(resized_image)
|
||||
return(mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]))
|
||||
return(preprocessing_fn)
|
||||
|
||||
def get_network_fn(num_classes, weight_decay=0.0):
|
||||
arg_scope = resnet_v1.resnet_arg_scope(weight_decay=weight_decay)
|
||||
func = resnet_v1.resnet_v1_152
|
||||
@functools.wraps(func)
|
||||
def network_fn(images):
|
||||
with slim.arg_scope(arg_scope):
|
||||
return func(images, num_classes)
|
||||
if hasattr(func, 'default_image_size'):
|
||||
network_fn.default_image_size = func.default_image_size
|
||||
return(network_fn)
|
||||
|
||||
def _add_variables_summaries(learning_rate):
|
||||
summaries = []
|
||||
for variable in slim.get_model_variables():
|
||||
summaries.append(tf.summary.image(variable.op.name, variable))
|
||||
summaries.append(tf.summary.scalar(learning_rate, name='training/Learning Rate'))
|
||||
return(summaries)
|
||||
|
||||
def _get_init_fn():
|
||||
if (FLAGS.checkpoint_path is None) or (tf.train.latest_checkpoint(FLAGS.train_dir)):
|
||||
return None
|
||||
|
||||
exclusions = []
|
||||
if FLAGS.checkpoint_exclude_scopes:
|
||||
exclusions = [scope.strip() for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
|
||||
|
||||
variables_to_restore = []
|
||||
for var in slim.get_model_variables():
|
||||
excluded = False
|
||||
for exclusion in exclusions:
|
||||
if var.op.name.startswith(exclusion):
|
||||
excluded = True
|
||||
break
|
||||
if not excluded:
|
||||
variables_to_restore.append(var)
|
||||
|
||||
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
|
||||
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
|
||||
else:
|
||||
checkpoint_path = FLAGS.checkpoint_path
|
||||
|
||||
tf.logging.info('Fine-tuning from {}'.format(checkpoint_path))
|
||||
|
||||
return(slim.assign_from_checkpoint_fn(checkpoint_path,
|
||||
variables_to_restore,
|
||||
ignore_missing_vars=False))
|
||||
|
||||
def _get_variables_to_train():
|
||||
scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
|
||||
variables_to_train = []
|
||||
for scope in scopes:
|
||||
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
|
||||
variables_to_train.extend(variables)
|
||||
return(variables_to_train)
|
||||
|
||||
def get_dataset(dataset_name, dataset_dir, image_count, class_count, split_name):
|
||||
slim = tf.contrib.slim
|
||||
items_to_descriptions = {'image': 'A color image.',
|
||||
'label': 'An integer in range(0, class_count)'}
|
||||
file_pattern = os.path.join(dataset_dir, '{}_{}_*.tfrecord'.format(dataset_name, split_name))
|
||||
reader = tf.TFRecordReader
|
||||
keys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
|
||||
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
|
||||
'image/class/label': tf.FixedLenFeature([], tf.int64,
|
||||
default_value=tf.zeros([], dtype=tf.int64))}
|
||||
items_to_handlers = {'image': slim.tfexample_decoder.Image(),
|
||||
'label': slim.tfexample_decoder.Tensor('image/class/label')}
|
||||
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
|
||||
labels_to_names = read_label_file(dataset_dir)
|
||||
return(slim.dataset.Dataset(data_sources=file_pattern,
|
||||
reader=reader,
|
||||
decoder=decoder,
|
||||
num_samples=image_count,
|
||||
items_to_descriptions=items_to_descriptions,
|
||||
num_classes=class_count,
|
||||
labels_to_names=labels_to_names,
|
||||
shuffle=True))
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
with tf.Graph().as_default():
|
||||
deploy_config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
|
||||
clone_on_cpu=FLAGS.clone_on_cpu,
|
||||
replica_id=0,
|
||||
num_replicas=1,
|
||||
num_ps_tasks=0)
|
||||
|
||||
with tf.device(deploy_config.variables_device()):
|
||||
global_step = slim.create_global_step()
|
||||
|
||||
image_count, class_count = get_image_and_class_count(FLAGS.dataset_dir, 'train')
|
||||
dataset = get_dataset('aerial', FLAGS.dataset_dir, image_count, class_count, 'train')
|
||||
network_fn = get_network_fn(num_classes=(dataset.num_classes), weight_decay=FLAGS.weight_decay)
|
||||
image_preprocessing_fn = get_preprocessing()
|
||||
|
||||
with tf.device(deploy_config.inputs_device()):
|
||||
provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
|
||||
num_readers=FLAGS.num_readers,
|
||||
common_queue_capacity=20 * FLAGS.batch_size,
|
||||
common_queue_min=10 * FLAGS.batch_size)
|
||||
[image, label] = provider.get(['image', 'label'])
|
||||
image = image_preprocessing_fn(image, 224, 224)
|
||||
images, labels = tf.train.batch([image, label],
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_threads=FLAGS.num_preprocessing_threads,
|
||||
capacity=5 * FLAGS.batch_size)
|
||||
labels = slim.one_hot_encoding(labels, dataset.num_classes)
|
||||
batch_queue = slim.prefetch_queue.prefetch_queue([images, labels], capacity=2 * deploy_config.num_clones)
|
||||
|
||||
def clone_fn(batch_queue):
|
||||
images, labels = batch_queue.dequeue()
|
||||
logits, end_points = network_fn(images)
|
||||
logits = tf.squeeze(logits) # added -- does this help?
|
||||
slim.losses.softmax_cross_entropy(logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
|
||||
return(end_points)
|
||||
|
||||
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||
|
||||
clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
|
||||
first_clone_scope = deploy_config.clone_scope(0)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
|
||||
|
||||
end_points = clones[0].outputs
|
||||
for end_point in end_points:
|
||||
x = end_points[end_point]
|
||||
summaries.add(tf.summary.histogram('activations/' + end_point, x))
|
||||
summaries.add(tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x)))
|
||||
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
|
||||
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
|
||||
for variable in slim.get_model_variables():
|
||||
summaries.add(tf.summary.histogram(variable.op.name, variable))
|
||||
|
||||
with tf.device(deploy_config.optimizer_device()):
|
||||
decay_steps = int(dataset.num_samples / FLAGS.batch_size * FLAGS.num_epochs_per_decay)
|
||||
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
|
||||
global_step,
|
||||
decay_steps,
|
||||
FLAGS.learning_rate_decay_factor,
|
||||
staircase=True,
|
||||
name='exponential_decay_learning_rate')
|
||||
optimizer = tf.train.RMSPropOptimizer(learning_rate,
|
||||
decay=FLAGS.rmsprop_decay,
|
||||
momentum=FLAGS.rmsprop_momentum,
|
||||
epsilon=FLAGS.opt_epsilon)
|
||||
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
|
||||
|
||||
|
||||
|
||||
variables_to_train = _get_variables_to_train()
|
||||
total_loss, clones_gradients = model_deploy.optimize_clones(clones, optimizer, var_list=variables_to_train)
|
||||
summaries.add(tf.summary.scalar('total_loss', total_loss))
|
||||
|
||||
grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step)
|
||||
update_ops.append(grad_updates)
|
||||
|
||||
update_op = tf.group(*update_ops)
|
||||
train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op')
|
||||
|
||||
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
|
||||
summary_op = tf.summary.merge(list(summaries), name='summary_op')
|
||||
|
||||
slim.learning.train(train_tensor,
|
||||
logdir=FLAGS.train_dir,
|
||||
master='',
|
||||
is_chief=True,
|
||||
init_fn=_get_init_fn(),
|
||||
summary_op=summary_op,
|
||||
number_of_steps=FLAGS.max_number_of_steps,
|
||||
log_every_n_steps=FLAGS.log_every_n_steps,
|
||||
save_summaries_secs=FLAGS.save_summaries_secs,
|
||||
save_interval_secs=FLAGS.save_interval_secs,
|
||||
sync_optimizer=None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
Загрузка…
Ссылка в новой задаче