Fix conv2d bug when having conv2d over the static axes of sequences.
This commit is contained in:
Родитель
438f61985e
Коммит
51b9c0e4c2
|
@ -165,7 +165,7 @@ class Crosstalk(object):
|
|||
attr : attributes for the variable that would be used when getting/setting values. Could be one of Conv2DAttr/EmbedAttr/RnnAttr
|
||||
'''
|
||||
if name in self.vars.keys():
|
||||
raise Exception('var with name {} already exists')
|
||||
raise Exception('var with name {} already exists'.format(name))
|
||||
self.vars[name] = _VarInfo(var, var_type if var_type else type(var), attr)
|
||||
|
||||
def register_funcs(self, var_type, setter=None, getter=None):
|
||||
|
|
|
@ -67,11 +67,16 @@ def _variable_getter(sess, data):
|
|||
def _conv2d_getter(sess):
|
||||
def _get(pd, attr):
|
||||
W = _trainable_getter(sess)(pd.W)
|
||||
#handling input with sequence axis:
|
||||
W_rank = len(W.shape)
|
||||
#the transpose from tf [H, W, C] to cntk's [C, H, W] happens at the tailing axes excluding the leading dynamic
|
||||
#axes (batch and sequence axes) in the data format:
|
||||
axis_perm = (list(range(W_rank - 3)) if W_rank > 3 else []) + [i + W_rank - 3 for i in [2,0,1]]
|
||||
if pd.b:
|
||||
b = _trainable_getter(sess)(pd.b)
|
||||
else:
|
||||
b = None
|
||||
return cstk.Conv2DArgs(W=W.transpose(2,0,1), b=b.reshape(attr.num_filters,))
|
||||
return cstk.Conv2DArgs(W=W.transpose(axis_perm), b=b.reshape(attr.num_filters,))
|
||||
return _get
|
||||
|
||||
def _conv2d_setter(sess):
|
||||
|
@ -93,10 +98,16 @@ def _rnn_trainable_in_scope(scope):
|
|||
bw_M=find_trainable('Matrix', scope=scope+'/BW')
|
||||
bw_b=find_trainable('Bias', scope=scope+'/BW')
|
||||
elif tf.VERSION.startswith('1'):
|
||||
fw_M=find_trainable('weights', scope=scope+'/fw')
|
||||
fw_b=find_trainable('biases', scope=scope+'/fw')
|
||||
bw_M=find_trainable('weights', scope=scope+'/bw')
|
||||
bw_b=find_trainable('biases', scope=scope+'/bw')
|
||||
if tf.VERSION.startswith('1.1'):
|
||||
fw_M=find_trainable('weights', scope=scope+'/fw')
|
||||
fw_b=find_trainable('biases', scope=scope+'/fw')
|
||||
bw_M=find_trainable('weights', scope=scope+'/bw')
|
||||
bw_b=find_trainable('biases', scope=scope+'/bw')
|
||||
else: # the following changes started with version '1.2' until as of version 1.7 for now
|
||||
fw_M = find_trainable('kernel', scope=scope + '/fw')
|
||||
fw_b = find_trainable('bias', scope=scope + '/fw')
|
||||
bw_M = find_trainable('kernel', scope=scope + '/bw')
|
||||
bw_b = find_trainable('bias', scope=scope + '/bw')
|
||||
else:
|
||||
raise Exception('only supports 0.12.* and 1.*')
|
||||
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
import numpy as np
|
||||
from cntk.contrib import crosstalk as cstk
|
||||
import tempfile
|
||||
workdir = tempfile.gettempdir()
|
||||
|
||||
batch_size = 20
|
||||
filter_width = 5
|
||||
char_emb_dim = 8
|
||||
num_chars = 16
|
||||
seq_len = 4
|
||||
sample_shape = (num_chars, char_emb_dim,)
|
||||
input_data = np.random.random((batch_size,seq_len)+sample_shape).astype(np.float32)
|
||||
filter_shape = (filter_width,char_emb_dim,)
|
||||
num_filters = 100
|
||||
|
||||
def cntk_baseline_conv2d():
|
||||
import cntk as C
|
||||
import cntk.contrib.crosstalk.crosstalk_cntk as crct
|
||||
ci = crct.instance
|
||||
input_var = C.sequence.input_variable(shape=sample_shape)
|
||||
input_reshaped = C.reshape(input_var, (1,)+sample_shape)
|
||||
conv_out = C.layers.Convolution2D(filter_shape, num_filters, init_bias=C.glorot_uniform())(input_reshaped)
|
||||
ci.watch(conv_out, 'conv2d', var_type=cstk.Conv2DAttr,
|
||||
attr=cstk.Conv2DAttr(filter_shape=filter_shape, num_filters=num_filters))
|
||||
ci.watch(conv_out, 'conv2d_out')
|
||||
|
||||
data = {input_var:input_data}
|
||||
ci.set_data(data)
|
||||
ci.set_workdir(workdir)
|
||||
ci.fetch('conv2d', save=True)
|
||||
ci.fetch('conv2d_out', save=True)
|
||||
ci.reset()
|
||||
|
||||
def tf_baseline_conv2d():
|
||||
import tensorflow as tf
|
||||
import cntk.contrib.crosstalk.crosstalk_tensorflow as crtf
|
||||
ci = crtf.instance
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
x = tf.placeholder(tf.float32, [batch_size, seq_len, num_chars, char_emb_dim])
|
||||
filter_bank = tf.get_variable("char_filter_bank",
|
||||
shape=[filter_width, char_emb_dim, num_filters],
|
||||
dtype=tf.float32)
|
||||
bias = tf.get_variable("char_filter_biases", shape=[num_filters], dtype=tf.float32)
|
||||
|
||||
x_reshape = tf.reshape(x, [-1] + x.get_shape().as_list()[-2:])
|
||||
char_conv = tf.expand_dims(tf.transpose(tf.nn.conv1d(x_reshape, filter_bank, stride=1, padding='VALID') + bias, perm=[0,2,1]), -1)
|
||||
char_conv = tf.reshape(char_conv, [-1, seq_len] + char_conv.shape.as_list()[-3:])
|
||||
|
||||
ci.watch(cstk.Conv2DArgs(W=crtf.find_trainable('char_filter_bank'), b=crtf.find_trainable('char_filter_biases')), 'conv2d', var_type=cstk.Conv2DAttr,
|
||||
attr=cstk.Conv2DAttr(filter_shape=(filter_width, char_emb_dim,), num_filters=num_filters))
|
||||
ci.watch(char_conv, 'conv2d_out', var_type=crtf.VariableType) # note the output is transposed to NCHW
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
data = {x:input_data}
|
||||
ci.set_workdir(workdir)
|
||||
ci.set_data(sess, data)
|
||||
ci.fetch('conv2d_out', save=True)
|
||||
ci.fetch('conv2d', save=True)
|
||||
ci.assign('conv2d', load=True)
|
||||
assert ci.compare('conv2d_out')
|
||||
ci.reset()
|
||||
sess.close()
|
||||
|
||||
def test_cntk_conv2d():
|
||||
try:
|
||||
import tensorflow
|
||||
has_tensorflow = True
|
||||
except:
|
||||
has_tensorflow = False
|
||||
|
||||
if has_tensorflow:
|
||||
tf_baseline_conv2d()
|
||||
else:
|
||||
cntk_baseline_conv2d()
|
||||
|
||||
import cntk as C
|
||||
import cntk.contrib.crosstalk.crosstalk_cntk as crct
|
||||
ci = crct.instance
|
||||
|
||||
input_var = C.sequence.input_variable(shape=sample_shape)
|
||||
input_reshaped = C.reshape(input_var, (1,)+sample_shape)
|
||||
conv_out = C.layers.Convolution2D(filter_shape, num_filters, activation=None)(input_reshaped)
|
||||
|
||||
ci.watch(conv_out, 'conv2d', var_type=cstk.Conv2DAttr,
|
||||
attr=cstk.Conv2DAttr(filter_shape=filter_shape, num_filters=num_filters))
|
||||
ci.watch(conv_out, 'conv2d_out')
|
||||
|
||||
data = {input_var:input_data}
|
||||
ci.set_data(data)
|
||||
ci.set_workdir(workdir)
|
||||
conv_out_values = conv_out.eval(data)
|
||||
|
||||
# load parameters from crosstalk and verify results are the same
|
||||
ci.assign('conv2d', load=True)
|
||||
assert ci.compare('conv2d_out', rtol=1e-4, atol=1e-6)
|
||||
|
||||
# test assign with value
|
||||
ci.assign('conv2d', value=cstk.Conv2DArgs(W=np.random.random((num_filters,) + filter_shape).astype(np.float32),
|
||||
b=np.random.random((num_filters,)).astype(np.float32)))
|
||||
|
||||
ci.reset()
|
Загрузка…
Ссылка в новой задаче