Adding tests to layers library, removing previous non-layers constructors and updating distributed seq2seq.
This commit is contained in:
Родитель
4812079700
Коммит
100ae50494
7
CNTK.sln
7
CNTK.sln
|
@ -1040,12 +1040,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "HelloWorld-LogisticRegressi
|
|||
Tutorials\HelloWorld-LogisticRegression\Train_cntk_text.txt = Tutorials\HelloWorld-LogisticRegression\Train_cntk_text.txt
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "common", "common", "{9A7E977F-0A11-43F6-8FBB-D5697E7DC9B2}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Examples\common\__init__.py = Examples\common\__init__.py
|
||||
Examples\common\nn.py = Examples\common\nn.py
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Regression", "Regression", "{B1F7024F-51E9-43E7-99BE-8190CBC0B129}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Examples\Image\Regression\README.md = Examples\Image\Regression\README.md
|
||||
|
@ -2124,7 +2118,6 @@ Global
|
|||
{A22E7B97-B4D2-43EA-AD53-307FA767A38D} = {305456F0-D9DE-4452-87BE-1C9F3C34C14F}
|
||||
{2DD4DF97-4379-4D5F-9C1D-7AAC59E47796} = {305456F0-D9DE-4452-87BE-1C9F3C34C14F}
|
||||
{8D116405-E726-4BF9-B2F7-30CA52CD59C7} = {305456F0-D9DE-4452-87BE-1C9F3C34C14F}
|
||||
{9A7E977F-0A11-43F6-8FBB-D5697E7DC9B2} = {47755F2E-D674-4175-9E38-8EA053455072}
|
||||
{B1F7024F-51E9-43E7-99BE-8190CBC0B129} = {9BDFA4BE-790E-408F-915B-5979BB5078C6}
|
||||
{219A815E-11F4-4C01-9025-342B07768399} = {9BDFA4BE-790E-408F-915B-5979BB5078C6}
|
||||
{8A624183-63DB-4221-81CD-E8577AD72868} = {9BDFA4BE-790E-408F-915B-5979BB5078C6}
|
||||
|
|
|
@ -13,14 +13,13 @@ from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INF
|
|||
from cntk.device import cpu, try_set_default_device
|
||||
from cntk.learners import adadelta, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input, relu, element_times, constant
|
||||
from cntk.layers import Dense, Sequential, For
|
||||
from cntk.losses import cross_entropy_with_softmax
|
||||
from cntk.metrics import classification_error
|
||||
from cntk.train.training_session import *
|
||||
from cntk.logging import ProgressPrinter, TensorBoardProgressWriter
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", "..", "..", "..", "common"))
|
||||
from nn import fully_connected_classifier_net
|
||||
|
||||
def check_path(path):
|
||||
if not os.path.exists(path):
|
||||
|
@ -50,8 +49,9 @@ def simple_mnist(tensorboard_logdir=None):
|
|||
|
||||
# Instantiate the feedforward classification model
|
||||
scaled_input = element_times(constant(0.00390625), feature)
|
||||
z = fully_connected_classifier_net(
|
||||
scaled_input, num_output_classes, hidden_layers_dim, num_hidden_layers, relu)
|
||||
|
||||
z = Sequential([For(range(num_hidden_layers), lambda i: Dense(hidden_layers_dim, activation=relu)),
|
||||
Dense(num_output_classes)])(scaled_input)
|
||||
|
||||
ce = cross_entropy_with_softmax(z, label)
|
||||
pe = classification_error(z, label)
|
||||
|
|
|
@ -11,12 +11,11 @@ from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INF
|
|||
from cntk.device import cpu, try_set_default_device
|
||||
from cntk.learners import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input, sequence
|
||||
from cntk.logging import ProgressPrinter
|
||||
from cntk.losses import cross_entropy_with_softmax
|
||||
from cntk.metrics import classification_error
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", "..", "..", "common"))
|
||||
from nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, print_training_progress
|
||||
from cntk.layers import Sequential, Embedding, Recurrence, LSTM, Dense
|
||||
|
||||
# Creates the reader
|
||||
def create_reader(path, is_training, input_dim, label_dim):
|
||||
|
@ -26,16 +25,15 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
)), randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
# Defines the LSTM model for classifying sequences
|
||||
def LSTM_sequence_classifer_net(feature, num_output_classes, embedding_dim, LSTM_dim, cell_dim):
|
||||
embedding_function = embedding(feature, embedding_dim)
|
||||
LSTM_function = LSTMP_component_with_self_stabilization(
|
||||
embedding_function.output, LSTM_dim, cell_dim)[0]
|
||||
thought_vector = sequence.last(LSTM_function)
|
||||
|
||||
return linear_layer(thought_vector, num_output_classes)
|
||||
def LSTM_sequence_classifier_net(feature, num_output_classes, embedding_dim, LSTM_dim, cell_dim):
|
||||
lstm_classifier = Sequential([Embedding(embedding_dim),
|
||||
Recurrence(LSTM(LSTM_dim, cell_dim))[0],
|
||||
sequence.last,
|
||||
Dense(num_output_classes)])
|
||||
return lstm_classifier(feature)
|
||||
|
||||
# Creates and trains a LSTM sequence classification model
|
||||
def train_sequence_classifier(debug_output=False):
|
||||
def train_sequence_classifier():
|
||||
input_dim = 2000
|
||||
cell_dim = 25
|
||||
hidden_dim = 25
|
||||
|
@ -47,7 +45,7 @@ def train_sequence_classifier(debug_output=False):
|
|||
label = input(num_output_classes)
|
||||
|
||||
# Instantiate the sequence classification model
|
||||
classifier_output = LSTM_sequence_classifer_net(
|
||||
classifier_output = LSTM_sequence_classifier_net(
|
||||
features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
|
||||
|
||||
ce = cross_entropy_with_softmax(classifier_output, label)
|
||||
|
@ -64,21 +62,19 @@ def train_sequence_classifier(debug_output=False):
|
|||
}
|
||||
|
||||
lr_per_sample = learning_rate_schedule(0.0005, UnitType.sample)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
progress_printer = ProgressPrinter(10)
|
||||
trainer = Trainer(classifier_output, (ce, pe),
|
||||
sgd(classifier_output.parameters, lr=lr_per_sample))
|
||||
sgd(classifier_output.parameters, lr=lr_per_sample),
|
||||
progress_printer)
|
||||
|
||||
# Get minibatches of sequences to train with and perform model training
|
||||
minibatch_size = 200
|
||||
training_progress_output_freq = 10
|
||||
|
||||
if debug_output:
|
||||
training_progress_output_freq = training_progress_output_freq/3
|
||||
|
||||
for i in range(251):
|
||||
mb = reader.next_minibatch(minibatch_size, input_map=input_map)
|
||||
trainer.train_minibatch(mb)
|
||||
print_training_progress(trainer, i, training_progress_output_freq)
|
||||
|
||||
import copy
|
||||
|
||||
|
|
|
@ -12,24 +12,25 @@ import argparse
|
|||
import _cntk_py
|
||||
import cntk
|
||||
|
||||
from cntk import Trainer, Axis
|
||||
from cntk.device import try_set_default_device, gpu
|
||||
from cntk.train.distributed import *
|
||||
from cntk import Trainer
|
||||
from cntk.train.distributed import Communicator, data_parallel_distributed_learner, block_momentum_distributed_learner
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.learners import learning_rate_schedule, UnitType, momentum_sgd, momentum_as_time_constant_schedule
|
||||
from cntk import input, cross_entropy_with_softmax, classification_error, sequence, element_select, alias, hardmax
|
||||
from cntk.ops.functions import CloneMethod
|
||||
from cntk.learners import fsadagrad, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
|
||||
from cntk.train.training_session import *
|
||||
from cntk.logging import *
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
sys.path.append(os.path.join(abs_path, "..", "..", "..", "common"))
|
||||
from nn import LSTMP_component_with_self_stabilization, stabilize, linear_layer, print_training_progress
|
||||
|
||||
default_quantization_bits = 32
|
||||
|
||||
def create_reader(path, randomize, input_vocab_dim, label_vocab_dim, size=INFINITELY_REPEAT):
|
||||
# model dimensions
|
||||
input_vocab_dim = 69
|
||||
label_vocab_dim = 69
|
||||
|
||||
use_attention = True
|
||||
|
||||
def create_reader(path, randomize, size=INFINITELY_REPEAT):
|
||||
if not os.path.exists(path):
|
||||
raise RuntimeError("File '%s' does not exist." % (path))
|
||||
|
||||
|
@ -38,115 +39,31 @@ def create_reader(path, randomize, input_vocab_dim, label_vocab_dim, size=INFINI
|
|||
labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
|
||||
)), randomize=randomize, max_samples = size)
|
||||
|
||||
def create_trainer(network, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer):
|
||||
# Instantiate the trainer object to drive the model training
|
||||
lr_per_minibatch = learning_rate_schedule(0.5, UnitType.minibatch)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
clipping_threshold_per_sample = 2.3
|
||||
gradient_clipping_with_truncation = True
|
||||
def train_and_test(s2smodel, train_reader, test_reader, block_size, num_quantization_bits, max_epochs, epoch_size, minibatch_size, progress_printer, warm_up):
|
||||
from Sequence2Sequence import create_criterion_function, create_model_train
|
||||
model_train = create_model_train(s2smodel)
|
||||
criterion = create_criterion_function(model_train)
|
||||
|
||||
# Create learner
|
||||
if block_size is not None and num_quantization_bits != default_quantization_bits:
|
||||
raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.")
|
||||
|
||||
local_learner = momentum_sgd(network['output'].parameters,
|
||||
lr_per_minibatch, momentum_time_constant,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
lr = 0.001 if use_attention else 0.005 # TODO: can we use the same value for both?
|
||||
local_learner = fsadagrad(model_train.parameters,
|
||||
lr = learning_rate_schedule([lr]*2+[lr/2]*3+[lr/4], UnitType.sample, epoch_size),
|
||||
momentum = momentum_as_time_constant_schedule(1100),
|
||||
gradient_clipping_threshold_per_sample=2.3,
|
||||
gradient_clipping_with_truncation=True)
|
||||
|
||||
if block_size != None:
|
||||
learner = block_momentum_distributed_learner(local_learner, block_size=block_size)
|
||||
else:
|
||||
learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up)
|
||||
|
||||
return Trainer(network['output'], (network['ce'], network['pe']), learner, progress_printer)
|
||||
trainer = Trainer(None, criterion, learner, progress_printer)
|
||||
|
||||
def create_network(input_vocab_dim, label_vocab_dim):
|
||||
# network complexity; initially low for faster testing
|
||||
hidden_dim = 256
|
||||
num_layers = 1
|
||||
|
||||
# Source and target inputs to the model
|
||||
input_seq_axis = Axis('inputAxis')
|
||||
label_seq_axis = Axis('labelAxis')
|
||||
raw_input = sequence.input(shape=(input_vocab_dim), sequence_axis=input_seq_axis, name='raw_input')
|
||||
raw_labels = sequence.input(shape=(label_vocab_dim), sequence_axis=label_seq_axis, name='raw_labels')
|
||||
|
||||
# Instantiate the sequence to sequence translation model
|
||||
input_sequence = raw_input
|
||||
|
||||
# Drop the sentence start token from the label, for decoder training
|
||||
label_sequence = sequence.slice(raw_labels, 1, 0) # <s> A B C </s> --> A B C </s>
|
||||
label_sentence_start = sequence.first(raw_labels) # <s>
|
||||
|
||||
is_first_label = sequence.is_first(label_sequence) # <s> 0 0 0 ...
|
||||
label_sentence_start_scattered = sequence.scatter(
|
||||
label_sentence_start, is_first_label)
|
||||
|
||||
# Encoder
|
||||
encoder_outputH = stabilize(input_sequence)
|
||||
for i in range(0, num_layers):
|
||||
(encoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(
|
||||
encoder_outputH.output, hidden_dim, hidden_dim, sequence.future_value, sequence.future_value)
|
||||
|
||||
thought_vectorH = sequence.first(encoder_outputH)
|
||||
thought_vectorC = sequence.first(encoder_outputC)
|
||||
|
||||
thought_vector_broadcastH = sequence.broadcast_as(
|
||||
thought_vectorH, label_sequence)
|
||||
thought_vector_broadcastC = sequence.broadcast_as(
|
||||
thought_vectorC, label_sequence)
|
||||
|
||||
# Decoder
|
||||
decoder_history_hook = alias(label_sequence, name='decoder_history_hook') # copy label_sequence
|
||||
|
||||
decoder_input = element_select(is_first_label, label_sentence_start_scattered, sequence.past_value(
|
||||
decoder_history_hook))
|
||||
|
||||
decoder_outputH = stabilize(decoder_input)
|
||||
for i in range(0, num_layers):
|
||||
if (i > 0):
|
||||
recurrence_hookH = sequence.past_value
|
||||
recurrence_hookC = sequence.past_value
|
||||
else:
|
||||
isFirst = sequence.is_first(label_sequence)
|
||||
recurrence_hookH = lambda operand: element_select(
|
||||
isFirst, thought_vector_broadcastH, sequence.past_value(operand))
|
||||
recurrence_hookC = lambda operand: element_select(
|
||||
isFirst, thought_vector_broadcastC, sequence.past_value(operand))
|
||||
|
||||
(decoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(
|
||||
decoder_outputH.output, hidden_dim, hidden_dim, recurrence_hookH, recurrence_hookC)
|
||||
|
||||
decoder_output = decoder_outputH
|
||||
|
||||
# Softmax output layer
|
||||
z = linear_layer(stabilize(decoder_output), label_vocab_dim)
|
||||
|
||||
# Criterion nodes
|
||||
ce = cross_entropy_with_softmax(z, label_sequence)
|
||||
errs = classification_error(z, label_sequence)
|
||||
|
||||
# network output for decoder history
|
||||
net_output = hardmax(z)
|
||||
|
||||
# make a clone of the graph where the ground truth is replaced by the network output
|
||||
ng = z.clone(CloneMethod.share, {decoder_history_hook.output : net_output.output})
|
||||
|
||||
return {
|
||||
'raw_input' : raw_input,
|
||||
'raw_labels' : raw_labels,
|
||||
'ce' : ce,
|
||||
'pe' : errs,
|
||||
'ng' : ng,
|
||||
'output': z
|
||||
}
|
||||
|
||||
def train_and_test(network, trainer, train_reader, test_reader, epoch_size, minibatch_size):
|
||||
train_bind = {
|
||||
network['raw_input'] : train_reader.streams.features,
|
||||
network['raw_labels'] : train_reader.streams.labels
|
||||
}
|
||||
train_bind = {criterion.arguments[0]: train_reader.streams.features,
|
||||
criterion.arguments[1]: train_reader.streams.labels}
|
||||
|
||||
training_session(
|
||||
mb_source = train_reader,
|
||||
|
@ -162,6 +79,10 @@ def train_and_test(network, trainer, train_reader, test_reader, epoch_size, mini
|
|||
|
||||
def sequence_to_sequence_translator(train_data, test_data, epoch_size=908241, num_quantization_bits=default_quantization_bits, block_size=3200, warm_up=0, minibatch_size=72, max_epochs=10, randomize_data=False, log_to_file=None, num_mbs_per_log=10, gen_heartbeat=False):
|
||||
cntk.debugging.set_computation_network_trace_level(0)
|
||||
from _cntk_py import set_fixed_random_seed
|
||||
set_fixed_random_seed(1)
|
||||
|
||||
from Sequence2Sequence import create_model
|
||||
|
||||
distributed_sync_report_freq = None
|
||||
if block_size is not None:
|
||||
|
@ -175,17 +96,13 @@ def sequence_to_sequence_translator(train_data, test_data, epoch_size=908241, nu
|
|||
num_epochs=max_epochs,
|
||||
distributed_freq=distributed_sync_report_freq)
|
||||
|
||||
input_vocab_dim = 69
|
||||
label_vocab_dim = 69
|
||||
# create inputs and create model
|
||||
model = create_model()
|
||||
|
||||
network = create_network(input_vocab_dim, label_vocab_dim)
|
||||
trainer = create_trainer(network, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer)
|
||||
train_reader = create_reader(train_data, randomize_data, size=max_epochs*epoch_size)
|
||||
test_reader = create_reader(test_data, False, size=max_epochs*epoch_size*10)
|
||||
|
||||
train_reader = create_reader(train_data, randomize_data, input_vocab_dim, label_vocab_dim, size=max_epochs*epoch_size)
|
||||
|
||||
test_reader = create_reader(test_data, False, input_vocab_dim, label_vocab_dim, size=cntk.io.FULL_DATA_SWEEP)
|
||||
|
||||
train_and_test(network, trainer, train_reader, test_reader, epoch_size, minibatch_size)
|
||||
train_and_test(model, train_reader, test_reader, block_size, num_quantization_bits, max_epochs, epoch_size, minibatch_size, progress_printer, warm_up)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_path = os.path.join(abs_path, "..", "Data")
|
||||
|
@ -209,7 +126,7 @@ if __name__ == '__main__':
|
|||
if args['outputdir'] is not None:
|
||||
model_path = args['outputdir'] + "/models"
|
||||
if args['device'] is not None:
|
||||
try_set_default_device(gpu(args['device']))
|
||||
cntk.device.try_set_default_device(cntk.device.gpu(args['device']))
|
||||
|
||||
data_path = args['datadir']
|
||||
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
# Note: the files in the 'common' folder are deprecated and will be replaced by the layer library
|
||||
|
||||
import numpy as np
|
||||
from cntk.ops import *
|
||||
from cntk.initializer import glorot_uniform, he_normal
|
||||
|
||||
|
||||
def linear_layer(input_var, output_dim):
|
||||
times_param = parameter(shape=(list(input_var.shape)+[output_dim]), init=glorot_uniform())
|
||||
bias_param = parameter(shape=(output_dim), init=0)
|
||||
|
||||
t = times(input_var, times_param)
|
||||
return bias_param + t
|
||||
|
||||
|
||||
def fully_connected_layer(input, output_dim, nonlinearity):
|
||||
p = linear_layer(input, output_dim)
|
||||
return nonlinearity(p)
|
||||
|
||||
# Defines a multilayer feedforward classification model
|
||||
|
||||
|
||||
def fully_connected_classifier_net(input, num_output_classes, hidden_layer_dim, num_hidden_layers, nonlinearity):
|
||||
r = fully_connected_layer(input, hidden_layer_dim, nonlinearity)
|
||||
for i in range(1, num_hidden_layers):
|
||||
r = fully_connected_layer(r, hidden_layer_dim, nonlinearity)
|
||||
|
||||
return linear_layer(r, num_output_classes)
|
||||
|
||||
|
||||
def conv_bn_layer(input, out_feature_map_count, kernel_shape, strides, bn_time_const, b_value=0, sc_value=1):
|
||||
num_in_channels = input.shape[0]
|
||||
kernel_width = kernel_shape[0]
|
||||
kernel_height = kernel_shape[1]
|
||||
v_stride = strides[0]
|
||||
h_stride = strides[1]
|
||||
#TODO: use RandomNormal to initialize, needs to be exposed in the python api
|
||||
conv_params = parameter(shape=(out_feature_map_count, num_in_channels, kernel_height, kernel_width), init=he_normal())
|
||||
conv_func = convolution(conv_params, input, (num_in_channels, v_stride, h_stride))
|
||||
|
||||
#TODO: initialize using b_value and sc_value, needs to be exposed in the python api
|
||||
bias_params = parameter(shape=(out_feature_map_count), init=b_value)
|
||||
scale_params = parameter(shape=(out_feature_map_count), init=sc_value)
|
||||
running_mean = constant(0., (out_feature_map_count))
|
||||
running_invstd = constant(0., (out_feature_map_count))
|
||||
running_count = constant(0., (1))
|
||||
return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, running_count=running_count, spatial=True,
|
||||
normalization_time_constant=bn_time_const, use_cudnn_engine=True)
|
||||
|
||||
|
||||
def conv_bn_relu_layer(input, out_feature_map_count, kernel_shape, strides, bn_time_const, b_value=0, sc_value=1):
|
||||
conv_bn_function = conv_bn_layer(input, out_feature_map_count, kernel_shape, strides, bn_time_const, b_value, sc_value)
|
||||
return relu(conv_bn_function)
|
||||
|
||||
|
||||
def embedding(input, embedding_dim):
|
||||
input_dim = input.shape[0]
|
||||
|
||||
embedding_parameters = parameter(shape=(input_dim, embedding_dim), init=glorot_uniform())
|
||||
return times(input, embedding_parameters)
|
||||
|
||||
|
||||
def select_last(operand):
|
||||
return slice(operand, Axis.default_dynamic_axis(), -1, 0)
|
||||
|
||||
|
||||
def stabilize(operand):
|
||||
scalar_constant = 4.0
|
||||
f = constant(scalar_constant)
|
||||
fInv = constant(1.0 / scalar_constant)
|
||||
|
||||
beta = element_times(fInv,
|
||||
log(1.0 + exp(element_times(f, parameter(init=0.99537863)))))
|
||||
return element_times(beta, operand)
|
||||
|
||||
|
||||
def LSTMP_cell_with_self_stabilization(input, prev_output, prev_cell_state):
|
||||
input_dim = input.shape[0]
|
||||
output_dim = prev_output.shape[0]
|
||||
cell_dim = prev_cell_state.shape[0]
|
||||
|
||||
Wxo = parameter(shape=(input_dim, cell_dim), init=glorot_uniform())
|
||||
Wxi = parameter(shape=(input_dim, cell_dim), init=glorot_uniform())
|
||||
Wxf = parameter(shape=(input_dim, cell_dim), init=glorot_uniform())
|
||||
Wxc = parameter(shape=(input_dim, cell_dim), init=glorot_uniform())
|
||||
|
||||
Bo = parameter(shape=(cell_dim), init=0)
|
||||
Bc = parameter(shape=(cell_dim), init=0)
|
||||
Bi = parameter(shape=(cell_dim), init=0)
|
||||
Bf = parameter(shape=(cell_dim), init=0)
|
||||
|
||||
Whi = parameter(shape=(output_dim, cell_dim), init=glorot_uniform())
|
||||
Wci = parameter(shape=(cell_dim), init=glorot_uniform())
|
||||
|
||||
Whf = parameter(shape=(output_dim, cell_dim), init=glorot_uniform())
|
||||
Wcf = parameter(shape=(cell_dim), init=glorot_uniform())
|
||||
|
||||
Who = parameter(shape=(output_dim, cell_dim), init=glorot_uniform())
|
||||
Wco = parameter(shape=(cell_dim), init=glorot_uniform())
|
||||
|
||||
Whc = parameter(shape=(output_dim, cell_dim), init=glorot_uniform())
|
||||
|
||||
Wmr = parameter(shape=(cell_dim, output_dim), init=glorot_uniform())
|
||||
|
||||
# Stabilization by routing input through an extra scalar parameter
|
||||
sWxo = parameter(init=0)
|
||||
sWxi = parameter(init=0)
|
||||
sWxf = parameter(init=0)
|
||||
sWxc = parameter(init=0)
|
||||
|
||||
sWhi = parameter(init=0)
|
||||
sWci = parameter(init=0)
|
||||
|
||||
sWhf = parameter(init=0)
|
||||
sWcf = parameter(init=0)
|
||||
sWho = parameter(init=0)
|
||||
sWco = parameter(init=0)
|
||||
sWhc = parameter(init=0)
|
||||
|
||||
sWmr = parameter(init=0)
|
||||
|
||||
expsWxo = exp(sWxo)
|
||||
expsWxi = exp(sWxi)
|
||||
expsWxf = exp(sWxf)
|
||||
expsWxc = exp(sWxc)
|
||||
|
||||
expsWhi = exp(sWhi)
|
||||
expsWci = exp(sWci)
|
||||
|
||||
expsWhf = exp(sWhf)
|
||||
expsWcf = exp(sWcf)
|
||||
expsWho = exp(sWho)
|
||||
expsWco = exp(sWco)
|
||||
expsWhc = exp(sWhc)
|
||||
|
||||
expsWmr = exp(sWmr)
|
||||
|
||||
Wxix = times(element_times(expsWxi, input), Wxi)
|
||||
Whidh = times(element_times(expsWhi, prev_output), Whi)
|
||||
Wcidc = element_times(Wci, element_times(expsWci, prev_cell_state))
|
||||
|
||||
it = sigmoid(Wxix + Bi + Whidh + Wcidc)
|
||||
Wxcx = times(element_times(expsWxc, input), Wxc)
|
||||
Whcdh = times(element_times(expsWhc, prev_output), Whc)
|
||||
bit = element_times(it, tanh(Wxcx + Whcdh + Bc))
|
||||
Wxfx = times(element_times(expsWxf, input), Wxf)
|
||||
Whfdh = times(element_times(expsWhf, prev_output), Whf)
|
||||
Wcfdc = element_times(Wcf, element_times(expsWcf, prev_cell_state))
|
||||
|
||||
ft = sigmoid(Wxfx + Bf + Whfdh + Wcfdc)
|
||||
bft = element_times(ft, prev_cell_state)
|
||||
|
||||
ct = bft + bit
|
||||
|
||||
Wxox = times(element_times(expsWxo, input), Wxo)
|
||||
Whodh = times(element_times(expsWho, prev_output), Who)
|
||||
Wcoct = element_times(Wco, element_times(expsWco, ct))
|
||||
|
||||
ot = sigmoid(Wxox + Bo + Whodh + Wcoct)
|
||||
|
||||
mt = element_times(ot, tanh(ct))
|
||||
return (times(element_times(expsWmr, mt), Wmr), ct)
|
||||
|
||||
|
||||
def LSTMP_component_with_self_stabilization(input, output_dim, cell_dim, recurrence_hookH=sequence.past_value, recurrence_hookC=sequence.past_value):
|
||||
dh = placeholder(
|
||||
shape=(output_dim), dynamic_axes=input.dynamic_axes)
|
||||
dc = placeholder(
|
||||
shape=(cell_dim), dynamic_axes=input.dynamic_axes)
|
||||
|
||||
LSTMCell = LSTMP_cell_with_self_stabilization(input, dh, dc)
|
||||
actualDh = recurrence_hookH(LSTMCell[0])
|
||||
actualDc = recurrence_hookC(LSTMCell[1])
|
||||
|
||||
# Form the recurrence loop by replacing the dh and dc placeholders with
|
||||
# the actualDh and actualDc
|
||||
LSTMCell[0].replace_placeholders(
|
||||
{dh: actualDh.output, dc: actualDc.output})
|
||||
|
||||
return (LSTMCell[0], LSTMCell[1])
|
||||
|
||||
|
||||
def print_training_progress(trainer, mb, frequency):
|
||||
|
||||
if mb % frequency == 0:
|
||||
training_loss = trainer.previous_minibatch_loss_average
|
||||
eval_crit = trainer.previous_minibatch_evaluation_average
|
||||
print("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(
|
||||
mb, training_loss, eval_crit))
|
|
@ -25,16 +25,16 @@ def test_sequence_to_sequence_distributed_1bitsgd(device_id):
|
|||
params = [ "-e", "2",
|
||||
"-datadir", cmudict_dataset_directory(),
|
||||
"-q", "1",
|
||||
"-ms", "72",
|
||||
"-ms", "100",
|
||||
"-es", "500",
|
||||
"-device", str(device_id) ]
|
||||
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.8622, False, 0, 2E-2)
|
||||
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.8625, False, 0, 2E-2)
|
||||
|
||||
def test_sequence_to_sequence_distributed_block_momentum(device_id):
|
||||
params = [ "-e", "2",
|
||||
params = [ "-e", "4",
|
||||
"-datadir", cmudict_dataset_directory(),
|
||||
"-ms", "72",
|
||||
"-es", "100",
|
||||
"-ms", "100",
|
||||
"-es", "1000",
|
||||
"-b", "3200",
|
||||
"-device", str(device_id) ]
|
||||
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.97, False, 1, 2E-2)
|
||||
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.8612, False, 1, 2E-2)
|
||||
|
|
|
@ -9,16 +9,13 @@ import sys
|
|||
import os
|
||||
from cntk.device import cpu, try_set_default_device
|
||||
from cntk import Trainer
|
||||
from cntk.layers import Dense, Sequential, For
|
||||
from cntk.learners import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input, sigmoid
|
||||
from cntk.losses import cross_entropy_with_softmax
|
||||
from cntk.metrics import classification_error
|
||||
from cntk.logging import ProgressPrinter
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", "..", "Examples", "common"))
|
||||
from nn import fully_connected_classifier_net
|
||||
|
||||
# make sure we get always the same "randomness"
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -47,9 +44,8 @@ def ffnet():
|
|||
feature = input((input_dim), np.float32)
|
||||
label = input((num_output_classes), np.float32)
|
||||
|
||||
# Instantiate the feedforward classification model
|
||||
netout = fully_connected_classifier_net(
|
||||
feature, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
|
||||
netout = Sequential([For(range(num_hidden_layers), lambda i: Dense(hidden_layers_dim, activation=sigmoid)),
|
||||
Dense(num_output_classes)])(feature)
|
||||
|
||||
ce = cross_entropy_with_softmax(netout, label)
|
||||
pe = classification_error(netout, label)
|
||||
|
|
|
@ -39,7 +39,7 @@ __doc__ = '''
|
|||
In order to debug a graph one simply needs to wrap the root node as follows::
|
||||
|
||||
# ... setting up the model in z
|
||||
from cntk.debug import debug_model
|
||||
from cntk.debugging import debug_model
|
||||
z = debug_model(z)
|
||||
|
||||
Then, when ``z`` is evaluated or trained (i.e. when either
|
||||
|
|
|
@ -68,7 +68,7 @@ def _get_initial_state_or_default(initial_state):
|
|||
|
||||
def BlockFunction(op_name, name):
|
||||
'''
|
||||
Decorator for defining a @Function as a BlockFunction. Same as @Function, but wrap the content into an as_block().
|
||||
Decorator for defining a @Function as a BlockFunction. Same as @Function, but wrap the content into an :func:`~cntk.ops.as_block`.
|
||||
'''
|
||||
return lambda f: Function(f, make_block=True, op_name=op_name, name=name)
|
||||
|
||||
|
@ -86,7 +86,7 @@ def _inject_name(f, name):
|
|||
def ForwardDeclaration(name='forward_declaration'):
|
||||
'''
|
||||
Helper for recurrent network declarations.
|
||||
Returns a placeholder variable with an added method resolve_to() to be called
|
||||
Returns a placeholder variable with an added method ``resolve_to()`` to be called
|
||||
at the end to close the loop.
|
||||
This is used for explicit graph building with recurrent connections.
|
||||
|
||||
|
|
|
@ -9,6 +9,9 @@ higher_order_layers -- higher-order functions, like Sequential() and ResNetBlock
|
|||
Note that sequential higher-order functions like Recurrence() are in sequence.py.
|
||||
'''
|
||||
|
||||
from types import FunctionType
|
||||
from inspect import getargspec
|
||||
|
||||
from ..variables import Record
|
||||
from .blocks import *
|
||||
from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED, _inject_name
|
||||
|
@ -141,18 +144,23 @@ def For(what_range, constructor, name=''):
|
|||
constructor (Python function/lambda with 1 or 0 arguments): lambda that constructs a layer
|
||||
|
||||
Returns:
|
||||
cntk.ops.functions.Function:
|
||||
cntk.ops.functions.Function:
|
||||
A function that accepts one argument and applies the layers as constructed by ``constructor`` one after another.
|
||||
'''
|
||||
# Python 2.7 support requires us to use getargspec() instead of inspect
|
||||
from inspect import getargspec
|
||||
takes_arg = len(getargspec(constructor).args) > 0
|
||||
|
||||
# For Python 3, check if it is a python function/lambda
|
||||
if type(constructor) != FunctionType or not callable(constructor):
|
||||
raise ValueError("constructor must be a Python function/lambda")
|
||||
|
||||
# helper to call the layer constructor
|
||||
def call(i):
|
||||
if takes_arg:
|
||||
return constructor(i) # takes an arg: pass it
|
||||
else:
|
||||
return constructor() # takes no arg: call without, that's fine too
|
||||
|
||||
layers = [call(i) for i in what_range]
|
||||
sequential = Sequential(layers)
|
||||
|
||||
|
@ -171,9 +179,24 @@ def SequentialClique(functions, name=''):
|
|||
'''
|
||||
SequentialClique(functions, name='')
|
||||
|
||||
Layer factory function to create a composite that applies a sequence of or any functions onto an input,
|
||||
Layer factory function to create a composite that applies a sequence of functions onto an input,
|
||||
with skip connections between all function. I.e. each function receives a sum of the input and all
|
||||
prior functions' outputs.
|
||||
|
||||
Example:
|
||||
>>> from cntk.layers import *
|
||||
>>> from cntk.ops import abs, sqrt, square
|
||||
>>> x = input(2)
|
||||
>>> seq_clique = SequentialClique([abs, sqrt, square])
|
||||
>>> seq_clique(x).eval(np.array([2, 8], np.float32)) # 400 = square((8 + abs(8)) + sqrt(8 + abs(8)))
|
||||
array([[ 36., 400.]], dtype=float32)
|
||||
|
||||
Args:
|
||||
functions (single or list of :class:`~cntk.ops.functions.Function`): functions to be applied.
|
||||
|
||||
Returns:
|
||||
cntk.ops.functions.Function:
|
||||
A function that accepts one argument and applies the sequence of functions.
|
||||
'''
|
||||
def clique(x):
|
||||
for f in functions:
|
||||
|
@ -207,7 +230,7 @@ def ResNetBlock(f, name=''):
|
|||
the function to add the skip connection to.
|
||||
|
||||
Returns:
|
||||
cntk.ops.functions.Function:
|
||||
cntk.ops.functions.Function:
|
||||
A function that accepts one argument, applies ``f`` to it, and adds the original argument.
|
||||
'''
|
||||
def skip(x):
|
||||
|
|
|
@ -338,7 +338,7 @@ def Convolution(filter_shape, # shape of receptive field, e.g. (3,3)
|
|||
reduction_rank (`int`, defaults to 1): set to 0 if input items are scalars (input has no depth axis), e.g. an audio signal or a black-and-white image
|
||||
that is stored with tensor shape (H,W) instead of (1,H,W)
|
||||
transpose_weight (bool, defaults to `False`): When this is `True` this is convolution, otherwise this is correlation (which is common for most toolkits)
|
||||
max_temp_mem_size_in_samples (int, defaults to 0): Limits the amount of memory for intermiadate convolution results. A value of 0 means, memory is automatically managed.
|
||||
max_temp_mem_size_in_samples (int, defaults to 0): Limits the amount of memory for intermediate convolution results. A value of 0 means, memory is automatically managed.
|
||||
name (str, defaults to ''): the name of the function instance in the network
|
||||
|
||||
Returns:
|
||||
|
@ -680,7 +680,6 @@ def ConvolutionTranspose(filter_shape, # shape of receptive field, e.g. (
|
|||
Returns:
|
||||
:class:`~cntk.ops.functions.Function` that accepts one argument and applies the convolution operation to it
|
||||
'''
|
||||
|
||||
activation = get_default_override(ConvolutionTranspose, activation=activation)
|
||||
init = get_default_override(ConvolutionTranspose, init=init)
|
||||
pad = get_default_override(ConvolutionTranspose, pad=pad)
|
||||
|
@ -1055,11 +1054,14 @@ def Dropout(dropout_rate=None,
|
|||
A function that accepts one argument and applies the operation to it
|
||||
'''
|
||||
if dropout_rate is None and keep_prob is None:
|
||||
raise ValueError("Dense: either dropout_rate or keep_prob must be specified.")
|
||||
raise ValueError("Dropout: either dropout_rate or keep_prob must be specified.")
|
||||
elif dropout_rate is not None and keep_prob is not None:
|
||||
raise ValueError("Dense: dropout_rate and keep_prob cannot be specified at the same time.")
|
||||
raise ValueError("Dropout: dropout_rate and keep_prob cannot be specified at the same time.")
|
||||
elif keep_prob is not None:
|
||||
if keep_prob < 0.0 or keep_prob >= 1.0:
|
||||
raise ValueError("Dropout: keep_prob must be in the interval [0,1)")
|
||||
dropout_rate = 1-keep_prob
|
||||
|
||||
@BlockFunction('Dropout', name)
|
||||
def dropout_f(x):
|
||||
return dropout(x, dropout_rate=dropout_rate, seed=seed)
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cntk.ops import input, abs, square, sqrt, cos
|
||||
from cntk.layers import For, Dense, SequentialClique, ResNetBlock, Sequential
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("layers_count, dense_units", [(4,5), (6,9), (7, 10)])
|
||||
def test_for_constructor_layer(layers_count, dense_units):
|
||||
x = input(4)
|
||||
|
||||
network = For(range(layers_count), lambda i: Dense(dense_units))
|
||||
|
||||
expected_num_of_parameters = 2 * layers_count
|
||||
assert len(network.parameters) == expected_num_of_parameters
|
||||
|
||||
res = network(x)
|
||||
|
||||
expected_output_shape = (dense_units,)
|
||||
assert res.shape == expected_output_shape
|
||||
|
||||
def test_failing_for_constructor():
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
network = For(range(3), Dense(5))
|
||||
|
||||
class MyFunction:
|
||||
def __call__(self, x):
|
||||
return Dense(x)
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
network = For(range(3), MyFunction())
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
network = For(range(3), MyFunction()(5))
|
||||
|
||||
INPUT_DATA = [[2, 8],[4, 7, 9], [5, 6, 10]]
|
||||
|
||||
@pytest.mark.parametrize("input_data", INPUT_DATA)
|
||||
def test_sequential_clique_with_functions(input_data):
|
||||
x = input(len(input_data))
|
||||
|
||||
seq_clique = SequentialClique([abs, sqrt, square])(x)
|
||||
|
||||
assert seq_clique.shape == x.shape
|
||||
|
||||
np_data = np.asarray(input_data, np.float32)
|
||||
res = seq_clique.eval(np_data)
|
||||
|
||||
expected_res = np.abs(np_data) + np_data
|
||||
expected_res += np.sqrt(expected_res)
|
||||
expected_res = np.square(expected_res)
|
||||
|
||||
expected_res.shape = (1,) + expected_res.shape
|
||||
|
||||
np.testing.assert_array_almost_equal(res, expected_res, decimal=4)
|
||||
|
||||
@pytest.mark.parametrize("input_elements, expected", [(5,360.0), (7,1344.0)])
|
||||
def test_sequential_clique_with_layers(input_elements, expected):
|
||||
x = input(input_elements)
|
||||
np_data = np.arange(input_elements, dtype=np.float32)
|
||||
|
||||
unit_dense = Dense(input_elements, activation=None, init=1)
|
||||
|
||||
seq_clique = SequentialClique([unit_dense, unit_dense, unit_dense])(x)
|
||||
|
||||
assert seq_clique.shape == x.shape
|
||||
|
||||
res = seq_clique.eval(np_data)
|
||||
|
||||
assert res[0].shape == (input_elements,)
|
||||
assert np.unique(res[0])[0] == expected
|
||||
|
||||
@pytest.mark.parametrize("input_data", INPUT_DATA)
|
||||
def test_sequential_constructor(input_data):
|
||||
x = input(len(input_data))
|
||||
np_data = np.asarray(input_data, np.float32)
|
||||
|
||||
seq_layers = Sequential([abs, sqrt, square, cos])(x)
|
||||
|
||||
assert seq_layers.shape == x.shape
|
||||
|
||||
res = seq_layers(np_data)
|
||||
|
||||
expected_res = np.cos(np.square(np.sqrt(np.abs(np_data))))
|
||||
|
||||
np.testing.assert_array_almost_equal(res[0], expected_res, decimal=4)
|
||||
|
||||
@pytest.mark.parametrize("input_data", [[3, 5],[9, 25, 13]])
|
||||
def test_resnet_block(input_data):
|
||||
x = input(len(input_data))
|
||||
|
||||
res_net = ResNetBlock(square)(x)
|
||||
|
||||
np_data = np.asarray(input_data, np.float32)
|
||||
|
||||
actual_res = res_net.eval(np_data)
|
||||
|
||||
expected_res = np.square(np_data) + np_data
|
||||
expected_res.shape = (1,) + expected_res.shape
|
||||
|
||||
np.testing.assert_array_equal(actual_res, expected_res)
|
||||
|
|
@ -5,27 +5,37 @@
|
|||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from cntk import *
|
||||
from cntk.layers import *
|
||||
from cntk.layers.typing import *
|
||||
from cntk import Axis, input, reshape, sigmoid, element_max, Function, Constant, greater, default_options, default_options_for, \
|
||||
get_default_override, default_override_or
|
||||
from cntk.layers import BlockFunction, Convolution, Convolution1D, Convolution2D, Convolution3D, Dense, Embedding, Fold, For, \
|
||||
MaxPooling, MaxUnpooling, LSTM, GRU, RNNUnit, Sequential, Stabilizer, Dropout, Recurrence, \
|
||||
RecurrenceFrom, LayerNormalization, ConvolutionTranspose
|
||||
from cntk.layers.typing import Sequence, Signature, Tensor, SequenceOver
|
||||
|
||||
import pytest
|
||||
|
||||
# Note: We do not test gradients here, assuming that those are tested elsewhere.
|
||||
# Forward outputs are tested to verify that the structure of the layer is as expected.
|
||||
|
||||
def test_layers_name(device_id):
|
||||
def test_layers_name():
|
||||
from cntk import placeholder
|
||||
I = placeholder(name='input')
|
||||
p = Dense(10, name='dense10')(I)
|
||||
|
||||
assert(p.name == 'dense10')
|
||||
assert(I.name == 'input')
|
||||
assert(p.root_function.name == 'dense10')
|
||||
|
||||
q = Convolution((3, 3), 3, name='conv33')(I)
|
||||
assert(q.name == 'conv33')
|
||||
assert(q.root_function.name == 'conv33')
|
||||
|
||||
e = Embedding(0, name='emb')(I)
|
||||
assert(e.name == 'emb')
|
||||
assert(e.root_function.name == 'emb')
|
||||
|
||||
e = Embedding(0, name='')(I)
|
||||
assert(e.name == '')
|
||||
assert(e.root_function.name == '')
|
||||
|
||||
def assert_list_of_arrays_equal(r, exp, err_msg):
|
||||
|
@ -56,7 +66,7 @@ def test_default_options():
|
|||
# @Function, @BlockFunction, types
|
||||
####################################
|
||||
|
||||
def test_Function(device_id):
|
||||
def test_Function():
|
||||
|
||||
####################################################
|
||||
# Test 1: BlockFunction()
|
||||
|
@ -88,7 +98,7 @@ def test_Function(device_id):
|
|||
# . syntax for name lookup
|
||||
####################################
|
||||
|
||||
def test_lookup(device_id):
|
||||
def test_lookup():
|
||||
model = Sequential([ Dense(3, init=1, name='first'), Dense(2, init=2, name='second')])
|
||||
model.update_signature((2,))
|
||||
W1 = model.first.W.value
|
||||
|
@ -144,7 +154,7 @@ def test_recurrence():
|
|||
# recurrence (Fold()) over regular function
|
||||
####################################
|
||||
|
||||
def test_recurrence_fun(device_id):
|
||||
def test_recurrence_fun():
|
||||
from cntk.layers import Recurrence
|
||||
from cntk.ops import plus
|
||||
|
||||
|
@ -172,7 +182,7 @@ def test_recurrence_fun(device_id):
|
|||
# UnfoldFrom()
|
||||
####################################
|
||||
|
||||
def test_unfold(device_id):
|
||||
def test_unfold():
|
||||
from cntk.layers import UnfoldFrom
|
||||
|
||||
@Function
|
||||
|
@ -205,10 +215,56 @@ def test_unfold(device_id):
|
|||
r = FU(x)
|
||||
exp = [[[ 2 ], [ 4 ], [ 8 ], [ 16 ], [ 32 ]], # tests length_increase
|
||||
[[ 2 ], [ 4 ], [ 8 ], [ 16 ], [ 32 ], [ 64 ]]] # tests early cut-off due to until_predicate
|
||||
print(r)
|
||||
print(exp)
|
||||
|
||||
assert_list_of_arrays_equal(r, exp, err_msg='Error in UnfoldFrom(..., until_predicate, length_increase, ...) forward')
|
||||
|
||||
####################################
|
||||
# Test LSTM recurrence
|
||||
####################################
|
||||
|
||||
|
||||
RECURRENT_BLOCK_DATA = [ # block_type, block_outputs_count, block_size, W_mult, H_mult, outputs_count
|
||||
# expected_res
|
||||
(LSTM, 2, 5, 4, 4,
|
||||
[[ 0.21532 , 0.21532 , 0.21532 , 0.21532 , 0.21532 ],
|
||||
[ 0.760161, 0.760161, 0.760161, 0.760161, 0.760161],
|
||||
[ 0.95975 , 0.95975 , 0.95975 , 0.95975 , 0.95975 ],
|
||||
[ 0.993661, 0.993661, 0.993661, 0.993661, 0.993661]]),
|
||||
(GRU, 1, 5, 3, 2,
|
||||
[[ 0.1903 , 0.1903 , 0.1903 , 0.1903 , 0.1903 ],
|
||||
[ 0.262537, 0.262537, 0.262537, 0.262537, 0.262537],
|
||||
[ 0.276712, 0.276712, 0.276712, 0.276712, 0.276712],
|
||||
[ 0.279545, 0.279545, 0.279545, 0.279545, 0.279545]]),
|
||||
(RNNUnit, 1, 5, 1, 1,
|
||||
[[ 0.645656, 0.645656, 0.645656, 0.645656, 0.645656],
|
||||
[ 0.925727, 0.925727, 0.925727, 0.925727, 0.925727],
|
||||
[ 0.986114, 0.986114, 0.986114, 0.986114, 0.986114],
|
||||
[ 0.997249, 0.997249, 0.997249, 0.997249, 0.997249]]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("block_type, block_outputs_count, block_size, W_mult, H_mult, expected_res", RECURRENT_BLOCK_DATA)
|
||||
def test_recurrent_block(block_type, block_outputs_count, block_size, W_mult, H_mult, expected_res):
|
||||
input_shape = 4
|
||||
|
||||
sequenceAxis = Axis('sequenceAxis')
|
||||
|
||||
y = input(input_shape, dynamic_axes=[Axis.default_batch_axis(), sequenceAxis])
|
||||
data = np.reshape(np.arange(0,16, dtype=np.float32), (1,4,4))
|
||||
|
||||
rnn_block = block_type(block_size, init=0.1)
|
||||
|
||||
assert len(rnn_block.outputs) == block_outputs_count
|
||||
rnn_net = Recurrence(rnn_block)(y)
|
||||
|
||||
assert rnn_net.b.shape == (W_mult*block_size,)
|
||||
assert rnn_net.W.shape == (input_shape, W_mult*block_size)
|
||||
assert rnn_net.H.shape == (block_size, H_mult*block_size)
|
||||
|
||||
res = rnn_net.eval(data)
|
||||
expected = np.asarray(expected_res, dtype=np.float32)
|
||||
|
||||
np.testing.assert_array_almost_equal(res[0], expected, decimal=6)
|
||||
|
||||
####################################
|
||||
# Test dense layer for correctness
|
||||
####################################
|
||||
|
@ -224,8 +280,7 @@ def test_layers_dense(device_id):
|
|||
res = p(y).eval({y: dat})
|
||||
|
||||
npout = np.matrix(dat[0]) * p.foo.W.value + p.foo.b.value
|
||||
print(res[0])
|
||||
print(npout)
|
||||
|
||||
np.testing.assert_array_equal(res, npout, err_msg='Error in dense layer')
|
||||
|
||||
####################################################
|
||||
|
@ -238,8 +293,6 @@ def test_layers_dense(device_id):
|
|||
return 1./(1 + np.exp(-x))
|
||||
|
||||
npout = _sigmoid(np.matrix(dat[0]) * p.foo.W.value + p.foo.b.value)
|
||||
print(res[0])
|
||||
print(npout)
|
||||
|
||||
np.testing.assert_array_almost_equal(res, npout, decimal=7, err_msg='Error in dense layer with sigmoid')
|
||||
|
||||
|
@ -255,10 +308,17 @@ def test_layers_dense(device_id):
|
|||
|
||||
np.testing.assert_array_almost_equal(res, npout, decimal=7, err_msg='Error in 2-dense layer')
|
||||
|
||||
####################################################
|
||||
# Test 4: Failing configuration
|
||||
####################################################
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Dense(2, input_rank=1, map_rank=1) # input_rank and map_rank can be specified at the same time
|
||||
|
||||
########################################
|
||||
# Test Embedding layer for correctness
|
||||
########################################
|
||||
def test_layers_embedding(device_id):
|
||||
def test_layers_embedding():
|
||||
embDim = 3
|
||||
y = input(2)
|
||||
|
||||
|
@ -289,6 +349,16 @@ def test_layers_embedding(device_id):
|
|||
npout = np.matrix(dat[0]) * e.E.value
|
||||
np.testing.assert_array_equal(res, npout, err_msg='Error in constant embedding layer')
|
||||
|
||||
# Failing calls
|
||||
with pytest.raises(ValueError):
|
||||
Embedding(shape=None, init=1, weights=[1., 2., 3.])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Embedding(3, weights=[1., 2., 3.])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Embedding(name="embedding")
|
||||
|
||||
########################################
|
||||
# Test Convolutional layer for shape correctness
|
||||
########################################
|
||||
|
@ -313,7 +383,7 @@ def _getConvOutShape(inDim, kernelDim, zeroPad, strides):
|
|||
else:
|
||||
raise ValueError("Stride must be a non-zero positive number")
|
||||
|
||||
def test_layers_convolution_shape(device_id):
|
||||
def test_layers_convolution_shape():
|
||||
# Get the output shape
|
||||
# i: input dimension
|
||||
# k: kernel dimension
|
||||
|
@ -383,7 +453,7 @@ def test_layers_convolution_shape(device_id):
|
|||
expected_shape = (out_num_filters,
|
||||
_getConvOutShape(inH, in_filter_shape[0], zeropad, in_strides),
|
||||
_getConvOutShape(inW, in_filter_shape[1], zeropad, in_strides))
|
||||
print(expected_shape)
|
||||
|
||||
np.testing.assert_array_equal(model_shape, expected_shape, \
|
||||
"Error in convolution with stride = 1 and padding")
|
||||
|
||||
|
@ -409,8 +479,7 @@ def test_layers_convolution_shape(device_id):
|
|||
np.testing.assert_array_equal(model_shape, expected_shape, \
|
||||
"Error in convolution with stride > 1 and padding")
|
||||
|
||||
def test_layers_convolution_value(device_id):
|
||||
|
||||
def test_layers_convolution_value():
|
||||
# Common parameters
|
||||
inC, inH, inW = 1, 3, 3
|
||||
in_filter_shape = (3, 3)
|
||||
|
@ -500,7 +569,7 @@ def test_layers_convolution_value(device_id):
|
|||
##########################################################
|
||||
# Test convolutional 3D layer for correctness (p=False s = 1)
|
||||
##########################################################
|
||||
def test_layers_convolution_3d(device_id):
|
||||
def test_layers_convolution_3d():
|
||||
inC, inH, inW, inD = 1, 3, 3, 3
|
||||
y = input((inC,inH, inW, inD))
|
||||
dat = np.ones([1, inC, inH, inW, inD], dtype = np.float32)
|
||||
|
@ -526,7 +595,7 @@ def test_layers_convolution_3d(device_id):
|
|||
##########################################################
|
||||
# Test convolutional 2D layer for correctness (p=False s = 1)
|
||||
##########################################################
|
||||
def test_layers_convolution_2d(device_id):
|
||||
def test_layers_convolution_2d():
|
||||
inC, inH, inW = 1, 3, 3
|
||||
y = input((inC,inH, inW))
|
||||
|
||||
|
@ -553,7 +622,7 @@ def test_layers_convolution_2d(device_id):
|
|||
# sequential convolution without reduction dimension
|
||||
####################################
|
||||
|
||||
def test_sequential_convolution_without_reduction_dim(device_id):
|
||||
def test_sequential_convolution_without_reduction_dim():
|
||||
c = Convolution(3, init=np.array([4., 2., 1.], dtype=np.float32), sequential=True, pad=False, reduction_rank=0, bias=False)
|
||||
c.update_signature(Sequence[Tensor[()]]) # input is a sequence of scalars
|
||||
data = [np.array([2., 6., 4., 8., 6.])] # like a short audio sequence, in the dynamic dimension
|
||||
|
@ -588,7 +657,7 @@ def test_sequential_convolution_without_reduction_dim(device_id):
|
|||
# 1D convolution without reduction dimension
|
||||
####################################
|
||||
|
||||
def test_1D_convolution_without_reduction_dim(device_id):
|
||||
def test_1D_convolution_without_reduction_dim():
|
||||
c = Convolution1D(3, init=np.array([4, 2, 1]), pad=True, reduction_rank=0, bias=False)
|
||||
c.update_signature(5)
|
||||
data = [np.array([[2, 6, 4, 8, 6]])] # like a audio sequence, in a static dimension
|
||||
|
@ -596,18 +665,22 @@ def test_1D_convolution_without_reduction_dim(device_id):
|
|||
exp = [[10, 24, 40, 38, 44]]
|
||||
np.testing.assert_array_equal(out, exp, err_msg='Error in 1D convolution without reduction dimension')
|
||||
|
||||
# Failing call
|
||||
with pytest.raises(ValueError):
|
||||
Convolution1D((2,3))
|
||||
|
||||
##########################################################
|
||||
# Test Deconvolution layer for correctness
|
||||
##########################################################
|
||||
# TESTTODO: Add the test for deconvolution once current bug with lower/upper pad is fixed
|
||||
def test_layers_deconvolution(device_id):
|
||||
def test_layers_deconvolution():
|
||||
pass
|
||||
|
||||
##########################################################
|
||||
# Test Conv/Pooling/Unpooling/Deconvolution and layer for correctness
|
||||
##########################################################
|
||||
# TESTTODO: Add the test for deconvolution once current bug with lower/upper pad is fixed
|
||||
def test_layers_conv_pool_unpool_deconv(device_id):
|
||||
def test_layers_conv_pool_unpool_deconv():
|
||||
pass
|
||||
# inC, inH, inW = 1,4,4
|
||||
#
|
||||
|
@ -649,7 +722,7 @@ def test_layers_conv_pool_unpool_deconv(device_id):
|
|||
##########################################################
|
||||
# Test for dropout
|
||||
##########################################################
|
||||
def test_layers_dropout(device_id):
|
||||
def test_layers_dropout():
|
||||
dat = np.array([[1., 1., 1., 1.]], dtype=np.float32)
|
||||
y = input(4)
|
||||
p = Dense(1, activation=None, name='foo')(y)
|
||||
|
@ -661,10 +734,21 @@ def test_layers_dropout(device_id):
|
|||
np.testing.assert_array_almost_equal(res, expected_res, decimal=7, \
|
||||
err_msg="Error in dropout computation")
|
||||
|
||||
z = Dropout(keep_prob=0.25, name='bar')(p)
|
||||
res = z(y).eval({y: dat})
|
||||
np.testing.assert_array_almost_equal(res, expected_res, decimal=7, \
|
||||
err_msg="Error in dropout computation with keep_prob")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
z = Dropout(keep_prob=-1.5, name='bar')(p)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
z = Dropout(1.5, name='bar')(p)
|
||||
|
||||
##########################################################
|
||||
# Test for Stabilizer
|
||||
##########################################################
|
||||
def test_layers_stabilizer(device_id):
|
||||
def test_layers_stabilizer():
|
||||
y = input(4)
|
||||
p = Stabilizer()(y)
|
||||
|
||||
|
@ -678,7 +762,7 @@ def test_layers_stabilizer(device_id):
|
|||
##########################################################
|
||||
# Test for LayerNormalization
|
||||
##########################################################
|
||||
def test_layers_layer_normalization(device_id):
|
||||
def test_layers_layer_normalization():
|
||||
y = input(4)
|
||||
p = LayerNormalization(name='foo')(y)
|
||||
|
||||
|
@ -697,7 +781,7 @@ def test_layers_layer_normalization(device_id):
|
|||
# Test for BatchNormalization
|
||||
##########################################################
|
||||
# TESTTODO: Currently the result doesn't match the expected result
|
||||
def test_layers_batch_normalization(device_id):
|
||||
def test_layers_batch_normalization():
|
||||
pass
|
||||
# dat = np.array([[1.0,0.5,1.0,0.5]], dtype=np.float32)
|
||||
# y = input(4)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from cntk import *
|
||||
from cntk.layers import *
|
||||
from cntk.layers.typing import *
|
||||
|
||||
import pytest
|
||||
|
||||
def test_attention_model():
|
||||
attention_dim = 128
|
||||
attention_span = 20
|
||||
attention_axis = -3
|
||||
|
||||
att_model = AttentionModel(attention_dim, attention_span, attention_axis, name='attention_model')
|
||||
|
||||
expected_num_of_inputs = 142
|
||||
|
||||
assert len(att_model.inputs) == expected_num_of_inputs
|
|
@ -396,4 +396,17 @@ def test_constant_data_type_mismatch():
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
b.eval({i:[[np.asarray(np.random.rand(5,5),dtype=np.float32)]]})
|
||||
|
||||
|
||||
def test_update_signature():
|
||||
from cntk.layers.typing import Tensor
|
||||
|
||||
input_dim = 14
|
||||
|
||||
@Function
|
||||
def f(x):
|
||||
return x*x
|
||||
|
||||
f.update_signature(Tensor[input_dim])
|
||||
|
||||
assert f.outputs[0].shape == (input_dim,)
|
||||
assert f.x.shape == (input_dim,)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
Getting started
|
||||
Getting started
|
||||
===============
|
||||
You can optionally try the `tutorials <https://notebooks.azure.com/cntk/libraries/tutorials>`__ with pre-installed CNTK running in Azure Notebook hosted environment (for free) if you have not installed the toolkit in your own machine.
|
||||
|
||||
|
@ -8,12 +8,12 @@ you can start using CNTK from Python right away (don't forget to ``activate`` yo
|
|||
>>> import cntk
|
||||
>>> cntk.__version__
|
||||
'2.0rc1+'
|
||||
|
||||
|
||||
>>> cntk.minus([1, 2, 3], [4, 5, 6]).eval()
|
||||
array([-3., -3., -3.], dtype=float32)
|
||||
|
||||
The above makes use of the CNTK ``minus`` node with two array constants. Every operator has an ``eval()`` method that can be called which runs a forward
|
||||
pass for that node using its inputs, and returns the result of the forward pass. A slightly more interesting example that uses input variables (the
|
||||
The above makes use of the CNTK ``minus`` node with two array constants. Every operator has an ``eval()`` method that can be called which runs a forward
|
||||
pass for that node using its inputs, and returns the result of the forward pass. A slightly more interesting example that uses input variables (the
|
||||
more common case) is as follows:
|
||||
|
||||
>>> import numpy as np
|
||||
|
@ -24,8 +24,8 @@ more common case) is as follows:
|
|||
>>> cntk.squared_error(x, y).eval({x:x0, y:y0})
|
||||
array([ 29.], dtype=float32)
|
||||
|
||||
In the above example we are first setting up two input variables with shape ``(1, 2)``. We then setup a ``squared_error`` node with those two variables as
|
||||
inputs. Within the ``eval()`` method we can setup the input-mapping of the data for those two variables. In this case we pass in two numpy arrays.
|
||||
In the above example we are first setting up two input variables with shape ``(1, 2)``. We then setup a ``squared_error`` node with those two variables as
|
||||
inputs. Within the ``eval()`` method we can setup the input-mapping of the data for those two variables. In this case we pass in two numpy arrays.
|
||||
The squared error is then of course ``(2-4)**2 + (1-6)**2 = 29``.
|
||||
|
||||
Most of the data containers like parameters, constants, values, etc. implement
|
||||
|
@ -48,27 +48,25 @@ where every NumPy arrays has the shape of the static axes of ``var``.
|
|||
Overview and first run
|
||||
----------------------
|
||||
|
||||
CNTK2 is a major overhaul of CNTK in that one now has full control over the data and how it is read in, the training and testing loops, and minibatch
|
||||
construction. The Python bindings provide direct access to the created network graph, and data can be manipulated outside of the readers not only
|
||||
CNTK2 is a major overhaul of CNTK in that one now has full control over the data and how it is read in, the training and testing loops, and minibatch
|
||||
construction. The Python bindings provide direct access to the created network graph, and data can be manipulated outside of the readers not only
|
||||
for more powerful and complex networks, but also for interactive Python sessions while a model is being created and debugged.
|
||||
|
||||
CNTK2 also includes a number of ready-to-extend examples and a layers library. The latter allows one to simply build a powerful deep network by
|
||||
snapping together building blocks such as convolution layers, recurrent neural net layers (LSTMs, etc.), and fully-connected layers. To begin, we will take a
|
||||
CNTK2 also includes a number of ready-to-extend examples and a layers library. The latter allows one to simply build a powerful deep network by
|
||||
snapping together building blocks such as convolution layers, recurrent neural net layers (LSTMs, etc.), and fully-connected layers. To begin, we will take a
|
||||
look at a standard fully connected deep network in our first basic use.
|
||||
|
||||
First basic use
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
The first step in training or running a network in CNTK is to decide which device it should be run on. If you have access to a GPU, training time
|
||||
The first step in training or running a network in CNTK is to decide which device it should be run on. If you have access to a GPU, training time
|
||||
can be vastly improved. To explicitly set the device to GPU, set the target device as follows::
|
||||
|
||||
from cntk.device import set_default_device, gpu
|
||||
set_default_device(gpu(0))
|
||||
|
||||
Now let's setup a network that will learn a classifier based on the example fully connected classifier network
|
||||
(``nn.fully_connected_classifier_net``). This is defined, along with several other simple and more complex DNN building blocks in
|
||||
``Examples/common/nn.py``. Go to the ``[CNTK root]/Examples/common/`` directory and create a ``simplenet.py`` file with the
|
||||
following contents:
|
||||
Now let's setup a network that will learn a classifier with fully connected layers using only the functions :func:`~cntk.layers.layers.Sequential`
|
||||
and :func:`~cntk.layers.layers.Dense` from the Layers Library. Create a ``simplenet.py`` file with the following contents:
|
||||
|
||||
.. literalinclude:: simplenet.py
|
||||
|
||||
|
@ -90,22 +88,22 @@ Running ``python simplenet.py`` (using the correct python environment) will gene
|
|||
error rate on an unseen minibatch: 0.0
|
||||
|
||||
|
||||
The example above sets up a 2-layer fully connected deep neural network with 50 hidden dimensions per layer. We first setup two input variables, one for
|
||||
the input data and one for the labels. We then called the fully connected classifier network model function which simply sets up the required weights,
|
||||
The example above sets up a 2-layer fully connected deep neural network with 50 hidden dimensions per layer. We first setup two input variables, one for
|
||||
the input data and one for the labels. We then called the fully connected classifier network model function which simply sets up the required weights,
|
||||
biases, and activation functions for each layer.
|
||||
|
||||
We set two root nodes in the network: ``ce`` is the cross entropy which defined our model's loss function, and ``pe`` is the classification error. We
|
||||
set up a trainer object with the root nodes of the network and a learner. In this case we pass in the standard SGD learner with default parameters and a
|
||||
We set two root nodes in the network: ``ce`` is the cross entropy which defined our model's loss function, and ``pe`` is the classification error. We
|
||||
set up a trainer object with the root nodes of the network and a learner. In this case we pass in the standard SGD learner with default parameters and a
|
||||
learning rate of 0.02.
|
||||
|
||||
Finally, we manually perform the training loop. We run through the data for the specific number of epochs (``num_minibatches_to_train``), get the ``features``
|
||||
and ``labels`` that will be used during this training step, and call the trainer's ``train_minibatch`` function which maps the input and label variables that
|
||||
we setup previously to the current ``features`` and ``labels`` data (numpy arrays) that we are using in this minibatch. We use the convenience function
|
||||
``print_training_progress`` to display our loss and error every 20 steps and then finally we test our network again using the ``trainer`` object. It's
|
||||
Finally, we manually perform the training loop. We run through the data for the specific number of epochs (``num_minibatches_to_train``), get the ``features``
|
||||
and ``labels`` that will be used during this training step, and call the trainer's ``train_minibatch`` function which maps the input and label variables that
|
||||
we setup previously to the current ``features`` and ``labels`` data (numpy arrays) that we are using in this minibatch. We use the convenience function
|
||||
``print_training_progress`` to display our loss and error every 20 steps and then finally we test our network again using the ``trainer`` object. It's
|
||||
as easy as that!
|
||||
|
||||
Now that we've seen some of the basics of setting up and training a network using the CNTK Python API, let's look at a more interesting deep
|
||||
learning problem in more detail (for the full example above along with the function to generate random data, please see
|
||||
Now that we've seen some of the basics of setting up and training a network using the CNTK Python API, let's look at a more interesting deep
|
||||
learning problem in more detail (for the full example above along with the function to generate random data, please see
|
||||
``Tutorials/NumpyInterop/FeedForwardNet.py``).
|
||||
|
||||
|
||||
|
|
|
@ -176,10 +176,10 @@ classes for our sequences. As before, we define two input variables: one for the
|
|||
that input through an LSTM recurrent neural network layer, and returns a fixed-size output from the LSTM by selecting the last hidden state of the
|
||||
LSTM::
|
||||
|
||||
embedded_inputs = embedding(input, embedding_dim)
|
||||
lstm_outputs = simple_lstm(embedded_inputs, LSTM_dim, cell_dim)[0]
|
||||
thought_vector = sequence.last(lstm_outputs)
|
||||
return linear_layer(thought_vector, num_output_classes)
|
||||
lstm_classifier = Sequential([Embedding(embedding_dim),
|
||||
Recurrence(LSTM(LSTM_dim, cell_dim))[0],
|
||||
sequence.last,
|
||||
Dense(num_output_classes)])
|
||||
|
||||
That is the entire network definition. In the second line above we select the first output from the LSTM. In
|
||||
this implementation of the LSTM this is the actual output while the second output is the state of the LSTM.
|
||||
|
|
|
@ -7,12 +7,7 @@ from cntk.learners import sgd, learning_rate_schedule, UnitType
|
|||
from cntk import input, cross_entropy_with_softmax, \
|
||||
classification_error, sequence
|
||||
from cntk.logging import ProgressPrinter
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", "..", "..", "Examples", "common"))
|
||||
from nn import LSTMP_component_with_self_stabilization as simple_lstm
|
||||
from nn import embedding, linear_layer
|
||||
|
||||
from cntk.layers import Sequential, Embedding, Recurrence, LSTM, Dense
|
||||
|
||||
# Creates the reader
|
||||
def create_reader(path, is_training, input_dim, label_dim):
|
||||
|
@ -24,16 +19,17 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
|
||||
|
||||
# Defines the LSTM model for classifying sequences
|
||||
def LSTM_sequence_classifer_net(input, num_output_classes, embedding_dim,
|
||||
def LSTM_sequence_classifier_net(input, num_output_classes, embedding_dim,
|
||||
LSTM_dim, cell_dim):
|
||||
embedded_inputs = embedding(input, embedding_dim)
|
||||
lstm_outputs = simple_lstm(embedded_inputs, LSTM_dim, cell_dim)[0]
|
||||
thought_vector = sequence.last(lstm_outputs)
|
||||
return linear_layer(thought_vector, num_output_classes)
|
||||
lstm_classifier = Sequential([Embedding(embedding_dim),
|
||||
Recurrence(LSTM(LSTM_dim, cell_dim))[0],
|
||||
sequence.last,
|
||||
Dense(num_output_classes)])
|
||||
return lstm_classifier(input)
|
||||
|
||||
|
||||
# Creates and trains a LSTM sequence classification model
|
||||
def train_sequence_classifier(debug_output=False):
|
||||
def train_sequence_classifier():
|
||||
input_dim = 2000
|
||||
cell_dim = 25
|
||||
hidden_dim = 25
|
||||
|
@ -45,7 +41,7 @@ def train_sequence_classifier(debug_output=False):
|
|||
label = input(num_output_classes)
|
||||
|
||||
# Instantiate the sequence classification model
|
||||
classifier_output = LSTM_sequence_classifer_net(
|
||||
classifier_output = LSTM_sequence_classifier_net(
|
||||
features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
|
||||
|
||||
ce = cross_entropy_with_softmax(classifier_output, label)
|
||||
|
|
Загрузка…
Ссылка в новой задаче