diff --git a/tests/unit/test_data_loaders.py b/tests/unit/test_data_loaders.py index 6aca5e6..aeda8bb 100644 --- a/tests/unit/test_data_loaders.py +++ b/tests/unit/test_data_loaders.py @@ -31,13 +31,18 @@ def csv_file(tmpdir): def test_dask_csv_loader(csv_file): num_batches = 500 - batch_size = 10 + batch_size = 12 + num_partitions = 4 + + loader = DaskCSVLoader( + csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions) + ) - loader = DaskCSVLoader(csv_file, header=None) sample = [] for batch in loader.get_random_batches(num_batches, batch_size): sample.append(list(batch.iloc[:, 1])) sample = np.concatenate(sample) - print(sample.mean()) + + assert loader.df.npartitions == num_partitions assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2 - assert len(sample) == num_batches * batch_size + assert len(sample) <= num_batches * batch_size