From 7536488928239c891d72e986b00db18128db1253 Mon Sep 17 00:00:00 2001 From: saidbleik Date: Sun, 15 Sep 2019 07:24:24 +0000 Subject: [PATCH] reduce batch size for test --- .../test_notebooks_text_classification.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/integration/test_notebooks_text_classification.py b/tests/integration/test_notebooks_text_classification.py index 54f905c..07eee51 100644 --- a/tests/integration/test_notebooks_text_classification.py +++ b/tests/integration/test_notebooks_text_classification.py @@ -49,18 +49,19 @@ def test_tc_dac_bert_ar(notebooks, tmp): NUM_GPUS=1, DATA_FOLDER=tmp, BERT_CACHE_DIR=tmp, - BATCH_SIZE=32, + MAX_LEN=175, + BATCH_SIZE=16, NUM_EPOCHS=1, TRAIN_SIZE=0.8, - NUM_ROWS=15000, + NUM_ROWS=8000, RANDOM_STATE=0, ), ) result = sb.read_notebook(OUTPUT_NOTEBOOK).scraps.data_dict - assert pytest.approx(result["accuracy"], 0.93, abs=ABS_TOL) - assert pytest.approx(result["precision"], 0.91, abs=ABS_TOL) - assert pytest.approx(result["recall"], 0.91, abs=ABS_TOL) - assert pytest.approx(result["f1"], 0.91, abs=ABS_TOL) + assert pytest.approx(result["accuracy"], 0.871, abs=ABS_TOL) + assert pytest.approx(result["precision"], 0.865, abs=ABS_TOL) + assert pytest.approx(result["recall"], 0.852, abs=ABS_TOL) + assert pytest.approx(result["f1"], 0.845, abs=ABS_TOL) @pytest.mark.gpu