Fix mlflow param overflow clean (#10071)

* Unify logging with f-strings

* Get limits from MLflow rather than hardcode

* Add a check for parameter length overflow

Also constants are marked as internal

* Don't stop run in on_train_end

This causes bad behaviour when there is a seprarte validation step:
validation gets recorded as separate run.

* Fix style
This commit is contained in:
noise-field 2021-02-08 19:58:02 +03:00 коммит произвёл GitHub
Родитель ece6c51458
Коммит ddaafd78fb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 21 добавлений и 13 удалений

Просмотреть файл

@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback):
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
"""
MAX_LOG_SIZE = 100
def __init__(self):
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
import mlflow
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
self._initialized = False
self._log_artifacts = False
self._ml_flow = mlflow
@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback):
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{value}" for key "{name}" as a parameter. '
f"MLflow's log_param() only accepts values no longer than "
f"250 characters so we dropped this attribute."
)
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback):
self._ml_flow.log_metric(k, v, step=state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a metric. '
"MLflow's log_metric() only accepts float and "
"int types so we dropped this attribute.",
v,
type(v),
k,
f"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
f"MLflow's log_metric() only accepts float and "
f"int types so we dropped this attribute."
)
def on_train_end(self, args, state, control, **kwargs):
@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback):
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)
self._ml_flow.end_run()
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
self._ml_flow.end_run(status="KILLED")
self._ml_flow.end_run()
INTEGRATION_TO_CALLBACK = {