only use 1 gpu for validation
This commit is contained in:
Родитель
6597c93148
Коммит
badbebb96d
|
@ -9,6 +9,7 @@ from collections import namedtuple
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -335,9 +336,9 @@ def validate(summarizer, validate_sum_dataset, cache_dir):
|
||||||
TOP_N = 8
|
TOP_N = 8
|
||||||
|
|
||||||
src = validate_sum_dataset.source[0:TOP_N]
|
src = validate_sum_dataset.source[0:TOP_N]
|
||||||
reference_summaries = ["".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]]
|
reference_summaries = [" ".join(t).rstrip("\n") for t in validate_sum_dataset.target[0:TOP_N]]
|
||||||
generated_summaries = summarizer.predict(
|
generated_summaries = summarizer.predict(
|
||||||
shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=2, batch_size=4
|
shorten_dataset(validate_sum_dataset, top_n=TOP_N), num_gpus=1, batch_size=4
|
||||||
)
|
)
|
||||||
assert len(generated_summaries) == len(reference_summaries)
|
assert len(generated_summaries) == len(reference_summaries)
|
||||||
for i in generated_summaries[0:1]:
|
for i in generated_summaries[0:1]:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче