diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index bc4e2a3b7..7c73803e4 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback): A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow `__. """ - MAX_LOG_SIZE = 100 - def __init__(self): assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`." import mlflow + self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH + self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH + self._initialized = False self._log_artifacts = False self._ml_flow = mlflow @@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback): if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict() combined_dict = {**model_config, **combined_dict} + # remove params that are too long for MLflow + for name, value in list(combined_dict.items()): + # internally, all values are converted to str in MLflow + if len(str(value)) > self._MAX_PARAM_VAL_LENGTH: + logger.warning( + f"Trainer is attempting to log a value of " + f'"{value}" for key "{name}" as a parameter. ' + f"MLflow's log_param() only accepts values no longer than " + f"250 characters so we dropped this attribute." + ) + del combined_dict[name] # MLflow cannot log more than 100 values in one go, so we have to split it combined_dict_items = list(combined_dict.items()) - for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE): - self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) + for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH): + self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH])) self._initialized = True def on_train_begin(self, args, state, control, model=None, **kwargs): @@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback): self._ml_flow.log_metric(k, v, step=state.global_step) else: logger.warning( - "Trainer is attempting to log a value of " - '"%s" of type %s for key "%s" as a metric. ' - "MLflow's log_metric() only accepts float and " - "int types so we dropped this attribute.", - v, - type(v), - k, + f"Trainer is attempting to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a metric. ' + f"MLflow's log_metric() only accepts float and " + f"int types so we dropped this attribute." ) def on_train_end(self, args, state, control, **kwargs): @@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback): if self._log_artifacts: logger.info("Logging artifacts. This may take time.") self._ml_flow.log_artifacts(args.output_dir) - self._ml_flow.end_run() def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed if self._ml_flow.active_run is not None: - self._ml_flow.end_run(status="KILLED") + self._ml_flow.end_run() INTEGRATION_TO_CALLBACK = {