Always use generated folds for model selection

Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
Keith Battocchi 2024-02-12 12:04:33 -05:00 коммит произвёл Keith Battocchi
Родитель 639d28cb09
Коммит d35c340331
10 изменённых файлов: 132 добавлений и 68 удалений

Просмотреть файл

@ -96,7 +96,7 @@ def _fit_fold(model, train_idxs, test_idxs, calculate_scores, args, kwargs):
kwargs_train = {key: var[train_idxs] for key, var in kwargs.items()}
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items()}
model.train(False, *args_train, **kwargs_train)
model.train(False, None, *args_train, **kwargs_train)
nuisance_temp = model.predict(*args_test, **kwargs_test)
if not isinstance(nuisance_temp, tuple):
@ -120,10 +120,10 @@ def _crossfit(models: Union[ModelSelector, List[ModelSelector]], folds, use_ray,
----------
models : ModelSelector or List[ModelSelector]
One or more objects that have train and predict methods.
The train method must take an 'is_selecting' argument first, and then
accept positional arguments `args` and keyword arguments `kwargs`; the predict method
just takes those `args` and `kwargs`. The train
method selects or estimates a model of the nuisance function, based on the input
The train method must take an 'is_selecting' argument first, a set of folds second
(which will be None when not selecting) and then accept positional arguments `args`
and keyword arguments `kwargs`; the predict method just takes those `args` and `kwargs`.
The train method selects or estimates a model of the nuisance function, based on the input
data to fit. Predict evaluates the fitted nuisance function on the input
data to predict.
folds : list of tuple or None
@ -175,7 +175,7 @@ def _crossfit(models: Union[ModelSelector, List[ModelSelector]], folds, use_ray,
class Wrapper:
def __init__(self, model):
self._model = model
def train(self, is_selecting, X, y, W=None):
def train(self, is_selecting, folds, X, y, W=None):
self._model.fit(X, y)
return self
def predict(self, X, y, W=None):
@ -224,6 +224,7 @@ def _crossfit(models: Union[ModelSelector, List[ModelSelector]], folds, use_ray,
fitted_inds = np.concatenate((fitted_inds, test_idxs))
fitted_inds = np.sort(fitted_inds.astype(int))
else:
fold_vals = [(np.arange(n), np.arange(n))]
fitted_inds = np.arange(n)
accumulated_nuisances = ()
@ -245,7 +246,7 @@ def _crossfit(models: Union[ModelSelector, List[ModelSelector]], folds, use_ray,
# come first as positional arguments
accumulated_args = accumulated_nuisances + args
if model.needs_fit:
model.train(True, *accumulated_args, **kwargs)
model.train(True, fold_vals if folds is None else folds, *accumulated_args, **kwargs)
calculate_scores &= hasattr(model, 'score')
@ -253,7 +254,7 @@ def _crossfit(models: Union[ModelSelector, List[ModelSelector]], folds, use_ray,
if folds is None: # skip crossfitting
model_list[-1].append(clone(model, safe=False))
model_list[-1][0].train(False, *accumulated_args, **kwargs) # fit the selected model
model_list[-1][0].train(False, None, *accumulated_args, **kwargs) # fit the selected model
nuisances = model_list[-1][0].predict(*accumulated_args, **kwargs)
if not isinstance(nuisances, tuple):
nuisances = (nuisances,)
@ -445,7 +446,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
def __init__(self, model_t, model_y):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self
@ -502,7 +503,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
def __init__(self, model_t, model_y):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, np.matmul(T, np.arange(1, T.shape[1]+1)))
self._model_y.fit(W, Y)
return self
@ -606,7 +607,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
The selector(s) for fitting the nuisance function. The returned estimators must implement
`train` and `predict` methods that both have signatures::
model_nuisance.train(is_selecting, Y, T, X=X, W=W, Z=Z,
model_nuisance.train(is_selecting, folds, Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight)
model_nuisance.predict(Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight)

Просмотреть файл

@ -48,10 +48,12 @@ class _ModelNuisance(ModelSelector):
self._model_y = model_y
self._model_t = model_t
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
assert Z is None, "Cannot accept instrument!"
self._model_t.train(is_selecting, X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_y.train(is_selecting, X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_t.train(is_selecting, folds, X, W, T, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_y.train(is_selecting, folds, X, W, Y, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
return self
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
@ -223,7 +225,7 @@ class _RLearner(_OrthoLearner):
class ModelSelector(SingleModelSelector):
def __init__(self, model):
self._model = ModelFirst(model)
def train(self, is_selecting, X, W, Y, sample_weight=None):
def train(self, is_selecting, folds, X, W, Y, sample_weight=None):
self._model.fit(X, W, Y, sample_weight=sample_weight)
return self
@property

Просмотреть файл

@ -82,7 +82,7 @@ class _FirstStageSelector(SingleModelSelector):
self._model = clone(model, safe=False)
self._discrete_target = discrete_target
def train(self, is_selecting, X, W, Target, sample_weight=None, groups=None):
def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None):
if self._discrete_target:
# In this case, the Target is the one-hot-encoding of the treatment variable
# We need to go back to the label representation of the one-hot so as to call
@ -92,7 +92,7 @@ class _FirstStageSelector(SingleModelSelector):
"don't contain all treatments")
Target = inverse_onehot(Target)
self._model.train(is_selecting, _combine(X, W, Target.shape[0]), Target,
self._model.train(is_selecting, folds, _combine(X, W, Target.shape[0]), Target,
**filter_none_kwargs(groups=groups, sample_weight=sample_weight))
return self

Просмотреть файл

@ -72,7 +72,7 @@ class _ModelNuisance(ModelSelector):
def _combine(self, X, W):
return np.hstack([arr for arr in [X, W] if arr is not None])
def train(self, is_selecting, Y, T, X=None, W=None, *, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None):
if Y.ndim != 1 and (Y.ndim != 2 or Y.shape[1] != 1):
raise ValueError("The outcome matrix must be of shape ({0}, ) or ({0}, 1), "
"instead got {1}.".format(len(X), Y.shape))
@ -84,8 +84,8 @@ class _ModelNuisance(ModelSelector):
XW = self._combine(X, W)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight)
self._model_propensity.train(is_selecting, XW, inverse_onehot(T), groups=groups, **filtered_kwargs)
self._model_regression.train(is_selecting, np.hstack([XW, T]), Y, groups=groups, **filtered_kwargs)
self._model_propensity.train(is_selecting, folds, XW, inverse_onehot(T), groups=groups, **filtered_kwargs)
self._model_regression.train(is_selecting, folds, np.hstack([XW, T]), Y, groups=groups, **filtered_kwargs)
return self
def score(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None):

Просмотреть файл

@ -56,15 +56,16 @@ class _OrthoIVNuisanceSelector(ModelSelector):
else:
self._model_z_xw = model_z
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
self._model_t_xw.train(is_selecting, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
if self._projection:
# concat W and Z
WZ = _combine(W, Z, Y.shape[0])
self._model_t_xwz.train(is_selecting, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups)
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T,
sample_weight=sample_weight, groups=groups)
else:
self._model_z_xw.train(is_selecting, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
return self
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
@ -729,12 +730,14 @@ class _BaseDMLIVNuisanceSelector(ModelSelector):
self._model_t_xw = model_t_xw
self._model_t_xwz = model_t_xwz
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_t_xw.train(is_selecting, X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, folds, X, W, Y, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_t_xw.train(is_selecting, folds, X, W, T, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
# concat W and Z
WZ = _combine(W, Z, Y.shape[0])
self._model_t_xwz.train(is_selecting, X, WZ, T,
self._model_t_xwz.train(is_selecting, folds, X, WZ, T,
**filter_none_kwargs(sample_weight=sample_weight, groups=groups))
return self

Просмотреть файл

@ -58,19 +58,20 @@ class _BaseDRIVNuisanceSelector(ModelSelector):
else:
self._model_z_xw = model_z
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
# T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
T = T.ravel() if not self._discrete_treatment else T
Z = Z.ravel() if not self._discrete_instrument else Z
self._model_y_xw.train(is_selecting, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
self._model_t_xw.train(is_selecting, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
if self._projection:
WZ = _combine(W, Z, Y.shape[0])
self._model_t_xwz.train(is_selecting, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups)
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T,
sample_weight=sample_weight, groups=groups)
else:
self._model_z_xw.train(is_selecting, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
# TODO: prel_model_effect could allow sample_var and freq_weight?
if self._discrete_instrument:
@ -220,12 +221,12 @@ class _BaseDRIVNuisanceCovarianceSelector(ModelSelector):
target = T * Z
return target
def train(self, is_selecting,
def train(self, is_selecting, folds,
prel_theta, Y_res, T_res, Z_res,
Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
# T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
target = self._get_target(T_res, Z_res, T, Z)
self._model_tz_xw.train(is_selecting, X=X, W=W, Target=target,
self._model_tz_xw.train(is_selecting, folds, X=X, W=W, Target=target,
sample_weight=sample_weight, groups=groups)
return self
@ -2402,12 +2403,12 @@ class _IntentToTreatDRIVNuisanceSelector(ModelSelector):
self._dummy_z = dummy_z
self._prel_model_effect = prel_model_effect
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
# concat W and Z
WZ = _combine(W, Z, Y.shape[0])
self._model_t_xwz.train(is_selecting, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups)
self._dummy_z.train(is_selecting, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups)
self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
# we need to undo the one-hot encoding for calling effect,
# since it expects raw values
self._prel_model_effect.fit(Y, inverse_onehot(T), Z=inverse_onehot(Z), X=X, W=W,

Просмотреть файл

@ -45,7 +45,7 @@ class _DynamicModelNuisanceSelector(ModelSelector):
self._model_t = model_t
self.n_periods = n_periods
def train(self, is_selecting, Y, T, X=None, W=None, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, sample_weight=None, groups=None):
"""Fit a series of nuisance models for each period or period pairs."""
assert Y.shape[0] % self.n_periods == 0, \
"Length of training data should be an integer multiple of time periods."
@ -56,16 +56,46 @@ class _DynamicModelNuisanceSelector(ModelSelector):
self._model_t_trained = {j: {t: clone(self._model_t, safe=False)
for t in np.arange(j + 1)}
for j in np.arange(self.n_periods)}
# we have to filter the folds because they contain the indices in the original data not
# the indices in the period-filtered data
def _translate_inds(t, inds):
# translate the indices in a fold to the indices in the period-filtered data
# if groups was [3,3,4,4,5,5,6,6,1,1,2,2,0,0] (the group ids can be in any order, but the
# time periods for each group should be contguous), and we had [10,11,0,1] as the indices in a fold
# (so the fold is taking the entries corresponding to groups 2 and 3)
# then group_period_filter(0) is [0,2,4,6,8,10,12] and gpf(1) is [1,3,5,7,9,11,13]
# so for period 1, the fold should be [10,0] => [5,0] (the indices that return 10 and 0 in the t=0 data)
# and for period 2, the fold should be [11,1] => [5,0] again (the indices that return 11,1 in the t=1 data)
# filter to the indices for the time period
inds = inds[np.isin(inds, period_filters[t])]
# now find their index in the period-filtered data, which is always sorted
return np.searchsorted(period_filters[t], inds)
if folds is not None:
translated_folds = []
for (train, test) in folds:
translated_folds.append((_translate_inds(0, train), _translate_inds(0, test)))
# sanity check that the folds are the same no matter the time period
for t in range(1, self.n_periods):
assert np.array_equal(_translate_inds(t, train), _translate_inds(0, train))
assert np.array_equal(_translate_inds(t, test), _translate_inds(0, test))
else:
translated_folds = None
for t in np.arange(self.n_periods):
self._model_y_trained[t].train(
is_selecting,
is_selecting, translated_folds,
self._index_or_None(X, period_filters[t]),
self._index_or_None(
W, period_filters[t]),
Y[period_filters[self.n_periods - 1]])
for j in np.arange(t, self.n_periods):
self._model_t_trained[j][t].train(
is_selecting,
is_selecting, translated_folds,
self._index_or_None(X, period_filters[t]),
self._index_or_None(W, period_filters[t]),
T[period_filters[j]])

Просмотреть файл

@ -3,7 +3,9 @@
"""Collection of scikit-learn extensions for model selection techniques."""
from inspect import signature
import inspect
import numbers
from typing import List, Optional
import warnings
import abc
@ -280,9 +282,10 @@ class ModelSelector(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def train(self, is_selecting: bool, *args, **kwargs):
def train(self, is_selecting: bool, folds: Optional[List], *args, **kwargs):
"""
Either selects a model or fits a model, depending on the value of `is_selecting`.
If `is_selecting` is `False`, then `folds` should not be provided because they are only during selection.
"""
raise NotImplementedError("Abstract method")
@ -386,19 +389,31 @@ def _fit_with_groups(model, X, y, *, sub_model=None, groups, **kwargs):
class FixedModelSelector(SingleModelSelector):
"""
Model selection class that always selects the given model
Model selection class that always selects the given sklearn-compatible model
"""
def __init__(self, model):
self.model = clone(model, safe=False)
def train(self, is_selecting, *args, groups=None, **kwargs):
# whether selecting or not, need to train the model on the data
_fit_with_groups(self.model, *args, groups=groups, **kwargs)
if is_selecting and hasattr(self.model, 'score'):
# TODO: we need to alter this to use out-of-sample score here, which
# will require cross-validation, but should respect grouping, stratifying, etc.
self._score = self.model.score(*args, **kwargs)
def train(self, is_selecting, folds: Optional[List], X, y, groups=None, **kwargs):
if is_selecting:
# since needs_fit is False, is_selecting will only be true if
# the score needs to be compared to another model's
# so we don't need to fit the model itself, just get the out-of-sample score
assert hasattr(self.model, 'score'), (f"Can't select between a fixed {type(self.model)} model and others "
"because it doesn't have a score method")
scores = []
for train, test in folds:
# use _fit_with_groups instead of just fit to handle nested grouping
_fit_with_groups(self.model, X[train], y[train],
groups=None if groups is None else groups[train],
**{key: val[train] for key, val in kwargs.items()})
scores.append(self.model.score(X[test], y[test]))
self._score = np.mean(scores)
else:
# we need to train the model on the data
_fit_with_groups(self.model, X, y, groups=groups, **kwargs)
return self
@property
@ -411,7 +426,7 @@ class FixedModelSelector(SingleModelSelector):
@property
def needs_fit(self):
return False # We have only a single model
return False # We have only a single model so we can skip the selection process
def _copy_to(m1, m2, attrs, insert_underscore=False):
@ -534,17 +549,28 @@ class SklearnCVSelector(SingleModelSelector):
converter = SklearnCVSelector._model_mapping()[known_type]
return converter(model, args, kwargs)
def train(self, is_selecting: bool, *args, groups=None, **kwargs):
def train(self, is_selecting: bool, folds: Optional[List], *args, groups=None, **kwargs):
if is_selecting:
sub_model = None
sub_model = self.searcher
if isinstance(self.searcher, Pipeline):
sub_model = self.searcher.steps[-1][1]
_fit_with_groups(self.searcher, *args, sub_model=sub_model, groups=groups, **kwargs)
init_params = inspect.signature(sub_model.__init__).parameters
if 'cv' in init_params:
default_cv = init_params['cv'].default
else:
# constructor takes cv as a positional or kwarg, just pull it out of a new instance
default_cv = type(sub_model)().cv
if sub_model.cv != default_cv:
warnings.warn(f"Model {sub_model} has a non-default cv attribute, which will be ignored")
sub_model.cv = folds
self.searcher.fit(*args, **kwargs)
self._best_model, self._best_score = self._convert_model(self.searcher, args, kwargs)
else:
# don't need to use _fit_with_groups here since none of these models support it
self.best_model.fit(*args, **kwargs)
return self
@ -578,18 +604,19 @@ class ListSelector(SingleModelSelector):
self.models = [clone(model, safe=False) for model in models]
self.unwrap = unwrap
def train(self, is_selecting, *args, **kwargs):
def train(self, is_selecting, folds: Optional[List], *args, **kwargs):
assert len(self.models) > 0, "ListSelector must have at least one model"
if is_selecting:
scores = []
for model in self.models:
model.train(is_selecting, *args, **kwargs)
model.train(is_selecting, folds, *args, **kwargs)
scores.append(model.best_score)
self._all_scores = scores
self._best_score = np.max(scores)
self._best_model = self.models[np.argmax(scores)]
else:
self._best_model.train(is_selecting, *args, **kwargs)
self._best_model.train(is_selecting, folds, *args, **kwargs)
@property
def best_model(self):

Просмотреть файл

@ -27,7 +27,7 @@ class ModelNuisance:
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self

Просмотреть файл

@ -27,7 +27,7 @@ class TestOrthoLearner(unittest.TestCase):
def __init__(self, model):
self._model = model
def train(self, is_selecting, X, y, Q, W=None):
def train(self, is_selecting, folds, X, y, Q, W=None):
self._model.fit(X, y)
return self
@ -111,7 +111,7 @@ class TestOrthoLearner(unittest.TestCase):
def __init__(self, model):
self._model = model
def train(self, is_selecting, X, y, W=None):
def train(self, is_selecting, folds, X, y, W=None):
self._model.fit(X, y)
return self
@ -185,7 +185,7 @@ class TestOrthoLearner(unittest.TestCase):
def __init__(self, model):
self._model = model
def train(self, is_selecting, X, y, Q, W=None):
def train(self, is_selecting, folds, X, y, Q, W=None):
self._model.fit(X, y)
return self
@ -229,7 +229,7 @@ class TestOrthoLearner(unittest.TestCase):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self
@ -345,7 +345,7 @@ class TestOrthoLearner(unittest.TestCase):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self
@ -397,7 +397,7 @@ class TestOrthoLearner(unittest.TestCase):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self
@ -458,7 +458,7 @@ class TestOrthoLearner(unittest.TestCase):
self._model_t = model_t
self._model_y = model_y
def train(self, is_selecting, Y, T, W=None):
def train(self, is_selecting, folds, Y, T, W=None):
self._model_t.fit(W, np.matmul(T, np.arange(1, T.shape[1] + 1)))
self._model_y.fit(W, Y)
return self