This commit is contained in:
Patrick von Platen 2021-02-08 16:04:28 +03:00 коммит произвёл GitHub
Родитель 31563e056d
Коммит 9e795eac88
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 3 добавлений и 17 удалений

Просмотреть файл

@ -24,15 +24,9 @@ if is_datasets_available():
class Seq2seqTrainerTester(TestCasePlus):
@slow
@require_datasets
@require_torch
@require_datasets
def test_finetune_bert2bert(self):
"""
Currently fails with:
ImportError: To be able to use this metric, you need to install the following dependencies['absl', 'nltk', 'rouge_score']
"""
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus):
train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4
def _map_to_encoder_decoder_inputs(batch):
@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus):
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid
accuracy = sum([int(pred_str[i] == label_str[i]) for i in range(len(pred_str))]) / len(pred_str)
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
return {"accuracy": accuracy}
# map train dataset
train_dataset = train_dataset.map(