This commit is contained in:
miguelgfierro 2020-03-22 00:52:13 +00:00
Родитель 084229b9be
Коммит aa4ac5976e
1 изменённых файлов: 3 добавлений и 13 удалений

Просмотреть файл

@ -7,11 +7,8 @@ from nltk import tokenize
import os
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
torch.set_printoptions(threshold=5000)
from tempfile import TemporaryDirectory
from utils_nlp.models.transformers.datasets import SummarizationDataset
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
@ -19,8 +16,6 @@ from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
BertSumAbsProcessor,
validate,
)
from utils_nlp.eval.evaluate_summarization import get_rouge
# @pytest.fixture()
def source_data():
@ -49,7 +44,7 @@ def target_data():
]
NUM_GPUS = 2
#NUM_GPUS = 2
os.environ["NCCL_IB_DISABLE"] = "0"
@ -120,7 +115,7 @@ def test_train_model(tmp_module, test_dataset_for_bertsumabs, batch_size=1):
learning_rate_dec=0.2,
warmup_steps_bert=20000,
warmup_steps_dec=10000,
num_gpus=NUM_GPUS,
num_gpus=None,
report_every=10,
save_every=100,
validation_function=this_validate,
@ -143,7 +138,6 @@ def test_finetuned_model(
test_dataset_for_bertsumabs,
top_n=8,
batch_size=1,
num_gpus=4,
):
CACHE_PATH = (
tmp_module
@ -169,14 +163,13 @@ def test_finetuned_model(
)
summarizer.model.load_checkpoint(checkpoint["model"])
top_n = 50 # len(test_sum_dataset)
shortened_dataset = test_sum_dataset.shorten(top_n)
reference_summaries = [
"".join(t).rstrip("\n") for t in shortened_dataset.get_target()
]
print("start prediction")
generated_summaries = summarizer.predict(
shortened_dataset, batch_size=batch_size, num_gpus=num_gpus
shortened_dataset, batch_size=batch_size, num_gpus=None
)
def _write_list_to_file(list_items, filename):
@ -189,6 +182,3 @@ def test_finetuned_model(
_write_list_to_file(generated_summaries, os.path.join(CACHE_PATH, "prediction.txt"))
assert len(generated_summaries) == len(reference_summaries)
RESULT_DIR = TemporaryDirectory().name
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
print(rouge_score)