[NNVM][TESTING] Add two testing symbols: dqn and dcgan (#1294)
This commit is contained in:
Родитель
21ece7526a
Коммит
42b189cbc0
|
@ -7,4 +7,6 @@ from . import mobilenet
|
||||||
from . import mlp
|
from . import mlp
|
||||||
from . import resnet
|
from . import resnet
|
||||||
from . import vgg
|
from . import vgg
|
||||||
|
from . import dcgan
|
||||||
|
from . import dqn
|
||||||
from . import yolo2_detection
|
from . import yolo2_detection
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Symbol of the generator of DCGAN
|
||||||
|
|
||||||
|
Adopted from:
|
||||||
|
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Radford, Alec, Luke Metz, and Soumith Chintala.
|
||||||
|
"Unsupervised representation learning with deep convolutional generative adversarial networks."
|
||||||
|
arXiv preprint arXiv:1511.06434 (2015).
|
||||||
|
"""
|
||||||
|
from .. import symbol as sym
|
||||||
|
from . utils import create_workload
|
||||||
|
|
||||||
|
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
|
||||||
|
"""a deconv layer that enlarges the feature map"""
|
||||||
|
target_shape = (oshape[-2], oshape[-1])
|
||||||
|
|
||||||
|
pad_y = (kshape[0] - 1) // 2
|
||||||
|
pad_x = (kshape[1] - 1) // 2
|
||||||
|
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
|
||||||
|
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
|
||||||
|
|
||||||
|
net = sym.conv2d_transpose(data,
|
||||||
|
kernel_size=kshape,
|
||||||
|
strides=stride,
|
||||||
|
channels=oshape[0],
|
||||||
|
padding=(pad_y, pad_x),
|
||||||
|
output_padding=(adj_y, adj_x),
|
||||||
|
use_bias=False,
|
||||||
|
name=name)
|
||||||
|
return net
|
||||||
|
|
||||||
|
def deconv2d_bn_relu(data, prefix, **kwargs):
|
||||||
|
"""a block of deconv + batch norm + relu"""
|
||||||
|
eps = 1e-5 + 1e-12
|
||||||
|
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
|
||||||
|
net = sym.batch_norm(net, epsilon=eps, name="%s_bn" % prefix)
|
||||||
|
net = sym.relu(net, name="%s_act" % prefix)
|
||||||
|
return net
|
||||||
|
|
||||||
|
def get_symbol(oshape, ngf=128, code=None):
|
||||||
|
"""get symbol of dcgan generator"""
|
||||||
|
assert oshape[-1] == 32, "Only support 32x32 image"
|
||||||
|
assert oshape[-2] == 32, "Only support 32x32 image"
|
||||||
|
|
||||||
|
code = sym.Variable("data") if code is None else code
|
||||||
|
net = sym.dense(code, name="g1", units=4*4*ngf*4, use_bias=False)
|
||||||
|
net = sym.relu(net)
|
||||||
|
# 4 x 4
|
||||||
|
net = sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
|
||||||
|
# 8 x 8
|
||||||
|
net = deconv2d_bn_relu(
|
||||||
|
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
|
||||||
|
# 16x16
|
||||||
|
net = deconv2d_bn_relu(
|
||||||
|
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
|
||||||
|
# 32x32
|
||||||
|
net = deconv2d(
|
||||||
|
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
|
||||||
|
net = sym.tanh(net)
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def get_workload(batch_size, oshape=(3, 32, 32), ngf=128, random_len=100, dtype="float32"):
|
||||||
|
"""Get benchmark workload for a DCGAN generator
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size : int
|
||||||
|
The batch size used in the model
|
||||||
|
oshape : tuple, optional
|
||||||
|
The shape of output image, layout="CHW"
|
||||||
|
ngf: int, optional
|
||||||
|
The number of final feature maps in the generator
|
||||||
|
random_len : int, optional
|
||||||
|
The length of random input
|
||||||
|
dtype : str, optional
|
||||||
|
The data type
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
net : nnvm.symbol
|
||||||
|
The computational graph
|
||||||
|
params : dict of str to NDArray
|
||||||
|
The parameters.
|
||||||
|
"""
|
||||||
|
net = get_symbol(oshape=oshape, ngf=ngf)
|
||||||
|
return create_workload(net, batch_size, (random_len, ), dtype)
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Symbol of Nature DQN
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning."
|
||||||
|
Nature 518.7540 (2015): 529.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .. import symbol as sym
|
||||||
|
from . utils import create_workload
|
||||||
|
|
||||||
|
def get_symbol(num_actions=18):
|
||||||
|
"""get symbol of nature dqn"""
|
||||||
|
data = sym.Variable(name='data')
|
||||||
|
net = sym.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
|
||||||
|
channels=32, name='conv1')
|
||||||
|
net = sym.relu(net, name='relu1')
|
||||||
|
net = sym.conv2d(net, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
|
||||||
|
channels=64, name='conv2')
|
||||||
|
net = sym.relu(net, name='relu2')
|
||||||
|
net = sym.conv2d(net, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
|
||||||
|
channels=64, name='conv3')
|
||||||
|
net = sym.relu(net, name='relu3')
|
||||||
|
net = sym.flatten(net, name='flatten')
|
||||||
|
net = sym.dense(net, units=512, name='fc4')
|
||||||
|
net = sym.relu(net, name='relu4')
|
||||||
|
net = sym.dense(net, units=num_actions, name='fc5')
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
|
||||||
|
"""Get benchmark workload for a Deep Q Network
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size : int
|
||||||
|
The batch size used in the model
|
||||||
|
num_actions : int, optional
|
||||||
|
Number of actions
|
||||||
|
image_shape : tuple, optional
|
||||||
|
The input image shape
|
||||||
|
dtype : str, optional
|
||||||
|
The data type
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
net : nnvm.symbol
|
||||||
|
The computational graph
|
||||||
|
params : dict of str to NDArray
|
||||||
|
The parameters.
|
||||||
|
"""
|
||||||
|
net = get_symbol(num_actions=num_actions)
|
||||||
|
return create_workload(net, batch_size, image_shape, dtype)
|
|
@ -1,6 +1,6 @@
|
||||||
"""MXNet and NNVM model zoo."""
|
"""MXNet and NNVM model zoo."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from . import mlp, resnet, vgg
|
from . import mlp, resnet, vgg, dqn, dcgan
|
||||||
import nnvm.testing
|
import nnvm.testing
|
||||||
|
|
||||||
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
|
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
|
||||||
|
@ -26,3 +26,11 @@ for num_layer in [11, 13, 16, 19]:
|
||||||
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
|
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
|
||||||
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
|
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
|
||||||
1, _num_class, num_layers=num_layer)[0]
|
1, _num_class, num_layers=num_layer)[0]
|
||||||
|
|
||||||
|
# dqn
|
||||||
|
mx_dqn = dqn.get_symbol()
|
||||||
|
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
|
||||||
|
|
||||||
|
# dcgan generator
|
||||||
|
mx_dcgan = dcgan.get_symbol()
|
||||||
|
nnvm_dcgan = nnvm.testing.dcgan.get_workload(1)[0]
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
The MXNet symbol of DCGAN generator
|
||||||
|
|
||||||
|
Adopted from:
|
||||||
|
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Radford, Alec, Luke Metz, and Soumith Chintala.
|
||||||
|
"Unsupervised representation learning with deep convolutional generative adversarial networks."
|
||||||
|
arXiv preprint arXiv:1511.06434 (2015).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mxnet as mx
|
||||||
|
|
||||||
|
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
|
||||||
|
"""a deconv layer that enlarges the feature map"""
|
||||||
|
target_shape = (oshape[-2], oshape[-1])
|
||||||
|
pad_y = (kshape[0] - 1) // 2
|
||||||
|
pad_x = (kshape[1] - 1) // 2
|
||||||
|
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
|
||||||
|
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
|
||||||
|
|
||||||
|
net = mx.sym.Deconvolution(data,
|
||||||
|
kernel=kshape,
|
||||||
|
stride=stride,
|
||||||
|
pad=(pad_y, pad_x),
|
||||||
|
adj=(adj_y, adj_x),
|
||||||
|
num_filter=oshape[0],
|
||||||
|
no_bias=True,
|
||||||
|
name=name)
|
||||||
|
return net
|
||||||
|
|
||||||
|
def deconv2d_bn_relu(data, prefix, **kwargs):
|
||||||
|
"""a block of deconv + batch norm + relu"""
|
||||||
|
eps = 1e-5 + 1e-12
|
||||||
|
|
||||||
|
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
|
||||||
|
net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
|
||||||
|
net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
|
||||||
|
return net
|
||||||
|
|
||||||
|
def get_symbol(oshape=(3, 32, 32), ngf=128, code=None):
|
||||||
|
"""get symbol of dcgan generator"""
|
||||||
|
assert oshape[-1] == 32, "Only support 32x32 image"
|
||||||
|
assert oshape[-2] == 32, "Only support 32x32 image"
|
||||||
|
|
||||||
|
code = mx.sym.Variable("data") if code is None else code
|
||||||
|
net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True, flatten=False)
|
||||||
|
net = mx.sym.Activation(net, act_type='relu')
|
||||||
|
# 4 x 4
|
||||||
|
net = mx.sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
|
||||||
|
# 8 x 8
|
||||||
|
net = deconv2d_bn_relu(
|
||||||
|
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
|
||||||
|
# 16x16
|
||||||
|
net = deconv2d_bn_relu(
|
||||||
|
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
|
||||||
|
# 32x32
|
||||||
|
net = deconv2d(
|
||||||
|
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
|
||||||
|
net = mx.sym.Activation(net, act_type='tanh')
|
||||||
|
return net
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""
|
||||||
|
The mxnet symbol of Nature DQN
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Mnih, Volodymyr, et al.
|
||||||
|
"Human-level control through deep reinforcement learning."
|
||||||
|
Nature 518.7540 (2015): 529.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mxnet as mx
|
||||||
|
|
||||||
|
def get_symbol(num_action=18):
|
||||||
|
data = mx.sym.Variable(name='data')
|
||||||
|
net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
|
||||||
|
num_filter=32, name='conv1')
|
||||||
|
net = mx.sym.Activation(net, act_type='relu', name='relu1')
|
||||||
|
net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
|
||||||
|
num_filter=64, name='conv2')
|
||||||
|
net = mx.sym.Activation(net, act_type='relu', name='relu2')
|
||||||
|
net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
|
||||||
|
num_filter=64, name='conv3')
|
||||||
|
net = mx.sym.Activation(net, act_type='relu', name='relu3')
|
||||||
|
net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
|
||||||
|
net = mx.sym.Activation(net, act_type='relu', name='relu4')
|
||||||
|
net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
|
||||||
|
|
||||||
|
return net
|
|
@ -32,6 +32,18 @@ def test_resnet():
|
||||||
nnvm_sym = model_zoo.nnvm_resnet[n]
|
nnvm_sym = model_zoo.nnvm_resnet[n]
|
||||||
compare_graph(from_mx_sym, nnvm_sym)
|
compare_graph(from_mx_sym, nnvm_sym)
|
||||||
|
|
||||||
|
def test_dqn():
|
||||||
|
mx_sym = model_zoo.mx_dqn
|
||||||
|
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
|
||||||
|
nnvm_sym = model_zoo.nnvm_dqn
|
||||||
|
compare_graph(from_mx_sym, nnvm_sym)
|
||||||
|
|
||||||
|
def test_dcgan():
|
||||||
|
mx_sym = model_zoo.mx_dcgan
|
||||||
|
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
|
||||||
|
nnvm_sym = model_zoo.nnvm_dcgan
|
||||||
|
compare_graph(from_mx_sym, nnvm_sym)
|
||||||
|
|
||||||
def test_multi_outputs():
|
def test_multi_outputs():
|
||||||
def compose(F, **kwargs):
|
def compose(F, **kwargs):
|
||||||
x = F.sym.Variable('x')
|
x = F.sym.Variable('x')
|
||||||
|
@ -48,3 +60,5 @@ if __name__ == '__main__':
|
||||||
test_vgg()
|
test_vgg()
|
||||||
test_resnet()
|
test_resnet()
|
||||||
test_multi_outputs()
|
test_multi_outputs()
|
||||||
|
test_dqn()
|
||||||
|
test_dcgan()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче