Merge branch 'blis/ops14todo' of https://github.com/Microsoft/CNTK into blis/ops14todo

This commit is contained in:
Willi Richert 2016-05-04 12:45:47 +02:00
Родитель d91c8bc767 aa383f486b
Коммит 4945bb8504
3 изменённых файлов: 81 добавлений и 39 удалений

Просмотреть файл

@ -4,75 +4,99 @@
# for full license information. # for full license information.
# ============================================================================== # ==============================================================================
# TODO: re-write the example using the new facade
""" """
MNIST Example, one hidden layer neural network MNIST Example, one hidden layer neural network using training and testing data
generated through `uci_to_cntk_text_format_converter.py
<https://github.com/Microsoft/CNTK/blob/master/Source/Readers/CNTKTextFormatReader/uci_to_cntk_text_format_converter.py>`_
to convert it to the CNTKTextFormatReader format.
""" """
import sys import sys
import os import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
from cntk import * import numpy as np
import cntk as C
def add_dnn_sigmoid_layer(in_dim, out_dim, x, param_scale): def add_dnn_sigmoid_layer(in_dim, out_dim, x, param_scale):
W = LearnableParameter(out_dim, in_dim, initValueScale=param_scale) W = C.parameter((out_dim, in_dim), init_value_scale=param_scale)
b = LearnableParameter(out_dim, 1, initValueScale=param_scale) b = C.parameter((out_dim, 1), init_value_scale=param_scale)
t = Times(W, x) t = C.times(W, x)
z = Plus(t, b) z = C.plus(t, b)
return Sigmoid(z) return C.sigmoid(z)
def add_dnn_layer(in_dim, out_dim, x, param_scale): def add_dnn_layer(in_dim, out_dim, x, param_scale):
W = LearnableParameter(out_dim, in_dim, initValueScale=param_scale) W = C.parameter((out_dim, in_dim), init_value_scale=param_scale)
b = LearnableParameter(out_dim, 1, initValueScale=param_scale) b = C.parameter((out_dim, 1), init_value_scale=param_scale)
t = Times(W, x) t = C.times(W, x)
return Plus(t, b) return C.plus(t, b)
if (__name__ == "__main__"): def train_eval_mnist_onelayer_from_file(criterion_name=None, eval_name=None):
# Network definition # Network definition
feat_dim = 784 feat_dim = 784
label_dim = 10 label_dim = 10
hidden_dim = 200 hidden_dim = 200
training_filename = os.path.join("Data", "Train-28x28.txt") cur_dir = os.path.dirname(__file__)
test_filename = os.path.join("Data", "Test-28x28.txt")
features = Input(feat_dim) training_filename = os.path.join(cur_dir, "Data", "Train-28x28_text.txt")
test_filename = os.path.join(cur_dir, "Data", "Test-28x28_text.txt")
features = C.input(feat_dim)
features.name = 'features' features.name = 'features'
feat_scale = Constant(0.00390625) feat_scale = C.constant(0.00390625)
feats_scaled = Scale(feat_scale, features) feats_scaled = C.element_times(features, feat_scale)
labels = Input(label_dim) labels = C.input(label_dim)
labels.tag = 'label' labels.tag = 'label'
labels.name = 'labels' labels.name = 'labels'
f_reader = UCIFastReader(training_filename, 1, feat_dim) traning_reader = C.CNTKTextFormatReader(training_filename)
l_reader = UCIFastReader(training_filename, 0, 1, label_dim, test_reader = C.CNTKTextFormatReader(test_filename)
os.path.join("Data", "labelsmap.txt"))
f_reader_t = UCIFastReader(test_filename, 1, feat_dim)
l_reader_t = UCIFastReader(test_filename, 0, 1, label_dim,
os.path.join("Data", "labelsmap.txt"))
h1 = add_dnn_sigmoid_layer(feat_dim, hidden_dim, feats_scaled, 1) h1 = add_dnn_sigmoid_layer(feat_dim, hidden_dim, feats_scaled, 1)
out = add_dnn_layer(hidden_dim, label_dim, h1, 1) out = add_dnn_layer(hidden_dim, label_dim, h1, 1)
out.tag = 'output' out.tag = 'output'
ec = CrossEntropyWithSoftmax(labels, out) ec = C.cross_entropy_with_softmax(labels, out)
ec.name = criterion_name
ec.tag = 'criterion' ec.tag = 'criterion'
eval = C.ops.square_error(labels, out)
eval.name = eval_name
eval.tag = 'eval'
# Specify the training parameters (settings are scaled down) # Specify the training parameters (settings are scaled down)
my_sgd = SGDParams(epoch_size=600, minibatch_size=32, my_sgd = C.SGDParams(epoch_size=600, minibatch_size=32,
learning_ratesPerMB=0.1, max_epochs=5, momentum_per_mb=0) learning_rates_per_mb=0.1, max_epochs=5, momentum_per_mb=0)
# Create a context or re-use if already there # Create a context or re-use if already there
with LocalExecutionContext('mnist_one_layer', clean_up=True) as ctx: with C.LocalExecutionContext('mnist_one_layer', clean_up=True) as ctx:
# CNTK actions # CNTK actions
ctx.train(ec, my_sgd, {features: f_reader, labels: l_reader}) ctx.train(
ctx.write({features: f_reader_t, labels: l_reader_t}) root_nodes=[ec, eval],
print(ctx.test({features: f_reader_t, labels: l_reader_t})) training_params=my_sgd,
input_map=traning_reader.map(labels, alias='labels', dim=label_dim).map(features, alias='features', dim=feat_dim))
result = ctx.test(
root_nodes=[ec, eval],
input_map=test_reader.map(labels, alias='labels', dim=label_dim).map(features, alias='features', dim=feat_dim))
return result
def test_mnist_onelayer_from_file():
result = train_eval_mnist_onelayer_from_file('crit_node', 'eval_node')
TOLERANCE_ABSOLUTE = 1E-06
assert result['SamplesSeen'] == 10000
assert np.allclose(result['Perplexity'], 7.6323031, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['crit_node'], 2.0323896, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['eval_node'], 1.9882504, atol=TOLERANCE_ABSOLUTE)
if __name__ == "__main__":
print(train_eval_mnist_onelayer_from_file('crit_node', 'eval_node'))

Просмотреть файл

@ -518,8 +518,14 @@ def future_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None)
value is returned which is 0.1 by default. value is returned which is 0.1 by default.
Example: Example:
>>> future_value(0, [[1, 2], [3, 4], [5,6]], 1, 0.5) >>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
# [[3, 4], [5, 6], [0.5, 0.5]] >>> t = C.dynamic_axis(name='t')
>>> x = C.input_numpy([data], dynamic_axis=t)
>>> with C.LocalExecutionContext('future_value') as ctx:
... print(ctx.eval(C.future_value(0, x)))
[array([[ 5. , 6. , 7. , 8. ],
[ 9. , 10. , 11. , 12. ],
[ 0.1, 0.1, 0.1, 0.1]])]
Args: Args:
dims: dimensions of the input `x` dims: dimensions of the input `x`
@ -544,8 +550,14 @@ def past_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
value is returned which is 0.1 by default. value is returned which is 0.1 by default.
Example: Example:
>>> past_value(0, [[1, 2], [3, 4], [5,6]], 1, 0.5) >>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
# [[0.5, 0.5], [1, 2], [3, 4]] >>> t = C.dynamic_axis(name='t')
>>> x = C.input_numpy([data], dynamic_axis=t)
>>> with C.LocalExecutionContext('past_value') as ctx:
... print(ctx.eval(C.past_value(0, x)))
[array([[ 0.1, 0.1, 0.1, 0.1],
[ 1. , 2. , 3. , 4. ],
[ 5. , 6. , 7. , 8. ]])]
Args: Args:
dims: dimensions of the input `x` dims: dimensions of the input `x`
@ -755,6 +767,9 @@ def reconcile_dynamic_axis(data_input, layout_input, name=None):
of `layout_input`. It allows these two tensors to be properly compared using, e.g. of `layout_input`. It allows these two tensors to be properly compared using, e.g.
a criterion node. a criterion node.
Example:
See Examples/LSTM/seqcla.py for a use of :func:`cntk.ops.reconcile_dynamic_axis`.
Args: Args:
data_input: the tensor to have its dynamic axis layout adapted data_input: the tensor to have its dynamic axis layout adapted
layout_input: the tensor layout to use for adapting `data_input`s layout layout_input: the tensor layout to use for adapting `data_input`s layout

Просмотреть файл

@ -281,3 +281,6 @@ the criterion node that adds a softmax and then implements the cross entropy los
we add the criterion node, however, we call :func:`cntk.ops.reconcile_dynamic_axis` which will ensure we add the criterion node, however, we call :func:`cntk.ops.reconcile_dynamic_axis` which will ensure
that the minibatch layout for the labels and the data with dynamic axes is compatible. that the minibatch layout for the labels and the data with dynamic axes is compatible.
For the full explanation of how ``lstm_layer()`` is defined, please see the full example in the
Examples section.