From badbebb96d43b9ba9cd5c7d11ed774af25a5f6a9 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Mon, 24 Feb 2020 21:27:31 +0000 Subject: [PATCH] only use 1 gpu for validation --- utils_nlp/models/transformers/abssum.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils_nlp/models/transformers/abssum.py b/utils_nlp/models/transformers/abssum.py index 80751c2..14ba566 100644 --- a/utils_nlp/models/transformers/abssum.py +++ b/utils_nlp/models/transformers/abssum.py @@ -9,6 +9,7 @@ from collections import namedtuple import itertools import logging import os +import pickle import random import numpy as np @@ -335,9 +336,9 @@ def validate(summarizer, validate_sum_dataset, cache_dir): TOP_N = 8 src = validate_sum_dataset.source[0:TOP_N] - reference_summaries = ["".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]] + reference_summaries = [" ".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]] generated_summaries = summarizer.predict( - shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=2, batch_size=4 + shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=1, batch_size=4 ) assert len(generated_summaries) == len(reference_summaries) for i in generated_summaries[0:1]: