Merge branch 'blis/ops14todo' of https://github.com/Microsoft/CNTK into blis/ops14todo

This commit is contained in:
Willi Richert 2016-05-04 12:45:47 +02:00
Родитель d91c8bc767 aa383f486b
Коммит 4945bb8504
3 изменённых файлов: 81 добавлений и 39 удалений

Просмотреть файл

@ -4,75 +4,99 @@
# for full license information.
# ==============================================================================
# TODO: re-write the example using the new facade
"""
MNIST Example, one hidden layer neural network
MNIST Example, one hidden layer neural network using training and testing data
generated through `uci_to_cntk_text_format_converter.py
<https://github.com/Microsoft/CNTK/blob/master/Source/Readers/CNTKTextFormatReader/uci_to_cntk_text_format_converter.py>`_
to convert it to the CNTKTextFormatReader format.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
from cntk import *
import numpy as np
import cntk as C
def add_dnn_sigmoid_layer(in_dim, out_dim, x, param_scale):
W = LearnableParameter(out_dim, in_dim, initValueScale=param_scale)
b = LearnableParameter(out_dim, 1, initValueScale=param_scale)
t = Times(W, x)
z = Plus(t, b)
return Sigmoid(z)
W = C.parameter((out_dim, in_dim), init_value_scale=param_scale)
b = C.parameter((out_dim, 1), init_value_scale=param_scale)
t = C.times(W, x)
z = C.plus(t, b)
return C.sigmoid(z)
def add_dnn_layer(in_dim, out_dim, x, param_scale):
W = LearnableParameter(out_dim, in_dim, initValueScale=param_scale)
b = LearnableParameter(out_dim, 1, initValueScale=param_scale)
t = Times(W, x)
return Plus(t, b)
W = C.parameter((out_dim, in_dim), init_value_scale=param_scale)
b = C.parameter((out_dim, 1), init_value_scale=param_scale)
t = C.times(W, x)
return C.plus(t, b)
if (__name__ == "__main__"):
def train_eval_mnist_onelayer_from_file(criterion_name=None, eval_name=None):
# Network definition
feat_dim = 784
label_dim = 10
hidden_dim = 200
cur_dir = os.path.dirname(__file__)
training_filename = os.path.join("Data", "Train-28x28.txt")
test_filename = os.path.join("Data", "Test-28x28.txt")
training_filename = os.path.join(cur_dir, "Data", "Train-28x28_text.txt")
test_filename = os.path.join(cur_dir, "Data", "Test-28x28_text.txt")
features = Input(feat_dim)
features = C.input(feat_dim)
features.name = 'features'
feat_scale = Constant(0.00390625)
feats_scaled = Scale(feat_scale, features)
feat_scale = C.constant(0.00390625)
feats_scaled = C.element_times(features, feat_scale)
labels = Input(label_dim)
labels = C.input(label_dim)
labels.tag = 'label'
labels.name = 'labels'
f_reader = UCIFastReader(training_filename, 1, feat_dim)
l_reader = UCIFastReader(training_filename, 0, 1, label_dim,
os.path.join("Data", "labelsmap.txt"))
f_reader_t = UCIFastReader(test_filename, 1, feat_dim)
l_reader_t = UCIFastReader(test_filename, 0, 1, label_dim,
os.path.join("Data", "labelsmap.txt"))
traning_reader = C.CNTKTextFormatReader(training_filename)
test_reader = C.CNTKTextFormatReader(test_filename)
h1 = add_dnn_sigmoid_layer(feat_dim, hidden_dim, feats_scaled, 1)
out = add_dnn_layer(hidden_dim, label_dim, h1, 1)
out.tag = 'output'
ec = CrossEntropyWithSoftmax(labels, out)
ec = C.cross_entropy_with_softmax(labels, out)
ec.name = criterion_name
ec.tag = 'criterion'
eval = C.ops.square_error(labels, out)
eval.name = eval_name
eval.tag = 'eval'
# Specify the training parameters (settings are scaled down)
my_sgd = SGDParams(epoch_size=600, minibatch_size=32,
learning_ratesPerMB=0.1, max_epochs=5, momentum_per_mb=0)
my_sgd = C.SGDParams(epoch_size=600, minibatch_size=32,
learning_rates_per_mb=0.1, max_epochs=5, momentum_per_mb=0)
# Create a context or re-use if already there
with LocalExecutionContext('mnist_one_layer', clean_up=True) as ctx:
with C.LocalExecutionContext('mnist_one_layer', clean_up=True) as ctx:
# CNTK actions
ctx.train(ec, my_sgd, {features: f_reader, labels: l_reader})
ctx.write({features: f_reader_t, labels: l_reader_t})
print(ctx.test({features: f_reader_t, labels: l_reader_t}))
ctx.train(
root_nodes=[ec, eval],
training_params=my_sgd,
input_map=traning_reader.map(labels, alias='labels', dim=label_dim).map(features, alias='features', dim=feat_dim))
result = ctx.test(
root_nodes=[ec, eval],
input_map=test_reader.map(labels, alias='labels', dim=label_dim).map(features, alias='features', dim=feat_dim))
return result
def test_mnist_onelayer_from_file():
result = train_eval_mnist_onelayer_from_file('crit_node', 'eval_node')
TOLERANCE_ABSOLUTE = 1E-06
assert result['SamplesSeen'] == 10000
assert np.allclose(result['Perplexity'], 7.6323031, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['crit_node'], 2.0323896, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['eval_node'], 1.9882504, atol=TOLERANCE_ABSOLUTE)
if __name__ == "__main__":
print(train_eval_mnist_onelayer_from_file('crit_node', 'eval_node'))

Просмотреть файл

@ -518,8 +518,14 @@ def future_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None)
value is returned which is 0.1 by default.
Example:
>>> future_value(0, [[1, 2], [3, 4], [5,6]], 1, 0.5)
# [[3, 4], [5, 6], [0.5, 0.5]]
>>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
>>> t = C.dynamic_axis(name='t')
>>> x = C.input_numpy([data], dynamic_axis=t)
>>> with C.LocalExecutionContext('future_value') as ctx:
... print(ctx.eval(C.future_value(0, x)))
[array([[ 5. , 6. , 7. , 8. ],
[ 9. , 10. , 11. , 12. ],
[ 0.1, 0.1, 0.1, 0.1]])]
Args:
dims: dimensions of the input `x`
@ -544,8 +550,14 @@ def past_value(dims, x, time_step=1, default_hidden_activation=0.1, name=None):
value is returned which is 0.1 by default.
Example:
>>> past_value(0, [[1, 2], [3, 4], [5,6]], 1, 0.5)
# [[0.5, 0.5], [1, 2], [3, 4]]
>>> data = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
>>> t = C.dynamic_axis(name='t')
>>> x = C.input_numpy([data], dynamic_axis=t)
>>> with C.LocalExecutionContext('past_value') as ctx:
... print(ctx.eval(C.past_value(0, x)))
[array([[ 0.1, 0.1, 0.1, 0.1],
[ 1. , 2. , 3. , 4. ],
[ 5. , 6. , 7. , 8. ]])]
Args:
dims: dimensions of the input `x`
@ -755,6 +767,9 @@ def reconcile_dynamic_axis(data_input, layout_input, name=None):
of `layout_input`. It allows these two tensors to be properly compared using, e.g.
a criterion node.
Example:
See Examples/LSTM/seqcla.py for a use of :func:`cntk.ops.reconcile_dynamic_axis`.
Args:
data_input: the tensor to have its dynamic axis layout adapted
layout_input: the tensor layout to use for adapting `data_input`s layout

Просмотреть файл

@ -281,3 +281,6 @@ the criterion node that adds a softmax and then implements the cross entropy los
we add the criterion node, however, we call :func:`cntk.ops.reconcile_dynamic_axis` which will ensure
that the minibatch layout for the labels and the data with dynamic axes is compatible.
For the full explanation of how ``lstm_layer()`` is defined, please see the full example in the
Examples section.