This commit is contained in:
Daisy Deng 2020-02-24 21:27:31 +00:00
Родитель 6597c93148
Коммит badbebb96d
1 изменённых файлов: 3 добавлений и 2 удалений

Просмотреть файл

@ -9,6 +9,7 @@ from collections import namedtuple
import itertools import itertools
import logging import logging
import os import os
import pickle
import random import random
import numpy as np import numpy as np
@ -335,9 +336,9 @@ def validate(summarizer, validate_sum_dataset, cache_dir):
TOP_N = 8 TOP_N = 8
src = validate_sum_dataset.source[0:TOP_N] src = validate_sum_dataset.source[0:TOP_N]
reference_summaries = ["".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]] reference_summaries = [" ".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]]
generated_summaries = summarizer.predict( generated_summaries = summarizer.predict(
shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=2, batch_size=4 shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=1, batch_size=4
) )
assert len(generated_summaries) == len(reference_summaries) assert len(generated_summaries) == len(reference_summaries)
for i in generated_summaries[0:1]: for i in generated_summaries[0:1]: