add distributed test for bertsum
This commit is contained in:
Родитель
946ae06fb6
Коммит
c44920b988
|
@ -24,6 +24,25 @@ from utils_nlp.azureml import azureml_utils
|
|||
from azureml.core.webservice import Webservice
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scripts():
|
||||
folder_notebooks = path_notebooks()
|
||||
paths = {
|
||||
"ddp_bertsumext": os.path.join(
|
||||
folder_notebooks,
|
||||
"text_summarization",
|
||||
"extractive_summarization_cnndm_distributed_train.py",
|
||||
),
|
||||
"ddp_bertsumabs": os.path.join(
|
||||
folder_notebooks,
|
||||
"text_summarization",
|
||||
"abstractive_summarization_bertsum_cnndm_distributed_train.py",
|
||||
),
|
||||
}
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def notebooks():
|
||||
folder_notebooks = path_notebooks()
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.integration
|
||||
def test_ddp_extractive_summarization_cnndm_transformers(scripts, tmp):
|
||||
script = scripts["ddp_bertsumext"]
|
||||
summary_filename = "bertsumext_prediction.txt"
|
||||
import subprocess
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
script,
|
||||
"--data_dir",
|
||||
tmp,
|
||||
"--cache_dir",
|
||||
tmp,
|
||||
"--output_dir",
|
||||
tmp,
|
||||
"--quick_run",
|
||||
"true",
|
||||
"--summary_filename",
|
||||
summary_filename,
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = process.communicate()
|
||||
print(stdout)
|
||||
if process.returncode:
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
assert False
|
||||
assert os.path.exists(os.path.join(tmp, summary_filename))
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.integration
|
||||
def test_ddp_abstractive_summarization_cnndm_transformers(scripts, tmp):
|
||||
script = scripts["ddp_bertsumabs"]
|
||||
summary_filename = "bertsumext_prediction.txt"
|
||||
import subprocess
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
script,
|
||||
"--data_dir",
|
||||
tmp,
|
||||
"--cache_dir",
|
||||
tmp,
|
||||
"--output_dir",
|
||||
tmp,
|
||||
"--quick_run",
|
||||
"true",
|
||||
"--summary_filename",
|
||||
summary_filename,
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = process.communicate()
|
||||
print(stdout)
|
||||
if process.returncode:
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
assert False
|
||||
raise RuntimeError("something bad happened")
|
||||
assert os.path.exists(os.path.join(tmp, summary_filename))
|
Загрузка…
Ссылка в новой задаче