fix bert2bert test (#10063)
This commit is contained in:
Родитель
31563e056d
Коммит
9e795eac88
|
@ -24,15 +24,9 @@ if is_datasets_available():
|
|||
|
||||
class Seq2seqTrainerTester(TestCasePlus):
|
||||
@slow
|
||||
@require_datasets
|
||||
@require_torch
|
||||
@require_datasets
|
||||
def test_finetune_bert2bert(self):
|
||||
"""
|
||||
Currently fails with:
|
||||
|
||||
ImportError: To be able to use this metric, you need to install the following dependencies['absl', 'nltk', 'rouge_score']
|
||||
"""
|
||||
|
||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||
train_dataset = train_dataset.select(range(32))
|
||||
val_dataset = val_dataset.select(range(16))
|
||||
|
||||
rouge = datasets.load_metric("rouge")
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def _map_to_encoder_decoder_inputs(batch):
|
||||
|
@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||
"rouge2"
|
||||
].mid
|
||||
accuracy = sum([int(pred_str[i] == label_str[i]) for i in range(len(pred_str))]) / len(pred_str)
|
||||
|
||||
return {
|
||||
"rouge2_precision": round(rouge_output.precision, 4),
|
||||
"rouge2_recall": round(rouge_output.recall, 4),
|
||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||
}
|
||||
return {"accuracy": accuracy}
|
||||
|
||||
# map train dataset
|
||||
train_dataset = train_dataset.map(
|
||||
|
|
Загрузка…
Ссылка в новой задаче