only use 1 gpu for validation
This commit is contained in:
Родитель
6597c93148
Коммит
badbebb96d
|
@ -9,6 +9,7 @@ from collections import namedtuple
|
|||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
@ -335,9 +336,9 @@ def validate(summarizer, validate_sum_dataset, cache_dir):
|
|||
TOP_N = 8
|
||||
|
||||
src = validate_sum_dataset.source[0:TOP_N]
|
||||
reference_summaries = ["".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]]
|
||||
reference_summaries = [" ".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]]
|
||||
generated_summaries = summarizer.predict(
|
||||
shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=2, batch_size=4
|
||||
shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=1, batch_size=4
|
||||
)
|
||||
assert len(generated_summaries) == len(reference_summaries)
|
||||
for i in generated_summaries[0:1]:
|
||||
|
|
Загрузка…
Ссылка в новой задаче