Allow splitting of data in get next minibatch
This commit is contained in:
Родитель
81c97d24bd
Коммит
a6c74a28d8
|
@ -143,7 +143,7 @@ class MinibatchSource(cntk_py.MinibatchSource):
|
|||
|
||||
@typemap
|
||||
def next_minibatch(self, minibatch_size_in_samples,
|
||||
input_map=None, device=None):
|
||||
input_map=None, device=None, num_data_partitions=None, partition_index=None):
|
||||
'''
|
||||
Reads a minibatch that contains data for all input streams. The
|
||||
minibatch size is specified in terms of #samples and/or #sequences for the
|
||||
|
@ -159,17 +159,26 @@ class MinibatchSource(cntk_py.MinibatchSource):
|
|||
to :class:`StreamInformation` which will be used to convert the
|
||||
returned data.
|
||||
device (`DeviceDescriptor`, defaults to `None`): CNTK DeviceDescriptor
|
||||
num_data_partitions: Used for distributed training, indicates into how many partitions
|
||||
the source should split the data.
|
||||
partition_index: Used for distributed training, indicates data from which partition to take.
|
||||
|
||||
Returns:
|
||||
A mapping of :class:`StramInformation` to :class:`MinibatchData` if
|
||||
a mapping of :class:`StreamInformation` to :class:`MinibatchData` if
|
||||
``input_map`` was not specified. Otherwise, the returned value will
|
||||
be a mapping of :class:`~cntk.ops.variabls.Variable` to class:`MinibatchData`.
|
||||
be a mapping of :class:`~cntk.ops.variables.Variable` to class:`MinibatchData`.
|
||||
'''
|
||||
if device is None:
|
||||
device = use_default_device()
|
||||
|
||||
mb = super(MinibatchSource, self).get_next_minibatch(
|
||||
minibatch_size_in_samples, device)
|
||||
if num_data_partitions is None:
|
||||
num_data_partitions = 1
|
||||
|
||||
if partition_index is None:
|
||||
partition_index = 0
|
||||
|
||||
mb = super(MinibatchSource, self).get_next_minibatch(0,
|
||||
minibatch_size_in_samples, num_data_partitions, partition_index, device)
|
||||
|
||||
if input_map:
|
||||
if not mb:
|
||||
|
|
|
@ -76,6 +76,101 @@ def run_distributed_training(tmpdir, create_func):
|
|||
assert isinstance(trainer.model, Function)
|
||||
assert trainer.model.__doc__
|
||||
|
||||
|
||||
def test_distributed_mb_source(tmpdir):
|
||||
input_dim = 69
|
||||
|
||||
ctf_data = '''\
|
||||
0 |S0 3:1 |# <s> |S1 3:1 |# <s>
|
||||
0 |S0 4:1 |# A |S1 32:1 |# ~AH
|
||||
0 |S0 5:1 |# B |S1 36:1 |# ~B
|
||||
0 |S0 4:1 |# A |S1 31:1 |# ~AE
|
||||
0 |S0 7:1 |# D |S1 38:1 |# ~D
|
||||
0 |S0 12:1 |# I |S1 47:1 |# ~IY
|
||||
0 |S0 1:1 |# </s> |S1 1:1 |# </s>
|
||||
2 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
2 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
2 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
3 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
3 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
3 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
3 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
4 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
5 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
5 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
6 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
6 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
7 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
8 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
8 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
9 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
9 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
10 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
'''
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, FULL_DATA_SWEEP
|
||||
|
||||
ctf_file = str(tmpdir/'2seqtest.txt')
|
||||
with open(ctf_file, 'w') as f:
|
||||
f.write(ctf_data)
|
||||
|
||||
# No randomization
|
||||
|
||||
mb0 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=False, epoch_size=FULL_DATA_SWEEP)
|
||||
mb1 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=False, epoch_size=FULL_DATA_SWEEP)
|
||||
input = input_variable(shape=(input_dim,))
|
||||
label = input_variable(shape=(input_dim,))
|
||||
input_map = {
|
||||
input : mb0.streams.features,
|
||||
label : mb0.streams.labels
|
||||
}
|
||||
|
||||
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 7)
|
||||
|
||||
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 4)
|
||||
|
||||
data = mb0.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 5)
|
||||
|
||||
data = mb1.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
|
||||
assert(data[input].num_samples == 3)
|
||||
|
||||
data = mb1.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
|
||||
assert(data[input].num_samples == 5)
|
||||
|
||||
# Radomization
|
||||
|
||||
mb3 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=True, epoch_size=FULL_DATA_SWEEP)
|
||||
|
||||
mb4 = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)),
|
||||
randomize=True, epoch_size=FULL_DATA_SWEEP)
|
||||
|
||||
data = mb3.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 5)
|
||||
|
||||
data = mb3.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=0)
|
||||
assert(data[input].num_samples == 4)
|
||||
|
||||
data = mb4.next_minibatch(minibatch_size_in_samples=10, input_map=input_map, num_data_partitions=2, partition_index=1)
|
||||
assert(len(data) == 0)
|
||||
|
||||
|
||||
def test_distributed(tmpdir, is_1bit_sgd):
|
||||
quantized=(True if is_1bit_sgd==1 else False)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче