Fix mlflow param overflow clean (#10071)
* Unify logging with f-strings * Get limits from MLflow rather than hardcode * Add a check for parameter length overflow Also constants are marked as internal * Don't stop run in on_train_end This causes bad behaviour when there is a seprarte validation step: validation gets recorded as separate run. * Fix style
This commit is contained in:
Родитель
ece6c51458
Коммит
ddaafd78fb
|
@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback):
|
|||
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
||||
"""
|
||||
|
||||
MAX_LOG_SIZE = 100
|
||||
|
||||
def __init__(self):
|
||||
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
||||
import mlflow
|
||||
|
||||
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
|
||||
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
|
||||
|
||||
self._initialized = False
|
||||
self._log_artifacts = False
|
||||
self._ml_flow = mlflow
|
||||
|
@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback):
|
|||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
# remove params that are too long for MLflow
|
||||
for name, value in list(combined_dict.items()):
|
||||
# internally, all values are converted to str in MLflow
|
||||
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
|
||||
logger.warning(
|
||||
f"Trainer is attempting to log a value of "
|
||||
f'"{value}" for key "{name}" as a parameter. '
|
||||
f"MLflow's log_param() only accepts values no longer than "
|
||||
f"250 characters so we dropped this attribute."
|
||||
)
|
||||
del combined_dict[name]
|
||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||
combined_dict_items = list(combined_dict.items())
|
||||
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
||||
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
|
||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
|
||||
self._initialized = True
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
|
@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback):
|
|||
self._ml_flow.log_metric(k, v, step=state.global_step)
|
||||
else:
|
||||
logger.warning(
|
||||
"Trainer is attempting to log a value of "
|
||||
'"%s" of type %s for key "%s" as a metric. '
|
||||
"MLflow's log_metric() only accepts float and "
|
||||
"int types so we dropped this attribute.",
|
||||
v,
|
||||
type(v),
|
||||
k,
|
||||
f"Trainer is attempting to log a value of "
|
||||
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
|
||||
f"MLflow's log_metric() only accepts float and "
|
||||
f"int types so we dropped this attribute."
|
||||
)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
|
@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback):
|
|||
if self._log_artifacts:
|
||||
logger.info("Logging artifacts. This may take time.")
|
||||
self._ml_flow.log_artifacts(args.output_dir)
|
||||
self._ml_flow.end_run()
|
||||
|
||||
def __del__(self):
|
||||
# if the previous run is not terminated correctly, the fluent API will
|
||||
# not let you start a new run before the previous one is killed
|
||||
if self._ml_flow.active_run is not None:
|
||||
self._ml_flow.end_run(status="KILLED")
|
||||
self._ml_flow.end_run()
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
|
|
Загрузка…
Ссылка в новой задаче