CNTK v2 library: Fixed some python wrappers

This commit is contained in:
Amit Agarwal 2016-09-05 13:22:51 -07:00
Родитель d610b2c639
Коммит 3b763f7d53
8 изменённых файлов: 215 добавлений и 47 удалений

Просмотреть файл

@ -18,7 +18,7 @@ namespace CNTK
}
/*static*/ std::atomic<bool> DeviceDescriptor::s_defaultDeviceFrozen(false);
/*static*/ std::shared_ptr<DeviceDescriptor> DeviceDescriptor::s_defaultDevice(new DeviceDescriptor(DeviceDescriptor::GPUDevice(0)));
/*static*/ std::shared_ptr<DeviceDescriptor> DeviceDescriptor::s_defaultDevice(new DeviceDescriptor(DeviceDescriptor::CPUDevice()));
/*static*/ DeviceDescriptor DeviceDescriptor::DefaultDevice()
{

Просмотреть файл

@ -35,61 +35,61 @@ def combine(operands, name=''):
# evaluation ops
################################################################################
def cross_entropy_with_softmax(target_vector, output_vector, name=''):
def cross_entropy_with_softmax(output_vector, target_vector, name=''):
'''
This operation computes the cross entropy over the softmax of the `output_vector`.
It expects the `output_vector` as unscaled, and it computes softmax over
the `output_vector` internally. Any `output_vector` input over which softmax is
already computed before passing to this operator will be incorrect.
:math:`cross\_entropy\_with\_softmax(t, o) = {-{\sum_{i \in \{1,len(t)\}} t_i \log(softmax(o_i)) }}`
:math:`cross\_entropy\_with\_softmax(o, t) = {-{\sum_{i \in \{1,len(t)\}} t_i \log(softmax(o_i)) }}`
Example:
>>> C.eval(C.cross_entropy_with_softmax([0., 0., 0., 1.], [1., 1., 1., 50.]))
>>> C.eval(C.cross_entropy_with_softmax([1., 1., 1., 50.], [0., 0., 0., 1.]))
#[0.]
>>> C.eval(C.cross_entropy_with_softmax([0.35, 0.15, 0.05, 0.45], [1., 2., 3., 4.]))
>>> C.eval(C.cross_entropy_with_softmax([1., 2., 3., 4.], [0.35, 0.15, 0.05, 0.45]))
#[1.84]
Args:
output_vector: the unscaled computed output values from the network
target_vector: usually it is one-hot vector where the hot bit corresponds to the label index.
But it can be any probability distribution over the labels.
output_vector: the unscaled computed output values from the network
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
from cntk import cross_entropy_with_softmax
target_vector = sanitize_input(target_vector, get_data_type(output_vector))
output_vector = sanitize_input(output_vector, get_data_type(target_vector))
return cross_entropy_with_softmax(target_vector, output_vector, name).output()
target_vector = sanitize_input(target_vector, get_data_type(output_vector))
return cross_entropy_with_softmax(output_vector, target_vector, name).output()
def square_error(target_matrix, output_matrix, name=''):
def squared_error(output_matrix, target_matrix, name=''):
'''
This operation computes the sum of the squared difference between elements
in the two input matrices. The result is a scalar (i.e., one by one matrix).
This is often used as a training criterion node.
Example:
>>> C.eval(C.square_error([4., 6.], [2., 1.]))
>>> C.eval(C.square_error([2., 1.], [4., 6.]))
#[29.]
>>> C.eval(C.square_error([1., 2.], [1., 2.]))
#[0.]
Args:
target_matrix: target matrix, it is usually a one-hot vector where the hot bit corresponds to the label index
output_matrix: the output values from the network
target_matrix: target matrix, it is usually a one-hot vector where the hot bit corresponds to the label index
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
from cntk import square_error
target_matrix = sanitize_input(target_matrix, get_data_type(output_matrix))
from cntk import squared_error
output_matrix = sanitize_input(output_matrix, get_data_type(target_matrix))
return square_error(target_matrix, output_matrix, name).output()
target_matrix = sanitize_input(target_matrix, get_data_type(output_matrix))
return square_error(output_matrix, target_matrix, name).output()
def classification_error(target_vector, output_vector, name=''):
def classification_error(output_vector, target_vector, name=''):
'''
This operation computes the prediction error. It finds the index of the highest
value in the output_vector and compares it to the actual ground truth label
@ -99,23 +99,23 @@ def classification_error(target_vector, output_vector, name=''):
defined for it.
Example:
>>> C.eval(C.classification_error([0., 0., 0., 1.], [1., 2., 3., 4.]))
>>> C.eval(C.classification_error([1., 2., 3., 4.], [0., 0., 0., 1.]))
#[0.]
>>> C.eval(C.classification_error([0., 0., 1., 0.], [1., 2., 3., 4.]))
>>> C.eval(C.classification_error([1., 2., 3., 4.], [0., 0., 1., 0.]))
#[1.]
Args:
target_vector: it is one-hot vector where the hot bit corresponds to the label index
output_vector: the output values from the network
target_vector: it is one-hot vector where the hot bit corresponds to the label index
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
from cntk import classification_error
target_vector = sanitize_input(target_vector, get_data_type(output_vector))
output_vector = sanitize_input(output_vector, get_data_type(target_vector))
return classification_error(target_vector, output_vector, name).output()
target_vector = sanitize_input(target_vector, get_data_type(output_vector))
return classification_error(output_vector, target_vector, name).output()
################################################################################
# convolution ops
@ -903,7 +903,7 @@ def cond(flag, value_if_true, value_if_false, name=''):
# TODO: add default value for initial_state. It should be a constant scalar
# (0.0), using the default device
def future_value(initial_state, x, time_step=1, name=''):
def future_value(x, initial_state=None, time_step=1, name=''):
'''
This function returns the future value w.r.t. `x`. It is most often used when
creating RNNs. The resulting tensor has the same shape as the input but is
@ -915,20 +915,26 @@ def future_value(initial_state, x, time_step=1, name=''):
Example:
TBA
Args:
x: the tensor (or its name) from which the future value is obtained.
initial_state: tensor or scalar representing the initial value to be
used when the input tensor is shifted in time.
x: the tensor (or its name) from which the future value is obtained.
time_step (int): the number of time steps to look into the future (default 1)
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
from ..utils import sanitize_dtype_cntk
from ..cntk_py import Constant
from cntk import future_value
if initial_state is None:
initial_state = Constant.scalar(sanitize_dtype_cntk(np.float32), 0.0)
x = sanitize_input(x)
return future_value(initial_state, x, time_step, name).output()
return future_value(x, initial_state, time_step, name).output()
def past_value(initial_state, x, time_step=1, default_hidden_activation=0.1, name=''):
def past_value(x, initial_state=None, time_step=1, name=''):
'''
This function returns the past value w.r.t. `x`. It is most often used when
creating RNNs. The resulting tensor has the same shape as the input but is
@ -940,18 +946,24 @@ def past_value(initial_state, x, time_step=1, default_hidden_activation=0.1, nam
Example:
TBA
Args:
x: the tensor (or its name) from which the past value is obtained
initial_state: tensor or scalar representing the initial value to be
used when the input tensor is shifted in time.
x: the tensor (or its name) from which the past value is obtained
time_step (int): the number of time steps to look into the past (default 1)
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
from ..utils import sanitize_dtype_cntk
from ..cntk_py import Constant
from cntk import past_value
if initial_state is None:
initial_state = Constant.scalar(sanitize_dtype_cntk(np.float32), 0.0)
x = sanitize_input(x)
return past_value(initial_state, x, time_step, name).output()
return past_value(x, initial_state, time_step, name).output()
################################################################################
# reshaping ops

Просмотреть файл

@ -7,7 +7,7 @@
import numpy as np
import sys
import os
from cntk import learning_rates_per_sample, Trainer, sgdlearner, create_minibatch_source, get_train_loss, get_train_eval_criterion, cntk_device
from cntk import learning_rates_per_sample, Trainer, sgdlearner, create_minibatch_source, get_train_loss, get_train_eval_criterion, DeviceDescriptor
from cntk.ops import input_variable, constant, parameter, cross_entropy_with_softmax, combine, classification_error, times, pooling, AVG_POOLING
from examples.common.nn import conv_bn_relu_layer, conv_bn_layer, resnet_node2, resnet_node2_inc
@ -103,7 +103,7 @@ def cifar_resnet():
# Input variables denoting the features and label data
image_input = input_variable((num_channels, image_height, image_width), features_si.m_element_type)
label_var = input_variable((num_classes), features_si.m_element_type, needs_gradient=False)
label_var = input_variable((num_classes), features_si.m_element_type)
# Instantiate the resnet classification model
classifier_output = resnet_classifer(image_input, num_classes)
@ -124,14 +124,13 @@ def cifar_resnet():
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {image_input : mb[features_si].m_data, label_var : mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(training_progress_output_freq, i, trainer)
if __name__=='__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor.gpu_device(0)
target_device = DeviceDescriptor.gpudevice(0)
DeviceDescriptor.set_default_device(target_device)
cifar_resnet()

Просмотреть файл

@ -19,11 +19,11 @@ def simple_mnist():
hidden_layers_dim = 200
# Input variables denoting the features and label data
input = input_variable(input_dim, np.float32, needs_gradient=False, name="features")
label = input_variable(num_output_classes, np.float32, needs_gradient=False, name="labels")
input = input_variable(input_dim, np.float32)
label = input_variable(num_output_classes, np.float32)
scaled_input = element_times(constant((), 0.00390625), input)
# Instantiate the feedforward classification model
scaled_input = element_times(constant((), 0.00390625), input)
netout = fully_connected_classifier_net(scaled_input, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
ce = cross_entropy_with_softmax(netout, label)
@ -61,7 +61,7 @@ def simple_mnist():
if __name__=='__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor_cpudevice()
target_device = DeviceDescriptor.gpudevice(0)
DeviceDescriptor.set_default_device(target_device)
simple_mnist()

Просмотреть файл

@ -0,0 +1,151 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
import sys
import os
import time
from cntk import learning_rates_per_sample, DeviceDescriptor, Trainer, sgdlearner, Axis, get_train_loss, get_train_eval_criterion
from cntk.ops import variable, cross_entropy_with_softmax, classification_error
from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, fully_connected_linear_layer, select_last
# Creates and trains a sequence to sequence translation model
def train_sequence_to_sequence_translator():
input_vocab_dim = 69
label_vocab_dim = 69
hidden_dim = 512
num_layers = 2
# Source and target inputs to the model
input_dynamic_axes = [ Axis('inputAxis'), Axis.default_batch_axis() ]
raw_input = input_variable(shape=(input_vocab_dim), dynamic_axes = input_dynamic_axes)
label_dynamic_axes = [ Axis('labelAxis'), Axis.default_batch_axis() ]
raw_labels = input_variable(shape=(label_vocab_dim), dynamic_axes = label_dynamic_axes)
input_sequence = raw_input
# Drop the sentence start token from the label, for decoder training
label_sequence = cntk.ops.slice(raw_labels, label_dynamic_axes[0], 1, 0)
label_sentence_start = Sequence.first(raw_labels)
is_first_label = Sequence.is_first(label_sequence)
label_sentence_start_scattered = Sequence.scatter(label_sentence_start, is_first_label)
# Encoder
encoderOutputH = stabilize<float>(inputEmbedding, device)
futureValueRecurrenceHook = [](const Variable& x) { return FutureValue(x) }
for (size_t i = 0 i < num_layers ++i)
std::tie(encoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(encoderOutputH, hidden_dim, hidden_dim, futureValueRecurrenceHook, futureValueRecurrenceHook, device)
thoughtVectorH = Sequence::First(encoderOutputH)
thoughtVectorC = Sequence::First(encoderOutputC)
thoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding)
thoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding)
/* Decoder */
bool addBeamSearchReorderingHook = false
beamSearchReorderHook = Constant({ 1, 1 }, 1.0f)
decoderHistoryFromGroundTruth = labelEmbedding
decoderInput = ElementSelect(is_first_label, label_sentence_startEmbeddedScattered, PastValue(decoderHistoryFromGroundTruth))
decoderOutputH = Stabilize<float>(decoderInput, device)
FunctionPtr decoderOutputC
pastValueRecurrenceHookWithBeamSearchReordering = [addBeamSearchReorderingHook, beamSearchReorderHook](const FunctionPtr& operand) {
return PastValue(addBeamSearchReorderingHook ? Times(operand, beamSearchReorderHook) : operand)
}
for (size_t i = 0 i < num_layers ++i)
{
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC
if (i == 0)
{
recurrenceHookH = pastValueRecurrenceHookWithBeamSearchReordering
recurrenceHookC = pastValueRecurrenceHookWithBeamSearchReordering
}
else
{
isFirst = Sequence::IsFirst(labelEmbedding)
recurrenceHookH = [labelEmbedding, thoughtVectorBroadcastH, isFirst, addBeamSearchReorderingHook, beamSearchReorderHook](const FunctionPtr& operand) {
return ElementSelect(isFirst, thoughtVectorBroadcastH, PastValue(addBeamSearchReorderingHook ? Times(operand, beamSearchReorderHook) : operand))
}
recurrenceHookC = [labelEmbedding, thoughtVectorBroadcastC, isFirst, addBeamSearchReorderingHook, beamSearchReorderHook](const FunctionPtr& operand) {
return ElementSelect(isFirst, thoughtVectorBroadcastC, PastValue(addBeamSearchReorderingHook ? Times(operand, beamSearchReorderHook) : operand))
}
}
std::tie(decoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(decoderOutputH, hidden_dim, hidden_dim, recurrenceHookH, recurrenceHookC, device)
}
decoderOutput = decoderOutputH
decoderDim = hidden_dim
/* Softmax output layer */
outputLayerProjWeights = Parameter(NDArrayView::RandomUniform<float>({ label_vocab_dim, decoderDim }, -0.05, 0.05, 1, device))
biasWeights = Parameter({ label_vocab_dim }, 0.0f, device)
z = Plus(Times(outputLayerProjWeights, Stabilize<float>(decoderOutput, device)), biasWeights, L"classifierOutput")
ce = CrossEntropyWithSoftmax(z, label_sequence, L"lossFunction")
errs = ClassificationError(z, label_sequence, L"classificationError")
input_dim = 2000
cell_dim = 25
hidden_dim = 25
embedding_dim = 50
num_output_classes = 5
# Input variables denoting the features and label data
features = variable(shape=input_dim, is_sparse=True, name="features")
label = variable(num_output_classes, dynamic_axes = [Axis.default_batch_axis()], name="labels")
# Instantiate the sequence classification model
classifier_output = LSTM_sequence_classifer_net(features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
ce = cross_entropy_with_softmax(classifier_output, label)
pe = classification_error(classifier_output, label)
rel_path = r"../../../../Tests/EndToEndTests/Text/SequenceClassification/Data/Train.ctf"
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
mb_source = text_minibatch_source(path, [ ( 'features', input_dim, True, 'x' ), ( 'labels', num_output_classes, False, 'y' ) ], 0)
features_si = mb_source.stream_info(features)
labels_si = mb_source.stream_info(label)
# Instantiate the trainer object to drive the model training
lr = lr = learning_rates_per_sample(0.0005)
trainer = Trainer(classifier_output, ce, pe, [sgdlearner(classifier_output.owner.parameters(), lr)])
# Get minibatches of sequences to train with and perform model training
minibatch_size = 200
training_progress_output_freq = 1
i = 0
while True:
mb = mb_source.get_next_minibatch(minibatch_size)
if len(mb) == 0:
break
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {features : mb[features_si].m_data, label : mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(training_progress_output_freq, i, trainer)
i += 1
if __name__=='__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor.cpu_device()
DeviceDescriptor.set_default_device(target_device)
train_sequence_classifier()

Просмотреть файл

@ -27,11 +27,10 @@ def train_sequence_classifier():
hidden_dim = 25;
embedding_dim = 50;
num_output_classes = 5;
feature_stream_name = 'features'
labels_stream_name = 'labels'
# Input variables denoting the features and label data
features = input_variable(shape=input_dim, is_sparse=True, name=feature_stream_name)
label = input_variable(num_output_classes, dynamic_axes = [Axis.default_batch_axis()], name=labels_stream_name)
features = input_variable(shape=input_dim, is_sparse=True)
label = input_variable(num_output_classes, dynamic_axes = [Axis.default_batch_axis()])
# Instantiate the sequence classification model
classifier_output = LSTM_sequence_classifer_net(features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
@ -41,12 +40,13 @@ def train_sequence_classifier():
rel_path = r"../../../../Tests/EndToEndTests/Text/SequenceClassification/Data/Train.ctf"
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
feature_stream_name = 'features'
labels_stream_name = 'labels'
mb_source = text_format_minibatch_source(path, list([
StreamConfiguration( feature_stream_name, input_dim, True, 'x' ),
StreamConfiguration( labels_stream_name, num_output_classes, False, 'y')]), 0)
features_si = mb_source.stream_info(features)
labels_si = mb_source.stream_info(label)
@ -66,13 +66,14 @@ def train_sequence_classifier():
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {features : mb[features_si].m_data, label : mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(i, trainer, training_progress_output_freq)
i += 1
if __name__=='__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor_cpudevice()
target_device = DeviceDescriptor.cpudevice()
DeviceDescriptor.set_default_device(target_device)
train_sequence_classifier()

Просмотреть файл

@ -33,7 +33,6 @@ def ffnet():
feature_stream_name = 'features'
labels_stream_name = 'labels'
mb_source = text_format_minibatch_source(path, list([
StreamConfiguration( feature_stream_name, input_dim ),
StreamConfiguration( labels_stream_name, num_output_classes)]))
@ -58,11 +57,9 @@ def ffnet():
trainer.train_minibatch(arguments)
print_training_progress(i, trainer, training_progress_output_freq)
if __name__=='__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor_cpudevice()
target_device = DeviceDescriptor.cpudevice()
DeviceDescriptor.set_default_device(target_device)
ffnet()

Просмотреть файл

@ -81,6 +81,14 @@ def embedding(input, embedding_dim):
def select_last(operand):
return slice(operand, Axis.default_dynamic_axis(), -1, 0)
def stabilize(operand):
scalar_constant = 4.0
f = Constant.scalar(scalar_constant);
fInv = Constant.scalar(f.get_data_type(), 1.0 / scalar_constant)
beta = element_times(fInv, log(Constant.scalar(f.get_data_type(), 1.0) + exp(element_times(f, parameter(shape=(), dtype=f.get_data_type(), init_value=0.99537863)))))
return element_times(beta, operand)
def LSTMP_cell_with_self_stabilization(input, prev_output, prev_cell_state):
input_dim = input.shape()[0]
output_dim = prev_output.shape()[0];
@ -173,8 +181,8 @@ def LSTMP_component_with_self_stabilization(input, output_dim, cell_dim):
dc = placeholder_variable(shape=(cell_dim))
LSTMCell = LSTMP_cell_with_self_stabilization(input, dh, dc)
actualDh = past_value(LSTMCell[0], constant((), 0.0), 1);
actualDc = past_value(LSTMCell[1], constant((), 0.0), 1);
actualDh = past_value(LSTMCell[0]);
actualDc = past_value(LSTMCell[1]);
# Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
return LSTMCell[0].owner.replace_placeholders({ dh : actualDh, dc : actualDc})