Add custom callback to avoid azureml 100 params limitation
This commit is contained in:
Родитель
b5ed0dee2c
Коммит
a083781f26
|
@ -10,6 +10,7 @@ from transformers import (
|
|||
HfArgumentParser,
|
||||
IntervalStrategy,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
import torch
|
||||
import nltk
|
||||
|
@ -108,6 +109,21 @@ def compute_metrics(eval_preds, tokenizer, metric):
|
|||
return result
|
||||
|
||||
|
||||
class CustomCallback(TrainerCallback):
|
||||
"""A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
|
||||
|
||||
This is a hotfix for the issue raised here:
|
||||
https://github.com/huggingface/transformers/issues/18870
|
||||
"""
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
metrics[k] = v
|
||||
mlflow.log_metrics(metrics=metrics, step=state.global_step)
|
||||
|
||||
|
||||
def main():
|
||||
# Setup logging
|
||||
logger = logging.getLogger()
|
||||
|
@ -119,6 +135,9 @@ def main():
|
|||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# initialize the mlflow session
|
||||
mlflow.start_run()
|
||||
|
||||
parser = HfArgumentParser((ModelArgs, DataArgs, Seq2SeqTrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
logger.info(f"Running with arguments: {model_args}, {data_args}, {training_args}")
|
||||
|
@ -188,7 +207,7 @@ def main():
|
|||
training_args.save_strategy = "epoch"
|
||||
training_args.evaluation_strategy = IntervalStrategy.EPOCH
|
||||
training_args.predict_with_generate = True
|
||||
training_args.report_to = ["mlflow"]
|
||||
training_args.report_to = [] # use our own callback
|
||||
logger.info(f"training args: {training_args}")
|
||||
|
||||
# Initialize our Trainer
|
||||
|
@ -200,6 +219,7 @@ def main():
|
|||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda preds : compute_metrics(preds, tokenizer, metric),
|
||||
callbacks=[CustomMetricCallback]
|
||||
)
|
||||
|
||||
# Start the actual training (to include evaluation use --do-eval)
|
||||
|
|
Загрузка…
Ссылка в новой задаче