This commit is contained in:
Said Bleik 2019-06-21 16:33:37 -04:00
Родитель 0bae4ff59c
Коммит ed4e09b9b3
1 изменённых файлов: 9 добавлений и 4 удалений

Просмотреть файл

@ -31,13 +31,18 @@ def csv_file(tmpdir):
def test_dask_csv_loader(csv_file): def test_dask_csv_loader(csv_file):
num_batches = 500 num_batches = 500
batch_size = 10 batch_size = 12
num_partitions = 4
loader = DaskCSVLoader(
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
)
loader = DaskCSVLoader(csv_file, header=None)
sample = [] sample = []
for batch in loader.get_random_batches(num_batches, batch_size): for batch in loader.get_random_batches(num_batches, batch_size):
sample.append(list(batch.iloc[:, 1])) sample.append(list(batch.iloc[:, 1]))
sample = np.concatenate(sample) sample = np.concatenate(sample)
print(sample.mean())
assert loader.df.npartitions == num_partitions
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2 assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
assert len(sample) == num_batches * batch_size assert len(sample) <= num_batches * batch_size