[NNVM][TESTING] Add two testing symbols: dqn and dcgan (#1294)

This commit is contained in:
Lianmin Zheng 2018-06-18 03:38:23 +08:00 коммит произвёл Tianqi Chen
Родитель 21ece7526a
Коммит 42b189cbc0
7 изменённых файлов: 276 добавлений и 1 удалений

Просмотреть файл

@ -7,4 +7,6 @@ from . import mobilenet
from . import mlp
from . import resnet
from . import vgg
from . import dcgan
from . import dqn
from . import yolo2_detection

Просмотреть файл

@ -0,0 +1,90 @@
# pylint: disable=unused-argument
"""
Symbol of the generator of DCGAN
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
from .. import symbol as sym
from . utils import create_workload
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = sym.conv2d_transpose(data,
kernel_size=kshape,
strides=stride,
channels=oshape[0],
padding=(pad_y, pad_x),
output_padding=(adj_y, adj_x),
use_bias=False,
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = sym.batch_norm(net, epsilon=eps, name="%s_bn" % prefix)
net = sym.relu(net, name="%s_act" % prefix)
return net
def get_symbol(oshape, ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 32, "Only support 32x32 image"
assert oshape[-2] == 32, "Only support 32x32 image"
code = sym.Variable("data") if code is None else code
net = sym.dense(code, name="g1", units=4*4*ngf*4, use_bias=False)
net = sym.relu(net)
# 4 x 4
net = sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d(
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
net = sym.tanh(net)
return net
def get_workload(batch_size, oshape=(3, 32, 32), ngf=128, random_len=100, dtype="float32"):
"""Get benchmark workload for a DCGAN generator
Parameters
----------
batch_size : int
The batch size used in the model
oshape : tuple, optional
The shape of output image, layout="CHW"
ngf: int, optional
The number of final feature maps in the generator
random_len : int, optional
The length of random input
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(oshape=oshape, ngf=ngf)
return create_workload(net, batch_size, (random_len, ), dtype)

Просмотреть файл

@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Symbol of Nature DQN
Reference:
Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""
from .. import symbol as sym
from . utils import create_workload
def get_symbol(num_actions=18):
"""get symbol of nature dqn"""
data = sym.Variable(name='data')
net = sym.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name='conv1')
net = sym.relu(net, name='relu1')
net = sym.conv2d(net, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name='conv2')
net = sym.relu(net, name='relu2')
net = sym.conv2d(net, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name='conv3')
net = sym.relu(net, name='relu3')
net = sym.flatten(net, name='flatten')
net = sym.dense(net, units=512, name='fc4')
net = sym.relu(net, name='relu4')
net = sym.dense(net, units=num_actions, name='fc5')
return net
def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
"""Get benchmark workload for a Deep Q Network
Parameters
----------
batch_size : int
The batch size used in the model
num_actions : int, optional
Number of actions
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_actions=num_actions)
return create_workload(net, batch_size, image_shape, dtype)

Просмотреть файл

@ -1,6 +1,6 @@
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg
from . import mlp, resnet, vgg, dqn, dcgan
import nnvm.testing
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
@ -26,3 +26,11 @@ for num_layer in [11, 13, 16, 19]:
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]
# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
# dcgan generator
mx_dcgan = dcgan.get_symbol()
nnvm_dcgan = nnvm.testing.dcgan.get_workload(1)[0]

Просмотреть файл

@ -0,0 +1,63 @@
# pylint: disable=unused-argument
"""
The MXNet symbol of DCGAN generator
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
import mxnet as mx
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = mx.sym.Deconvolution(data,
kernel=kshape,
stride=stride,
pad=(pad_y, pad_x),
adj=(adj_y, adj_x),
num_filter=oshape[0],
no_bias=True,
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
return net
def get_symbol(oshape=(3, 32, 32), ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 32, "Only support 32x32 image"
assert oshape[-2] == 32, "Only support 32x32 image"
code = mx.sym.Variable("data") if code is None else code
net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True, flatten=False)
net = mx.sym.Activation(net, act_type='relu')
# 4 x 4
net = mx.sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d(
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
net = mx.sym.Activation(net, act_type='tanh')
return net

Просмотреть файл

@ -0,0 +1,27 @@
"""
The mxnet symbol of Nature DQN
Reference:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""
import mxnet as mx
def get_symbol(num_action=18):
data = mx.sym.Variable(name='data')
net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
num_filter=32, name='conv1')
net = mx.sym.Activation(net, act_type='relu', name='relu1')
net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
num_filter=64, name='conv2')
net = mx.sym.Activation(net, act_type='relu', name='relu2')
net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
num_filter=64, name='conv3')
net = mx.sym.Activation(net, act_type='relu', name='relu3')
net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
net = mx.sym.Activation(net, act_type='relu', name='relu4')
net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
return net

Просмотреть файл

@ -32,6 +32,18 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)
def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dqn
compare_graph(from_mx_sym, nnvm_sym)
def test_dcgan():
mx_sym = model_zoo.mx_dcgan
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dcgan
compare_graph(from_mx_sym, nnvm_sym)
def test_multi_outputs():
def compose(F, **kwargs):
x = F.sym.Variable('x')
@ -48,3 +60,5 @@ if __name__ == '__main__':
test_vgg()
test_resnet()
test_multi_outputs()
test_dqn()
test_dcgan()