From 784a8d45a5fe17d9cd476751a8814ab1fbfbf74f Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Thu, 27 Jun 2024 13:29:55 -0400 Subject: [PATCH] Enable support for scikit-learn 1.5 Signed-off-by: Keith Battocchi --- econml/_ensemble/_ensemble.py | 7 ++++++- econml/solutions/causal_analysis/_causal_analysis.py | 7 ++++++- pyproject.toml | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/econml/_ensemble/_ensemble.py b/econml/_ensemble/_ensemble.py index cdc23da4..fc37bac5 100644 --- a/econml/_ensemble/_ensemble.py +++ b/econml/_ensemble/_ensemble.py @@ -13,9 +13,14 @@ import numbers import numpy as np from abc import ABCMeta, abstractmethod from sklearn.base import BaseEstimator, clone -from sklearn.utils import _print_elapsed_time from sklearn.utils import check_random_state from joblib import effective_n_jobs +from packaging.version import parse +import sklearn +if parse(sklearn.__version__) < parse("1.5"): + from sklearn.utils import _print_elapsed_time +else: + from sklearn.utils._user_interface import _print_elapsed_time def _fit_single_estimator(estimator, X, y, sample_weight=None, diff --git a/econml/solutions/causal_analysis/_causal_analysis.py b/econml/solutions/causal_analysis/_causal_analysis.py index 422572b2..7a1e72df 100644 --- a/econml/solutions/causal_analysis/_causal_analysis.py +++ b/econml/solutions/causal_analysis/_causal_analysis.py @@ -30,7 +30,12 @@ from ...utilities import _RegressionWrapper, get_feature_names_or_default, inver # TODO: this utility is documented but internal; reimplement? from sklearn.utils import _safe_indexing # TODO: this utility is even less public... -from sklearn.utils import _get_column_indices +from packaging.version import parse +import sklearn +if parse(sklearn.__version__) < parse("1.5"): + from sklearn.utils import _get_column_indices +else: + from sklearn.utils._indexing import _get_column_indices class _CausalInsightsConstants: diff --git a/pyproject.toml b/pyproject.toml index b7114e94..a780b08d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ dependencies = [ "numpy", "scipy > 1.4.0", - "scikit-learn >= 1.0, < 1.5", + "scikit-learn >= 1.0, < 1.6", "sparse", "joblib >= 0.13.0", "statsmodels >= 0.10",