[examples/run_s2s] remove task_specific_params and update rouge computation (#10133)

* fix rouge metrics and task specific params

* fix typo

* round metrics

* typo

* remove task_specific_params
This commit is contained in:
Suraj Patil 2021-02-12 17:18:21 +05:30 коммит произвёл GitHub
Родитель 31245775e5
Коммит f51188cbe7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 44 добавлений и 19 удалений

Просмотреть файл

@ -25,10 +25,12 @@ import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import nltk # Here to have a nice missing dependency error message early on
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
import transformers import transformers
from filelock import FileLock
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
@ -44,6 +46,10 @@ from transformers import (
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,10 +116,22 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
) )
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field( validation_file: Optional[str] = field(
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
) )
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
@ -298,6 +316,9 @@ def main():
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files) datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
@ -335,15 +356,7 @@ def main():
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# Get the default prefix if None is passed. prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if data_args.source_prefix is None:
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
prefix = task_specific_params.get("prefix", "")
else:
prefix = ""
else:
prefix = data_args.source_prefix
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to tokenize inputs and targets. # We need to tokenize inputs and targets.
@ -487,6 +500,19 @@ def main():
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu" metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
metric = load_metric(metric_name) metric = load_metric(metric_name)
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
if metric_name == "rouge":
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]
return preds, labels
def compute_metrics(eval_preds): def compute_metrics(eval_preds):
preds, labels = eval_preds preds, labels = eval_preds
if isinstance(preds, tuple): if isinstance(preds, tuple):
@ -498,22 +524,19 @@ def main():
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing # Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds] decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
decoded_labels = [label.strip() for label in decoded_labels]
if metric_name == "sacrebleu":
decoded_labels = [[label] for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
# Extract a few results from ROUGE
if metric_name == "rouge": if metric_name == "rouge":
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else: else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]} result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens) result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result return result
# Initialize our Trainer # Initialize our Trainer
@ -555,6 +578,7 @@ def main():
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams) results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt") output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
@ -574,6 +598,7 @@ def main():
num_beams=data_args.num_beams, num_beams=data_args.num_beams,
) )
test_metrics = test_results.metrics test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt") output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():