add distributed test for bertsum

This commit is contained in:
Daisy Deng 2020-03-20 19:15:55 +00:00
Родитель 946ae06fb6
Коммит c44920b988
2 изменённых файлов: 93 добавлений и 0 удалений

Просмотреть файл

@ -24,6 +24,25 @@ from utils_nlp.azureml import azureml_utils
from azureml.core.webservice import Webservice
@pytest.fixture(scope="module")
def scripts():
folder_notebooks = path_notebooks()
paths = {
"ddp_bertsumext": os.path.join(
folder_notebooks,
"text_summarization",
"extractive_summarization_cnndm_distributed_train.py",
),
"ddp_bertsumabs": os.path.join(
folder_notebooks,
"text_summarization",
"abstractive_summarization_bertsum_cnndm_distributed_train.py",
),
}
return paths
@pytest.fixture(scope="module")
def notebooks():
folder_notebooks = path_notebooks()

Просмотреть файл

@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
@pytest.mark.gpu
@pytest.mark.integration
def test_ddp_extractive_summarization_cnndm_transformers(scripts, tmp):
script = scripts["ddp_bertsumext"]
summary_filename = "bertsumext_prediction.txt"
import subprocess
process = subprocess.Popen(
[
"python",
script,
"--data_dir",
tmp,
"--cache_dir",
tmp,
"--output_dir",
tmp,
"--quick_run",
"true",
"--summary_filename",
summary_filename,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = process.communicate()
print(stdout)
if process.returncode:
print(stdout)
print(stderr)
assert False
assert os.path.exists(os.path.join(tmp, summary_filename))
@pytest.mark.gpu
@pytest.mark.integration
def test_ddp_abstractive_summarization_cnndm_transformers(scripts, tmp):
script = scripts["ddp_bertsumabs"]
summary_filename = "bertsumext_prediction.txt"
import subprocess
process = subprocess.Popen(
[
"python",
script,
"--data_dir",
tmp,
"--cache_dir",
tmp,
"--output_dir",
tmp,
"--quick_run",
"true",
"--summary_filename",
summary_filename,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = process.communicate()
print(stdout)
if process.returncode:
print(stdout)
print(stderr)
assert False
raise RuntimeError("something bad happened")
assert os.path.exists(os.path.join(tmp, summary_filename))