[examples/run_s2s] remove task_specific_params and update rouge computation (#10133)
* fix rouge metrics and task specific params * fix typo * round metrics * typo * remove task_specific_params
This commit is contained in:
Родитель
31245775e5
Коммит
f51188cbe7
|
@ -25,10 +25,12 @@ import sys
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -44,6 +46,10 @@ from transformers import (
|
|||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
|
||||
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -110,10 +116,22 @@ class DataTrainingArguments:
|
|||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
|
@ -298,6 +316,9 @@ def main():
|
|||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
@ -335,15 +356,7 @@ def main():
|
|||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
# Get the default prefix if None is passed.
|
||||
if data_args.source_prefix is None:
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
prefix = task_specific_params.get("prefix", "")
|
||||
else:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix = data_args.source_prefix
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
|
@ -487,6 +500,19 @@ def main():
|
|||
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
||||
metric = load_metric(metric_name)
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
# rougeLSum expects newline after each sentence
|
||||
if metric_name == "rouge":
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
else: # sacrebleu
|
||||
labels = [[label] for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
|
@ -498,22 +524,19 @@ def main():
|
|||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
||||
decoded_labels = [label.strip() for label in decoded_labels]
|
||||
if metric_name == "sacrebleu":
|
||||
decoded_labels = [[label] for label in decoded_labels]
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
|
||||
# Extract a few results from ROUGE
|
||||
if metric_name == "rouge":
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
else:
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
result = {"bleu": result["score"]}
|
||||
|
||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
|
@ -555,6 +578,7 @@ def main():
|
|||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
||||
results = {k: round(v, 4) for k, v in results.items()}
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
|
@ -574,6 +598,7 @@ def main():
|
|||
num_beams=data_args.num_beams,
|
||||
)
|
||||
test_metrics = test_results.metrics
|
||||
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
|
||||
|
||||
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
|
|
Загрузка…
Ссылка в новой задаче