bug fix
This commit is contained in:
Родитель
74f459b7fe
Коммит
481e9aadec
|
@ -8,24 +8,25 @@ import pytest
|
|||
|
||||
from utils_nlp.dataset.data_loaders import DaskCSVLoader
|
||||
|
||||
UNIF1 = {"a": 0, "b": 10, "n": 1000} # some uniform distribution
|
||||
UNIF1 = {"a": 4, "b": 6, "n": 10000} # some uniform distribution
|
||||
row_size = 5 # "a,b\n (5 bytes)"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def csv_file(tmpdir):
|
||||
random.seed(0)
|
||||
f = tmpdir.mkdir("test_loaders").join("tl_data.csv")
|
||||
for i in range(1000):
|
||||
f.write(
|
||||
"\n".join(
|
||||
[
|
||||
"{},{}".format(
|
||||
random.randint(0, 1),
|
||||
random.randint(UNIF1["a"], UNIF1["b"]),
|
||||
)
|
||||
for x in range(UNIF1["n"])
|
||||
]
|
||||
)
|
||||
f.write(
|
||||
"\n".join(
|
||||
[
|
||||
"{},{}".format(
|
||||
random.randint(0, 1),
|
||||
random.randint(UNIF1["a"], UNIF1["b"]),
|
||||
)
|
||||
for x in range(UNIF1["n"])
|
||||
]
|
||||
)
|
||||
)
|
||||
return str(f)
|
||||
|
||||
|
||||
|
@ -34,8 +35,14 @@ def test_dask_csv_rnd_loader(csv_file):
|
|||
batch_size = 12
|
||||
num_partitions = 4
|
||||
|
||||
import os
|
||||
|
||||
print("size:", os.stat(csv_file).st_size)
|
||||
loader = DaskCSVLoader(
|
||||
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
|
||||
csv_file,
|
||||
header=None,
|
||||
block_size=row_size * int(UNIF1["n"] / num_partitions),
|
||||
random_seed=0,
|
||||
)
|
||||
|
||||
sample = []
|
||||
|
@ -44,7 +51,7 @@ def test_dask_csv_rnd_loader(csv_file):
|
|||
sample = np.concatenate(sample)
|
||||
|
||||
assert loader.df.npartitions == num_partitions
|
||||
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
|
||||
assert sample.mean().round() == (UNIF1["a"] + UNIF1["b"]) / 2
|
||||
assert len(sample) <= num_batches * batch_size
|
||||
|
||||
|
||||
|
@ -53,7 +60,9 @@ def test_dask_csv_seq_loader(csv_file):
|
|||
num_partitions = 4
|
||||
|
||||
loader = DaskCSVLoader(
|
||||
csv_file, header=None, block_size=5 * int(UNIF1["n"] / num_partitions)
|
||||
csv_file,
|
||||
header=None,
|
||||
block_size=row_size * int(UNIF1["n"] / num_partitions),
|
||||
)
|
||||
|
||||
sample = []
|
||||
|
@ -62,5 +71,5 @@ def test_dask_csv_seq_loader(csv_file):
|
|||
sample = np.concatenate(sample)
|
||||
|
||||
assert loader.df.npartitions == num_partitions
|
||||
assert sample.mean().round() == UNIF1["a"] + UNIF1["b"] / 2
|
||||
assert sample.mean().round() == (UNIF1["a"] + UNIF1["b"]) / 2
|
||||
assert len(sample) == UNIF1["n"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче