Allow splitting of data in get next minibatch

This commit is contained in:
Eldar Akchurin 2017-02-02 17:46:41 +01:00
Родитель 81c97d24bd
Коммит a6c74a28d8
2 изменённых файлов: 109 добавлений и 5 удалений

Просмотреть файл

@ -143,7 +143,7 @@ class MinibatchSource(cntk_py.MinibatchSource):
@typemap
def next_minibatch(self, minibatch_size_in_samples,
input_map=None, device=None):
input_map=None, device=None, num_data_partitions=None, partition_index=None):
'''
Reads a minibatch that contains data for all input streams. The
minibatch size is specified in terms of #samples and/or #sequences for the
@ -159,17 +159,26 @@ class MinibatchSource(cntk_py.MinibatchSource):
to :class:`StreamInformation` which will be used to convert the
returned data.
device (`DeviceDescriptor`, defaults to `None`): CNTK DeviceDescriptor
num_data_partitions: Used for distributed training, indicates into how many partitions
the source should split the data.
partition_index: Used for distributed training, indicates data from which partition to take.
Returns:
A mapping of :class:`StramInformation` to :class:`MinibatchData` if
a mapping of :class:`StreamInformation` to :class:`MinibatchData` if
``input_map`` was not specified. Otherwise, the returned value will
be a mapping of :class:`~cntk.ops.variabls.Variable` to class:`MinibatchData`.
be a mapping of :class:`~cntk.ops.variables.Variable` to class:`MinibatchData`.
'''
if device is None:
device = use_default_device()
mb = super(MinibatchSource, self).get_next_minibatch(
minibatch_size_in_samples, device)
if num_data_partitions is None:
num_data_partitions = 1
if partition_index is None:
partition_index = 0
mb = super(MinibatchSource, self).get_next_minibatch(0,
minibatch_size_in_samples, num_data_partitions, partition_index, device)
if input_map:
if not mb:

Просмотреть файл

@ -76,6 +76,101 @@ def run_distributed_training(tmpdir, create_func):
assert isinstance(trainer.model, Function)
assert trainer.model.__doc__
def test_distributed_mb_source(tmpdir):
input_dim = 69
ctf_data = '''\
0 |S0 3:1 |# <s> |S1 3:1 |# <s>
0 |S0 4:1 |# A |S1 32:1 |# ~AH
0 |S0 5:1 |# B |S1 36:1 |# ~B
0 |S0 4:1 |# A |S1 31:1 |# ~AE
0 |S0 7:1 |# D |S1 38:1 |# ~D
0 |S0 12:1 |# I |S1 47:1 |# ~IY
0 |S0 1:1 |# </s> |S1 1:1 |# </s>
2 |S0 60:1 |# <s> |S1 3:1 |# <s>
2 |S0 61:1 |# A |S1 32:1 |# ~AH
2 |S0 61:1 |# A |S1 32:1 |# ~AH
3 |S0 60:1 |# <s> |S1 3:1 |# <s>
3 |S0 61:1 |# A |S1 32:1 |# ~AH
3 |S0 61:1 |# A |S1 32:1 |# ~AH
3 |S0 61:1 |# A |S1 32:1 |# ~AH
4 |S0 60:1 |# <s> |S1 3:1 |# <s>
5 |S0 60:1 |# <s> |S1 3:1 |# <s>
5 |S0 61:1 |# A |S1 32:1 |# ~AH
6 |S0 60:1 |# <s> |S1 3:1 |# <s>
6 |S0 61:1 |# A |S1 32:1 |# ~AH
7 |S0 60:1 |# <s> |S1 3:1 |# <s>
8 |S0 60:1 |# <s> |S1 3:1 |# <s>
8 |S0 61:1 |# A |S1 32:1 |# ~AH
9 |S0 60:1 |# <s> |S1 3:1 |# <s>
9 |S0 61:1 |# A |S1 32:1 |# ~AH
10 |S0 61:1 |# A |S1 32:1 |# ~AH
'''
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, FULL_DATA_SWEEP
ctf_file = str(tmpdir/'2seqtest.txt')
with open(ctf_file, 'w') as f:
f.write(ctf_data)
# No randomization
mb0 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
)),
randomize=False, epoch_size=FULL_DATA_SWEEP)
mb1 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
)),
randomize=False, epoch_size=FULL_DATA_SWEEP)
input = input_variable(shape=(input_dim,))
label = input_variable(shape=(input_dim,))
input_map = {
input : mb0.streams.features,
label : mb0.streams.labels
}
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
assert(data[input].num_samples == 7)
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
assert(data[input].num_samples == 4)
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
assert(data[input].num_samples == 5)
data = mb1.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
assert(data[input].num_samples == 3)
data = mb1.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
assert(data[input].num_samples == 5)
# Radomization
mb3 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
)),
randomize=True, epoch_size=FULL_DATA_SWEEP)
mb4 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
)),
randomize=True, epoch_size=FULL_DATA_SWEEP)
data = mb3.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
assert(data[input].num_samples == 5)
data = mb3.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
assert(data[input].num_samples == 4)
data = mb4.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
assert(len(data) == 0)
def test_distributed(tmpdir, is_1bit_sgd):
quantized=(True if is_1bit_sgd==1 else False)