edits to loader test
This commit is contained in:
Родитель
0bae4ff59c
Коммит
ed4e09b9b3
|
@ -31,13 +31,18 @@ def csv_file(tmpdir):
|
||||||
|
|
||||||
def test_dask_csv_loader(csv_file):
|
def test_dask_csv_loader(csv_file):
|
||||||
num_batches = 500
|
num_batches = 500
|
||||||
batch_size = 10
|
batch_size = 12
|
||||||
|
num_partitions = 4
|
||||||
|
|
||||||
|
loader = DaskCSVLoader(
|
||||||
|
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
|
||||||
|
)
|
||||||
|
|
||||||
loader = DaskCSVLoader(csv_file, header=None)
|
|
||||||
sample = []
|
sample = []
|
||||||
for batch in loader.get_random_batches(num_batches, batch_size):
|
for batch in loader.get_random_batches(num_batches, batch_size):
|
||||||
sample.append(list(batch.iloc[:, 1]))
|
sample.append(list(batch.iloc[:, 1]))
|
||||||
sample = np.concatenate(sample)
|
sample = np.concatenate(sample)
|
||||||
print(sample.mean())
|
|
||||||
|
assert loader.df.npartitions == num_partitions
|
||||||
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
|
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
|
||||||
assert len(sample) == num_batches * batch_size
|
assert len(sample) <= num_batches * batch_size
|
||||||
|
|
Загрузка…
Ссылка в новой задаче