move lstm seqcla to row vector
This commit is contained in:
Родитель
8c6b8eb97a
Коммит
09d7e74fa7
|
@ -29,7 +29,7 @@ class Last(C.ComputationNode):
|
|||
super(Last, self).__init__(params=['x'], op_name=op_name, name=name)
|
||||
self.x = x
|
||||
self.params_with_defaults = []
|
||||
|
||||
self.rank = x.rank
|
||||
|
||||
def lstm_layer(output_dim, cell_dim, x, input_dim):
|
||||
|
||||
|
@ -47,23 +47,23 @@ def lstm_layer(output_dim, cell_dim, x, input_dim):
|
|||
def lstm_func(output_dim, cell_dim, x, input_dim, prev_state_h, prev_state_c):
|
||||
|
||||
# input gate (t)
|
||||
it_w = C.times(C.parameter((cell_dim, input_dim)), x)
|
||||
it_b = C.parameter((cell_dim))
|
||||
it_h = C.times(C.parameter((cell_dim, output_dim)), prev_state_h)
|
||||
it_c = C.parameter((cell_dim)) * prev_state_c
|
||||
it_w = C.times(x,C.parameter((input_dim, cell_dim)))
|
||||
it_b = C.parameter((1,cell_dim))
|
||||
it_h = C.times(prev_state_h,C.parameter((output_dim, cell_dim)))
|
||||
it_c = C.parameter((1,cell_dim)) * prev_state_c
|
||||
it = C.sigmoid((it_w + it_b + it_h + it_c), name='it')
|
||||
|
||||
# applied to tanh of input
|
||||
bit_w = C.times(C.parameter((cell_dim, input_dim)), x)
|
||||
bit_h = C.times(C.parameter((cell_dim, output_dim)), prev_state_h)
|
||||
bit_b = C.parameter((cell_dim))
|
||||
bit_w = C.times(x,C.parameter((input_dim,cell_dim)))
|
||||
bit_h = C.times(prev_state_h,C.parameter((output_dim,cell_dim)))
|
||||
bit_b = C.parameter((1,cell_dim))
|
||||
bit = it * C.tanh(bit_w + (bit_h + bit_b))
|
||||
|
||||
# forget-me-not gate (t)
|
||||
ft_w = C.times(C.parameter((cell_dim, input_dim)), x)
|
||||
ft_b = C.parameter((cell_dim))
|
||||
ft_h = C.times(C.parameter((cell_dim, output_dim)), prev_state_h)
|
||||
ft_c = C.parameter((cell_dim)) * prev_state_c
|
||||
ft_w = C.times(x, C.parameter((input_dim,cell_dim)))
|
||||
ft_b = C.parameter((1,cell_dim))
|
||||
ft_h = C.times(prev_state_h,C.parameter((output_dim,cell_dim)))
|
||||
ft_c = C.parameter((1,cell_dim)) * prev_state_c
|
||||
ft = C.sigmoid((ft_w + ft_b + ft_h + ft_c), name='ft')
|
||||
|
||||
# applied to cell(t-1)
|
||||
|
@ -73,10 +73,10 @@ def lstm_func(output_dim, cell_dim, x, input_dim, prev_state_h, prev_state_c):
|
|||
ct = bft + bit
|
||||
|
||||
# output gate
|
||||
ot_w = C.times(C.parameter((cell_dim, input_dim)), x)
|
||||
ot_b = C.parameter((cell_dim))
|
||||
ot_h = C.times(C.parameter((cell_dim, output_dim)), prev_state_h)
|
||||
ot_c = C.parameter((cell_dim)) * prev_state_c
|
||||
ot_w = C.times(x, C.parameter((input_dim,cell_dim)))
|
||||
ot_b = C.parameter((1,cell_dim))
|
||||
ot_h = C.times(prev_state_h,C.parameter((output_dim,cell_dim)))
|
||||
ot_c = C.parameter((1,cell_dim)) * prev_state_c
|
||||
ot = C.sigmoid((ot_w + ot_b + ot_h + ot_c), name='ot')
|
||||
|
||||
# applied to tanh(cell(t))
|
||||
|
@ -107,19 +107,19 @@ def seqcla():
|
|||
train_reader = C.CNTKTextFormatReader(train_file)
|
||||
|
||||
# setup embedding matrix
|
||||
embedding = C.parameter((embed_dim, vocab), learning_rate_multiplier=0.0,
|
||||
embedding = C.parameter((vocab, embed_dim), learning_rate_multiplier=0.0,
|
||||
init_from_file_path=embedding_file)
|
||||
|
||||
# get the vector representing the word
|
||||
sequence = C.times(embedding, features, name='sequence')
|
||||
sequence = C.times(features, embedding, name='sequence')
|
||||
|
||||
# add an LSTM layer
|
||||
L = lstm_layer(output_dim, cell_dim, sequence, input_dim)
|
||||
|
||||
# add a softmax layer on top
|
||||
w = C.parameter((num_labels, output_dim), name='w')
|
||||
b = C.parameter((num_labels), name='b')
|
||||
z = C.times(w, L) + b
|
||||
w = C.parameter((output_dim, num_labels), name='w')
|
||||
b = C.parameter((1,num_labels), name='b')
|
||||
z = C.times(L, w) + b
|
||||
z.name='z'
|
||||
z.tag = "output"
|
||||
|
||||
|
@ -132,6 +132,7 @@ def seqcla():
|
|||
my_sgd = C.SGDParams(epoch_size=0, minibatch_size=10, learning_rates_per_mb=0.1, max_epochs=3)
|
||||
|
||||
with C.LocalExecutionContext('seqcla') as ctx:
|
||||
ctx.clean_up=False
|
||||
# train the model
|
||||
ctx.train(root_nodes=[ce], training_params=my_sgd, input_map=train_reader.map(
|
||||
features, alias='x', dim=vocab, format='Sparse').map(
|
||||
|
@ -144,7 +145,6 @@ def seqcla():
|
|||
|
||||
# do some manual accuracy testing
|
||||
acc = calc_accuracy(train_file, ctx.output_filename_base)
|
||||
|
||||
# and test for the same number...
|
||||
TOLERANCE_ABSOLUTE = 1E-02
|
||||
assert np.allclose(acc, 0.6006415396952687, atol=TOLERANCE_ABSOLUTE)
|
||||
|
|
|
@ -77,7 +77,7 @@ def train_eval_mnist_onelayer_from_file(criterion_name=None, eval_name=None):
|
|||
learning_rates_per_mb=0.1, max_epochs=30, momentum_per_mb=0)
|
||||
|
||||
# Create a context or re-use if already there
|
||||
with C.LocalExecutionContext('mnist_one_layer', clean_up=False) as ctx:
|
||||
with C.LocalExecutionContext('mnist_one_layer', clean_up=True) as ctx:
|
||||
# CNTK actions
|
||||
ctx.train(
|
||||
root_nodes=[ec, eval],
|
||||
|
|
|
@ -773,7 +773,7 @@ def cond(flag, value_if_true, value_if_false, name=None):
|
|||
# recurrent ops
|
||||
################################################################################
|
||||
|
||||
def future_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
"""
|
||||
This function op =s the future value wrt `x`. It is most often used when
|
||||
creating RNNs. The resulting tensor has the same shape as the input but is
|
||||
|
@ -793,7 +793,7 @@ def future_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None)
|
|||
[ 0.1, 0.1, 0.1, 0.1]])]
|
||||
|
||||
Args:
|
||||
dims: dimensions of the input `x`
|
||||
shape: dimensions of the input `x`
|
||||
x: the tensor from which the future value is obtained
|
||||
time_step: the number of time steps to look into the future (default 1)
|
||||
default_hidden_activation: the default value to use when no future value
|
||||
|
@ -803,11 +803,11 @@ def future_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None)
|
|||
"""
|
||||
|
||||
from cntk.ops.cntk1 import FutureValue
|
||||
op = FutureValue(dims, x, time_step, default_hidden_activation, name = name)
|
||||
op.rank = x.rank
|
||||
op = FutureValue(shape, x, time_step, default_hidden_activation, name = name)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
return op
|
||||
|
||||
def past_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
"""
|
||||
This function op =s the past value wrt `x`. It is most often used when
|
||||
creating RNNs. The resulting tensor has the same shape as the input but is
|
||||
|
@ -827,7 +827,7 @@ def past_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
|
|||
[ 5. , 6. , 7. , 8. ]])]
|
||||
|
||||
Args:
|
||||
dims: dimensions of the input `x`
|
||||
shape: dimensions of the input `x`
|
||||
x: the tensor from which the past value is obtained
|
||||
time_step: the number of time steps to look into the past (default 1)
|
||||
default_hidden_activation: the default value to use when no past value
|
||||
|
@ -837,8 +837,8 @@ def past_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
|
|||
"""
|
||||
|
||||
from cntk.ops.cntk1 import PastValue
|
||||
op = PastValue(dims, x, time_step, default_hidden_activation, name = name)
|
||||
op.rank = x.rank
|
||||
op = PastValue(shape, x, time_step, default_hidden_activation, name = name)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
return op
|
||||
|
||||
################################################################################
|
||||
|
@ -1118,6 +1118,9 @@ def sparse_input(shape, dynamic_axis='', name=None):
|
|||
"""
|
||||
|
||||
from cntk.ops.cntk1 import SparseInput
|
||||
if not np.isscalar(shape):
|
||||
# cntk uses column major, thus we reverse the shape
|
||||
shape = tuple(reversed(shape))
|
||||
op = SparseInput(shape, dynamicAxis=dynamic_axis, name=name)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
return op
|
||||
|
|
|
@ -421,7 +421,9 @@ class LazySparseInputReader(_LazyInputReaderBase):
|
|||
self.indices = indices
|
||||
self.values = values
|
||||
|
||||
self.shape = shape
|
||||
# cntk uses column major, thus we reverse the shape
|
||||
self.shape = tuple(reversed(shape))
|
||||
|
||||
|
||||
self.param_dict = {}
|
||||
self.param_dict['dim'] = np.multiply.reduce(self.shape)
|
||||
|
|
Загрузка…
Ссылка в новой задаче