This commit is contained in:
hlums 2020-02-27 22:12:57 +00:00
Родитель 520d0dd937
Коммит bbc1287783
1 изменённых файлов: 0 добавлений и 42 удалений

Просмотреть файл

@ -254,47 +254,6 @@ class S2SAbsSumProcessor:
input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))
return S2SAbsSumDataset(input_lines)
# def test_dataset_from_iterable_sum_ds():
# input_lines = []
# for src in sum_ds:
# example = {"src": src}
# input_lines.append(self._preprocess_test_src(example))
# input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))
# return S2SAbsSumDataset(input_lines)
# def test_dataset_from_sum_ds(self, sum_ds):
# input_lines = []
# for example in sum_ds:
# input_lines.append(self._preprocess_test_src(example))
# input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))
# return S2SAbsSumDataset(input_lines)
# def test_dataset_from_file(self, test_file):
# with open(test_file, encoding="utf-8", mode="r") as fin:
# input_lines = []
# for line in fin:
# example = json.loads(line)
# input_lines.append(self._preprocess_test_src(example))
# input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))
# return S2SAbsSumDataset(input_lines)
def _preprocess_test_src(self, example):
if isinstance(example["src"], list):
source_tokens = example["src"]
else:
source_tokens = self.tokenizer.tokenize(example["src"])
if self._model_type != "roberta":
enter_token = self.tokenizer.tokenize("Enter\nToken")[1]
source_tokens = [
enter_token if x == "[X_SEP]" else x for x in source_tokens
]
return source_tokens
class S2SConfig:
"""This class contains some default decoding settings that the users usually
@ -556,7 +515,6 @@ class S2SAbstractiveSummarizer(Transformer):
scheduler=self.scheduler,
local_rank=local_rank,
fp16=fp16,
fp16_opt_level=fp16_opt_level,
amp=amp,
max_grad_norm=max_grad_norm,
verbose=verbose,