diff --git a/utils_nlp/models/transformers/abssum.py b/utils_nlp/models/transformers/abssum.py index 80751c2..14ba566 100644 --- a/utils_nlp/models/transformers/abssum.py +++ b/utils_nlp/models/transformers/abssum.py @@ -9,6 +9,7 @@ from collections import namedtuple import itertools import logging import os +import pickle import random import numpy as np @@ -335,9 +336,9 @@ def validate(summarizer, validate_sum_dataset, cache_dir): TOP_N = 8 src = validate_sum_dataset.source[0:TOP_N] - reference_summaries = ["".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]] + reference_summaries = [" ".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]] generated_summaries = summarizer.predict( - shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=2, batch_size=4 + shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=1, batch_size=4 ) assert len(generated_summaries) == len(reference_summaries) for i in generated_summaries[0:1]: