This commit is contained in:
Said Bleik 2019-07-08 12:38:21 -04:00
Родитель 74f459b7fe
Коммит 481e9aadec
1 изменённых файлов: 25 добавлений и 16 удалений

Просмотреть файл

@ -8,24 +8,25 @@ import pytest
from utils_nlp.dataset.data_loaders import DaskCSVLoader
UNIF1 = {"a": 0, "b": 10, "n": 1000} # some uniform distribution
UNIF1 = {"a": 4, "b": 6, "n": 10000} # some uniform distribution
row_size = 5 # "a,b\n (5 bytes)"
@pytest.fixture()
def csv_file(tmpdir):
random.seed(0)
f = tmpdir.mkdir("test_loaders").join("tl_data.csv")
for i in range(1000):
f.write(
"\n".join(
[
"{},{}".format(
random.randint(0, 1),
random.randint(UNIF1["a"], UNIF1["b"]),
)
for x in range(UNIF1["n"])
]
)
f.write(
"\n".join(
[
"{},{}".format(
random.randint(0, 1),
random.randint(UNIF1["a"], UNIF1["b"]),
)
for x in range(UNIF1["n"])
]
)
)
return str(f)
@ -34,8 +35,14 @@ def test_dask_csv_rnd_loader(csv_file):
batch_size = 12
num_partitions = 4
import os
print("size:", os.stat(csv_file).st_size)
loader = DaskCSVLoader(
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
csv_file,
header=None,
block_size=row_size * int(UNIF1["n"] / num_partitions),
random_seed=0,
)
sample = []
@ -44,7 +51,7 @@ def test_dask_csv_rnd_loader(csv_file):
sample = np.concatenate(sample)
assert loader.df.npartitions == num_partitions
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
assert sample.mean().round() == (UNIF1["a"] + UNIF1["b"]) / 2
assert len(sample) <= num_batches * batch_size
@ -53,7 +60,9 @@ def test_dask_csv_seq_loader(csv_file):
num_partitions = 4
loader = DaskCSVLoader(
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
csv_file,
header=None,
block_size=row_size * int(UNIF1["n"] / num_partitions),
)
sample = []
@ -62,5 +71,5 @@ def test_dask_csv_seq_loader(csv_file):
sample = np.concatenate(sample)
assert loader.df.npartitions == num_partitions
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
assert sample.mean().round() == (UNIF1["a"] + UNIF1["b"]) / 2
assert len(sample) == UNIF1["n"]