Add an ability to limit the number of sweeps for training
* Add MinibatchSourceConfig struct that makes construction of a reader a little more straight-forward (esp. on the c++ side) * Refactor python implementation of MinibatchSource. * Add support for specifying the maximum number of sweeps over the input dataset when instating a minibatch source.
This commit is contained in:
Родитель
0ee09cf771
Коммит
503983b0d1
|
@ -59,7 +59,7 @@ def create_image_mb_source(map_file, is_training, total_number_of_samples):
|
|||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
randomize = is_training,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Local Response Normalization layer. See Section 3.3 of the paper:
|
||||
|
|
|
@ -21,7 +21,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
return cntk.io.MinibatchSource(cntk.io.CTFDeserializer(path, cntk.io.StreamDefs(
|
||||
features = cntk.io.StreamDef(field='features', shape=input_dim),
|
||||
labels = cntk.io.StreamDef(field='labels', shape=label_dim)
|
||||
)), randomize=is_training, epoch_size = cntk.io.INFINITELY_REPEAT if is_training else cntk.io.FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = cntk.io.INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
|
||||
# Creates and trains a feedforward classification model for MNIST images
|
||||
|
|
|
@ -14,7 +14,7 @@ import cntk.io.transforms as xforms
|
|||
|
||||
from cntk.layers import Convolution2D, MaxPooling, AveragePooling, Dropout, BatchNormalization, Dense, default_options, Placeholder, identity, Sequential, For
|
||||
from cntk.layers.typing import *
|
||||
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
|
||||
from cntk import Trainer
|
||||
from cntk.learners import momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
|
||||
from cntk import cross_entropy_with_softmax, classification_error, relu
|
||||
|
@ -60,7 +60,7 @@ def create_reader(map_file, mean_file, is_training):
|
|||
return MinibatchSource(ImageDeserializer(map_file, StreamDefs(
|
||||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
########################
|
||||
# define the model #
|
||||
|
|
|
@ -50,7 +50,7 @@ def create_image_mb_source(map_file, mean_file, train, total_number_of_samples):
|
|||
features = cntk.io.StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = cntk.io.StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
randomize=train,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Create the network.
|
||||
|
|
|
@ -20,7 +20,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
return cntk.io.MinibatchSource(cntk.io.CTFDeserializer(path, cntk.io.StreamDefs(
|
||||
features = cntk.io.StreamDef(field='features', shape=input_dim),
|
||||
labels = cntk.io.StreamDef(field='labels', shape=label_dim)
|
||||
)), randomize=is_training, epoch_size = cntk.io.INFINITELY_REPEAT if is_training else cntk.io.FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = cntk.io.INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
|
||||
# Creates and trains a feedforward classification model for MNIST images
|
||||
|
|
|
@ -64,7 +64,7 @@ def create_image_mb_source(map_file, mean_file, is_training, total_number_of_sam
|
|||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))),
|
||||
randomize = is_training,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Create the network.
|
||||
|
|
|
@ -64,7 +64,7 @@ def create_image_mb_source(map_file, mean_file, is_training, total_number_of_sam
|
|||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))),
|
||||
randomize = is_training,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Create the network.
|
||||
|
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
import sys
|
||||
import os
|
||||
from cntk.train import Trainer, minibatch_size_schedule
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
|
||||
from cntk.device import cpu, try_set_default_device
|
||||
from cntk.learners import adadelta, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input, relu, element_times, constant
|
||||
|
@ -33,7 +33,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='features', shape=input_dim, is_sparse=False),
|
||||
labels = StreamDef(field='labels', shape=label_dim, is_sparse=False)
|
||||
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
|
||||
# Creates and trains a feedforward classification model for MNIST images
|
||||
|
|
|
@ -60,7 +60,7 @@ def create_image_mb_source(map_file, is_training, total_number_of_samples):
|
|||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
randomize = is_training,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Create the network.
|
||||
|
|
|
@ -59,7 +59,7 @@ def create_image_mb_source(map_file, is_training, total_number_of_samples):
|
|||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
randomize = is_training,
|
||||
epoch_size=total_number_of_samples,
|
||||
max_samples=total_number_of_samples,
|
||||
multithreaded_deserializer = True)
|
||||
|
||||
# Create the network.
|
||||
|
|
|
@ -20,7 +20,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
return cntk.io.MinibatchSource(cntk.io.CTFDeserializer(path, cntk.io.StreamDefs(
|
||||
features = cntk.io.StreamDef(field='features', shape=input_dim),
|
||||
labels = cntk.io.StreamDef(field='labels', shape=label_dim)
|
||||
)), randomize=is_training, epoch_size = cntk.io.INFINITELY_REPEAT if is_training else cntk.io.FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = cntk.io.INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
|
||||
# Trains and tests a simple auto encoder for MNIST images using deconvolution
|
||||
|
|
|
@ -9,7 +9,7 @@ import os
|
|||
import numpy as np
|
||||
from cntk import load_model
|
||||
from cntk.ops import combine
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
|
||||
from PIL import Image
|
||||
from cntk import graph
|
||||
|
||||
|
@ -81,7 +81,7 @@ if __name__ == '__main__':
|
|||
minibatch_source = MinibatchSource(CTFDeserializer(data_file, StreamDefs(
|
||||
features = StreamDef(field='features', shape=(28*28)),
|
||||
labels = StreamDef(field='labels', shape=10)
|
||||
)), randomize=False, epoch_size = FULL_DATA_SWEEP)
|
||||
)), randomize=False, max_sweeps = 1)
|
||||
|
||||
# use this to print all node names in the model
|
||||
# print_all_node_names(model_file, use_brain_script_model)
|
||||
|
|
|
@ -10,7 +10,7 @@ import argparse
|
|||
import math
|
||||
from cntk.layers import * # Layers library
|
||||
from cntk.layers.typing import *
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
|
||||
from cntk import Trainer, Value
|
||||
from cntk.learners import fsadagrad, learning_rate_schedule, momentum_as_time_constant_schedule, UnitType
|
||||
from cntk import splice, relu
|
||||
|
@ -42,7 +42,7 @@ def create_reader(path, is_training):
|
|||
query = StreamDef(field='S0', shape=vocab_size, is_sparse=True),
|
||||
intent_labels = StreamDef(field='S1', shape=num_intents, is_sparse=True), # (used for intent classification variant)
|
||||
slot_labels = StreamDef(field='S2', shape=num_labels, is_sparse=True)
|
||||
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
########################
|
||||
# define the model #
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import sys
|
||||
import os
|
||||
from cntk import Trainer, Axis #, text_format_minibatch_source, StreamConfiguration
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
|
||||
from cntk.device import cpu, try_set_default_device
|
||||
from cntk.learners import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input, sequence
|
||||
|
@ -23,7 +23,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='x', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='y', shape=label_dim, is_sparse=False)
|
||||
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
)), randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
# Defines the LSTM model for classifying sequences
|
||||
def LSTM_sequence_classifer_net(feature, num_output_classes, embedding_dim, LSTM_dim, cell_dim):
|
||||
|
|
|
@ -8,7 +8,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import os
|
||||
from cntk import Trainer, Axis
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
|
||||
from cntk.learners import momentum_sgd, fsadagrad, momentum_as_time_constant_schedule, learning_rate_schedule, UnitType
|
||||
from cntk import input, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, \
|
||||
element_select, alias, hardmax, placeholder, combine, parameter, times, plus
|
||||
|
@ -64,7 +64,7 @@ def create_reader(path, is_training):
|
|||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
|
||||
)), randomize = is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
)), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
########################
|
||||
# define the model #
|
||||
|
|
|
@ -36,7 +36,7 @@ def create_reader(path, randomize, input_vocab_dim, label_vocab_dim, size=INFINI
|
|||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
|
||||
)), randomize=randomize, epoch_size = size)
|
||||
)), randomize=randomize, max_samples = size)
|
||||
|
||||
def create_trainer(network, epoch_size, num_quantization_bits, block_size, warm_up):
|
||||
# Instantiate the trainer object to drive the model training
|
||||
|
|
|
@ -40,7 +40,7 @@ def create_mb_source(features_file, labels_file, label_mapping_filem, total_numb
|
|||
awesome_labels = StreamDef(shape=num_classes, mlf=labels_file)))
|
||||
|
||||
# Enabling BPTT with truncated_length > 0
|
||||
return MinibatchSource([fd,ld], truncation_length=250, epoch_size=total_number_of_samples)
|
||||
return MinibatchSource([fd,ld], truncation_length=250, max_samples=total_number_of_samples)
|
||||
|
||||
def create_recurrent_network():
|
||||
# Input variables denoting the features and label data
|
||||
|
|
|
@ -1455,7 +1455,7 @@ namespace CNTK
|
|||
///
|
||||
/// A type denoting a dictionary (keyed by Unicode strings) of serializable values (dynamically typed).
|
||||
///
|
||||
class Dictionary final
|
||||
class Dictionary
|
||||
{
|
||||
friend inline void AddConfigString(std::wstringstream& s, const DictionaryValue& value, size_t numIndentationSpaces);
|
||||
friend class CompositeMinibatchSource;
|
||||
|
@ -1507,11 +1507,12 @@ namespace CNTK
|
|||
CNTK_API bool operator==(const Dictionary& other) const;
|
||||
CNTK_API bool operator!=(const Dictionary& other) const;
|
||||
|
||||
typedef std::unordered_map<std::wstring, DictionaryValue>::iterator DictionaryIterator;
|
||||
typedef std::unordered_map<std::wstring, DictionaryValue>::const_iterator ConstDictionaryIterator;
|
||||
|
||||
ConstDictionaryIterator begin() const { return m_dictionaryData->begin(); }
|
||||
DictionaryIterator begin() const { return m_dictionaryData->begin(); }
|
||||
ConstDictionaryIterator cbegin() const { return m_dictionaryData->cbegin(); }
|
||||
ConstDictionaryIterator end() const { return m_dictionaryData->end(); }
|
||||
DictionaryIterator end() const { return m_dictionaryData->end(); }
|
||||
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }
|
||||
|
||||
size_t Size() { return m_dictionaryData->size(); }
|
||||
|
@ -4665,11 +4666,8 @@ namespace CNTK
|
|||
class MinibatchSource : public std::enable_shared_from_this<MinibatchSource>
|
||||
{
|
||||
public:
|
||||
static const size_t InfinitelyRepeat = SIZE_MAX;
|
||||
static const size_t FullDataSweep = SIZE_MAX - 2; // An arbitrary sentinel value
|
||||
static const size_t InfiniteSamples = SIZE_MAX;
|
||||
static const size_t DefaultRandomizationWindow = SIZE_MAX - 2;
|
||||
|
||||
CNTK_API static const size_t InfinitelyRepeat;
|
||||
CNTK_API static const size_t FullDataSweep;
|
||||
CNTK_API static const size_t DefaultRandomizationWindowInChunks;
|
||||
|
||||
public:
|
||||
|
@ -4746,10 +4744,93 @@ namespace CNTK
|
|||
MinibatchSource() {}
|
||||
};
|
||||
|
||||
typedef Dictionary Deserializer;
|
||||
|
||||
///
|
||||
/// A configuration required to instantiate the CNTK built-in composite minibatch source.
|
||||
///
|
||||
struct MinibatchSourceConfig
|
||||
{
|
||||
// TODO: This is general enough and be hoisted out once there are specific use-cases outside of
|
||||
// configuring a MinibatchSource.
|
||||
enum TraceLevel : unsigned int
|
||||
{
|
||||
Error = 0,
|
||||
Warning = 1,
|
||||
Info = 2
|
||||
};
|
||||
|
||||
///
|
||||
/// Creates a new minibatch source configuration, with enabled randomization and
|
||||
/// the randomization window set to DefaultRandomizationWindowInChunks when 'randomize' is
|
||||
/// 'true' (default).
|
||||
///
|
||||
CNTK_API MinibatchSourceConfig(const std::vector<Deserializer>& deserializers, bool randomize = true);
|
||||
|
||||
///
|
||||
/// The maximum number of input samples (not 'label samples') the reader can produce
|
||||
/// (the default value is InfinitelyRepeat). After this number has been reached, the reader
|
||||
/// returns empty minibatches on subsequent calls to GetNextMinibatch(). 'maxSweeps' and 'maxSamples'
|
||||
/// are mutually exclusive, an exception will be raised if both have non-default values.
|
||||
///
|
||||
size_t maxSamples { MinibatchSource::InfinitelyRepeat };
|
||||
|
||||
///
|
||||
/// The maximum allowed number of sweeps over the input dataset. After this number has been reached,
|
||||
/// the reader returns empty minibatches on subsequent calls to GetNextMinibatch().
|
||||
/// 'maxSweeps' and 'maxSamples' are mutually exclusive, an exception will be raised if both have
|
||||
/// non-default values.
|
||||
///
|
||||
size_t maxSweeps { MinibatchSource::InfinitelyRepeat };
|
||||
|
||||
///
|
||||
/// Size of the randomization window in chunks, non-zero value enables randomization.
|
||||
/// 'randomizationWindowInChunks' and 'randomizationWindowInSamples' are mutually exclusive,
|
||||
/// an exception will be raised if both have non-zero values.
|
||||
///
|
||||
size_t randomizationWindowInChunks { MinibatchSource::DefaultRandomizationWindowInChunks };
|
||||
|
||||
///
|
||||
/// Size of the randomization window in samples, non-zero value enables randomization.
|
||||
/// 'randomizationWindowInChunks' and 'randomizationWindowInSamples' are mutually exclusive,
|
||||
/// an exception will be raised if both have non-zero values.
|
||||
///
|
||||
size_t randomizationWindowInSamples { 0 };
|
||||
|
||||
///
|
||||
/// Output verbosity level.
|
||||
///
|
||||
TraceLevel traceLevel { TraceLevel::Warning };
|
||||
|
||||
///
|
||||
/// Truncation length in samples, non-zero value enables the truncation (only applicable for BPTT,
|
||||
/// cannot be used in frame mode, an exception will be raised if frame mode is enabled and the
|
||||
/// truncation length is non-zero).
|
||||
///
|
||||
size_t truncationLength { 0 };
|
||||
|
||||
///
|
||||
/// Switches the frame mode on and off. If the frame mode is enabled the input data will be processed
|
||||
/// as individual frames ignoring all sequence information (this option cannot be used for BPTT,
|
||||
/// an exception will be raised if frame mode is enabled and the truncation length is non-zero).
|
||||
///
|
||||
bool isFrameModeEnabled { false };
|
||||
|
||||
///
|
||||
/// Specifies if the deserialization should be done on a single or multiple threads.
|
||||
///
|
||||
bool isMultithreaded { false };
|
||||
|
||||
///
|
||||
/// Deserializers to be used in the composite reader.
|
||||
///
|
||||
std::vector<Deserializer> deserializers;
|
||||
};
|
||||
|
||||
///
|
||||
/// Instantiate the CNTK built-in composite minibatch source.
|
||||
///
|
||||
CNTK_API MinibatchSourcePtr CreateCompositeMinibatchSource(const Dictionary& configuration);
|
||||
CNTK_API MinibatchSourcePtr CreateCompositeMinibatchSource(const MinibatchSourceConfig& configuration);
|
||||
|
||||
struct StreamConfiguration
|
||||
{
|
||||
|
@ -4777,51 +4858,6 @@ namespace CNTK
|
|||
bool m_broadcast;
|
||||
};
|
||||
|
||||
///
|
||||
/// Instantiate the CNTK built-in text format minibatch source
|
||||
///
|
||||
inline MinibatchSourcePtr TextFormatMinibatchSource(const std::wstring& dataFilePath, const std::vector<StreamConfiguration>& streamConfigs,
|
||||
size_t epochSize = MinibatchSource::InfinitelyRepeat,
|
||||
bool randomize = true,
|
||||
size_t randomizationWindow = MinibatchSource::DefaultRandomizationWindowInChunks,
|
||||
bool sampleBasedRandomizationWindow = false)
|
||||
{
|
||||
::CNTK::Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = epochSize;
|
||||
|
||||
if (randomize)
|
||||
{
|
||||
minibatchSourceConfiguration[L"randomize"] = true;
|
||||
minibatchSourceConfiguration[L"randomizationWindow"] = randomizationWindow;
|
||||
minibatchSourceConfiguration[L"sampleBasedRandomizationWindow"] = sampleBasedRandomizationWindow;
|
||||
}
|
||||
|
||||
::CNTK::Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"file"] = dataFilePath;
|
||||
|
||||
::CNTK::Dictionary inputStreamsConfig;
|
||||
for (auto streamConfig : streamConfigs)
|
||||
{
|
||||
std::wstring streamName = streamConfig.m_streamName;
|
||||
size_t streamDim = streamConfig.m_dim;
|
||||
bool isSparse = streamConfig.m_isSparse;
|
||||
std::wstring streamAlias = streamConfig.m_streamAlias;
|
||||
|
||||
::CNTK::Dictionary inputStreamConfig;
|
||||
inputStreamConfig[L"dim"] = streamDim;
|
||||
inputStreamConfig[L"format"] = isSparse ? L"sparse" : L"dense";
|
||||
if (!streamAlias.empty())
|
||||
inputStreamConfig[L"alias"] = streamAlias;
|
||||
|
||||
inputStreamsConfig[streamName] = inputStreamConfig;
|
||||
}
|
||||
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<::CNTK::DictionaryValue>({ deserializerConfiguration });
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
}
|
||||
|
||||
typedef Dictionary ImageTransform;
|
||||
|
||||
///
|
||||
|
@ -4849,9 +4885,6 @@ namespace CNTK
|
|||
CNTK_API ImageTransform ReaderColor(float brightnessRadius = 0.0f,
|
||||
float contrastRadius = 0.0f, float saturationRadius = 0.0f);
|
||||
|
||||
|
||||
typedef Dictionary Deserializer;
|
||||
|
||||
///
|
||||
/// Create an ImageDeserializer with the specified options
|
||||
///
|
||||
|
@ -4872,6 +4905,29 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API Deserializer HTKMLFDeserializer(const std::wstring& streamName, const std::wstring& labelMappingFile, size_t dimension, const std::vector<std::wstring>& mlfFiles);
|
||||
|
||||
///
|
||||
/// Instantiate the CNTK built-in text format minibatch source
|
||||
///
|
||||
inline MinibatchSourcePtr TextFormatMinibatchSource(const std::wstring& dataFilePath, const std::vector<StreamConfiguration>& streamConfigs,
|
||||
size_t epochSize = MinibatchSource::InfinitelyRepeat,
|
||||
bool randomize = true,
|
||||
size_t randomizationWindow = MinibatchSource::DefaultRandomizationWindowInChunks,
|
||||
bool sampleBasedRandomizationWindow = false)
|
||||
{
|
||||
MinibatchSourceConfig config({ CTFDeserializer(dataFilePath, streamConfigs) }, randomize);
|
||||
config.maxSamples = epochSize;
|
||||
|
||||
if (randomize)
|
||||
{
|
||||
if (sampleBasedRandomizationWindow)
|
||||
config.randomizationWindowInSamples = randomizationWindow;
|
||||
else
|
||||
config.randomizationWindowInChunks = randomizationWindow;
|
||||
}
|
||||
|
||||
return CreateCompositeMinibatchSource(config);
|
||||
}
|
||||
|
||||
///
|
||||
/// Compute the per dimension means and variances for each of the specified streams using data from the specified minibatchSource.
|
||||
///
|
||||
|
|
|
@ -212,6 +212,8 @@ namespace CNTK
|
|||
class Accumulator;
|
||||
typedef std::shared_ptr<Accumulator> AccumulatorPtr;
|
||||
|
||||
struct MinibatchSourceConfig;
|
||||
|
||||
namespace Internal
|
||||
{
|
||||
CNTK_API FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name = L"");
|
||||
|
@ -290,6 +292,9 @@ namespace CNTK
|
|||
|
||||
CNTK_API size_t DefaultPackThresholdSizeInBytes();
|
||||
|
||||
// This is an internal API, needed for testing.
|
||||
CNTK_API Dictionary ToDictionary(const MinibatchSourceConfig& dict);
|
||||
|
||||
class VariableResolver;
|
||||
|
||||
///
|
||||
|
|
|
@ -21,6 +21,9 @@ using namespace Microsoft::MSR::CNTK;
|
|||
namespace CNTK
|
||||
{
|
||||
const size_t MinibatchSource::DefaultRandomizationWindowInChunks = g_4GB / g_32MB;
|
||||
const size_t MinibatchSource::InfinitelyRepeat = g_infinity;
|
||||
const size_t MinibatchSource::FullDataSweep = g_dataSweep;
|
||||
|
||||
|
||||
const std::unordered_map<StreamInformation, MinibatchData>& MinibatchSource::GetNextMinibatch(size_t minibatchSizeInSamples, const DeviceDescriptor& device /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
|
@ -32,6 +35,16 @@ namespace CNTK
|
|||
return GetNextMinibatch(minibatchSizeInSequences, minibatchSizeInSamples, 1, 0, device);
|
||||
}
|
||||
|
||||
MinibatchSourceConfig::MinibatchSourceConfig(const std::vector<Deserializer>& deserializers, bool randomize/* = true*/)
|
||||
: deserializers(deserializers)
|
||||
{
|
||||
if (!randomize)
|
||||
{
|
||||
randomizationWindowInChunks = 0;
|
||||
randomizationWindowInSamples = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const StreamInformation& MinibatchSource::StreamInfo(const std::wstring& streamName)
|
||||
{
|
||||
std::unordered_set<const StreamInformation*> matchingStreamInfos;
|
||||
|
@ -71,66 +84,26 @@ namespace CNTK
|
|||
return *(*(matchingStreamInfos.begin()));
|
||||
}
|
||||
|
||||
MinibatchSourcePtr CreateCompositeMinibatchSource(const Dictionary& configuration)
|
||||
MinibatchSourcePtr CreateCompositeMinibatchSource(const MinibatchSourceConfig& configuration)
|
||||
{
|
||||
return MinibatchSourcePtr(new CompositeMinibatchSource(configuration));
|
||||
}
|
||||
|
||||
/*static*/ const std::wstring CompositeMinibatchSource::PositionAttributeName = L"minibatchSourcePosition";
|
||||
|
||||
CompositeMinibatchSource::CompositeMinibatchSource(const Dictionary& configuration)
|
||||
CompositeMinibatchSource::CompositeMinibatchSource(const MinibatchSourceConfig& configuration)
|
||||
: m_epochEndReached(false),
|
||||
m_prevMinibatchSize(0),
|
||||
m_maxNumSamplesToRead(MinibatchSource::InfinitelyRepeat),
|
||||
m_randomizedWindow(MinibatchSource::DefaultRandomizationWindow),
|
||||
m_maxNumSamplesToRead(configuration.maxSamples),
|
||||
m_maxNumSweepsToRead(configuration.maxSweeps),
|
||||
m_truncationLength(0),
|
||||
m_numWorkers(1),
|
||||
m_workerRank(0),
|
||||
m_restorePosition(0)
|
||||
{
|
||||
// The CNTK reader implementation requires for each deserializer both the module and deserializer type be specified
|
||||
// This is redundant and the V2 API users will just specify type from which the module is automatically inferred
|
||||
// TODO: This should be done in the same manner for CNTK exe as well.
|
||||
Dictionary augmentedConfiguration = configuration;
|
||||
auto& deserializerConfigurations = augmentedConfiguration[L"deserializers"].Value<std::vector<DictionaryValue>>();
|
||||
for (auto& deserializerConfig : deserializerConfigurations)
|
||||
{
|
||||
static const std::unordered_map<std::wstring, std::wstring> deserializerTypeNameToModuleNameMap = {
|
||||
{ L"CNTKTextFormatDeserializer", L"CNTKTextFormatReader" },
|
||||
{ L"ImageDeserializer", L"ImageReader" },
|
||||
{ L"HTKFeatureDeserializer", L"HTKDeserializers" },
|
||||
{ L"HTKMLFDeserializer", L"HTKDeserializers" },
|
||||
};
|
||||
m_truncationLength = configuration.truncationLength;
|
||||
|
||||
auto& deserializerConfigDict = deserializerConfig.Value<Dictionary>();
|
||||
auto deserializerTypeName = deserializerConfigDict[L"type"].Value<std::wstring>();
|
||||
if (deserializerTypeName == L"ImageDeserializer")
|
||||
{
|
||||
// Add a transpose transform since the image data in read in HWC (CWH in column major format) form while
|
||||
// the CNTK convolution engive supports WHC (in column-major format)
|
||||
auto& inputStreamsConfig = deserializerConfigDict[L"input"].Value<Dictionary>();
|
||||
auto& streamsMap = *(inputStreamsConfig.m_dictionaryData);
|
||||
for (auto& inputStreamEntry : streamsMap)
|
||||
{
|
||||
auto& inputStreamConfig = inputStreamEntry.second.Value<Dictionary>();
|
||||
if (inputStreamConfig.Contains(L"transforms"))
|
||||
{
|
||||
auto& transforms = inputStreamConfig[L"transforms"].Value<std::vector<DictionaryValue>>();
|
||||
|
||||
// Add the transpose transform
|
||||
Dictionary transposeTransform;
|
||||
transposeTransform[L"type"] = L"Transpose";
|
||||
transforms.push_back(transposeTransform);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (deserializerTypeNameToModuleNameMap.find(deserializerTypeName) == deserializerTypeNameToModuleNameMap.end())
|
||||
InvalidArgument("Unknown deserializer type '%S' specified for CNTK built-in composite MinibatchSource construction.", deserializerTypeName.c_str());
|
||||
|
||||
deserializerConfigDict[L"module"] = deserializerTypeNameToModuleNameMap.at(deserializerTypeName);
|
||||
}
|
||||
auto augmentedConfiguration = Internal::ToDictionary(configuration);
|
||||
|
||||
ConfigParameters config;
|
||||
std::wstringstream s;
|
||||
|
@ -139,26 +112,6 @@ namespace CNTK
|
|||
|
||||
config.Parse(msra::strfun::utf8(s.str()));
|
||||
|
||||
const wchar_t* epochSizeConfigurationKey = L"epochSize";
|
||||
if (augmentedConfiguration.Contains(epochSizeConfigurationKey))
|
||||
m_maxNumSamplesToRead = augmentedConfiguration[epochSizeConfigurationKey].Value<size_t>();
|
||||
|
||||
const wchar_t* randomizedWindowConfigurationKey = L"randomizationWindow";
|
||||
if (augmentedConfiguration.Contains(randomizedWindowConfigurationKey))
|
||||
m_randomizedWindow = augmentedConfiguration[randomizedWindowConfigurationKey].Value<size_t>();
|
||||
|
||||
if (m_randomizedWindow == MinibatchSource::DefaultRandomizationWindow)
|
||||
m_randomizedWindow = randomizeAuto;
|
||||
|
||||
const wchar_t* truncatedConfigurationKey = L"truncated";
|
||||
const wchar_t* truncationLengthConfigurationKey = L"truncationLength";
|
||||
if (augmentedConfiguration.Contains(truncatedConfigurationKey) &&
|
||||
augmentedConfiguration[truncatedConfigurationKey].Value<bool>() &&
|
||||
augmentedConfiguration.Contains(truncationLengthConfigurationKey))
|
||||
{
|
||||
m_truncationLength = augmentedConfiguration[truncationLengthConfigurationKey].Value<size_t>();
|
||||
}
|
||||
|
||||
typedef Reader*(*CreateCompositeDataReaderProc)(const ConfigParameters* parameters);
|
||||
CreateCompositeDataReaderProc createReaderProc = (CreateCompositeDataReaderProc)Plugin().Load(L"CompositeDataReader", "CreateCompositeDataReader");
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Reader> compositeDataReader(createReaderProc(&config));
|
||||
|
@ -169,19 +122,6 @@ namespace CNTK
|
|||
|
||||
m_shim = std::shared_ptr<ReaderShim<float>>(new ReaderShim<float>(compositeDataReader), [](ReaderShim<float>* x) { x->Destroy(); });
|
||||
m_shim->Init(config);
|
||||
|
||||
const wchar_t* numWorkersConfigurationKey = L"numWorkers";
|
||||
if (configuration.Contains(numWorkersConfigurationKey))
|
||||
{
|
||||
m_numWorkers = configuration[numWorkersConfigurationKey].Value<size_t>();
|
||||
|
||||
const wchar_t* workerRankConfigurationKey = L"workerRank";
|
||||
if (configuration.Contains(workerRankConfigurationKey))
|
||||
m_workerRank = configuration[workerRankConfigurationKey].Value<size_t>();
|
||||
|
||||
if (m_workerRank > m_numWorkers - 1)
|
||||
LogicError("CompositeMinibatchSource: Invalid worker rank %lu (numWorkers %lu)", m_workerRank, m_numWorkers);
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ const std::unordered_map<StreamInformation, MinibatchData>&
|
||||
|
@ -226,6 +166,8 @@ namespace CNTK
|
|||
epochConfig.m_totalEpochSizeInSamples = m_maxNumSamplesToRead;
|
||||
}
|
||||
|
||||
epochConfig.m_totalEpochSizeInSweeps = m_maxNumSweepsToRead;
|
||||
|
||||
epochConfig.m_epochIndex = 0;
|
||||
|
||||
m_matrices.clear();
|
||||
|
@ -418,7 +360,10 @@ namespace CNTK
|
|||
{
|
||||
const auto& key = s.m_streamName;
|
||||
Dictionary stream;
|
||||
stream.Add(L"alias", s.m_streamAlias, L"dim", s.m_dim, L"format", s.m_isSparse ? L"sparse" : L"dense");
|
||||
stream[L"dim"] = s.m_dim;
|
||||
stream[L"format"] = s.m_isSparse ? L"sparse" : L"dense";
|
||||
if (!s.m_streamAlias.empty())
|
||||
stream[L"alias"] = s.m_streamAlias;
|
||||
input[key] = stream;
|
||||
}
|
||||
ctf.Add(L"type", L"CNTKTextFormatDeserializer", L"file", fileName, L"input", input);
|
||||
|
@ -459,4 +404,101 @@ namespace CNTK
|
|||
htk.Add(L"type", L"HTKMLFDeserializer", L"input", stream);
|
||||
return htk;
|
||||
}
|
||||
|
||||
namespace Internal
|
||||
{
|
||||
|
||||
void Validate(const MinibatchSourceConfig& configuration)
|
||||
{
|
||||
if (configuration.maxSamples != MinibatchSource::InfinitelyRepeat && configuration.maxSweeps != MinibatchSource::InfinitelyRepeat)
|
||||
LogicError("MinibatchSourceConfig: max samples and max sweeps are mutually exclusive options"
|
||||
" and cannot have non-default values at the same time.");
|
||||
|
||||
if (configuration.randomizationWindowInChunks != 0 && configuration.randomizationWindowInSamples != 0)
|
||||
LogicError("MinibatchSourceConfig: randomization window in chunks and randomization window in samples"
|
||||
" are mutually exclusive options and cannot have non-zero values at the same time.");
|
||||
|
||||
if (configuration.isFrameModeEnabled && configuration.truncationLength != 0)
|
||||
LogicError("MinibatchSourceConfig: truncation and frame mode are mutually exclusive options.");
|
||||
}
|
||||
|
||||
Dictionary ToDictionary(const ::CNTK::MinibatchSourceConfig& configuration)
|
||||
{
|
||||
Validate(configuration);
|
||||
|
||||
Dictionary augmentedConfiguration;
|
||||
|
||||
if (configuration.randomizationWindowInSamples != 0)
|
||||
{
|
||||
augmentedConfiguration[L"randomize"] = true;
|
||||
augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInSamples;
|
||||
augmentedConfiguration[L"sampleBasedRandomizationWindow"] = true;
|
||||
}
|
||||
else if (configuration.randomizationWindowInChunks != 0)
|
||||
{
|
||||
augmentedConfiguration[L"randomize"] = true;
|
||||
augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInChunks;
|
||||
augmentedConfiguration[L"sampleBasedRandomizationWindow"] = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
augmentedConfiguration[L"randomize"] = false;
|
||||
}
|
||||
|
||||
if (configuration.truncationLength != 0)
|
||||
{
|
||||
augmentedConfiguration[L"truncated"] = true;
|
||||
augmentedConfiguration[L"truncationLength"] = configuration.truncationLength;
|
||||
}
|
||||
|
||||
augmentedConfiguration[L"frameMode"] = configuration.isFrameModeEnabled;
|
||||
augmentedConfiguration[L"multiThreadedDeserialization"] = configuration.isMultithreaded;
|
||||
augmentedConfiguration[L"traceLevel"] = static_cast<size_t>(configuration.traceLevel);
|
||||
|
||||
// The CNTK reader implementation requires for each deserializer both the module and deserializer type be specified
|
||||
// This is redundant and the V2 API users will just specify type from which the module is automatically inferred
|
||||
// TODO: This should be done in the same manner for CNTK exe as well.
|
||||
vector<DictionaryValue> deserializers;
|
||||
for (auto deserializerConfig : configuration.deserializers)
|
||||
{
|
||||
static const std::unordered_map<std::wstring, std::wstring> deserializerTypeNameToModuleNameMap = {
|
||||
{ L"CNTKTextFormatDeserializer", L"CNTKTextFormatReader" },
|
||||
{ L"ImageDeserializer", L"ImageReader" },
|
||||
{ L"HTKFeatureDeserializer", L"HTKDeserializers" },
|
||||
{ L"HTKMLFDeserializer", L"HTKDeserializers" },
|
||||
};
|
||||
|
||||
auto deserializerTypeName = deserializerConfig[L"type"].Value<std::wstring>();
|
||||
if (deserializerTypeName == L"ImageDeserializer")
|
||||
{
|
||||
// Add a transpose transform since the image data in read in HWC (CWH in column major format) form while
|
||||
// the CNTK convolution engive supports WHC (in column-major format)
|
||||
auto& inputStreamsConfig = deserializerConfig[L"input"].Value<Dictionary>();
|
||||
for (auto& inputStreamEntry : inputStreamsConfig)
|
||||
{
|
||||
auto& inputStreamConfig = inputStreamEntry.second.Value<Dictionary>();
|
||||
if (inputStreamConfig.Contains(L"transforms"))
|
||||
{
|
||||
auto& transforms = inputStreamConfig[L"transforms"].Value<std::vector<DictionaryValue>>();
|
||||
|
||||
// Add the transpose transform
|
||||
Dictionary transposeTransform;
|
||||
transposeTransform[L"type"] = L"Transpose";
|
||||
transforms.push_back(DictionaryValue(transposeTransform));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (deserializerTypeNameToModuleNameMap.find(deserializerTypeName) == deserializerTypeNameToModuleNameMap.end())
|
||||
InvalidArgument("Unknown deserializer type '%S' specified for CNTK built-in composite MinibatchSource construction.", deserializerTypeName.c_str());
|
||||
|
||||
deserializerConfig[L"module"] = deserializerTypeNameToModuleNameMap.at(deserializerTypeName);
|
||||
deserializers.push_back(deserializerConfig);
|
||||
}
|
||||
|
||||
augmentedConfiguration[L"deserializers"] = deserializers;
|
||||
|
||||
return augmentedConfiguration;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace CNTK
|
|||
static const std::wstring DistributedAfterSampleCountAttributeName;
|
||||
|
||||
public:
|
||||
CompositeMinibatchSource(const Dictionary& configuration);
|
||||
CompositeMinibatchSource(const MinibatchSourceConfig& configuration);
|
||||
|
||||
virtual const std::unordered_set<StreamInformation>& StreamInfos() override { return m_streamInfos; }
|
||||
|
||||
|
@ -51,7 +51,7 @@ namespace CNTK
|
|||
size_t m_workerRank;
|
||||
size_t m_prevMinibatchSize;
|
||||
size_t m_maxNumSamplesToRead;
|
||||
size_t m_randomizedWindow;
|
||||
size_t m_maxNumSweepsToRead;
|
||||
size_t m_truncationLength;
|
||||
std::unordered_map<StreamInformation, MinibatchData> m_minibatchData;
|
||||
std::vector<Microsoft::MSR::CNTK::StreamDescriptionPtr> m_compositeDataReaderStreamDescs;
|
||||
|
|
|
@ -62,7 +62,12 @@ void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
m_currentWindowRange = ClosedOpenChunkInterval{};
|
||||
|
||||
m_config = config;
|
||||
if (config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
|
||||
if (config.m_totalEpochSizeInSweeps != g_infinity)
|
||||
{
|
||||
m_epochSize = m_sweepSizeInSamples * config.m_totalEpochSizeInSweeps;
|
||||
}
|
||||
else if (config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
{
|
||||
m_epochSize = m_sweepSizeInSamples;
|
||||
}
|
||||
|
|
|
@ -54,7 +54,11 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
{
|
||||
m_config = config;
|
||||
|
||||
if (m_config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
if (config.m_totalEpochSizeInSweeps != g_infinity)
|
||||
{
|
||||
m_config.m_totalEpochSizeInSamples = m_sweepSizeInSamples * config.m_totalEpochSizeInSweeps;
|
||||
}
|
||||
else if (m_config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
m_config.m_totalEpochSizeInSamples = m_sweepSizeInSamples;
|
||||
|
||||
SetCurrentSamplePosition(m_config.m_totalEpochSizeInSamples * config.m_epochIndex);
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <functional>
|
||||
#include "Sequences.h"
|
||||
#include "TensorShape.h"
|
||||
#include "ReaderConstants.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -44,6 +45,7 @@ struct ReaderConfiguration
|
|||
struct EpochConfiguration : public ReaderConfiguration
|
||||
{
|
||||
size_t m_totalEpochSizeInSamples; // Total size of the epoch in samples
|
||||
size_t m_totalEpochSizeInSweeps {g_infinity}; // Total size of the epoch in sweeps (default = no limit).
|
||||
size_t m_epochIndex; // Current epoch index [0 .. max number of epochs)
|
||||
};
|
||||
|
||||
|
|
|
@ -7,6 +7,10 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
static size_t const g_infinity = SIZE_MAX;
|
||||
|
||||
static size_t const g_dataSweep = SIZE_MAX - 1;
|
||||
|
||||
static size_t const g_32MB = 32 * 1024 * 1024;
|
||||
|
||||
static size_t const g_4GB = 0x100000000L;
|
||||
|
|
|
@ -73,7 +73,7 @@ bool Is1bitSGDAvailable()
|
|||
return is1bitSGDAvailable;
|
||||
}
|
||||
|
||||
MinibatchSourcePtr CreateHTKMinibatchSource(size_t featureDim, size_t numOutputClasses, const Dictionary& readModeConfig, size_t epochSize, bool randomize = true)
|
||||
MinibatchSourceConfig GetHTKMinibatchSourceConfig(size_t featureDim, size_t numOutputClasses, size_t epochSize, bool randomize = true)
|
||||
{
|
||||
auto featuresFilePath = L"glob_0000.scp";
|
||||
auto labelsFilePath = L"glob_0000.mlf";
|
||||
|
@ -82,13 +82,7 @@ MinibatchSourcePtr CreateHTKMinibatchSource(size_t featureDim, size_t numOutputC
|
|||
Deserializer featureDeserializer = HTKFeatureDeserializer({ HTKFeatureConfiguration(L"features", featuresFilePath, featureDim, 0, 0, false) });
|
||||
Deserializer labelDeserializer = HTKMLFDeserializer(L"labels", labelMappingFile, numOutputClasses, { labelsFilePath });
|
||||
|
||||
Dictionary minibatchSourceConfiguration;
|
||||
if (randomize)
|
||||
minibatchSourceConfiguration[L"randomize"] = true;
|
||||
|
||||
minibatchSourceConfiguration[L"epochSize"] = epochSize;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ featureDeserializer, labelDeserializer });
|
||||
minibatchSourceConfiguration.Add(readModeConfig);
|
||||
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
MinibatchSourceConfig config({ featureDeserializer, labelDeserializer }, randomize);
|
||||
config.maxSamples = epochSize;
|
||||
return config;
|
||||
}
|
||||
|
|
|
@ -639,4 +639,4 @@ inline void CompareFunctions(const FunctionPtr& first, const FunctionPtr& second
|
|||
}
|
||||
}
|
||||
|
||||
MinibatchSourcePtr CreateHTKMinibatchSource(size_t featureDim, size_t numOutputClasses, const Dictionary& readModeConfig, size_t epochSize, bool randomize = true);
|
||||
MinibatchSourceConfig GetHTKMinibatchSourceConfig(size_t featureDim, size_t numOutputClasses, size_t epochSize = MinibatchSource::InfinitelyRepeat, bool randomize = true);
|
||||
|
|
|
@ -52,11 +52,10 @@ MinibatchSourcePtr CreateCifarMinibatchSource(size_t epochSize)
|
|||
deserializerConfiguration[L"file"] = mapFilePath;
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
|
||||
Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = epochSize;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
|
||||
MinibatchSourceConfig config({ deserializerConfiguration });
|
||||
config.maxSamples = epochSize;
|
||||
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
return CreateCompositeMinibatchSource(config);
|
||||
}
|
||||
|
||||
Constant GetProjectionMap(size_t outputDim, size_t inputDim, const DeviceDescriptor& device)
|
||||
|
|
|
@ -33,9 +33,11 @@ void TrainTruncatedLSTMAcousticModelClassifier(const DeviceDescriptor& device, b
|
|||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
|
||||
|
||||
const size_t numSamplesForFeatureStatistics = MinibatchSource::FullDataSweep;
|
||||
Dictionary frameModeConfig;
|
||||
frameModeConfig[L"frameMode"] = true;
|
||||
auto minibatchSource = CreateHTKMinibatchSource(baseFeaturesDim, numOutputClasses, frameModeConfig, numSamplesForFeatureStatistics, false);
|
||||
|
||||
auto config = GetHTKMinibatchSourceConfig(baseFeaturesDim, numOutputClasses, numSamplesForFeatureStatistics, false);
|
||||
config.isFrameModeEnabled = true;
|
||||
auto minibatchSource = CreateCompositeMinibatchSource(config);
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>> featureMeansAndInvStdDevs = { { featureStreamInfo, { nullptr, nullptr } } };
|
||||
ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, featureMeansAndInvStdDevs);
|
||||
|
@ -61,10 +63,10 @@ void TrainTruncatedLSTMAcousticModelClassifier(const DeviceDescriptor& device, b
|
|||
|
||||
const size_t numTrainingSamples = 81920;
|
||||
const size_t truncationLength = 20;
|
||||
Dictionary truncatedModeConfig;
|
||||
truncatedModeConfig[L"truncated"] = true;
|
||||
truncatedModeConfig[L"truncationLength"] = truncationLength;
|
||||
minibatchSource = CreateHTKMinibatchSource(baseFeaturesDim, numOutputClasses, truncatedModeConfig, numTrainingSamples);
|
||||
|
||||
config = GetHTKMinibatchSourceConfig(baseFeaturesDim, numOutputClasses, numTrainingSamples);
|
||||
config.truncationLength = truncationLength;
|
||||
minibatchSource = CreateCompositeMinibatchSource(config);
|
||||
|
||||
const size_t numberParallelSequencesPerMB1 = 16;
|
||||
const size_t numberParallelSequencesPerMB2 = 32;
|
||||
|
|
|
@ -3,7 +3,7 @@ ReasoNet model in CNTK
|
|||
@author penhe@microsoft.com
|
||||
"""
|
||||
import sys
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, DEFAULT_RANDOMIZATION_WINDOW
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
|
||||
import cntk.ops as ops
|
||||
from cntk.layers.blocks import _INFERRED, Parameter
|
||||
#import cntk.internal.utils as utils
|
||||
|
@ -14,7 +14,7 @@ import cntk.learner as learner
|
|||
from .utils import *
|
||||
from .layers import *
|
||||
|
||||
def create_reader(path, vocab_dim, entity_dim, randomize, rand_size= DEFAULT_RANDOMIZATION_WINDOW, size=INFINITELY_REPEAT):
|
||||
def create_reader(path, vocab_dim, entity_dim, randomize):
|
||||
"""
|
||||
Create data reader for the model
|
||||
Args:
|
||||
|
|
|
@ -375,9 +375,14 @@ Test module "V2LibraryTests" has passed with:
|
|||
Test case "LoadLegacyModelSuite/LoadLegacyModelWithPrecomputeInGPU" has passed
|
||||
|
||||
Test suite "MinibatchSourceSuite" has passed with:
|
||||
5 test cases out of 5 passed
|
||||
6 test cases out of 6 passed
|
||||
48 assertions out of 48 passed
|
||||
|
||||
Test case "MinibatchSourceSuite/EndOfSweepFlagIsSetCorrectly" has passed
|
||||
Test case "MinibatchSourceSuite/TestThatEndOfSweepFlagIsSetCorrectly" has passed with:
|
||||
12 assertions out of 12 passed
|
||||
|
||||
Test case "MinibatchSourceSuite/TestSettingMaximumNumberOfSweepsToRead" has passed with:
|
||||
36 assertions out of 36 passed
|
||||
|
||||
Test case "MinibatchSourceSuite/NoRandomizedMinibatchSourceWarmStart" has passed
|
||||
|
||||
|
|
|
@ -744,4 +744,4 @@ inline void CompareFunctions(const FunctionPtr& first, const FunctionPtr& second
|
|||
}
|
||||
}
|
||||
|
||||
MinibatchSourcePtr CreateHTKMinibatchSource(size_t featureDim, size_t numOutputClasses, const Dictionary& readModeConfig, size_t epochSize, bool randomize = true);
|
||||
MinibatchSourceConfig GetHTKMinibatchSourceConfig(size_t featureDim, size_t numOutputClasses, size_t epochSize = MinibatchSource::InfinitelyRepeat, bool randomize = true);
|
||||
|
|
|
@ -38,9 +38,9 @@ void TestLoadLegacyModelWithPrecompute(const DeviceDescriptor& device)
|
|||
FunctionPtr loss = FindVariableByName(outputs, L"CrossEntropyWithSoftmax");
|
||||
FunctionPtr eval = FindVariableByName(outputs, L"EvalClassificationError");
|
||||
|
||||
Dictionary frameModeConfig;
|
||||
frameModeConfig[L"frameMode"] = true;
|
||||
auto minibatchSource = CreateHTKMinibatchSource(baseFeaturesDim, numOutputClasses, frameModeConfig, MinibatchSource::InfinitelyRepeat, true);
|
||||
auto config = GetHTKMinibatchSourceConfig(baseFeaturesDim, numOutputClasses);
|
||||
config.isFrameModeEnabled = true;
|
||||
auto minibatchSource = CreateCompositeMinibatchSource(config);
|
||||
|
||||
const size_t minbatchSize = 256;
|
||||
size_t numMinibatches = 10;
|
||||
|
|
|
@ -85,43 +85,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
MinibatchSourcePtr CreateTextFormatMinibatchSource(const std::wstring& dataFilePath, const std::vector<StreamConfiguration>& streamConfigs, size_t epochSize, bool randomize, size_t chunkSizeInBytes)
|
||||
{
|
||||
::CNTK::Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = epochSize;
|
||||
|
||||
if (randomize)
|
||||
minibatchSourceConfiguration[L"randomize"] = true;
|
||||
else
|
||||
minibatchSourceConfiguration[L"randomize"] = false;
|
||||
|
||||
::CNTK::Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"file"] = dataFilePath;
|
||||
|
||||
::CNTK::Dictionary inputStreamsConfig;
|
||||
for (auto streamConfig : streamConfigs)
|
||||
{
|
||||
std::wstring streamName = streamConfig.m_streamName;
|
||||
size_t streamDim = streamConfig.m_dim;
|
||||
bool isSparse = streamConfig.m_isSparse;
|
||||
std::wstring streamAlias = streamConfig.m_streamAlias;
|
||||
|
||||
::CNTK::Dictionary inputStreamConfig;
|
||||
inputStreamConfig[L"dim"] = streamDim;
|
||||
inputStreamConfig[L"format"] = isSparse ? L"sparse" : L"dense";
|
||||
if (!streamAlias.empty())
|
||||
inputStreamConfig[L"alias"] = streamAlias;
|
||||
|
||||
inputStreamsConfig[streamName] = inputStreamConfig;
|
||||
}
|
||||
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
deserializerConfiguration[L"chunkSizeInBytes"] = chunkSizeInBytes;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<::CNTK::DictionaryValue>({ deserializerConfiguration });
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
}
|
||||
|
||||
void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples, bool randomize, size_t chunkSizeInBytes, bool expectNoData = false)
|
||||
{
|
||||
// TODO: Currently this test is based on the number of samples.
|
||||
|
@ -134,23 +97,19 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
|
|||
|
||||
const size_t numberOfSamplesInSweep = 10000;
|
||||
|
||||
auto ctf = CTFDeserializer(L"SimpleDataTrain_cntk_text.txt", { { featureStreamName, inputDim },{ labelsStreamName, numOutputClasses } });
|
||||
ctf[L"chunkSizeInBytes"] = chunkSizeInBytes;
|
||||
MinibatchSourceConfig config({ ctf }, randomize);
|
||||
config.maxSamples = numberOfSamplesInSweep;
|
||||
|
||||
|
||||
// Let's create two workers.
|
||||
auto minibatchSource = CreateTextFormatMinibatchSource(
|
||||
L"SimpleDataTrain_cntk_text.txt",
|
||||
{ { featureStreamName, inputDim }, { labelsStreamName, numOutputClasses } },
|
||||
numberOfSamplesInSweep,
|
||||
randomize,
|
||||
chunkSizeInBytes);
|
||||
auto minibatchSource = CreateCompositeMinibatchSource(config);
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(featureStreamName);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labelsStreamName);
|
||||
|
||||
auto minibatchSource2 = CreateTextFormatMinibatchSource(
|
||||
L"SimpleDataTrain_cntk_text.txt",
|
||||
{ { featureStreamName, inputDim }, { labelsStreamName, numOutputClasses } },
|
||||
numberOfSamplesInSweep,
|
||||
randomize,
|
||||
chunkSizeInBytes);
|
||||
auto minibatchSource2 = CreateCompositeMinibatchSource(config);
|
||||
|
||||
size_t totalSamples = 0;
|
||||
bool hasData = true;
|
||||
|
@ -219,18 +178,22 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
|
|||
(int)totalSamples);
|
||||
}
|
||||
|
||||
|
||||
void TestEndOfSweepFlag(size_t maxSamples, size_t mbSize, bool randomize)
|
||||
{
|
||||
const size_t sweepSize = 603;
|
||||
auto ctfInput = L"SimpleDataTest_cntk_text.txt";
|
||||
std::vector<StreamConfiguration> streamConfig{ { L"features", 2 } };
|
||||
auto cpuDevice = DeviceDescriptor::CPUDevice();
|
||||
auto src = TextFormatMinibatchSource(ctfInput, streamConfig, maxSamples, randomize);
|
||||
|
||||
MinibatchSourceConfig config({ CTFDeserializer(ctfInput, streamConfig) }, randomize);
|
||||
config.maxSamples = maxSamples;
|
||||
auto src = CreateCompositeMinibatchSource(config);
|
||||
|
||||
maxSamples = (maxSamples == MinibatchSource::FullDataSweep) ? sweepSize : maxSamples;
|
||||
|
||||
bool reachedEndOfEpoch = false;
|
||||
size_t sampleCount = 0;
|
||||
auto cpuDevice = DeviceDescriptor::CPUDevice();
|
||||
|
||||
while (sampleCount < maxSamples)
|
||||
{
|
||||
|
@ -271,10 +234,47 @@ void TestEndOfSweepFlag(size_t maxSamples, size_t mbSize, bool randomize)
|
|||
}
|
||||
|
||||
auto& emptyDataMap = src->GetNextMinibatch(mbSize, cpuDevice);
|
||||
assert(emptyDataMap.empty());
|
||||
BOOST_TEST(emptyDataMap.empty());
|
||||
}
|
||||
|
||||
void TestThatEndOfSweepFlagIsSetCorrectly()
|
||||
void TestMaxSweeps(size_t maxSweeps, size_t mbSize, bool randomize)
|
||||
{
|
||||
const size_t sweepSize = 603;
|
||||
auto ctfInput = L"SimpleDataTest_cntk_text.txt";
|
||||
std::vector<StreamConfiguration> streamConfig{ { L"features", 2 } };
|
||||
|
||||
MinibatchSourceConfig config({ CTFDeserializer(ctfInput, streamConfig) }, randomize);
|
||||
config.maxSweeps = maxSweeps;
|
||||
auto src = CreateCompositeMinibatchSource(config);
|
||||
|
||||
auto maxSamples = sweepSize * maxSweeps;
|
||||
|
||||
size_t sampleCount = 0;
|
||||
size_t sweepCount = 0;
|
||||
auto cpuDevice = DeviceDescriptor::CPUDevice();
|
||||
|
||||
while (sampleCount < maxSamples)
|
||||
{
|
||||
const auto& dataMap = src->GetNextMinibatch(mbSize, cpuDevice);
|
||||
const auto& data = dataMap.at(src->StreamInfo(L"features"));
|
||||
|
||||
sampleCount += data.numberOfSamples;
|
||||
if (data.sweepEnd)
|
||||
sweepCount++;
|
||||
}
|
||||
|
||||
BOOST_TEST(sampleCount == maxSamples);
|
||||
BOOST_TEST(sweepCount == maxSweeps);
|
||||
|
||||
auto& emptyDataMap = src->GetNextMinibatch(mbSize, cpuDevice);
|
||||
BOOST_TEST(emptyDataMap.empty());
|
||||
}
|
||||
|
||||
|
||||
|
||||
BOOST_AUTO_TEST_SUITE(MinibatchSourceSuite)
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestThatEndOfSweepFlagIsSetCorrectly)
|
||||
{
|
||||
for (auto randomize : { false, true })
|
||||
{
|
||||
|
@ -288,11 +288,18 @@ void TestThatEndOfSweepFlagIsSetCorrectly()
|
|||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE(MinibatchSourceSuite)
|
||||
|
||||
BOOST_AUTO_TEST_CASE(EndOfSweepFlagIsSetCorrectly)
|
||||
BOOST_AUTO_TEST_CASE(TestSettingMaximumNumberOfSweepsToRead)
|
||||
{
|
||||
TestThatEndOfSweepFlagIsSetCorrectly();
|
||||
for (auto randomize : { false, true })
|
||||
{
|
||||
TestMaxSweeps(2, 100, randomize);
|
||||
TestMaxSweeps(2, 603, randomize);
|
||||
TestMaxSweeps(2, 1000, randomize);
|
||||
|
||||
TestMaxSweeps(3, 30, randomize);
|
||||
TestMaxSweeps(3, 500, randomize);
|
||||
TestMaxSweeps(3, 301, randomize);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(NoRandomizedMinibatchSourceWarmStart)
|
||||
|
|
|
@ -95,7 +95,7 @@
|
|||
"import cntk as C\n",
|
||||
"from cntk import UnitType\n",
|
||||
"from cntk.io import CTFDeserializer, MinibatchSource, StreamDef, StreamDefs\n",
|
||||
"from cntk.io import INFINITELY_REPEAT, FULL_DATA_SWEEP\n",
|
||||
"from cntk.io import INFINITELY_REPEAT\n",
|
||||
"from cntk.initializer import glorot_uniform\n",
|
||||
"from cntk.layers import default_options, Input, Dense\n",
|
||||
"\n",
|
||||
|
@ -164,7 +164,7 @@
|
|||
" return MinibatchSource(CTFDeserializer(path, StreamDefs(\n",
|
||||
" labels = StreamDef(field='labels', shape=num_label_classes, is_sparse=False),\n",
|
||||
" features = StreamDef(field='features', shape=input_dim, is_sparse=False)\n",
|
||||
" )), randomize = is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)"
|
||||
" )), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -85,7 +85,7 @@
|
|||
"import cntk as C\n",
|
||||
"from cntk.device import try_set_default_device, gpu, cpu\n",
|
||||
"from cntk.layers import default_options, Input, Dense\n",
|
||||
"from cntk.io import StreamConfiguration, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP\n",
|
||||
"from cntk.io import StreamConfiguration, StreamDef, StreamDefs, INFINITELY_REPEAT\n",
|
||||
"from cntk.io import MinibatchSource, CTFDeserializer\n",
|
||||
"\n",
|
||||
"%matplotlib inline"
|
||||
|
@ -169,7 +169,7 @@
|
|||
" return MinibatchSource(CTFDeserializer(path, StreamDefs(\n",
|
||||
" labels_viz = StreamDef(field='labels', shape=num_label_classes, is_sparse=False),\n",
|
||||
" features = StreamDef(field='features', shape=input_dim, is_sparse=False)\n",
|
||||
" )), randomize = is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)"
|
||||
" )), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -126,7 +126,7 @@
|
|||
"\n",
|
||||
"from cntk.logging import ProgressPrinter, log_number_of_parameters\n",
|
||||
"from cntk.io import MinibatchSource, CTFDeserializer\n",
|
||||
"from cntk.io import StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP\n",
|
||||
"from cntk.io import StreamDef, StreamDefs, INFINITELY_REPEAT\n",
|
||||
"from cntk import *\n",
|
||||
"from cntk.learners import fsadagrad, learning_rate_schedule\n",
|
||||
"from cntk.layers import * # CNTK Layers library\n",
|
||||
|
@ -342,7 +342,7 @@
|
|||
" query = StreamDef(field='S0', shape=vocab_size, is_sparse=True),\n",
|
||||
" intent_unused = StreamDef(field='S1', shape=num_intents, is_sparse=True), \n",
|
||||
" slot_labels = StreamDef(field='S2', shape=num_labels, is_sparse=True)\n",
|
||||
" )), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)"
|
||||
" )), randomize=is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -182,7 +182,7 @@
|
|||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"from cntk import Trainer, Axis\n",
|
||||
"from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP\n",
|
||||
"from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT\n",
|
||||
"from cntk.learners import momentum_sgd, fsadagrad, momentum_as_time_constant_schedule, learning_rate_schedule, UnitType\n",
|
||||
"from cntk import input_variable, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, \\\n",
|
||||
" element_select, alias, hardmax, placeholder_variable, combine, parameter, times, plus\n",
|
||||
|
@ -313,7 +313,7 @@
|
|||
" return MinibatchSource(CTFDeserializer(path, StreamDefs(\n",
|
||||
" features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),\n",
|
||||
" labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)\n",
|
||||
" )), randomize = is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)"
|
||||
" )), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -368,7 +368,7 @@
|
|||
" return MinibatchSource(CTFDeserializer(path, StreamDefs(\n",
|
||||
" features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),\n",
|
||||
" labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)\n",
|
||||
" )), randomize = is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)\n",
|
||||
" )), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)\n",
|
||||
"\n",
|
||||
"# Train data reader\n",
|
||||
"train_reader = create_reader(TRAINING_DATA, True)\n",
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
"from cntk.device import try_set_default_device, gpu, cpu\n",
|
||||
"from cntk.initializer import xavier\n",
|
||||
"from cntk.io import (MinibatchSource, CTFDeserializer, StreamDef, StreamDefs,\n",
|
||||
" INFINITELY_REPEAT, FULL_DATA_SWEEP)\n",
|
||||
" INFINITELY_REPEAT)\n",
|
||||
"from cntk.layers import Dense, default_options, Input\n",
|
||||
"from cntk.learners import (fsadagrad, UnitType, sgd, learning_rate_schedule,\n",
|
||||
" momentum_as_time_constant_schedule)\n",
|
||||
|
@ -166,7 +166,7 @@
|
|||
" return MinibatchSource(\n",
|
||||
" deserializers = deserializer,\n",
|
||||
" randomize = is_training,\n",
|
||||
" epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP\n",
|
||||
" max_sweeps = INFINITELY_REPEAT if is_training else 1\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
"from cntk.device import set_default_device, gpu, cpu\n",
|
||||
"from cntk.initializer import normal\n",
|
||||
"from cntk.io import (MinibatchSource, CTFDeserializer, StreamDef, StreamDefs,\n",
|
||||
" INFINITELY_REPEAT, FULL_DATA_SWEEP)\n",
|
||||
" INFINITELY_REPEAT)\n",
|
||||
"from cntk.layers import Dense, Convolution2D, ConvolutionTranspose2D, BatchNormalization\n",
|
||||
"from cntk.learners import (adam, UnitType, learning_rate_schedule,\n",
|
||||
" momentum_as_time_constant_schedule, momentum_schedule)\n",
|
||||
|
@ -177,7 +177,7 @@
|
|||
" return MinibatchSource(\n",
|
||||
" deserializers = deserializer,\n",
|
||||
" randomize = is_training,\n",
|
||||
" epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP\n",
|
||||
" max_sweeps = INFINITELY_REPEAT if is_training else 1\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -263,6 +263,7 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::DeviceDescriptor)
|
|||
%ignore_struct CNTK::MinibatchData;
|
||||
%ignore_class CNTK::MinibatchSource;
|
||||
%ignore_struct CNTK::MinibatchInfo;
|
||||
%ignore_struct CNTK::MinibatchSourceConfig;
|
||||
|
||||
%ignore_function CNTK::CreateCompositeMinibatchSource;
|
||||
%ignore_struct CNTK::StreamConfiguration;
|
||||
|
@ -344,6 +345,7 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::DeviceDescriptor)
|
|||
%ignore_function CNTK::Internal::AreEqual;
|
||||
%ignore_function CNTK::PrintBuiltInfo;
|
||||
%ignore_function CNTK::Internal::DefaultPackThresholdSizeInBytes;
|
||||
%ignore_function CNTK::Internal::ToDictionary;
|
||||
|
||||
%ignore_class CNTK::Internal::TensorBoardFileWriter;
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from .. import cntk_py, Value
|
|||
from ..tensor import ArrayMixin
|
||||
from cntk.internal import typemap
|
||||
from cntk.device import use_default_device
|
||||
from enum import Enum, unique
|
||||
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
@ -16,8 +17,6 @@ INFINITELY_REPEAT = cntk_py.MinibatchSource.infinitely_repeat
|
|||
'''int: constant used to specify a minibatch scheduling unit to equal the size of the full data sweep.'''
|
||||
|
||||
FULL_DATA_SWEEP = cntk_py.MinibatchSource.full_data_sweep
|
||||
INFINITE_SAMPLES = cntk_py.MinibatchSource.infinite_samples
|
||||
DEFAULT_RANDOMIZATION_WINDOW = cntk_py.MinibatchSource.default_randomization_window
|
||||
DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS = cntk_py.MinibatchSource.default_randomization_window_in_chunks
|
||||
|
||||
class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
|
||||
|
@ -82,79 +81,176 @@ class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
|
|||
def __len__(self):
|
||||
return self.num_sequences
|
||||
|
||||
@unique
|
||||
class TraceLevel(Enum):
|
||||
|
||||
Error = cntk_py.MinibatchSourceConfig.Error
|
||||
Warning = cntk_py.MinibatchSourceConfig.Warning
|
||||
Info = cntk_py.MinibatchSourceConfig.Info
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, TraceLevel):
|
||||
return self.value == other.value
|
||||
return self.value == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class MinibatchSource(cntk_py.MinibatchSource):
|
||||
'''MinibatchSource(deserializers=None, randomize=True, randomization_window=cntk.io.DEFAULT_RANDOMIZATION_WINDOW, epoch_size=cntk.io.INFINITELY_REPEAT, distributed_after=cntk.io.INFINITE_SAMPLES, multithreaded_deserializer=None)
|
||||
A `MinibatchSource` can be indexed by the stream name, which will return a
|
||||
Parent class of all minibatch sources. A `MinibatchSource` can be indexed by the stream name, which will return a
|
||||
:class:`MinibatchData` object that can be passed e.g. to the
|
||||
:func:`~cntk.train.trainer.Trainer.train_minibatch` function.
|
||||
'''
|
||||
MinibatchSource(deserializers, max_samples=cntk.io.INFINITELY_REPEAT, max_sweeps=cntk.io.INFINITELY_REPEAT,
|
||||
randomization_window_in_chunks=cntk.io.DEFAULT_RANDOMIZATION_WINDOW, randomization_window_in_samples=0,
|
||||
trace_level=cntk.io.TraceLevel.Warning, multithreaded_deserializer=False, frame_mode=False,
|
||||
truncation_length=0, randomize=None, randomization_window=None, sample_based_randomization_window=None,
|
||||
epoch_size=None)
|
||||
|
||||
Args:
|
||||
deserializers (`list`, defaults to empty): list of deserializers
|
||||
randomize (`bool`, defaults to `True`): randomize before every epoch
|
||||
randomization_window (int): size of window that reader will shuffle, ignored if `randomize`
|
||||
is `False`
|
||||
sample_based_randomization_window (`bool`, defaults to `False`): specifies how to interpret
|
||||
`randomization_window`. If `True`, the size of the randomization window is interpreted as a certain
|
||||
number of samples, otherwise -- as a number of chunks. Similarly to `randomization_window`,
|
||||
this parameter is ignored, when `randomize` is `False`
|
||||
epoch_size (`int`, defaults to :const:`~cntk.io.INFINITELY_REPEAT`): number of samples as a scheduling unit.
|
||||
Parameters in the schedule change their values every `epoch_size`
|
||||
samples. If no `epoch_size` is provided, this parameter is substituted
|
||||
by the size of the full data sweep with infinite repeat, in which case the scheduling unit is
|
||||
the entire data sweep (as indicated by the MinibatchSource) and parameters
|
||||
change their values on the sweep-by-sweep basis specified by the schedule.
|
||||
**Important:**
|
||||
Click `here <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`__ for a full description of this parameter.
|
||||
multithreaded_deserializer (`bool`, defaults to `None`): using multi threaded deserializer
|
||||
frame_mode (`bool`, defaults to `False`): Specifies if data should be randomized and returned at the frame
|
||||
or sequence level. When true , input sequence are split into frames.
|
||||
truncation_length (`int`): Specifies the truncation length in samples for BPTT (positive integer). If greater than zero
|
||||
`frame_mode` cannot be used at the same time.
|
||||
deserializers (a single deserializer or a `list`): deserializers to be used in the composite reader
|
||||
max_samples (`int`, defaults to :const:`cntk.io.INFINITELY_REPEAT`): The maximum number of input samples
|
||||
(not 'label samples') the reader can produce. After this number has been reached, the reader
|
||||
returns empty minibatches on subsequent calls to GetNextMinibatch(). `max_samples` and `max_sweeps`
|
||||
are mutually exclusive, an exception will be raised if both have non-default values.
|
||||
**Important:**
|
||||
`See <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`__
|
||||
for a description of input and label samples.
|
||||
max_sweeps (`int`, defaults to :const:`cntk.io.INFINITELY_REPEAT`): The maximum number of of sweeps over
|
||||
the input dataset After this number has been reached, the reader returns empty minibatches on
|
||||
subsequent calls to GetNextMinibatch(). `max_samples` and `max_sweeps` are mutually exclusive,
|
||||
an exception will be raised if both have non-default values.
|
||||
randomization_window_in_chunks (`int`, defaults to :const:`cntk.io.DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS`):
|
||||
size of the randomization window in chunks, non-zero value enables randomization.
|
||||
`randomization_window_in_chunks` and `randomization_window_in_samples` are mutually exclusive,
|
||||
an exception will be raised if both have non-zero values.
|
||||
randomization_window_in_samples (`int`, defaults to `0`): size of the randomization window in samples,
|
||||
non-zero value enables randomization.
|
||||
`randomization_window_in_chunks` and `randomization_window_in_samples` are mutually exclusive,
|
||||
an exception will be raised if both have non-zero values.
|
||||
trace_level (an instance of :class:`cntk.io.TraceLevel`, defaults to `TraceLevel.Warning`):
|
||||
the output verbosity level.
|
||||
multithreaded_deserializer (`bool`, defaults to `False`): specifies if the deserialization should be
|
||||
done on a single or multiple threads.
|
||||
frame_mode (`bool`, defaults to `False`): switches the frame mode on and off. If the frame mode
|
||||
is enabled the input data will be processed as individual frames ignoring all sequence information
|
||||
(this option cannot be used for BPTT, an exception will be raised if frame mode is enabled and the
|
||||
truncation length is non-zero).
|
||||
truncation_length (`int`, defaults to `0`): truncation length in samples, non-zero value enables
|
||||
the truncation (only applicable for BPTT, cannot be used in frame mode, an exception will be raised
|
||||
if frame mode is enabled and the truncation length is non-zero).
|
||||
|
||||
randomize (`bool`, defaults to `None`): !DEPRECATED! please use randomization_window_in_chunks or
|
||||
randomization_window_in_samples instead
|
||||
randomization_window (int, defaults to `None`): !DEPRECATED! please use randomization_window_in_chunks or
|
||||
randomization_window_in_samples instead
|
||||
sample_based_randomization_window (`bool`, defaults to `None`): !DEPRECATED! please use
|
||||
randomization_window_in_chunks or randomization_window_in_samples instead
|
||||
epoch_size (`int`, defaults to `None`): !DEPRECATED! please use max_samples or max_sweeps instead
|
||||
'''
|
||||
def __init__(self,
|
||||
deserializers=None,
|
||||
randomize=True,
|
||||
randomization_window=DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS,
|
||||
sample_based_randomization_window=False,
|
||||
epoch_size=INFINITELY_REPEAT,
|
||||
distributed_after=INFINITE_SAMPLES,
|
||||
multithreaded_deserializer=None,
|
||||
def __init__(self,
|
||||
deserializers,
|
||||
max_samples = INFINITELY_REPEAT,
|
||||
max_sweeps = INFINITELY_REPEAT,
|
||||
randomization_window_in_chunks = DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS,
|
||||
randomization_window_in_samples = 0,
|
||||
trace_level = TraceLevel.Warning,
|
||||
multithreaded_deserializer=False,
|
||||
frame_mode=False,
|
||||
truncation_length=0):
|
||||
truncation_length=0,
|
||||
# all parameters below are deprecated
|
||||
randomize=None,
|
||||
randomization_window=None,
|
||||
sample_based_randomization_window=None,
|
||||
epoch_size=None,
|
||||
distributed_after=None):
|
||||
|
||||
if not isinstance(deserializers, (list,tuple)):
|
||||
deserializers = [deserializers] # allow passing a single item or a list
|
||||
reader_config = _ReaderConfig(
|
||||
deserializers=deserializers,
|
||||
randomize=randomize,
|
||||
randomization_window=randomization_window,
|
||||
sample_based_randomization_window=sample_based_randomization_window,
|
||||
epoch_size=epoch_size,
|
||||
distributed_after=distributed_after,
|
||||
multithreaded_deserializer=multithreaded_deserializer,
|
||||
frame_mode=frame_mode,
|
||||
truncation_length=truncation_length)
|
||||
source = reader_config.minibatch_source()
|
||||
deserializers = [ deserializers ]
|
||||
|
||||
config = cntk_py.MinibatchSourceConfig(deserializers)
|
||||
config.max_samples = max_samples
|
||||
config.max_sweeps = max_sweeps
|
||||
config.randomization_window_in_chunks = randomization_window_in_chunks
|
||||
config.randomization_window_in_samples = randomization_window_in_samples
|
||||
config.is_multithreaded = multithreaded_deserializer
|
||||
config.is_frame_mode_enabled = frame_mode
|
||||
config.truncation_length = truncation_length
|
||||
|
||||
if isinstance(trace_level, TraceLevel):
|
||||
trace_level = trace_level.value
|
||||
|
||||
config.trace_level = trace_level
|
||||
|
||||
# the following deals with deprecated parameters.
|
||||
import warnings
|
||||
# TODO: 'randomize=False' is the only legacy option that still makes sense
|
||||
# (as a shortcut to randomization_window_in_chunks=0 and
|
||||
# randomization_window_in_samples=0), maybe we should keep it?
|
||||
if randomize is not None and randomize:
|
||||
warnings.warn('"randomize" parameter is deprecated and will be removed '
|
||||
'in future versions. Please specify "randomization_window_in_chunks" or '
|
||||
'"randomization_window_in_samples" instead', DeprecationWarning)
|
||||
elif randomize is None:
|
||||
randomize = True # previously default value
|
||||
|
||||
if randomization_window is not None:
|
||||
warnings.warn('"randomization_window" parameter is deprecated and will be removed '
|
||||
'in future versions. Please specify "randomization_window_in_chunks" or '
|
||||
'"randomization_window_in_samples" instead', DeprecationWarning)
|
||||
else:
|
||||
randomization_window = DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS # previously default value
|
||||
|
||||
if sample_based_randomization_window is not None:
|
||||
warnings.warn('"sample_based_randomization_window" parameter is deprecated and will be removed '
|
||||
'in future versions. Please specify "randomization_window_in_chunks" or '
|
||||
'"randomization_window_in_samples" instead', DeprecationWarning)
|
||||
else:
|
||||
sample_based_randomization_window = False # previously default value
|
||||
|
||||
if (randomize and sample_based_randomization_window):
|
||||
config.randomization_window_in_samples = randomization_window
|
||||
config.randomization_window_in_chunks = 0
|
||||
elif (randomize and not sample_based_randomization_window):
|
||||
config.randomization_window_in_chunks = randomization_window
|
||||
config.randomization_window_in_samples = 0
|
||||
elif not randomize:
|
||||
config.randomization_window_in_chunks = 0
|
||||
config.randomization_window_in_samples = 0
|
||||
|
||||
if (epoch_size is not None):
|
||||
warnings.warn('"epoch_size" parameter is deprecated and will be removed '
|
||||
'in future versions. Please specify "max_samples" or '
|
||||
'"max_sweeps" instead', DeprecationWarning)
|
||||
config.max_samples = epoch_size
|
||||
|
||||
source = cntk_py.create_composite_minibatch_source(config)
|
||||
# transplant into this class instance
|
||||
self.__dict__ = source.__dict__
|
||||
# transplant all members of deserializers into a record called streams
|
||||
streams = {}
|
||||
for si in self.stream_infos():
|
||||
streams[si.m_name] = si
|
||||
from ..variables import Record
|
||||
self.streams = Record(**streams)
|
||||
self._streams = None
|
||||
|
||||
|
||||
def stream_infos(self):
|
||||
'''
|
||||
Describes the stream that this source produces.
|
||||
Describes the streams 'this' minibatch source produces.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
A `dict` mapping input names to the stream information
|
||||
A `list` of instances of :class:`~cntk.cntk_py.StreamInformation`
|
||||
'''
|
||||
return super(MinibatchSource, self).stream_infos()
|
||||
|
||||
@property
|
||||
def streams(self):
|
||||
'''
|
||||
Describes the streams 'this' minibatch source produces.
|
||||
|
||||
Returns:
|
||||
A `dict` mapping input names to instances of
|
||||
:class:`~cntk.cntk_py.StreamInformation`
|
||||
'''
|
||||
if self._streams is None:
|
||||
from cntk.variables import Record
|
||||
self._streams = Record(**dict((info.m_name, info) for info in self.stream_infos()))
|
||||
|
||||
return self._streams
|
||||
|
||||
def stream_info(self, name):
|
||||
'''
|
||||
Gets the description of the stream with given name.
|
||||
|
@ -214,14 +310,14 @@ class MinibatchSource(cntk_py.MinibatchSource):
|
|||
mb = super(MinibatchSource, self).get_next_minibatch(0,
|
||||
minibatch_size_in_samples, num_data_partitions, partition_index, device)
|
||||
|
||||
if input_map:
|
||||
if not mb:
|
||||
return {}
|
||||
else:
|
||||
return { key : mb[value] for (key, value) in input_map.items() }
|
||||
else:
|
||||
if not mb:
|
||||
return mb
|
||||
|
||||
if not input_map:
|
||||
return mb
|
||||
|
||||
return { key : mb[value] for (key, value) in input_map.items() }
|
||||
|
||||
def get_checkpoint_state(self):
|
||||
'''
|
||||
Gets the checkpoint state of the MinibatchSource.
|
||||
|
@ -268,6 +364,7 @@ class MinibatchSource(cntk_py.MinibatchSource):
|
|||
'''
|
||||
self.restore_from_checkpoint(position)
|
||||
|
||||
|
||||
def _py_dict_to_cntk_dict(py_dict):
|
||||
'''
|
||||
Converts a Python dictionary into a CNTK Dictionary whose values are CNTK DictionaryValue instances.
|
||||
|
@ -277,7 +374,7 @@ def _py_dict_to_cntk_dict(py_dict):
|
|||
|
||||
Returns:
|
||||
cntk_py.Dictionary:
|
||||
A :class:`~cntk_py.Dictionary` that has been converted from the input `dict`
|
||||
A :class:`~cntk.cntk_py.Dictionary` that has been converted from the input `dict`
|
||||
'''
|
||||
res = cntk_py.Dictionary()
|
||||
for k, v in py_dict.items():
|
||||
|
@ -291,93 +388,6 @@ def _py_dict_to_cntk_dict(py_dict):
|
|||
return res
|
||||
|
||||
|
||||
# TODO: This should be a private function; use MinibatchSource(deserializer, ...).
|
||||
@typemap
|
||||
def _minibatch_source(config):
|
||||
'''
|
||||
Instantiate the CNTK built-in composite minibatch source which is used to stream data into the network.
|
||||
|
||||
Args:
|
||||
config (dict): a dictionary containing all the key-value configuration entries.
|
||||
|
||||
Returns:
|
||||
cntk.io.MinibatchSource:
|
||||
The :class:`MinibatchSource` used to stream data into the network
|
||||
'''
|
||||
cntk_dict = _py_dict_to_cntk_dict(config)
|
||||
return cntk_py.create_composite_minibatch_source(cntk_dict)
|
||||
|
||||
class _ReaderConfig(dict):
|
||||
'''
|
||||
Reader configuration.
|
||||
|
||||
Args:
|
||||
deserializers ('list', defaults to `None`): list of deserializers
|
||||
(:class:`ImageDeserializer` for now).
|
||||
randomize (`bool`, defaults to `True`): randomize images before every epoch
|
||||
randomization_window (int): size of window that reader will shuffle, ignored if `randomize`
|
||||
is `False`
|
||||
sample_based_randomization_window (bool, defaults to `False`): specifies how to interpret
|
||||
`randomization_range`. If `True`, the size of the randomization window is interpreted as a certain
|
||||
number of samples, otherwise -- as a number of chunks. Similarly to `randomization_window`,
|
||||
this parameter is ignored, when `randomize` is `False`
|
||||
epoch_size (`int`, defaults to `cntk.io.INFINITELY_REPEAT`): number of samples as a scheduling unit.
|
||||
Parameters in the schedule change their values every `epoch_size`
|
||||
samples. If no `epoch_size` is provided, this parameter is substituted
|
||||
by the size of the full data sweep with infinite repeat, in which case the scheduling unit is
|
||||
the entire data sweep (as indicated by the MinibatchSource) and parameters
|
||||
change their values on the sweep-by-sweep basis specified by the schedule.
|
||||
**Important:**
|
||||
Click `here <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`__ for a full description of this parameter.
|
||||
distributed_after (int, defaults to `cntk.io.INFINITE_SAMPLES`): sample count after which reader becomes distributed
|
||||
multithreaded_deserializer (`bool`, defaults to `None`): using multi threaded deserializer
|
||||
frame_mode (`bool`, defaults to `False`): Specifies if data should be randomized and returned at the frame
|
||||
or sequence level. When true , input sequence are split into frames.
|
||||
truncation_length (`int`): Specifies the truncation length in samples for BPTT (positive integer). When using truncation,
|
||||
frame mode cannot be used at the same time.
|
||||
'''
|
||||
def __init__(self,
|
||||
deserializers=None,
|
||||
randomize=True,
|
||||
randomization_window=DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS,
|
||||
sample_based_randomization_window=False,
|
||||
epoch_size=INFINITELY_REPEAT,
|
||||
distributed_after=INFINITE_SAMPLES,
|
||||
multithreaded_deserializer=None,
|
||||
frame_mode=False,
|
||||
truncated=False,
|
||||
truncation_length=0):
|
||||
self['epochSize'] = cntk_py.SizeTWrapper(epoch_size) # force to store in size_t
|
||||
if not isinstance(deserializers, (list, tuple)):
|
||||
deserializers = [deserializers]
|
||||
self['deserializers'] = self.deserializers = deserializers or []
|
||||
self['randomize'] = randomize
|
||||
self['randomizationWindow'] = cntk_py.SizeTWrapper(randomization_window)
|
||||
self['sampleBasedRandomizationWindow'] = sample_based_randomization_window
|
||||
self['distributedAfterSampleCount'] = cntk_py.SizeTWrapper(distributed_after)
|
||||
if multithreaded_deserializer is not None:
|
||||
self['multiThreadedDeserialization'] = multithreaded_deserializer
|
||||
|
||||
if truncation_length > 0:
|
||||
self['truncated'] = True
|
||||
self['truncationLength'] = cntk_py.SizeTWrapper(truncation_length)
|
||||
if frame_mode:
|
||||
raise ValueError("FrameMode and truncated BPTT are mutually exclusive.")
|
||||
self['frameMode'] = frame_mode
|
||||
|
||||
@typemap
|
||||
def minibatch_source(self):
|
||||
'''
|
||||
Creates an instance of :class:`MinibatchSource` from this
|
||||
instance, which can be used to feed data into the `eval()` methods of
|
||||
the graph nodes or the `train_minibatch()` of :class:`~cntk.train.trainer.Trainer`.
|
||||
|
||||
Returns:
|
||||
cntk.io.MinibatchSource:
|
||||
An instance of :class:`MinibatchSource` from this instance.
|
||||
'''
|
||||
return _minibatch_source(self)
|
||||
|
||||
def HTKFeatureDeserializer(streams):
|
||||
'''
|
||||
Configures the HTK feature reader that reads speech data from scp files.
|
||||
|
|
|
@ -10,13 +10,32 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from cntk.io import *
|
||||
from cntk.io import _ReaderConfig
|
||||
import cntk.io.transforms as xforms
|
||||
from cntk.cntk_py import to_dictionary
|
||||
from cntk.cntk_py import MinibatchSourceConfig
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
AA = np.asarray
|
||||
|
||||
def create_temp_file(tmpdir):
|
||||
tmpfile = str(tmpdir/'mbtest.txt')
|
||||
with open(tmpfile, 'w') as f:
|
||||
f.write("|S0 1\n|S0 2\n|S0 3\n|S0 4")
|
||||
return tmpfile
|
||||
|
||||
def create_ctf_deserializer(tmpdir):
|
||||
tmpfile = create_temp_file(tmpdir)
|
||||
return CTFDeserializer(tmpfile, StreamDefs(features = StreamDef(field='S0', shape=1)))
|
||||
|
||||
def create_config(tmpdir):
|
||||
tmpfile = create_temp_file(tmpdir)
|
||||
return MinibatchSourceConfig() \
|
||||
.add_deserializer(
|
||||
CTFDeserializer(tmpfile,
|
||||
StreamDefs(features = StreamDef(field='S0', shape=1))))
|
||||
|
||||
|
||||
def test_text_format(tmpdir):
|
||||
mbdata = r'''0 |x 560:1 |y 1 0 0 0 0
|
||||
0 |x 0:1
|
||||
|
@ -82,6 +101,124 @@ def test_text_format(tmpdir):
|
|||
assert features.num_samples < 7
|
||||
assert labels.num_samples == 1
|
||||
|
||||
def check_default_config_keys(d):
|
||||
assert 5 <= len(d.keys())
|
||||
assert False == d['frameMode']
|
||||
assert False == d['multiThreadedDeserialization']
|
||||
assert TraceLevel.Warning == d['traceLevel']
|
||||
assert 'randomize' in d.keys()
|
||||
assert 'deserializers' in d.keys()
|
||||
|
||||
def test_minibatch_source_config_constructor(tmpdir):
|
||||
ctf = create_ctf_deserializer(tmpdir)
|
||||
|
||||
config = MinibatchSourceConfig([ctf], False)
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
assert 5 == len(dictionary.keys())
|
||||
assert False == dictionary['randomize']
|
||||
|
||||
config = MinibatchSourceConfig([ctf], True)
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
|
||||
assert 7 == len(dictionary.keys())
|
||||
assert True == dictionary['randomize']
|
||||
assert DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS == dictionary['randomizationWindow']
|
||||
assert False == dictionary['sampleBasedRandomizationWindow']
|
||||
|
||||
config = MinibatchSourceConfig([ctf]) # 'randomize' is omitted
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
|
||||
assert 7 == len(dictionary.keys())
|
||||
assert True == dictionary['randomize']
|
||||
assert DEFAULT_RANDOMIZATION_WINDOW_IN_CHUNKS == dictionary['randomizationWindow']
|
||||
assert False == dictionary['sampleBasedRandomizationWindow']
|
||||
|
||||
def test_minibatch_source_config_sweeps_and_samples(tmpdir):
|
||||
ctf = create_ctf_deserializer(tmpdir)
|
||||
config = MinibatchSourceConfig([ctf])
|
||||
|
||||
assert INFINITELY_REPEAT == config.max_samples
|
||||
assert INFINITELY_REPEAT == config.max_sweeps
|
||||
|
||||
config.max_samples = 100
|
||||
config.max_sweeps = 3
|
||||
assert 100 == config.max_samples
|
||||
assert 3 == config.max_sweeps
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# to_dictionary will validate the config
|
||||
dictionary = to_dictionary(config)
|
||||
|
||||
config.max_samples = INFINITELY_REPEAT
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
|
||||
def test_minibatch_source_config_randomization(tmpdir):
|
||||
ctf = create_ctf_deserializer(tmpdir)
|
||||
config = MinibatchSourceConfig([ctf])
|
||||
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
assert True == dictionary['randomize']
|
||||
|
||||
config.randomization_window_in_chunks = 0
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
assert False == dictionary['randomize']
|
||||
|
||||
config.randomization_window_in_chunks = 10
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
assert True == dictionary['randomize']
|
||||
assert 10 == dictionary['randomizationWindow']
|
||||
assert False == dictionary['sampleBasedRandomizationWindow']
|
||||
|
||||
config.randomization_window_in_samples = 100
|
||||
with pytest.raises(Exception):
|
||||
# to_dictionary will validate the config
|
||||
dictionary = to_dictionary(config)
|
||||
|
||||
config.randomization_window_in_chunks = 0
|
||||
dictionary = to_dictionary(config)
|
||||
check_default_config_keys(dictionary)
|
||||
assert True == dictionary['randomize']
|
||||
assert 100 == dictionary['randomizationWindow']
|
||||
assert True == dictionary['sampleBasedRandomizationWindow']
|
||||
|
||||
def test_minibatch_source_config_other_properties(tmpdir):
|
||||
ctf = create_ctf_deserializer(tmpdir)
|
||||
config = MinibatchSourceConfig([ctf])
|
||||
|
||||
config.is_multithreaded = True
|
||||
config.trace_level = TraceLevel.Info.value
|
||||
config.is_frame_mode_enabled = True
|
||||
|
||||
dictionary = to_dictionary(config)
|
||||
assert 7 == len(dictionary.keys())
|
||||
assert TraceLevel.Info == dictionary['traceLevel']
|
||||
assert True == dictionary['frameMode']
|
||||
assert True == dictionary['multiThreadedDeserialization']
|
||||
|
||||
config.is_multithreaded = False
|
||||
config.trace_level = 0
|
||||
config.truncation_length = 123
|
||||
with pytest.raises(Exception):
|
||||
# to_dictionary will validate the config
|
||||
dictionary = to_dictionary(config)
|
||||
|
||||
config.is_frame_mode_enabled = False
|
||||
|
||||
dictionary = to_dictionary(config)
|
||||
assert 9 == len(dictionary.keys())
|
||||
assert 0 == dictionary['traceLevel']
|
||||
assert False == dictionary['frameMode']
|
||||
assert False == dictionary['multiThreadedDeserialization']
|
||||
assert True == dictionary['truncated']
|
||||
assert 123 == dictionary['truncationLength']
|
||||
|
||||
def test_image():
|
||||
map_file = "input.txt"
|
||||
mean_file = "mean.txt"
|
||||
|
@ -100,13 +237,10 @@ def test_image():
|
|||
xforms.mean(mean_file)]
|
||||
image = ImageDeserializer(map_file, StreamDefs(f = StreamDef(field='image', transforms=transforms), l = StreamDef(field='label', shape=num_classes)))
|
||||
|
||||
rc = _ReaderConfig(image, randomize=False, epoch_size=epoch_size)
|
||||
|
||||
assert rc['epochSize'].value == epoch_size
|
||||
assert rc['randomize'] == False
|
||||
assert rc['sampleBasedRandomizationWindow'] == False
|
||||
assert len(rc['deserializers']) == 1
|
||||
d = rc['deserializers'][0]
|
||||
config = to_dictionary(MinibatchSourceConfig([image], randomize=False))
|
||||
|
||||
assert len(config['deserializers']) == 1
|
||||
d = config['deserializers'][0]
|
||||
assert d['type'] == 'ImageDeserializer'
|
||||
assert d['file'] == map_file
|
||||
assert set(d['input'].keys()) == {label_name, feature_name}
|
||||
|
@ -116,7 +250,7 @@ def test_image():
|
|||
|
||||
f = d['input'][feature_name]
|
||||
assert set(f.keys()) == { 'transforms' }
|
||||
t0, t1, t2 = f['transforms']
|
||||
t0, t1, t2, _ = f['transforms']
|
||||
assert t0['type'] == 'Crop'
|
||||
assert t1['type'] == 'Scale'
|
||||
assert t2['type'] == 'Mean'
|
||||
|
@ -130,39 +264,16 @@ def test_image():
|
|||
assert t1['interpolations'] == 'linear'
|
||||
assert t2['meanFile'] == mean_file
|
||||
|
||||
rc = _ReaderConfig(image, randomize=False, randomization_window = 100,
|
||||
sample_based_randomization_window = True, epoch_size=epoch_size)
|
||||
|
||||
config = to_dictionary(MinibatchSourceConfig([image, image]))
|
||||
assert len(config['deserializers']) == 2
|
||||
|
||||
assert rc['epochSize'].value == epoch_size
|
||||
assert rc['randomize'] == False
|
||||
assert rc['sampleBasedRandomizationWindow'] == True
|
||||
assert len(rc['deserializers']) == 1
|
||||
d = rc['deserializers'][0]
|
||||
assert d['type'] == 'ImageDeserializer'
|
||||
assert d['file'] == map_file
|
||||
assert set(d['input'].keys()) == {label_name, feature_name}
|
||||
|
||||
l = d['input'][label_name]
|
||||
assert l['labelDim'] == num_classes
|
||||
|
||||
rc = _ReaderConfig(image, randomize=True, randomization_window = 100,
|
||||
sample_based_randomization_window = True, epoch_size=epoch_size)
|
||||
|
||||
assert rc['epochSize'].value == epoch_size
|
||||
assert rc['randomize'] == True
|
||||
assert rc['sampleBasedRandomizationWindow'] == True
|
||||
assert len(rc['deserializers']) == 1
|
||||
d = rc['deserializers'][0]
|
||||
assert d['type'] == 'ImageDeserializer'
|
||||
assert d['file'] == map_file
|
||||
assert set(d['input'].keys()) == {label_name, feature_name}
|
||||
|
||||
l = d['input'][label_name]
|
||||
assert l['labelDim'] == num_classes
|
||||
config = to_dictionary(MinibatchSourceConfig([image, image, image]))
|
||||
assert len(config['deserializers']) == 3
|
||||
|
||||
# TODO depends on ImageReader.dll
|
||||
'''
|
||||
mbs = rc.minibatch_source()
|
||||
mbs = config.create_minibatch_source()
|
||||
sis = mbs.stream_infos()
|
||||
assert set(sis.keys()) == { feature_name, label_name }
|
||||
'''
|
||||
|
@ -185,12 +296,13 @@ def test_full_sweep_minibatch(tmpdir):
|
|||
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=1),
|
||||
labels = StreamDef(field='S1', shape=1))),
|
||||
randomize=False, epoch_size=FULL_DATA_SWEEP)
|
||||
randomization_window_in_chunks=0, max_sweeps=1)
|
||||
|
||||
features_si = mb_source.stream_info('features')
|
||||
labels_si = mb_source.stream_info('labels')
|
||||
|
||||
mb = mb_source.next_minibatch(1000)
|
||||
|
||||
assert mb[features_si].num_sequences == 2
|
||||
assert mb[labels_si].num_sequences == 2
|
||||
|
||||
|
@ -225,6 +337,90 @@ def test_full_sweep_minibatch(tmpdir):
|
|||
[[2, 1, 1],
|
||||
[2, 1, 0]])
|
||||
|
||||
def test_max_samples(tmpdir):
|
||||
mb_source = MinibatchSource(
|
||||
create_ctf_deserializer(tmpdir), max_samples=1)
|
||||
|
||||
input_map = {'features' : mb_source['features']}
|
||||
mb = mb_source.next_minibatch(10, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 1
|
||||
assert not mb['features'].end_of_sweep
|
||||
|
||||
mb = mb_source.next_minibatch(10, input_map)
|
||||
|
||||
assert not mb
|
||||
|
||||
def test_max_sweeps(tmpdir):
|
||||
# set max sweeps to 3 (12 samples altogether).
|
||||
mb_source = MinibatchSource(
|
||||
create_ctf_deserializer(tmpdir), max_sweeps=3)
|
||||
|
||||
input_map = {'features' : mb_source['features']}
|
||||
|
||||
for i in range(2):
|
||||
mb = mb_source.next_minibatch(5, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 5
|
||||
assert mb['features'].end_of_sweep
|
||||
|
||||
mb = mb_source.next_minibatch(5, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 2
|
||||
assert mb['features'].end_of_sweep
|
||||
|
||||
mb = mb_source.next_minibatch(1, input_map)
|
||||
|
||||
assert not mb
|
||||
|
||||
def test_max_samples_over_several_sweeps(tmpdir):
|
||||
mb_source = MinibatchSource(
|
||||
create_ctf_deserializer(tmpdir), max_samples=11)
|
||||
|
||||
input_map = {'features' : mb_source['features']}
|
||||
|
||||
for i in range(2):
|
||||
mb = mb_source.next_minibatch(5, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 5
|
||||
assert mb['features'].end_of_sweep
|
||||
|
||||
mb = mb_source.next_minibatch(5, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 1
|
||||
assert not mb['features'].end_of_sweep
|
||||
|
||||
mb = mb_source.next_minibatch(1, input_map)
|
||||
|
||||
assert not mb
|
||||
|
||||
def test_one_sweep(tmpdir):
|
||||
ctf = create_ctf_deserializer(tmpdir)
|
||||
sources = [ MinibatchSource(ctf, max_sweeps=1),
|
||||
MinibatchSource(ctf, max_samples=FULL_DATA_SWEEP),
|
||||
MinibatchSource(ctf, max_sweeps=1,
|
||||
max_samples=INFINITELY_REPEAT),
|
||||
MinibatchSource(ctf, max_samples=FULL_DATA_SWEEP,
|
||||
max_sweeps=INFINITELY_REPEAT) ]
|
||||
|
||||
for source in sources:
|
||||
input_map = {'features' : source['features']}
|
||||
|
||||
mb = source.next_minibatch(100, input_map)
|
||||
|
||||
assert 'features' in mb
|
||||
assert mb['features'].num_samples == 4
|
||||
assert mb['features'].end_of_sweep
|
||||
|
||||
mb = source.next_minibatch(100, input_map)
|
||||
|
||||
assert not mb
|
||||
|
||||
def test_large_minibatch(tmpdir):
|
||||
|
||||
mbdata = r'''0 |S0 0 |S1 0
|
||||
|
@ -243,7 +439,7 @@ def test_large_minibatch(tmpdir):
|
|||
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=1),
|
||||
labels = StreamDef(field='S1', shape=1))),
|
||||
randomize=False)
|
||||
randomization_window_in_chunks=0)
|
||||
|
||||
features_si = mb_source.stream_info('features')
|
||||
labels_si = mb_source.stream_info('labels')
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_distributed_mb_source(tmpdir):
|
|||
9 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
10 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
'''
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, FULL_DATA_SWEEP
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
|
||||
|
||||
ctf_file = str(tmpdir/'2seqtest.txt')
|
||||
with open(ctf_file, 'w') as f:
|
||||
|
@ -120,12 +120,12 @@ def test_distributed_mb_source(tmpdir):
|
|||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=False, epoch_size=36) # A bit more than a sweep
|
||||
randomize=False, max_samples=36) # A bit more than a sweep
|
||||
mb1 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=False, epoch_size=36) # A bit more than a sweep
|
||||
randomize=False, max_samples=36) # A bit more than a sweep
|
||||
input = sequence.input(shape=(input_dim,))
|
||||
label = sequence.input(shape=(input_dim,))
|
||||
input_map = {
|
||||
|
@ -170,14 +170,12 @@ def test_distributed_mb_source(tmpdir):
|
|||
mb3 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=True, epoch_size=FULL_DATA_SWEEP)
|
||||
)), max_sweeps=1)
|
||||
|
||||
mb4 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=True, epoch_size=FULL_DATA_SWEEP)
|
||||
)), max_sweeps=1)
|
||||
|
||||
data = mb3.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 5)
|
||||
|
|
|
@ -117,7 +117,7 @@ def test_eval_sparse_dense(tmpdir, device_id):
|
|||
mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
|
||||
)), randomize=False, epoch_size = 2)
|
||||
)), randomize=False, max_samples = 2)
|
||||
|
||||
raw_input = sequence.input(shape=input_vocab_dim, sequence_axis=Axis('inputAxis'), name='raw_input', is_sparse=True)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ ctf_data = '''\
|
|||
'''
|
||||
|
||||
|
||||
def mb_source(tmpdir, fileprefix, epoch_size=FULL_DATA_SWEEP):
|
||||
def mb_source(tmpdir, fileprefix, max_samples=FULL_DATA_SWEEP):
|
||||
ctf_file = str(tmpdir / (fileprefix + '2seqtest.txt'))
|
||||
with open(ctf_file, 'w') as f:
|
||||
f.write(ctf_data)
|
||||
|
@ -60,7 +60,7 @@ def mb_source(tmpdir, fileprefix, epoch_size=FULL_DATA_SWEEP):
|
|||
features=StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels=StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=False, epoch_size=epoch_size)
|
||||
randomize=False, max_samples=max_samples)
|
||||
return mbs
|
||||
|
||||
|
||||
|
@ -126,7 +126,7 @@ def test_session_sanity_check(tmpdir, device_id):
|
|||
def test_session_max_samples(tmpdir, device_id):
|
||||
device = cntk_device(device_id)
|
||||
t, feature, label = create_sample_model(device)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
|
||||
input_map = {
|
||||
feature: mbs.streams.features,
|
||||
|
@ -146,7 +146,7 @@ def test_session_cross_validation_at_end(tmpdir, device_id):
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter(expected_test_summary=[[92, 25]])
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
mbs1 = mb_source(tmpdir, "cv")
|
||||
|
||||
input_map = {
|
||||
|
@ -169,7 +169,7 @@ def test_session_cross_validation_3_times(tmpdir, device_id):
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter(expected_test_summary=[[92, 25], [92, 25], [92, 25]])
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
mbs1 = mb_source(tmpdir, "cv")
|
||||
|
||||
input_map = {
|
||||
|
@ -195,7 +195,7 @@ def test_session_cross_validation_3_times_checkpoints_2_save_all(tmpdir, device_
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter(expected_test_summary=[[92, 25], [92, 25], [92, 25]])
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
mbs1 = mb_source(tmpdir, "cv")
|
||||
|
||||
input_map = {
|
||||
|
@ -236,7 +236,7 @@ def test_session_progress_print(tmpdir, device_id):
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter()
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
|
||||
input_map = {
|
||||
feature: mbs.streams.features,
|
||||
|
@ -263,7 +263,7 @@ def test_session_restart_from_checkpoint(tmpdir, device_id):
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter()
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
|
||||
input_map = {
|
||||
feature: mbs.streams.features,
|
||||
|
@ -340,7 +340,7 @@ def test_session_cv_callback_3_times(tmpdir, device_id):
|
|||
|
||||
device = cntk_device(device_id)
|
||||
t, feature, label = create_sample_model(device)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
|
||||
input_map = {
|
||||
feature: mbs.streams.features,
|
||||
|
@ -368,7 +368,7 @@ def test_session_cv_callback_3_times(tmpdir, device_id):
|
|||
def test_session_cv_callback_with_cross_validation_3_times(tmpdir, device_id):
|
||||
device = cntk_device(device_id)
|
||||
t, feature, label = create_sample_model(device)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
cv_mbs = mb_source(tmpdir, "cv")
|
||||
|
||||
input_map = {
|
||||
|
@ -404,7 +404,7 @@ def test_session_cv_callback_early_exit(tmpdir, device_id):
|
|||
|
||||
device = cntk_device(device_id)
|
||||
t, feature, label = create_sample_model(device)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
|
||||
input_map = {
|
||||
feature: mbs.streams.features,
|
||||
|
@ -434,7 +434,7 @@ def test_session_with_test(tmpdir, device_id):
|
|||
device = cntk_device(device_id)
|
||||
writer = MockProgressWriter(expected_test_summary=[[92, 25]])
|
||||
t, feature, label = create_sample_model(device, writer)
|
||||
mbs = mb_source(tmpdir, "training", epoch_size=INFINITELY_REPEAT)
|
||||
mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
|
||||
mbs1 = mb_source(tmpdir, "test")
|
||||
|
||||
input_map = {
|
||||
|
|
|
@ -2,7 +2,7 @@ import sys
|
|||
import os
|
||||
from cntk import Trainer, Axis
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs,\
|
||||
INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
INFINITELY_REPEAT
|
||||
from cntk.learners import sgd, learning_rate_schedule, UnitType
|
||||
from cntk import input, cross_entropy_with_softmax, \
|
||||
classification_error, sequence
|
||||
|
@ -20,7 +20,7 @@ def create_reader(path, is_training, input_dim, label_dim):
|
|||
features=StreamDef(field='x', shape=input_dim, is_sparse=True),
|
||||
labels=StreamDef(field='y', shape=label_dim, is_sparse=False)
|
||||
)), randomize=is_training,
|
||||
epoch_size=INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
max_sweeps=INFINITELY_REPEAT if is_training else 1)
|
||||
|
||||
|
||||
# Defines the LSTM model for classifying sequences
|
||||
|
|
Загрузка…
Ссылка в новой задаче