зеркало из https://github.com/py-why/EconML.git
Merge branch 'main' into kebatt/debugNotebooks
This commit is contained in:
Коммит
8e685533e8
|
@ -19,7 +19,7 @@ import warnings
|
|||
from collections.abc import Iterable
|
||||
from scipy.stats import norm
|
||||
from econml.sklearn_extensions.model_selection import WeightedKFold, WeightedStratifiedKFold
|
||||
from econml.utilities import ndim, shape, reshape, _safe_norm_ppf
|
||||
from econml.utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays
|
||||
from sklearn import clone
|
||||
from sklearn.linear_model import LinearRegression, LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso
|
||||
from sklearn.metrics import r2_score
|
||||
|
@ -1683,6 +1683,8 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
|
|||
def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
|
||||
"""Check dimensions and other assertions."""
|
||||
|
||||
X, y, sample_weight, freq_weight, sample_var = check_input_arrays(
|
||||
X, y, sample_weight, freq_weight, sample_var, dtype='numeric')
|
||||
if X is None:
|
||||
X = np.empty((y.shape[0], 0))
|
||||
if self.fit_intercept:
|
||||
|
|
|
@ -327,6 +327,25 @@ class TestStatsModels(unittest.TestCase):
|
|||
assert np.all(np.abs(est.intercept_ - lr.intercept_) <
|
||||
1e-12), "{}, {}".format(est.intercept_, lr.intercept_)
|
||||
|
||||
def test_o_dtype(self):
|
||||
""" Testing that the models still work when the np arrays are of O dtype """
|
||||
np.random.seed(123)
|
||||
n = 1000
|
||||
d = 3
|
||||
|
||||
X = np.random.normal(size=(n, d)).astype('O')
|
||||
y = np.random.normal(size=n).astype('O')
|
||||
|
||||
est = OLS().fit(X, y)
|
||||
lr = LinearRegression().fit(X, y)
|
||||
assert np.all(np.abs(est.coef_ - lr.coef_) < 1e-12), "{}, {}".format(est.coef_, lr.coef_)
|
||||
assert np.all(np.abs(est.intercept_ - lr.intercept_) < 1e-12), "{}, {}".format(est.coef_, lr.intercept_)
|
||||
|
||||
est = OLS(fit_intercept=False).fit(X, y)
|
||||
lr = LinearRegression(fit_intercept=False).fit(X, y)
|
||||
assert np.all(np.abs(est.coef_ - lr.coef_) < 1e-12), "{}, {}".format(est.coef_, lr.coef_)
|
||||
assert np.all(np.abs(est.intercept_ - lr.intercept_) < 1e-12), "{}, {}".format(est.coef_, lr.intercept_)
|
||||
|
||||
def test_inference(self):
|
||||
""" Testing that we recover the expected standard errors and confidence intervals in a known example """
|
||||
|
||||
|
|
|
@ -514,7 +514,7 @@ def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True):
|
|||
return Y, T, X, W
|
||||
|
||||
|
||||
def check_input_arrays(*args, validate_len=True, force_all_finite=True):
|
||||
def check_input_arrays(*args, validate_len=True, force_all_finite=True, dtype=None):
|
||||
"""Cast input sequences into numpy arrays.
|
||||
|
||||
Only inputs that are sequence-like will be converted, all other inputs will be left as is.
|
||||
|
@ -531,6 +531,13 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True):
|
|||
force_all_finite : bool (default=True)
|
||||
Whether to allow inf and nan in input arrays.
|
||||
|
||||
dtype : 'numeric', type, list of type or None (default=None)
|
||||
Argument passed to sklearn.utils.check_array.
|
||||
Specifies data type of result. If None, the dtype of the input is preserved.
|
||||
If "numeric", dtype is preserved unless array.dtype is object.
|
||||
If dtype is a list of types, conversion on the first type is only
|
||||
performed if the dtype of the input is not in the list.
|
||||
|
||||
Returns
|
||||
-------
|
||||
args: array-like
|
||||
|
@ -541,7 +548,7 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True):
|
|||
args = list(args)
|
||||
for i, arg in enumerate(args):
|
||||
if np.ndim(arg) > 0:
|
||||
new_arg = check_array(arg, dtype=None, ensure_2d=False, accept_sparse=True,
|
||||
new_arg = check_array(arg, dtype=dtype, ensure_2d=False, accept_sparse=True,
|
||||
force_all_finite=force_all_finite)
|
||||
if not force_all_finite:
|
||||
# For when checking input values is disabled
|
||||
|
|
Загрузка…
Ссылка в новой задаче