enable treatment featurization (#615)

Co-authored-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
fverac 2022-10-28 13:15:25 -04:00 коммит произвёл GitHub
Родитель b4191e8735
Коммит deb564fafa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
23 изменённых файлов: 3167 добавлений и 434 удалений

Просмотреть файл

@ -99,8 +99,8 @@ jobs:
# Work around https://github.com/pypa/pip/issues/9542
- script: 'pip install -U numpy~=1.21.0'
displayName: 'Upgrade numpy'
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm && pip list && python setup.py pytest'
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm py && pip list && python setup.py pytest'
displayName: 'Unit tests'
env:
PYTEST_ADDOPTS: '-m "notebook"'
@ -128,7 +128,7 @@ jobs:
- script: 'pip install -U numpy~=1.21.0'
displayName: 'Upgrade numpy'
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm && python setup.py pytest'
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm py && python setup.py pytest'
displayName: 'Unit tests'
env:
PYTEST_ADDOPTS: '-m "notebook"'
@ -199,10 +199,10 @@ jobs:
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
displayName: 'Run tests (main)'
steps:
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
displayName: 'Unit tests'
env:
PYTEST_ADDOPTS: '-m "not (notebook or automl or dml or serial or cate_api)" -n 2'
PYTEST_ADDOPTS: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization)" -n 2'
COVERAGE_PROCESS_START: 'setup.cfg'
- task: PublishTestResults@2
displayName: 'Publish Test Results **/test-results.xml'
@ -226,7 +226,7 @@ jobs:
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
displayName: 'Run tests (DML)'
steps:
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
displayName: 'Unit tests'
env:
PYTEST_ADDOPTS: '-m "dml"'
@ -255,7 +255,7 @@ jobs:
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
displayName: 'Run tests (Serial)'
steps:
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
displayName: 'Unit tests'
env:
PYTEST_ADDOPTS: '-m "serial" -n 1'
@ -282,7 +282,7 @@ jobs:
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
displayName: 'Run tests (Other)'
steps:
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''"'
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py'
displayName: 'Install pytest'
- script: 'python setup.py pytest'
displayName: 'CATE Unit tests'
@ -296,6 +296,35 @@ jobs:
testRunTitle: 'Python $(python.version), image $(imageName)'
condition: succeededOrFailed()
- task: PublishCodeCoverageResults@1
displayName: 'Publish Code Coverage Results'
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'
- template: azure-pipelines-steps.yml
parameters:
package: '-e .[tf,plt]'
job:
job: Tests_treatment_featurization
dependsOn: 'EvalChanges'
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
displayName: 'Run tests (Treatment Featurization)'
steps:
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py'
displayName: 'Install pytest'
- script: 'python setup.py pytest'
displayName: 'Treatment Featurization Unit tests'
env:
PYTEST_ADDOPTS: '-m "treatment_featurization" -n auto'
COVERAGE_PROCESS_START: 'setup.cfg'
- task: PublishTestResults@2
displayName: 'Publish Test Results **/test-results.xml'
inputs:
testResultsFiles: '**/test-results.xml'
testRunTitle: 'Python $(python.version), image $(imageName)'
condition: succeededOrFailed()
- task: PublishCodeCoverageResults@1
displayName: 'Publish Code Coverage Results'
inputs:

Просмотреть файл

@ -167,6 +167,9 @@ The base class of all the methods in our API has the following signature:
Linear in Treatment CATE Estimators
-----------------------------------
.. rubric::
Constant Marginal Effects
In many settings, we might want to make further structural assumptions on the form of the data generating process.
One particular prevalent assumption is that the outcome :math:`y` is linear in the treatment vector and therefore that the marginal effect is constant across treatments, i.e.:
@ -188,13 +191,45 @@ Hence, the marginal CATE is independent of :math:`\vec{t}`. In these settings, w
.. math ::
\theta(\vec{x}) = \E[H(X, W) | X=\vec{x}] \tag{constant marginal CATE}
.. rubric::
Constant Marginal Effects and Marginal Effects Given Treatment Featurization
Additionally, we may be interested in cases where the outcome depends linearly on a transformation of the treatment vector (via some featurizer :math:`\phi`).
Some estimators provide support for passing such a featurizer :math:`\phi` directly to the estimator, in which case the outcome would be modeled as follows:
.. math ::
Y = H(X, W) \cdot \phi(T) + g(X, W, \epsilon)
We can then get constant marginal effects in the featurized treatment space:
.. math ::
\tau(\phi(\vec{t_0}), \phi(\vec{t_1}), \vec{x}) =~& \E[H(X, W) | X=\vec{x}] \cdot (\phi(\vec{t_1}) - \phi(\vec{t_0}))
\partial \tau(\phi(\vec{t}), \vec{x}) =~& \E[H(X, W) | X=\vec{x}]
\theta(\vec{x}) =~& \E[H(X, W) | X=\vec{x}]
Finally, we can recover the marginal effect with respect to the original treatment space by multiplying the constant marginal effect (which is in featurized treatment space) with the jacobian of the treatment featurizer at :math:`\vec{t}`.
.. math ::
\partial \tau(\vec{t}, \vec{x}) = \theta(\vec{x}) \nabla \phi(\vec{t}) \tag{marginal CATE}
where :math:`\nabla \phi(\vec{t})` is the :math:`d_{ft} \times d_{t}` jacobian matrix, and :math:`d_{ft}` and :math:`d_{t}` are the dimensions of the featurized treatment and the original treatment, respectively.
.. rubric::
API for Linear in Treatment CATE Estimators
Given the prevalence of linear treatment effect assumptions, we will create a generic LinearCateEstimator, which will support a method that returns the constant marginal CATE
and constant marginal CATE interval at any target feature vector :math:`\vec{x}`.
and constant marginal CATE interval at any target feature vector :math:`\vec{x}`, as well as calculating marginal effects in the original treatment space when a treatment featurizer is provided.
.. code-block:: python3
:caption: Linear CATE Estimator Class
class LinearCateEstimator(BaseCateEstimator):
self.treatment_featurizer = None
def const_marginal_effect(self, X=None):
''' Calculates the constant marginal CATE θ(·) conditional on a vector of
@ -204,8 +239,9 @@ and constant marginal CATE interval at any target feature vector :math:`\vec{x}`
X: optional (m × d_x) matrix of features for each sample
Returns:
theta: (m × d_y × d_t) matrix of constant marginal CATE of each treatment
on each outcome for each sample
theta: (m × d_y × d_f_t) matrix of constant marginal CATE of each treatment on each outcome
for each sample, where d_f_t is the dimension of the featurized treatment.
If treatment_featurizer is None, d_f_t = d_t
'''
def const_marginal_effect_interval(self, X=None, *, alpha=0.05):
@ -222,13 +258,25 @@ and constant marginal CATE interval at any target feature vector :math:`\vec{x}`
'''
def effect(self, X=None, *, T0, T1,):
return const_marginal_effect(X) * (T1 - T0)
if self.treatment_featurizer:
return const_marginal_effect(X) * (T1 - T0)
else:
dt = self.treatment_featurizer.transform(T1) - self.treatment_featurizer.transform(T0)
return const_marginal_effect(X) * dt
def marginal_effect(self, T, X=None)
return const_marginal_effect(X)
if self.treatment_featurizer is None:
return const_marginal_effect(X)
else:
# for every observation X_i, T_i,
# calculate jacobian at T_i and multiply with const_marginal_effect at X_i
def marginal_effect_interval(self, T, X=None, *, alpha=0.05):
return const_marginal_effect_interval(X, alpha=alpha)
if self.treatment_featurizer is None:
return const_marginal_effect_interval(X, alpha=alpha)
else:
# perform separate treatment featurization inference logic

Просмотреть файл

@ -508,7 +508,33 @@ Usage FAQs
The method is going to assume that each of these treatments enters linearly into the model. So it cannot capture complementarities or substitutabilities
of the different treatments. For that you can also create composite treatments that look like the product
of two base treatments. Then these product will enter in the model and an effect for that product will be estimated.
This effect will be the substitute/complement effect of both treatments being present, i.e.:
This effect will be the substitute/complement effect of both treatments being present. See below for more examples.
If you have too many treatments, then you can use the :class:`.SparseLinearDML`. However,
this method will essentially impose a regularization that only a small subset of your featurized treatments has any effect.
- **What if my treatments are continuous and don't have a linear effect on the outcome?**
You can impose a particular form of non-linearity by specifying a `treatment_featurizer` to the estimator.
For example, one can use the sklearn `PolynomialFeatures` transformer as a `treatment_featurizer` in order to learn
higher-order polynomial treatment effects.
Using the `treatment_featurizer` argument additionally has the benefit of calculating marginal effects with respect to the original treatment dimension,
as opposed to featurizing the treatment yourself before passing to the estimator.
.. testcode::
from econml.dml import LinearDML
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
est = LinearDML(treatment_featurizer=poly)
est.fit(y, T, X=X, W=W)
point = est.const_marginal_effect(X)
est.effect(X, T0=T0, T1=T1)
est.marginal_effect(T, X)
Alternatively, you can still create composite treatments and add them as extra treatment variables:
.. testcode::
@ -521,14 +547,6 @@ Usage FAQs
point = est.const_marginal_effect(X)
est.effect(X, T0=poly.transform(T0), T1=poly.transform(T1))
If your treatments are too many, then you can use the :class:`.SparseLinearDML`. However,
this method will essentially impose a regularization that only a small subset of them has any effect.
- **What if my treatments are continuous and don't have a linear effect on the outcome?**
You can create composite treatments and add them as extra treatment variables (see above). This would require
imposing a particular form of non-linearity.
- **What if my treatment is categorical/binary?**
You can simply set `discrete_treatment=True` in the parameters of the class. Then use any classifier for
@ -691,14 +709,15 @@ We can even create a Pipeline or Union of featurizers that will apply multiply f
.. rubric:: Single Outcome, Multiple Treatments
Suppose that we believed that our treatment was affecting the outcome in a non-linear manner.
Then we could expand the treatment vector to contain also polynomial features:
Suppose we want to estimate treatment effects for multiple continuous treatments at the same time.
Then we can simply concatenate them before passing them to the estimator.
.. testcode::
import numpy as np
est = LinearDML()
est.fit(y, np.concatenate((T, T**2), axis=1), X=X, W=W)
est.fit(y, np.concatenate((T0, T1), axis=1), X=X, W=W)
.. rubric:: Multiple Outcome, Multiple Treatments

Просмотреть файл

@ -511,8 +511,8 @@ Usage FAQs
Usage FAQs
==========
Usage Examples
==============
Check out the following Jupyter notebooks:

Просмотреть файл

@ -9,8 +9,8 @@ from functools import wraps
from copy import deepcopy
from warnings import warn
from .inference import BootstrapInference
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params,
inverse_onehot, Summary, get_input_columns, check_input_arrays)
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params, get_feature_names_or_default,
inverse_onehot, Summary, get_input_columns, check_input_arrays, jacify_featurizer)
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
GenericModelFinalInferenceDiscrete
@ -320,14 +320,16 @@ class BaseCateEstimator(metaclass=abc.ABCMeta):
"""
return (X,) + Ts
def _use_inference_method(self, name, *args, **kwargs):
if self._inference is not None:
return getattr(self._inference, name)(*args, **kwargs)
else:
raise AttributeError("Can't call '%s' because 'inference' is None" % name)
def _defer_to_inference(m):
@wraps(m)
def call(self, *args, **kwargs):
name = m.__name__
if self._inference is not None:
return getattr(self._inference, name)(*args, **kwargs)
else:
raise AttributeError("Can't call '%s' because 'inference' is None" % name)
return self._use_inference_method(m.__name__, *args, **kwargs)
return call
@_defer_to_inference
@ -534,7 +536,11 @@ class BaseCateEstimator(metaclass=abc.ABCMeta):
class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
"""
Base class for all CATE estimators in this package where the outcome is linear given
some user-defined treatment featurization.
"""
_original_treatment_featurizer = None
@abc.abstractmethod
def const_marginal_effect(self, X=None):
@ -551,10 +557,11 @@ class LinearCateEstimator(BaseCateEstimator):
Returns
-------
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
Constant marginal CATE of each treatment on each outcome for each sample X[i].
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
theta: (m, d_y, d_f_t) matrix or (d_y, d_f_t) matrix if X is None where d_f_t is \
the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t.
Constant marginal CATE of each featurized treatment on each outcome for each sample X[i].
Note that when Y or featurized-T (or T if treatment_featurizer is None) is a vector
rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""
pass
@ -607,7 +614,8 @@ class LinearCateEstimator(BaseCateEstimator):
The marginal effect is calculated around a base treatment
point conditional on a vector of features on a set of m test samples :math:`\\{T_i, X_i\\}`.
Since this class assumes a linear model, the base treatment is ignored in this calculation.
If treatment_featurizer is None, the base treatment is ignored in this calculation and the result
is equivalent to const_marginal_effect.
Parameters
----------
@ -624,24 +632,49 @@ class LinearCateEstimator(BaseCateEstimator):
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""
X, T = self._expand_treatments(X, T)
X, T = self._expand_treatments(X, T, transform=False)
eff = self.const_marginal_effect(X)
return np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
if X is None:
eff = np.repeat(eff, shape(T)[0], axis=0)
if self._original_treatment_featurizer:
feat_T = self.transformer.transform(T)
jac_T = self.transformer.jac(T)
einsum_str = 'myf, mtf->myt'
if ndim(T) == 1:
einsum_str = einsum_str.replace('t', '')
if ndim(feat_T) == 1:
einsum_str = einsum_str.replace('f', '')
if (ndim(eff) == ndim(feat_T)):
einsum_str = einsum_str.replace('y', '')
return np.einsum(einsum_str, eff, jac_T)
else:
return eff
def marginal_effect_interval(self, T, X=None, *, alpha=0.05):
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
if X is None: # need to repeat by the number of rows of T to ensure the right shape
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
return effs
if self._original_treatment_featurizer:
return self._use_inference_method('marginal_effect_interval', T, X, alpha=alpha)
else:
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
if X is None: # need to repeat by the number of rows of T to ensure the right shape
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
return effs
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__
def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
if X is None:
cme_inf = cme_inf._expand_outputs(shape(T)[0])
return cme_inf
if self._original_treatment_featurizer:
return self._use_inference_method('marginal_effect_inference', T, X)
else:
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
if X is None:
cme_inf = cme_inf._expand_outputs(shape(T)[0])
return cme_inf
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__
@BaseCateEstimator._defer_to_inference
@ -697,11 +730,12 @@ class LinearCateEstimator(BaseCateEstimator):
Returns
-------
theta: (d_y, d_t) matrix
theta: (d_y, d_f_t) matrix where d_f_t is the dimension of the featurized treatment. \
If treatment_featurizer is None, d_f_t = d_t.
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
Note that when Y or featurized-T (or T if treatment_featurizer is None) is a vector
rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a scalar)
"""
return np.mean(self.const_marginal_effect(X=X), axis=0)
@ -748,15 +782,17 @@ class LinearCateEstimator(BaseCateEstimator):
pass
def marginal_ate(self, T, X=None):
return self.const_marginal_ate(X=X)
return np.mean(self.marginal_effect(T, X=X), axis=0)
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__
@BaseCateEstimator._defer_to_inference
def marginal_ate_interval(self, T, X=None, *, alpha=0.05):
return self.const_marginal_ate_interval(X=X, alpha=alpha)
pass
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__
@BaseCateEstimator._defer_to_inference
def marginal_ate_inference(self, T, X=None):
return self.const_marginal_ate_inference(X=X)
pass
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
@ -769,7 +805,7 @@ class LinearCateEstimator(BaseCateEstimator):
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
The names of input features.
treatment_names: optional None or list (Default=None)
The name of treatment. In discrete treatment scenario, the name should not include the name of
The name of featurized treatment. In discrete treatment scenario, the name should not include the name of
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
@ -793,9 +829,13 @@ class LinearCateEstimator(BaseCateEstimator):
class TreatmentExpansionMixin(BaseCateEstimator):
"""Mixin which automatically handles promotions of scalar treatments to the appropriate shape."""
"""
Mixin which automatically handles promotions of scalar treatments to the appropriate shape,
as well as treatment featurization for discrete treatments and user-specified treatment transformers
"""
transformer = None
_original_treatment_featurizer = None
def _prefit(self, Y, T, *args, **kwargs):
super()._prefit(Y, T, *args, **kwargs)
@ -808,7 +848,7 @@ class TreatmentExpansionMixin(BaseCateEstimator):
if self.transformer:
self._set_transformed_treatment_names()
def _expand_treatments(self, X=None, *Ts):
def _expand_treatments(self, X=None, *Ts, transform=True):
X, *Ts = check_input_arrays(X, *Ts)
n_rows = 1 if X is None else shape(X)[0]
outTs = []
@ -820,23 +860,29 @@ class TreatmentExpansionMixin(BaseCateEstimator):
if ndim(T) == 0:
T = np.full((n_rows,) + self._d_t_in, T)
if self.transformer:
T = self.transformer.transform(reshape(T, (-1, 1)))
if self.transformer and transform:
if not self._original_treatment_featurizer:
T = T.reshape(-1, 1)
T = self.transformer.transform(T)
outTs.append(T)
return (X,) + tuple(outTs)
def _set_transformed_treatment_names(self):
"""Works with sklearn OHEs"""
"""
Extracts treatment names from sklearn transformers.
Or, if transformer does not have a get_feature_names method, sets default treatment names.
"""
if hasattr(self, "_input_names"):
self._input_names["treatment_names"] = self.transformer.get_feature_names(
self._input_names["treatment_names"]).tolist()
ret = get_feature_names_or_default(self.transformer, self._input_names["treatment_names"], prefix='T')
self._input_names["treatment_names"] = list(ret) if ret is not None else ret
def cate_treatment_names(self, treatment_names=None):
"""
Get treatment names.
If the treatment is discrete, it will return expanded treatment names.
If the treatment is discrete or featurized, it will return expanded treatment names.
Parameters
----------
@ -851,7 +897,8 @@ class TreatmentExpansionMixin(BaseCateEstimator):
"""
if treatment_names is not None:
if self.transformer:
return self.transformer.get_feature_names(treatment_names).tolist()
ret = get_feature_names_or_default(self.transformer, treatment_names)
return list(ret) if ret is not None else None
return treatment_names
# Treatment names is None, default to BaseCateEstimator
return super().cate_treatment_names()
@ -891,6 +938,7 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
Whether the CATE model's intercept is contained in the final model's ``coef_`` rather
than as a separate ``intercept_``
"""
featurizer = None
def _get_inference_options(self):
options = super()._get_inference_options()
@ -1030,14 +1078,31 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
output_names = self.cate_output_names(output_names)
# Summary
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and treatment $j$. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
extra_txt = ["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:"]
if self._original_treatment_featurizer:
extra_txt.append("$Y = \\Theta(X)\\cdot \\psi(T) + g(X, W) + \\epsilon$")
extra_txt.append("where $\\psi(T)$ is the output of the `treatment_featurizer")
extra_txt.append(
"and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
else:
extra_txt.append("$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$")
extra_txt.append(
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
if self.featurizer:
extra_txt.append("$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$")
extra_txt.append("where $\\phi(X)$ is the output of the `featurizer`")
else:
extra_txt.append("$\\Theta_{ij}(X) = X' coef_{ij} + cate\\_intercept_{ij}$")
extra_txt.append("Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and treatment $j$. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>")
smry.add_extra_txt(extra_txt)
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
try:

Просмотреть файл

@ -44,7 +44,7 @@ from ._cate_estimator import (BaseCateEstimator, LinearCateEstimator,
from .inference import BootstrapInference
from .utilities import (_deprecate_positional, check_input_arrays,
cross_product, filter_none_kwargs,
inverse_onehot, ndim, reshape, shape, transpose)
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)
def _crossfit(model, folds, *args, **kwargs):
@ -256,6 +256,11 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer` or None
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
discrete_instrument: bool
Whether the instrument values should be treated as categorical, rather than continuous, quantities
@ -337,8 +342,8 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
np.random.seed(123)
X = np.random.normal(size=(100, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
categories='auto', random_state=None)
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
>>> est.score_
@ -396,7 +401,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
categories='auto', random_state=None)
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, T, W=W)
>>> est.score_
@ -427,10 +432,12 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
"""
def __init__(self, *,
discrete_treatment, discrete_instrument, categories, cv, random_state,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean'):
self.cv = cv
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
self.discrete_instrument = discrete_instrument
self.random_state = random_state
self.categories = categories
@ -595,6 +602,8 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
self._random_state = check_random_state(self.random_state)
assert (freq_weight is None) == (
sample_var is None), "Sample variances and frequency weights must be provided together!"
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"
if check_input:
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
@ -609,6 +618,11 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
self.transformer.fit(reshape(T, (-1, 1)))
self._d_t = (len(self.transformer.categories_[0]) - 1,)
elif self.treatment_featurizer:
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
self.transformer = jacify_featurizer(self.treatment_featurizer)
output_T = self.transformer.fit_transform(T)
self._d_t = np.shape(output_T)[1:]
else:
self.transformer = None
@ -675,10 +689,21 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
# _d_t is altered by fit nuisances to what prefit does. So we need to perform the same
# alteration even when we only want to fit_final.
if self.transformer is not None:
self._d_t = (len(self.transformer.categories_[0]) - 1,)
if self.discrete_treatment:
self._d_t = (len(self.transformer.categories_[0]) - 1,)
else:
output_T = self.transformer.fit_transform(T)
self._d_t = np.shape(output_T)[1:]
final_T = T
if self.transformer:
if (self.discrete_treatment):
final_T = self.transformer.transform(final_T.reshape(-1, 1))
else: # treatment featurizer case
final_T = output_T
self._fit_final(Y=Y,
T=self.transformer.transform(T.reshape((-1, 1))) if self.transformer is not None else T,
T=final_T,
X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
@ -733,8 +758,10 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
if strata is None:
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying
if self.discrete_treatment:
T = self.transformer.transform(reshape(T, (-1, 1)))
if self.transformer:
if self.discrete_treatment:
T = reshape(T, (-1, 1))
T = self.transformer.transform(T)
if self.discrete_instrument:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))

Просмотреть файл

@ -145,6 +145,11 @@ class _RLearner(_OrthoLearner):
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer` or None
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -218,7 +223,8 @@ class _RLearner(_OrthoLearner):
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
est = RLearner(cv=2, discrete_treatment=False, categories='auto', random_state=None)
est = RLearner(cv=2, discrete_treatment=False,
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
>>> est.const_marginal_effect(np.ones((1,1)))
@ -265,8 +271,10 @@ class _RLearner(_OrthoLearner):
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
"""
def __init__(self, *, discrete_treatment, categories, cv, random_state, mc_iters=None, mc_agg='mean'):
def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean'):
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
cv=cv,

Просмотреть файл

@ -197,6 +197,64 @@ class _GenericSingleOutcomeModelFinalWithCovInference(Inference):
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')
def marginal_effect_interval(self, T, X, alpha=0.05):
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
def marginal_effect_inference(self, T, X):
if X is None:
raise ValueError("This inference method currently does not support X=None!")
if not self._est._original_treatment_featurizer:
return self.const_marginal_effect_inference(X)
X, T = self._est._expand_treatments(X, T, transform=False)
if self.featurizer is not None:
X = self.featurizer.transform(X)
feat_T = self._est.transformer.transform(T)
jac_T = self._est.transformer.jac(T)
d_t_orig = T.shape[1:]
d_t_orig = d_t_orig[0] if d_t_orig else 1
d_y = self._d_y[0] if self._d_y else 1
d_t = self._d_t[0] if self._d_t else 1
output_shape = [X.shape[0]]
if self._d_y:
output_shape.append(self._d_y[0])
if T.shape[1:]:
output_shape.append(T.shape[1])
me_pred = np.zeros(shape=output_shape)
me_stderr = np.zeros(shape=output_shape)
for i in range(d_t_orig):
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
jac_index = [slice(None)]
me_index = [slice(None)]
if self._d_y:
me_index.append(slice(None))
if T.shape[1:]:
jac_index.append(i)
me_index.append(i)
if feat_T.shape[1:]: # if featurized T is not a vector
jac_index.append(slice(None))
jac_slice = jac_T[tuple(jac_index)]
if jac_slice.ndim == 1:
jac_slice.reshape((-1, 1))
e_pred, e_var = self.model_final.predict_projection_and_var(X, jac_slice)
e_stderr = np.sqrt(e_var)
if not self._d_y:
e_pred = e_pred.squeeze(axis=1)
e_stderr = e_stderr.squeeze(axis=1)
me_pred[tuple(me_index)] = e_pred
me_stderr[tuple(me_index)] = e_stderr
return NormalInferenceResults(d_t=d_t_orig, d_y=self.d_y, pred=me_pred,
pred_stderr=me_stderr, mean_pred_stderr=None, inf_type='effect')
class CausalForestDML(_BaseDML):
"""A Causal Forest [cfdml1]_ combined with double machine learning based residualization of the treatment
@ -227,6 +285,11 @@ class CausalForestDML(_BaseDML):
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
If featurizer=None, then CATE is trained on X.
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
@ -513,6 +576,7 @@ class CausalForestDML(_BaseDML):
model_y='auto',
model_t='auto',
featurizer=None,
treatment_featurizer=None,
discrete_treatment=False,
categories='auto',
cv=2,
@ -567,6 +631,7 @@ class CausalForestDML(_BaseDML):
self.n_jobs = n_jobs
self.verbose = verbose
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,

Просмотреть файл

@ -360,6 +360,11 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
If featurizer=None, then CATE is trained on X.
treatment_featurizer : :term:`transformer`, optional, default None
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
fit_cate_intercept : bool, optional, default True
Whether the linear CATE model should have a constant term.
@ -452,6 +457,7 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
def __init__(self, *,
model_y, model_t, model_final,
featurizer=None,
treatment_featurizer=None,
fit_cate_intercept=True,
linear_first_stages=False,
discrete_treatment=False,
@ -469,6 +475,7 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
self.model_t = clone(model_t, safe=False)
self.model_final = clone(model_final, safe=False)
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -585,6 +592,11 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
If featurizer=None, then CATE is trained on X.
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
fit_cate_intercept : bool, optional, default True
Whether the linear CATE model should have a constant term.
@ -670,6 +682,7 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
def __init__(self, *,
model_y='auto', model_t='auto',
featurizer=None,
treatment_featurizer=None,
fit_cate_intercept=True,
linear_first_stages=True,
discrete_treatment=False,
@ -682,6 +695,7 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
model_t=model_t,
model_final=None,
featurizer=featurizer,
treatment_featurizer=treatment_featurizer,
fit_cate_intercept=fit_cate_intercept,
linear_first_stages=linear_first_stages,
discrete_treatment=discrete_treatment,
@ -811,6 +825,11 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
If featurizer=None, then CATE is trained on X.
treatment_featurizer : :term:`transformer`, optional, default None
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
fit_cate_intercept : bool, optional, default True
Whether the linear CATE model should have a constant term.
@ -902,6 +921,7 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
tol=1e-4,
n_jobs=None,
featurizer=None,
treatment_featurizer=None,
fit_cate_intercept=True,
linear_first_stages=True,
discrete_treatment=False,
@ -921,6 +941,7 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
model_t=model_t,
model_final=None,
featurizer=featurizer,
treatment_featurizer=treatment_featurizer,
fit_cate_intercept=fit_cate_intercept,
linear_first_stages=linear_first_stages,
discrete_treatment=discrete_treatment,
@ -1041,6 +1062,11 @@ class KernelDML(DML):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -1101,7 +1127,9 @@ class KernelDML(DML):
"""
def __init__(self, model_y='auto', model_t='auto',
discrete_treatment=False, categories='auto',
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
fit_cate_intercept=True,
dim=20,
bw=1.0,
@ -1114,6 +1142,7 @@ class KernelDML(DML):
model_t=model_t,
model_final=None,
featurizer=None,
treatment_featurizer=treatment_featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=discrete_treatment,
categories=categories,
@ -1212,6 +1241,11 @@ class NonParamDML(_BaseDML):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -1282,6 +1316,7 @@ class NonParamDML(_BaseDML):
model_y, model_t, model_final,
featurizer=None,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -1295,6 +1330,7 @@ class NonParamDML(_BaseDML):
self.featurizer = clone(featurizer, safe=False)
self.model_final = clone(model_final, safe=False)
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,

Просмотреть файл

@ -44,7 +44,7 @@ from sklearn.linear_model import (LassoCV, LinearRegression,
from sklearn.ensemble import RandomForestRegressor
from .._ortho_learner import _OrthoLearner
from .._cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin,
from .._cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin, BaseCateEstimator,
ForestModelFinalCateEstimatorDiscreteMixin,
StatsModelsCateEstimatorDiscreteMixin, LinearCateEstimator)
from ..inference import GenericModelFinalInferenceDiscrete
@ -420,10 +420,54 @@ class DRLearner(_OrthoLearner):
mc_iters=mc_iters,
mc_agg=mc_agg,
discrete_treatment=True,
treatment_featurizer=None, # treatment featurization not supported with discrete treatment
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
random_state=random_state)
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X=None):
"""
Calculate the constant marginal CATE :math:`\\theta(·)`.
The marginal effect is conditional on a vector of
features on a set of m test samples X[i].
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
Constant marginal CATE of each treatment on each outcome for each sample X[i].
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""
return super().const_marginal_effect(X=X)
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return super().const_marginal_ate(X=X)
def _get_inference_options(self):
options = super()._get_inference_options()
if not self.multitask_model_final:

Просмотреть файл

@ -469,6 +469,7 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
self.model_y = clone(model_y, safe=False)
self.model_t = clone(model_t, safe=False)
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=None,
discrete_instrument=False,
categories=categories,
cv=GroupKFold(cv) if isinstance(cv, int) else cv,
@ -476,6 +477,49 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
mc_agg=mc_agg,
random_state=random_state)
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X=None):
"""
Calculate the constant marginal CATE :math:`\\theta(·)`.
The marginal effect is conditional on a vector of
features on a set of m test samples X[i].
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
Constant marginal CATE of each treatment on each outcome for each sample X[i].
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""
return super().const_marginal_effect(X=X)
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return super().const_marginal_ate(X=X)
def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
@ -705,13 +749,13 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
return feature_names
return get_feature_names_or_default(self.original_featurizer, feature_names)
def _expand_treatments(self, X, *Ts):
def _expand_treatments(self, X, *Ts, transform=True):
# Expand treatments for each time period
outTs = []
base_expand_treatments = super()._expand_treatments
for T in Ts:
if ndim(T) == 0:
one_T = base_expand_treatments(X, T)[1]
one_T = base_expand_treatments(X, T, transform=transform)[1]
one_T = one_T.reshape(-1, 1) if ndim(one_T) == 1 else one_T
T = np.tile(one_T, (1, self._n_periods, ))
else:
@ -720,7 +764,7 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
if self.transformer:
T = np.hstack([
base_expand_treatments(
X, T[:, [t]])[1] for t in range(self._n_periods)
X, T[:, [t]], transform=transform)[1] for t in range(self._n_periods)
])
outTs.append(T)
return (X,) + tuple(outTs)

Просмотреть файл

@ -15,7 +15,7 @@ from ._bootstrap import BootstrapEstimator
from ..sklearn_extensions.linear_model import StatsModelsLinearRegression
from ..utilities import (Summary, _safe_norm_ppf, broadcast_unit_treatments,
cross_product, inverse_onehot, ndim,
parse_final_model_params,
parse_final_model_params, jacify_featurizer,
reshape_treatmentwise_effects, shape, filter_none_kwargs)
"""Options for performing inference in estimators."""
@ -201,6 +201,42 @@ class GenericSingleTreatmentModelFinalInference(GenericModelFinalInference):
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
def marginal_effect_inference(self, T, X):
X, T = self._est._expand_treatments(X, T, transform=False)
cme_inf = self.const_marginal_effect_inference(X)
if not self._est._original_treatment_featurizer:
return cme_inf
feat_T = self._est.transformer.transform(T)
cme_pred = cme_inf.point_estimate
cme_stderr = cme_inf.stderr
jac_T = self._est.transformer.jac(T)
einsum_str = 'myf, mtf->myt'
if ndim(T) == 1:
einsum_str = einsum_str.replace('t', '')
if ndim(feat_T) == 1:
einsum_str = einsum_str.replace('f', '')
# y is a vector, rather than a 2D array
if (ndim(cme_pred) == ndim(feat_T)):
einsum_str = einsum_str.replace('y', '')
e_pred = np.einsum(einsum_str, cme_pred, jac_T)
e_stderr = np.einsum(einsum_str, cme_stderr, np.abs(jac_T)) if cme_stderr is not None else None
d_y = self._d_y[0] if self._d_y else 1
d_t = self._d_t[0] if self._d_t else 1
d_t_orig = T.shape[1:][0] if T.shape[1:] else 1
return NormalInferenceResults(d_t=d_t_orig, d_y=d_y, pred=e_pred,
pred_stderr=e_stderr, mean_pred_stderr=None, inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
def marginal_effect_interval(self, T, X, *, alpha=0.05):
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
class LinearModelFinalInference(GenericModelFinalInference):
"""
@ -274,6 +310,68 @@ class LinearModelFinalInference(GenericModelFinalInference):
inf_res.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
return inf_res
def marginal_effect_inference(self, T, X):
X, T = self._est._expand_treatments(X, T, transform=False)
if not self._est._original_treatment_featurizer:
return self.const_marginal_effect_inference(X)
if X is None:
X = np.ones((T.shape[0], 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
feat_T = self._est.transformer.transform(T)
jac_T = self._est.transformer.jac(T)
d_t_orig = T.shape[1:]
d_t_orig = d_t_orig[0] if d_t_orig else 1
d_y = self._d_y[0] if self._d_y else 1
d_t = self._d_t[0] if self._d_t else 1
output_shape = [X.shape[0]]
if self._d_y:
output_shape.append(self._d_y[0])
if T.shape[1:]:
output_shape.append(T.shape[1])
me_pred = np.zeros(shape=output_shape)
me_stderr = np.zeros(shape=output_shape)
mean_pred_stderr_res = np.zeros(shape=output_shape[1:])
for i in range(d_t_orig):
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
jac_index = [slice(None)]
me_index = [slice(None)]
if self._d_y:
me_index.append(slice(None))
if T.shape[1:]:
jac_index.append(i)
me_index.append(i)
if feat_T.shape[1:]: # if featurized T is not a vector
jac_index.append(slice(None))
XT = cross_product(X, jac_T[tuple(jac_index)])
e_pred = self._predict(XT).reshape(X.shape[:1] + self._d_y) # enforce output shape
e_stderr = self._prediction_stderr(XT).reshape(X.shape[:1] + self._d_y)
mean_XT = XT.mean(axis=0, keepdims=True)
mean_pred_stderr = self._prediction_stderr(mean_XT) # shape[0] will always be 1 here
# squeeze the first axis
mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0) if mean_pred_stderr is not None else None
if mean_pred_stderr is not None:
mean_pred_stderr_res[tuple(me_index[1:])] = mean_pred_stderr
me_pred[tuple(me_index)] = e_pred
me_stderr[tuple(me_index)] = e_stderr
return NormalInferenceResults(d_t=d_t_orig, d_y=d_y, pred=me_pred,
pred_stderr=me_stderr, mean_pred_stderr=mean_pred_stderr_res, inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
def marginal_effect_interval(self, T, X, *, alpha=0.05):
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
def coef__interval(self, *, alpha=0.05):
lo, hi = self.model_final.coef__interval(alpha)
lo_int, hi_int = self.model_final.intercept__interval(alpha)
@ -817,7 +915,7 @@ class InferenceResults(metaclass=abc.ABCMeta):
The mean value of the metric you'd like to test under null hypothesis.
decimals: optinal int (default=3)
Number of decimal places to round each column to.
tol: optinal float (default=0.001)
tol: optional float (default=0.001)
The stopping criterion. The iterations will stop when the outcome is less than ``tol``
output_names: optional list of strings or None (default is None)
The names of the outputs

Просмотреть файл

@ -237,6 +237,11 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
discrete_instrument: bool, optional, default False
Whether the instrument values should be treated as categorical, rather than continuous, quantities
@ -338,6 +343,7 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
featurizer=None,
fit_cate_intercept=True,
discrete_treatment=False,
treatment_featurizer=None,
discrete_instrument=False,
categories='auto',
cv=2,
@ -354,6 +360,7 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
super().__init__(discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -1039,6 +1046,11 @@ class DMLIV(_BaseDMLIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -1129,6 +1141,7 @@ class DMLIV(_BaseDMLIV):
featurizer=None,
fit_cate_intercept=True,
discrete_treatment=False,
treatment_featurizer=None,
discrete_instrument=False,
categories='auto',
cv=2,
@ -1142,6 +1155,7 @@ class DMLIV(_BaseDMLIV):
self.featurizer = clone(featurizer, safe=False)
self.fit_cate_intercept = fit_cate_intercept
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=discrete_instrument,
categories=categories,
cv=cv,
@ -1282,14 +1296,30 @@ class DMLIV(_BaseDMLIV):
feature_names = self.cate_feature_names(feature_names)
# Summary
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and treatment $j$. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
extra_txt = ["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:"]
if self._original_treatment_featurizer:
extra_txt.append("$Y = \\Theta(X)\\cdot \\psi(T) + g(X, W) + \\epsilon$")
extra_txt.append("where $\\psi(T)$ is the output of the `treatment_featurizer")
extra_txt.append(
"and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
else:
extra_txt.append("$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$")
extra_txt.append(
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
if self.featurizer:
extra_txt.append("$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$")
extra_txt.append("where $\\phi(X)$ is the output of the `featurizer`")
else:
extra_txt.append("$\\Theta_{ij}(X) = X' coef_{ij} + cate\\_intercept_{ij}$")
extra_txt.append("Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and treatment $j$. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>")
smry.add_extra_txt(extra_txt)
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
@ -1403,6 +1433,11 @@ class NonParamDMLIV(_BaseDMLIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
discrete_instrument: bool, optional, default False
Whether the instrument values should be treated as categorical, rather than continuous, quantities
@ -1495,6 +1530,7 @@ class NonParamDMLIV(_BaseDMLIV):
model_t_xwz="auto",
model_final,
discrete_treatment=False,
treatment_featurizer=None,
discrete_instrument=False,
featurizer=None,
categories='auto',
@ -1509,6 +1545,7 @@ class NonParamDMLIV(_BaseDMLIV):
self.featurizer = clone(featurizer, safe=False)
super().__init__(discrete_treatment=discrete_treatment,
discrete_instrument=discrete_instrument,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,

Просмотреть файл

@ -303,6 +303,7 @@ class _BaseDRIV(_OrthoLearner):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -315,6 +316,7 @@ class _BaseDRIV(_OrthoLearner):
self.opt_reweighted = opt_reweighted
super().__init__(discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -344,6 +346,8 @@ class _BaseDRIV(_OrthoLearner):
if len(T1.shape) > 1 and T1.shape[1] > 1:
if self.discrete_treatment:
raise AttributeError("DRIV only supports binary treatments")
elif self.treatment_featurizer: # defer possible failure to downstream logic
pass
else:
raise AttributeError("DRIV only supports single-dimensional continuous treatments")
if len(Z1.shape) > 1 and Z1.shape[1] > 1:
@ -548,6 +552,7 @@ class _DRIV(_BaseDRIV):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -567,6 +572,7 @@ class _DRIV(_BaseDRIV):
opt_reweighted=opt_reweighted,
discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -733,6 +739,11 @@ class DRIV(_DRIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -828,6 +839,7 @@ class DRIV(_DRIV):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -855,6 +867,7 @@ class DRIV(_DRIV):
opt_reweighted=opt_reweighted,
discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -1177,6 +1190,11 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -1283,6 +1301,7 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -1305,6 +1324,7 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
opt_reweighted=opt_reweighted,
discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -1493,6 +1513,11 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -1606,6 +1631,7 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -1635,6 +1661,7 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
opt_reweighted=opt_reweighted,
discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,
@ -1897,6 +1924,11 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
@ -2006,6 +2038,7 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
opt_reweighted=False,
discrete_instrument=False,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
cv=2,
mc_iters=None,
@ -2041,6 +2074,7 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
opt_reweighted=opt_reweighted,
discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
mc_iters=mc_iters,

Просмотреть файл

@ -40,8 +40,8 @@ from ._causal_tree import CausalTree
from ..inference import NormalInferenceResults
from ..inference._inference import Inference
from ..utilities import (reshape, reshape_Y_T, MAX_RAND_SEED, check_inputs, _deprecate_positional,
cross_product, inverse_onehot, check_input_arrays,
_RegressionWrapper, deprecated)
cross_product, inverse_onehot, check_input_arrays, jacify_featurizer,
_RegressionWrapper, deprecated, ndim)
from sklearn.model_selection import check_cv
# TODO: consider working around relying on sklearn implementation details
from ..sklearn_extensions.model_selection import _cross_val_predict
@ -209,6 +209,7 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
second_stage_parameter_estimator,
moment_and_mean_gradient_estimator,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
n_trees=500,
min_leaf_size=10, max_depth=10,
@ -244,6 +245,7 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
# Fit check
self.model_is_fitted = False
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
self.backend = backend
self.verbose = verbose
self.batch_size = batch_size
@ -312,7 +314,8 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
Returns
-------
Theta : matrix , shape (n, d_t)
Theta : matrix , shape (n, d_f_t) where d_f_t is \
the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t
Constant marginal CATE of each treatment for each sample.
"""
# TODO: Check performance
@ -501,6 +504,11 @@ class DMLOrthoForest(BaseOrthoForest):
one-hot-encoded and the model_T is treated as a classifier that must have a predict_proba
method.
treatment_featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
categories : array like or 'auto', optional (default='auto')
A list of pre-specified treatment categories. If 'auto' then categories are automatically
recognized at fit time.
@ -540,6 +548,7 @@ class DMLOrthoForest(BaseOrthoForest):
global_residualization=False,
global_res_cv=2,
discrete_treatment=False,
treatment_featurizer=None,
categories='auto',
n_jobs=-1,
backend='loky',
@ -567,6 +576,7 @@ class DMLOrthoForest(BaseOrthoForest):
self.random_state = check_random_state(random_state)
self.global_residualization = global_residualization
self.global_res_cv = global_res_cv
self.treatment_featurizer = treatment_featurizer
# Define nuisance estimators
nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
self.model_T, self.model_Y, self.random_state, second_stage=False,
@ -580,6 +590,7 @@ class DMLOrthoForest(BaseOrthoForest):
self.lambda_reg)
# Define
moment_and_mean_gradient_estimator = _DMLOrthoForest_moment_and_mean_gradient_estimator_func
super().__init__(
nuisance_estimator,
second_stage_nuisance_estimator,
@ -596,6 +607,7 @@ class DMLOrthoForest(BaseOrthoForest):
verbose=verbose,
batch_size=batch_size,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
random_state=self.random_state)
@ -635,6 +647,9 @@ class DMLOrthoForest(BaseOrthoForest):
"""
self._set_input_names(Y, T, X, set_flag=True)
Y, T, X, W = check_inputs(Y, T, X, W)
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"
if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
@ -643,6 +658,12 @@ class DMLOrthoForest(BaseOrthoForest):
d_t_in = T.shape[1:]
T = self.transformer.fit_transform(T.reshape(-1, 1))
self._d_t = T.shape[1:]
elif self.treatment_featurizer:
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
self.transformer = jacify_featurizer(self.treatment_featurizer)
d_t_in = T.shape[1:]
T = self.transformer.fit_transform(T)
self._d_t = np.shape(T)[1:]
if self.global_residualization:
cv = check_cv(self.global_res_cv, y=T, classifier=self.discrete_treatment)
@ -654,7 +675,7 @@ class DMLOrthoForest(BaseOrthoForest):
# weirdness of wrap_fit. We need to store d_t_in. But because wrap_fit decorates the parent
# fit, we need to set explicitly d_t_in here after super fit is called.
if self.discrete_treatment:
if self.discrete_treatment or self.treatment_featurizer:
self._d_t_in = d_t_in
return self
@ -995,12 +1016,44 @@ class DROrthoForest(BaseOrthoForest):
self._d_t_in = d_t_in
return self
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X):
"""Calculate the constant marginal CATE θ(·) conditional on a vector of features X.
Parameters
----------
X : array-like, shape (n, d_x)
Feature vector that captures heterogeneity.
Returns
-------
Theta : matrix , shape (n, d_t)
Constant marginal CATE of each treatment for each sample.
"""
X = check_array(X)
# Override to flatten output if T is flat
effects = super().const_marginal_effect(X=X)
return effects.reshape((-1,) + self._d_y + self._d_t)
const_marginal_effect.__doc__ = BaseOrthoForest.const_marginal_effect.__doc__
# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return super().const_marginal_ate(X=X)
@staticmethod
def nuisance_estimator_generator(propensity_model, model_Y, random_state=None, second_stage=False):
@ -1162,6 +1215,10 @@ class BLBInference(Inference):
This is called after the estimator's fit.
"""
self._estimator = estimator
self._d_t = estimator._d_t
self._d_y = estimator._d_y
self.d_t = self._d_t[0] if self._d_t else 1
self.d_y = self._d_y[0] if self._d_y else 1
# Test whether the input estimator is supported
if not hasattr(self._estimator, "_predict"):
raise TypeError("Unsupported estimator of type {}.".format(self._estimator.__class__.__name__) +
@ -1300,6 +1357,81 @@ class BLBInference(Inference):
output_names=self._estimator.cate_output_names(),
treatment_names=self._estimator.cate_treatment_names())
def _marginal_effect_inference_helper(self, T, X):
if not self._estimator._original_treatment_featurizer:
return self.const_marginal_effect_inference(X)
X, T = check_input_arrays(X, T)
X, T = self._estimator._expand_treatments(X, T, transform=False)
feat_T = self._estimator.transformer.transform(T)
jac_T = self._estimator.transformer.jac(T)
params, cov = zip(*(self._predict_wrapper(X)))
params = np.array(params)
cov = np.array(cov)
eff_einsum_str = 'mf, mtf-> mt'
# conditionally expand jacobian dimensions to align with einsum str
jac_index = [slice(None), slice(None), slice(None)]
if ndim(T) == 1:
jac_index[1] = None
if ndim(feat_T) == 1:
jac_index[2] = None
# Calculate the effects
eff = np.einsum(eff_einsum_str, params, jac_T[tuple(jac_index)])
# Calculate the standard deviations for the effects
d_t_orig = T.shape[1:]
d_t_orig = d_t_orig[0] if d_t_orig else 1
self.d_t_orig = d_t_orig
output_shape = [X.shape[0]]
if T.shape[1:]:
output_shape.append(T.shape[1])
scales = np.zeros(shape=output_shape)
for i in range(d_t_orig):
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
jac_index = [slice(None)]
me_index = [slice(None)]
if T.shape[1:]:
jac_index.append(i)
me_index.append(i)
if feat_T.shape[1:]: # if featurized T is not a vector
jac_index.append(slice(None))
else:
jac_index.append(None)
jac = jac_T[tuple(jac_index)]
final = np.einsum('mj, mjk, mk -> m', jac, cov, jac)
scales[tuple(me_index)] = final
eff = eff.reshape((-1,) + self._d_y + T.shape[1:])
scales = scales.reshape((-1,) + self._d_y + T.shape[1:])
return eff, scales
def marginal_effect_inference(self, T, X):
if self._estimator._original_treatment_featurizer is None:
return self.const_marginal_effect_inference(X)
eff, scales = self._marginal_effect_inference_helper(T, X)
d_y = self._d_y[0] if self._d_y else 1
d_t = self._d_t[0] if self._d_t else 1
return NormalInferenceResults(d_t=self.d_t_orig, d_y=d_y,
pred=eff, pred_stderr=scales, mean_pred_stderr=None, inf_type='effect',
feature_names=self._estimator.cate_feature_names(),
output_names=self._estimator.cate_output_names(),
treatment_names=self._estimator.cate_treatment_names())
def marginal_effect_interval(self, T, X, *, alpha=0.05):
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
def _predict_wrapper(self, X=None):
return self._estimator._predict(X, stderr=True)

Просмотреть файл

@ -1988,7 +1988,11 @@ class StatsModels2SLS(_StatsModelsWrapper):
# check dimension of instruments is more than dimension of treatments
if Z.shape[1] < T.shape[1]:
raise AssertionError("The number of treatments couldn't be larger than the number of instruments!")
raise AssertionError("The number of treatments couldn't be larger than the number of instruments!" +
" If you are using a treatment featurizer, make sure the number of featurized" +
" treatments is less than or equal to the number of instruments. You can either" +
" featurize the instrument yourself, or consider using projection = True" +
" along with a flexible model_t_xwz.")
# weight X and y
weighted_Z = Z * np.sqrt(sample_weight).reshape(-1, 1)

Просмотреть файл

@ -40,6 +40,10 @@ def rand_sol(A, b):
class TestDML(unittest.TestCase):
def test_cate_api(self):
treatment_featurizations = [None]
self._test_cate_api(treatment_featurizations)
def _test_cate_api(self, treatment_featurizations):
"""Test that we correctly implement the CATE API."""
n_c = 20 # number of rows for continuous models
n_d = 30 # number of rows for discrete models
@ -63,284 +67,293 @@ class TestDML(unittest.TestCase):
for d_t in [2, 1, -1]:
for is_discrete in [True, False] if d_t <= 1 else [False]:
for d_y in [3, 1, -1]:
for d_x in [2, None]:
for d_w in [2, None]:
n = n_d if is_discrete else n_c
W, X, Y, T = [make_random(n, is_discrete, d)
for is_discrete, d in [(False, d_w),
(False, d_x),
(False, d_y),
(is_discrete, d_t)]]
for treatment_featurizer in treatment_featurizations:
for d_y in [3, 1, -1]:
for d_x in [2, None]:
for d_w in [2, None]:
n = n_d if is_discrete else n_c
W, X, Y, T = [make_random(n, is_discrete, d)
for is_discrete, d in [(False, d_w),
(False, d_x),
(False, d_y),
(is_discrete, d_t)]]
for featurizer, fit_cate_intercept in\
[(None, True),
(PolynomialFeatures(degree=2, include_bias=False), True),
(PolynomialFeatures(degree=2, include_bias=True), False)]:
for featurizer, fit_cate_intercept in\
[(None, True),
(PolynomialFeatures(degree=2, include_bias=False), True),
(PolynomialFeatures(degree=2, include_bias=True), False)]:
d_t_final = 2 if is_discrete else d_t
effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
marginal_effect_shape = ((n,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
marginal_effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1) *
(d_t_final if d_t_final > 0 else 1), 6)
# since T isn't passed to const_marginal_effect, defaults to one row if X is None
const_marginal_effect_shape = ((n if d_x else 1,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
const_marginal_effect_summaryframe_shape = (
(n if d_x else 1) * (d_y if d_y > 0 else 1) *
(d_t_final if d_t_final > 0 else 1), 6)
fd_x = featurizer.fit_transform(X).shape[1:] if featurizer and d_x\
else ((d_x,) if d_x else (0,))
coef_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,)) + fd_x
coef_summaryframe_shape = (
(d_y if d_y > 0 else 1) * (fd_x[0] if fd_x[0] >
0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
intercept_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,))
intercept_summaryframe_shape = (
(d_y if d_y > 0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
model_t = LogisticRegression() if is_discrete else Lasso()
all_infs = [None, 'auto', BootstrapInference(2)]
for est, multi, infs in\
[(DML(model_y=Lasso(),
model_t=model_t,
model_final=Lasso(alpha=0.1, fit_intercept=False),
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
True,
[None] +
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
(DML(model_y=Lasso(),
model_t=model_t,
model_final=StatsModelsRLM(fit_intercept=False),
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
True,
['auto']),
(LinearDML(model_y=Lasso(),
model_t='auto',
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
True,
all_infs),
(SparseLinearDML(model_y=WeightedLasso(),
model_t=model_t,
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
True,
[None, 'auto'] +
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
(KernelDML(model_y=WeightedLasso(),
model_t=model_t,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
False,
[None]),
(CausalForestDML(model_y=WeightedLasso(),
model_t=model_t,
featurizer=featurizer,
n_estimators=4,
n_jobs=1,
discrete_treatment=is_discrete),
True,
['auto', 'blb'])]:
if not (multi) and d_y > 1:
if is_discrete and treatment_featurizer:
continue
if X is None and isinstance(est, CausalForestDML):
continue
d_t_final = 2 if is_discrete else d_t
# ensure we can serialize the unfit estimator
pickle.dumps(est)
effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
marginal_effect_shape = ((n,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
marginal_effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1) *
(d_t_final if d_t_final > 0 else 1), 6)
for inf in infs:
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
is_discrete=is_discrete, est=est, inf=inf):
# since T isn't passed to const_marginal_effect, defaults to one row if X is None
const_marginal_effect_shape = ((n if d_x else 1,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
const_marginal_effect_summaryframe_shape = (
(n if d_x else 1) * (d_y if d_y > 0 else 1) *
(d_t_final if d_t_final > 0 else 1), 6)
if X is None and (not fit_cate_intercept):
with pytest.raises(AttributeError):
est.fit(Y, T, X=X, W=W, inference=inf)
continue
fd_x = featurizer.fit_transform(X).shape[1:] if featurizer and d_x\
else ((d_x,) if d_x else (0,))
coef_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,)) + fd_x
est.fit(Y, T, X=X, W=W, inference=inf)
coef_summaryframe_shape = (
(d_y if d_y > 0 else 1) * (fd_x[0] if fd_x[0] >
0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
intercept_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,))
intercept_summaryframe_shape = (
(d_y if d_y > 0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
# ensure we can pickle the fit estimator
pickle.dumps(est)
model_t = LogisticRegression() if is_discrete else Lasso()
# make sure we can call the marginal_effect and effect methods
const_marg_eff = est.const_marginal_effect(X)
marg_eff = est.marginal_effect(T, X)
self.assertEqual(shape(marg_eff), marginal_effect_shape)
self.assertEqual(shape(const_marg_eff), const_marginal_effect_shape)
all_infs = [None, 'auto', BootstrapInference(2)]
np.testing.assert_allclose(
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
for est, multi, infs in\
[(DML(model_y=Lasso(),
model_t=model_t,
model_final=Lasso(alpha=0.1, fit_intercept=False),
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete,
treatment_featurizer=treatment_featurizer),
True,
[None] +
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
(DML(model_y=Lasso(),
model_t=model_t,
model_final=StatsModelsRLM(fit_intercept=False),
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete,
treatment_featurizer=treatment_featurizer),
True,
['auto']),
(LinearDML(model_y=Lasso(),
model_t='auto',
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete,
treatment_featurizer=treatment_featurizer),
True,
all_infs),
(SparseLinearDML(model_y=WeightedLasso(),
model_t=model_t,
featurizer=featurizer,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete,
treatment_featurizer=treatment_featurizer),
True,
[None, 'auto'] +
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
(KernelDML(model_y=WeightedLasso(),
model_t=model_t,
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete,
treatment_featurizer=treatment_featurizer),
False,
[None]),
(CausalForestDML(model_y=WeightedLasso(),
model_t=model_t,
featurizer=featurizer,
n_estimators=4,
n_jobs=1,
discrete_treatment=is_discrete),
True,
['auto', 'blb'])]:
assert isinstance(est.score_, float)
for score_list in est.nuisance_scores_y:
for score in score_list:
assert isinstance(score, float)
for score_list in est.nuisance_scores_t:
for score in score_list:
assert isinstance(score, float)
if not (multi) and d_y > 1:
continue
T0 = np.full_like(T, 'a') if is_discrete else np.zeros_like(T)
eff = est.effect(X, T0=T0, T1=T)
self.assertEqual(shape(eff), effect_shape)
if X is None and isinstance(est, CausalForestDML):
continue
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
self.assertEqual(shape(est.coef_), coef_shape)
if fit_cate_intercept:
self.assertEqual(shape(est.intercept_), intercept_shape)
else:
# ensure we can serialize the unfit estimator
pickle.dumps(est)
for inf in infs:
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
is_discrete=is_discrete, est=est, inf=inf):
if X is None and (not fit_cate_intercept):
with pytest.raises(AttributeError):
self.assertEqual(shape(est.intercept_), intercept_shape)
est.fit(Y, T, X=X, W=W, inference=inf)
continue
est.fit(Y, T, X=X, W=W, inference=inf)
# ensure we can pickle the fit estimator
pickle.dumps(est)
# make sure we can call the marginal_effect and effect methods
const_marg_eff = est.const_marginal_effect(X)
marg_eff = est.marginal_effect(T, X)
self.assertEqual(shape(marg_eff), marginal_effect_shape)
self.assertEqual(shape(const_marg_eff), const_marginal_effect_shape)
np.testing.assert_allclose(
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
assert isinstance(est.score_, float)
for score_list in est.nuisance_scores_y:
for score in score_list:
assert isinstance(score, float)
for score_list in est.nuisance_scores_t:
for score in score_list:
assert isinstance(score, float)
T0 = np.full_like(T, 'a') if is_discrete else np.zeros_like(T)
eff = est.effect(X, T0=T0, T1=T)
self.assertEqual(shape(eff), effect_shape)
if inf is not None:
const_marg_eff_int = est.const_marginal_effect_interval(X)
marg_eff_int = est.marginal_effect_interval(T, X)
self.assertEqual(shape(marg_eff_int),
(2,) + marginal_effect_shape)
self.assertEqual(shape(const_marg_eff_int),
(2,) + const_marginal_effect_shape)
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
(2,) + effect_shape)
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
self.assertEqual(shape(est.coef__interval()),
(2,) + coef_shape)
self.assertEqual(shape(est.coef_), coef_shape)
if fit_cate_intercept:
self.assertEqual(shape(est.intercept__interval()),
(2,) + intercept_shape)
self.assertEqual(shape(est.intercept_), intercept_shape)
else:
with pytest.raises(AttributeError):
self.assertEqual(shape(est.intercept_), intercept_shape)
if inf is not None:
const_marg_eff_int = est.const_marginal_effect_interval(X)
marg_eff_int = est.marginal_effect_interval(T, X)
self.assertEqual(shape(marg_eff_int),
(2,) + marginal_effect_shape)
self.assertEqual(shape(const_marg_eff_int),
(2,) + const_marginal_effect_shape)
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
(2,) + effect_shape)
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
self.assertEqual(shape(est.coef__interval()),
(2,) + coef_shape)
if fit_cate_intercept:
self.assertEqual(shape(est.intercept__interval()),
(2,) + intercept_shape)
else:
with pytest.raises(AttributeError):
self.assertEqual(shape(est.intercept__interval()),
(2,) + intercept_shape)
const_marg_effect_inf = est.const_marginal_effect_inference(X)
T1 = np.full_like(T, 'b') if is_discrete else T
effect_inf = est.effect_inference(X, T0=T0, T1=T1)
marg_effect_inf = est.marginal_effect_inference(T, X)
# test const marginal inference
self.assertEqual(shape(const_marg_effect_inf.summary_frame()),
const_marginal_effect_summaryframe_shape)
self.assertEqual(shape(const_marg_effect_inf.point_estimate),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.stderr),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.var),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.pvalue()),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.zstat()),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
(2,) + const_marginal_effect_shape)
np.testing.assert_array_almost_equal(
const_marg_effect_inf.conf_int()[0],
const_marg_eff_int[0], decimal=5)
const_marg_effect_inf.population_summary()._repr_html_()
const_marg_effect_inf = est.const_marginal_effect_inference(X)
T1 = np.full_like(T, 'b') if is_discrete else T
effect_inf = est.effect_inference(X, T0=T0, T1=T1)
marg_effect_inf = est.marginal_effect_inference(T, X)
# test const marginal inference
self.assertEqual(shape(const_marg_effect_inf.summary_frame()),
const_marginal_effect_summaryframe_shape)
self.assertEqual(shape(const_marg_effect_inf.point_estimate),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.stderr),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.var),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.pvalue()),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.zstat()),
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
(2,) + const_marginal_effect_shape)
np.testing.assert_array_almost_equal(
const_marg_effect_inf.conf_int()[0],
const_marg_eff_int[0], decimal=5)
const_marg_effect_inf.population_summary()._repr_html_()
# test effect inference
self.assertEqual(shape(effect_inf.summary_frame()),
effect_summaryframe_shape)
self.assertEqual(shape(effect_inf.point_estimate),
effect_shape)
self.assertEqual(shape(effect_inf.stderr),
effect_shape)
self.assertEqual(shape(effect_inf.var),
effect_shape)
self.assertEqual(shape(effect_inf.pvalue()),
effect_shape)
self.assertEqual(shape(effect_inf.zstat()),
effect_shape)
self.assertEqual(shape(effect_inf.conf_int()),
(2,) + effect_shape)
np.testing.assert_array_almost_equal(
effect_inf.conf_int()[0],
est.effect_interval(X, T0=T0, T1=T1)[0], decimal=5)
effect_inf.population_summary()._repr_html_()
# test effect inference
self.assertEqual(shape(effect_inf.summary_frame()),
effect_summaryframe_shape)
self.assertEqual(shape(effect_inf.point_estimate),
effect_shape)
self.assertEqual(shape(effect_inf.stderr),
effect_shape)
self.assertEqual(shape(effect_inf.var),
effect_shape)
self.assertEqual(shape(effect_inf.pvalue()),
effect_shape)
self.assertEqual(shape(effect_inf.zstat()),
effect_shape)
self.assertEqual(shape(effect_inf.conf_int()),
(2,) + effect_shape)
np.testing.assert_array_almost_equal(
effect_inf.conf_int()[0],
est.effect_interval(X, T0=T0, T1=T1)[0], decimal=5)
effect_inf.population_summary()._repr_html_()
# test marginal effect inference
self.assertEqual(shape(marg_effect_inf.summary_frame()),
marginal_effect_summaryframe_shape)
self.assertEqual(shape(marg_effect_inf.point_estimate),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.stderr),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.var),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.pvalue()),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.zstat()),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.conf_int()),
(2,) + marginal_effect_shape)
np.testing.assert_array_almost_equal(
marg_effect_inf.conf_int()[0], marg_eff_int[0], decimal=5)
marg_effect_inf.population_summary()._repr_html_()
# test marginal effect inference
self.assertEqual(shape(marg_effect_inf.summary_frame()),
marginal_effect_summaryframe_shape)
self.assertEqual(shape(marg_effect_inf.point_estimate),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.stderr),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.var),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.pvalue()),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.zstat()),
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.conf_int()),
(2,) + marginal_effect_shape)
np.testing.assert_array_almost_equal(
marg_effect_inf.conf_int()[0], marg_eff_int[0], decimal=5)
marg_effect_inf.population_summary()._repr_html_()
# test coef__inference and intercept__inference
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
if X is not None:
self.assertEqual(
shape(est.coef__inference().summary_frame()),
coef_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.coef__inference().conf_int()
[0], est.coef__interval()[0], decimal=5)
# test coef__inference and intercept__inference
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
if X is not None:
self.assertEqual(
shape(est.coef__inference().summary_frame()),
coef_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.coef__inference().conf_int()
[0], est.coef__interval()[0], decimal=5)
if fit_cate_intercept:
cm = ExitStack()
# ExitStack can be used as a "do nothing" ContextManager
else:
cm = pytest.raises(AttributeError)
with cm:
self.assertEqual(shape(est.intercept__inference().
summary_frame()),
intercept_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.intercept__inference().conf_int()
[0], est.intercept__interval()[0], decimal=5)
if fit_cate_intercept:
cm = ExitStack()
# ExitStack can be used as a "do nothing" ContextManager
else:
cm = pytest.raises(AttributeError)
with cm:
self.assertEqual(shape(est.intercept__inference().
summary_frame()),
intercept_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.intercept__inference().conf_int()
[0], est.intercept__interval()[0], decimal=5)
est.summary()
est.summary()
est.score(Y, T, X, W)
est.score(Y, T, X, W)
if isinstance(est, CausalForestDML):
np.testing.assert_array_equal(est.feature_importances_.shape,
((d_y,) if d_y > 0 else ()) + fd_x)
if isinstance(est, CausalForestDML):
np.testing.assert_array_equal(est.feature_importances_.shape,
((d_y,) if d_y > 0 else ()) + fd_x)
# make sure we can call effect with implied scalar treatments,
# no matter the dimensions of T, and also that we warn when there
# are multiple treatments
if d_t > 1:
cm = self.assertWarns(Warning)
else:
# ExitStack can be used as a "do nothing" ContextManager
cm = ExitStack()
with cm:
effect_shape2 = (n if d_x else 1,) + ((d_y,) if d_y > 0 else ())
eff = est.effect(X) if not is_discrete else est.effect(
X, T0='a', T1='b')
self.assertEqual(shape(eff), effect_shape2)
# make sure we can call effect with implied scalar treatments,
# no matter the dimensions of T, and also that we warn when there
# are multiple treatments
if d_t > 1:
cm = self.assertWarns(Warning)
else:
# ExitStack can be used as a "do nothing" ContextManager
cm = ExitStack()
with cm:
effect_shape2 = (n if d_x else 1,) + ((d_y,) if d_y > 0 else ())
eff = est.effect(X) if not is_discrete else est.effect(
X, T0='a', T1='b')
self.assertEqual(shape(eff), effect_shape2)
def test_cate_api_nonparam(self):
"""Test that we correctly implement the CATE API."""
@ -1183,3 +1196,38 @@ class TestDML(unittest.TestCase):
est = LinearDML(model_y=LassoCV(cv=5), model_t=LassoCV(cv=5), cv=GroupKFold(2))
with pytest.raises(Exception):
est.fit(y, t, groups=groups)
def test_treatment_names(self):
Y = np.random.normal(size=(100, 1))
T = np.random.binomial(n=1, p=0.5, size=(100, 1))
X = np.random.normal(size=(100, 3))
Ts = [
T,
pd.DataFrame(T, columns=[0])
]
init_args_list = [
{'discrete_treatment': True},
{'treatment_featurizer': PolynomialFeatures(degree=2, include_bias=False)},
{'treatment_featurizer': FunctionTransformer(lambda x: np.hstack([x, np.sqrt(x)]))},
]
for T in Ts:
for init_args in init_args_list:
est = LinearDML(**init_args).fit(Y=Y, T=T, X=X)
t_name = '0' if isinstance(T, pd.DataFrame) else 'T0' # default treatment name
postfixes = ['_1'] if 'discrete_treatment' in init_args else ['', '^2'] # transformer postfixes
# Try default, integer, and new user-passed treatment name
for new_treatment_name in [None, [999], ['NewTreatmentName']]:
# FunctionTransformers are agnostic to passed treatment names
if isinstance(init_args.get('treatment_featurizer'), FunctionTransformer):
assert (est.cate_treatment_names(new_treatment_name) == ['feat(T)0', 'feat(T)1'])
# Expected treatment names are the sums of user-passed prefixes and transformer-specific postfixes
else:
expected_prefix = str(new_treatment_name[0]) if new_treatment_name is not None else t_name
assert (est.cate_treatment_names(new_treatment_name) == [
expected_prefix + postfix for postfix in postfixes])

Просмотреть файл

@ -220,67 +220,70 @@ class TestOrthoForest(unittest.TestCase):
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
from sklearn.preprocessing import FunctionTransformer
from sklearn.dummy import DummyClassifier, DummyRegressor
for global_residualization in [False, True]:
est = DMLOrthoForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
model_T=DummyRegressor(strategy='mean'),
global_residualization=global_residualization,
n_jobs=1)
est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
est.fit(y.reshape(-1, 1), T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
print(lb.shape)
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
est.fit(y, T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3,), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3,), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
for treatment_featurization in [None, FunctionTransformer()]:
est = DMLOrthoForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
model_T=DummyRegressor(strategy='mean'),
global_residualization=global_residualization,
treatment_featurizer=treatment_featurization,
n_jobs=1)
est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
est.fit(y.reshape(-1, 1), T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
print(lb.shape)
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
est.fit(y, T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3,), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3,), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
def test_nuisance_model_has_weights(self):
"""Test whether the correct exception is being raised if model_final doesn't have weights."""

Просмотреть файл

@ -170,8 +170,8 @@ class TestOrthoLearner(unittest.TestCase):
X = np.random.normal(size=(10000, 3))
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False, categories='auto',
random_state=None)
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
np.testing.assert_array_almost_equal(est.effect(), np.ones(1), decimal=3)
@ -187,7 +187,7 @@ class TestOrthoLearner(unittest.TestCase):
X = np.random.normal(size=(10000, 3))
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
categories='auto', random_state=None)
# test non-array inputs
est.fit(list(y), list(X[:, 0]), X=None, W=X[:, 1:])
@ -204,7 +204,7 @@ class TestOrthoLearner(unittest.TestCase):
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=KFold(n_splits=3),
discrete_treatment=False, discrete_instrument=False,
discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, X[:, 0], X=None, W=X[:, 1:])
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
@ -220,7 +220,7 @@ class TestOrthoLearner(unittest.TestCase):
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
folds = [(np.arange(X.shape[0] // 2), np.arange(X.shape[0] // 2, X.shape[0]))]
est = OrthoLearner(cv=folds, discrete_treatment=False,
est = OrthoLearner(cv=folds, discrete_treatment=False, treatment_featurizer=None,
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=None, W=X[:, 1:])
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=2)
@ -268,7 +268,7 @@ class TestOrthoLearner(unittest.TestCase):
X = np.random.normal(size=(10000, 3))
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
@ -318,7 +318,7 @@ class TestOrthoLearner(unittest.TestCase):
X = np.random.normal(size=(10000, 3))
sigma = 0.1
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
@ -380,7 +380,7 @@ class TestOrthoLearner(unittest.TestCase):
T = np.random.binomial(1, scipy.special.expit(X[:, 0]))
sigma = 0.01
y = T + X[:, 0] + np.random.normal(0, sigma, size=(10000,))
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=True, treatment_featurizer=None, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, T, W=X)
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)

Просмотреть файл

@ -0,0 +1,580 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import pytest
import unittest
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import RandomForestRegressor
from joblib import Parallel, delayed
from econml._ortho_learner import _OrthoLearner
from econml.dml import LinearDML, SparseLinearDML, KernelDML, CausalForestDML, NonParamDML
from econml.iv.dml import OrthoIV, DMLIV, NonParamDMLIV
from econml.iv.dr import DRIV, LinearDRIV, SparseLinearDRIV, ForestDRIV
from econml.orf import DMLOrthoForest
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
from econml.utilities import jacify_featurizer
from econml.iv.sieve import DPolynomialFeatures
from econml.tests.test_dml import TestDML
from copy import deepcopy
from econml.tests.test_dml import TestDML
class DGP():
def __init__(self,
n=1000,
d_t=1,
d_y=1,
d_x=5,
d_z=None,
squeeze_T=False,
squeeze_Y=False,
nuisance_Y=None,
nuisance_T=None,
nuisance_TZ=None,
theta=None,
y_of_t=None,
x_eps=1,
y_eps=1,
t_eps=1
):
self.n = n
self.d_t = d_t
self.d_y = d_y
self.d_x = d_x
self.d_z = d_z
self.squeeze_T = squeeze_T
self.squeeze_Y = squeeze_Y
self.nuisance_Y = nuisance_Y if nuisance_Y else lambda X: 0
self.nuisance_T = nuisance_T if nuisance_T else lambda X: 0
self.nuisance_TZ = nuisance_TZ if nuisance_TZ else lambda X: 0
self.theta = theta if theta else lambda X: 1
self.y_of_t = y_of_t if y_of_t else lambda X: 0
self.x_eps = x_eps
self.y_eps = y_eps
self.t_eps = t_eps
def gen_Y(self):
noise = np.random.normal(size=(self.n, self.d_y), scale=self.y_eps)
self.Y = self.theta(self.X) * self.y_of_t(self.T) + self.nuisance_Y(self.X) + noise
return self.Y
def gen_X(self):
self.X = np.random.normal(size=(self.n, self.d_x), scale=self.x_eps)
return self.X
def gen_T(self):
noise = np.random.normal(size=(self.n, self.d_t), scale=self.t_eps)
self.T_noise = noise
self.T = noise + self.nuisance_T(self.X) + self.nuisance_TZ(self.Z)
return self.T
def gen_Z(self):
if self.d_z:
Z_noise = np.random.normal(size=(self.n, self.d_z), loc=3, scale=3)
self.Z = Z_noise
return self.Z
else:
self.Z = None
return self.Z
def gen_data(self):
X = self.gen_X()
Z = self.gen_Z()
T = self.gen_T()
Y = self.gen_Y()
if self.squeeze_T:
T = T.squeeze()
if self.squeeze_Y:
Y = Y.squeeze()
data_dict = {
'Y': Y,
'T': T,
'X': X
}
if self.d_z:
data_dict['Z'] = Z
return data_dict
def actual_effect(y_of_t, T0, T1):
return y_of_t(T1) - y_of_t(T0)
def nuisance_T(X):
return -0.3 * X[:, [1]]
def nuisance_Y(X):
return 0.2 * X[:, [0]]
# identity featurization effect functions
def identity_y_of_t(T):
return T
def identity_actual_marginal(T):
return np.ones(shape=(T.shape))
def identity_actual_cme():
return 1
identity_treatment_featurizer = FunctionTransformer()
# polynomial featurization effect functions
def poly_y_of_t(T):
return 0.5 * T**2
def poly_actual_marginal(t):
return t
def poly_actual_cme():
return np.array([0, 0.5])
def poly_func_transform(x):
x = x.reshape(-1, 1)
return np.hstack([x, x**2])
polynomial_treatment_featurizer = FunctionTransformer(func=poly_func_transform)
# 1d polynomial featurization functions
def poly_1d_actual_cme():
return 0.5
def poly_1d_func_transform(x):
return x**2
polynomial_1d_treatment_featurizer = FunctionTransformer(func=poly_1d_func_transform)
# 2d-to-1d featurization functions
def sum_y_of_t(T):
return 0.5 * T.sum(axis=1, keepdims=True)
def sum_actual_cme():
return 0.5
def sum_actual_marginal(t):
return np.ones(shape=t.shape) * 0.5
def sum_func_transform(x):
return x.sum(axis=1, keepdims=True)
sum_treatment_featurizer = FunctionTransformer(func=sum_func_transform)
# 2d-to-1d vector featurization functions
def sum_squeeze_func_transform(x):
return x.sum(axis=1, keepdims=False)
sum_squeeze_treatment_featurizer = FunctionTransformer(func=sum_squeeze_func_transform)
@pytest.mark.treatment_featurization
class TestTreatmentFeaturization(unittest.TestCase):
def test_featurization(self):
identity_config = {
'DGP_params': {
'n': 2000,
'd_t': 1,
'd_y': 1,
'd_x': 5,
'squeeze_T': False,
'squeeze_Y': False,
'nuisance_Y': nuisance_Y,
'nuisance_T': nuisance_T,
'theta': None,
'y_of_t': identity_y_of_t,
'x_eps': 1,
'y_eps': 1,
't_eps': 1
},
'treatment_featurizer': identity_treatment_featurizer,
'actual_marginal': identity_actual_marginal,
'actual_cme': identity_actual_cme,
'squeeze_Ts': [False, True],
'squeeze_Ys': [False, True],
'est_dicts': [
{'class': LinearDML, 'init_args': {}},
{'class': CausalForestDML, 'init_args': {}},
{'class': SparseLinearDML, 'init_args': {}},
{'class': KernelDML, 'init_args': {}},
]
}
poly_config = {
'DGP_params': {
'n': 2000,
'd_t': 1,
'd_y': 1,
'd_x': 5,
'squeeze_T': False,
'squeeze_Y': False,
'nuisance_Y': nuisance_Y,
'nuisance_T': nuisance_T,
'theta': None,
'y_of_t': poly_y_of_t,
'x_eps': 1,
'y_eps': 1,
't_eps': 1
},
'treatment_featurizer': polynomial_treatment_featurizer,
'actual_marginal': poly_actual_marginal,
'actual_cme': poly_actual_cme,
'squeeze_Ts': [False, True],
'squeeze_Ys': [False, True],
'est_dicts': [
{'class': LinearDML, 'init_args': {}},
{'class': CausalForestDML, 'init_args': {}},
{'class': SparseLinearDML, 'init_args': {}},
{'class': KernelDML, 'init_args': {}},
]
}
poly_config_scikit = deepcopy(poly_config)
poly_config_scikit['treatment_featurizer'] = PolynomialFeatures(degree=2, include_bias=False)
poly_config_scikit['squeeze_Ts'] = [False]
poly_IV_config = deepcopy(poly_config)
poly_IV_config['DGP_params']['d_z'] = 1
poly_IV_config['DGP_params']['nuisance_TZ'] = lambda Z: Z
poly_IV_config['est_dicts'] = [
{'class': OrthoIV, 'init_args': {
'model_t_xwz': RandomForestRegressor(random_state=1), 'projection': True}},
{'class': DMLIV, 'init_args': {'model_t_xwz': RandomForestRegressor(random_state=1)}},
]
poly_1d_config = deepcopy(poly_config)
poly_1d_config['treatment_featurizer'] = polynomial_1d_treatment_featurizer
poly_1d_config['actual_cme'] = poly_1d_actual_cme
poly_1d_config['est_dicts'].append({
'class': NonParamDML,
'init_args': {
'model_y': LinearRegression(),
'model_t': LinearRegression(),
'model_final': StatsModelsLinearRegression()}})
poly_1d_IV_config = deepcopy(poly_IV_config)
poly_1d_IV_config['treatment_featurizer'] = polynomial_1d_treatment_featurizer
poly_1d_IV_config['actual_cme'] = poly_1d_actual_cme
poly_1d_IV_config['est_dicts'] = [
{'class': NonParamDMLIV, 'init_args': {'model_final': StatsModelsLinearRegression()}},
{'class': DRIV, 'init_args': {'fit_cate_intercept': True}},
{'class': LinearDRIV, 'init_args': {}},
{'class': SparseLinearDRIV, 'init_args': {}},
{'class': ForestDRIV, 'init_args': {}},
]
sum_IV_config = {
'DGP_params': {
'n': 2000,
'd_t': 2,
'd_y': 1,
'd_x': 5,
'd_z': 1,
'squeeze_T': False,
'squeeze_Y': False,
'nuisance_Y': nuisance_Y,
'nuisance_T': nuisance_T,
'nuisance_TZ': lambda Z: Z,
'theta': None,
'y_of_t': sum_y_of_t,
'x_eps': 1,
'y_eps': 1,
't_eps': 1
},
'treatment_featurizer': sum_treatment_featurizer,
'actual_marginal': sum_actual_marginal,
'actual_cme': sum_actual_cme,
'squeeze_Ts': [False],
'squeeze_Ys': [False, True],
'est_dicts': [
{'class': NonParamDMLIV, 'init_args': {'model_final': StatsModelsLinearRegression()}},
{'class': DRIV, 'init_args': {'fit_cate_intercept': True}},
{'class': LinearDRIV, 'init_args': {}},
{'class': SparseLinearDRIV, 'init_args': {}},
{'class': ForestDRIV, 'init_args': {}},
]
}
sum_squeeze_IV_config = deepcopy(sum_IV_config)
sum_squeeze_IV_config['treatment_featurizer'] = sum_squeeze_treatment_featurizer
sum_config = deepcopy(sum_IV_config)
sum_config['DGP_params']['d_z'] = None
sum_config['DGP_params']['nuisance_TZ'] = None
sum_config['est_dicts'] = deepcopy(poly_1d_config['est_dicts'])
sum_squeeze_config = deepcopy(sum_config)
sum_squeeze_config['treatment_featurizer'] = sum_squeeze_treatment_featurizer
configs = [
identity_config,
poly_config,
poly_config_scikit,
poly_IV_config,
poly_1d_config,
poly_1d_IV_config,
sum_IV_config,
sum_squeeze_IV_config,
sum_config,
sum_squeeze_config
]
for config in configs:
for squeeze_Y in config['squeeze_Ys']:
for squeeze_T in config['squeeze_Ts']:
config['DGP_params']['squeeze_Y'] = squeeze_Y
config['DGP_params']['squeeze_T'] = squeeze_T
dgp = DGP(**config['DGP_params'])
data_dict = dgp.gen_data()
Y = data_dict['Y']
T = data_dict['T']
X = data_dict['X']
feat_T = config['treatment_featurizer'].fit_transform(T)
data_dict_outside_feat = deepcopy(data_dict)
data_dict_outside_feat['T'] = feat_T
est_dicts = config['est_dicts']
for est_dict in est_dicts:
estClass = est_dict['class']
init_args = deepcopy(est_dict['init_args'])
init_args['treatment_featurizer'] = config['treatment_featurizer']
init_args['random_state'] = 1
est = estClass(**init_args)
est.fit(**data_dict)
init_args_outside_feat = deepcopy(est_dict['init_args'])
init_args_outside_feat['random_state'] = 1
est_outside_feat = estClass(**init_args_outside_feat)
est_outside_feat.fit(**data_dict_outside_feat)
# test that treatment names are assigned for the featurized treatment
assert (est.cate_treatment_names() is not None)
if hasattr(est, 'summary'):
est.summary()
# expected shapes
expected_eff_shape = (config['DGP_params']['n'],) + Y.shape[1:]
expected_cme_shape = (config['DGP_params']['n'],) + Y.shape[1:] + feat_T.shape[1:]
expected_me_shape = (config['DGP_params']['n'],) + Y.shape[1:] + T.shape[1:]
expected_marginal_ate_shape = expected_me_shape[1:]
# check effects
T0 = np.ones(shape=T.shape) * 5
T1 = np.ones(shape=T.shape) * 10
eff = est.effect(X=X, T0=T0, T1=T1)
assert (eff.shape == expected_eff_shape)
outside_feat = config['treatment_featurizer']
eff_outside_feat = est_outside_feat.effect(
X=X, T0=outside_feat.fit_transform(T0), T1=outside_feat.fit_transform(T1))
np.testing.assert_almost_equal(eff, eff_outside_feat)
actual_eff = actual_effect(config['DGP_params']['y_of_t'], T0, T1)
cme = est.const_marginal_effect(X=X)
assert (cme.shape == expected_cme_shape)
cme_outside_feat = est_outside_feat.const_marginal_effect(X=X)
np.testing.assert_almost_equal(cme, cme_outside_feat)
actual_cme = config['actual_cme']()
me = est.marginal_effect(T=T, X=X)
assert (me.shape == expected_me_shape)
actual_me = config['actual_marginal'](T).reshape(me.shape)
# ate
m_ate = est.marginal_ate(T, X=X)
assert (m_ate.shape == expected_marginal_ate_shape)
if isinstance(est, (LinearDML, SparseLinearDML, LinearDRIV, SparseLinearDRIV)):
d_f_t = feat_T.shape[1] if feat_T.shape[1:] else 1
expected_coef_inference_shape = (
config['DGP_params']['d_y'] * config['DGP_params']['d_x'] * d_f_t, 6)
assert est.coef__inference().summary_frame().shape == expected_coef_inference_shape
expected_intercept_inf_shape = (
config['DGP_params']['d_y'] * d_f_t, 6)
assert est.intercept__inference().summary_frame().shape == expected_intercept_inf_shape
# loose inference checks
# temporarily skip LinearDRIV and SparseLinearDRIV for weird effect shape reasons
if isinstance(est, (KernelDML, LinearDRIV, SparseLinearDRIV)):
continue
if est._inference is None:
continue
# effect inference
eff_inf = est.effect_inference(X=X, T0=T0, T1=T1)
eff_lb, eff_ub = eff_inf.conf_int(alpha=0.01)
assert (eff.shape == eff_lb.shape)
proportion_in_interval = ((eff_lb < actual_eff) & (actual_eff < eff_ub)).mean()
np.testing.assert_array_less(0.50, proportion_in_interval)
np.testing.assert_almost_equal(eff, eff_inf.point_estimate)
# marginal effect inference
me_inf = est.marginal_effect_inference(T, X=X)
me_lb, me_ub = me_inf.conf_int(alpha=0.01)
assert (me.shape == me_lb.shape)
proportion_in_interval = ((me_lb < actual_me) & (actual_me < me_ub)).mean()
np.testing.assert_array_less(0.50, proportion_in_interval)
np.testing.assert_almost_equal(me, me_inf.point_estimate)
# const marginal effect inference
cme_inf = est.const_marginal_effect_inference(X=X)
cme_lb, cme_ub = cme_inf.conf_int(alpha=0.01)
assert (cme.shape == cme_lb.shape)
proportion_in_interval = ((cme_lb < actual_cme) & (actual_cme < cme_ub)).mean()
np.testing.assert_array_less(0.50, proportion_in_interval)
np.testing.assert_almost_equal(cme, cme_inf.point_estimate)
def test_jac(self):
def func_transform(x):
x = x.reshape(-1, 1)
return np.hstack([x, x**2])
def calc_expected_jacobian(T):
jac = DPolynomialFeatures(degree=2, include_bias=False).fit_transform(T)
return jac
treatment_featurizers = [
PolynomialFeatures(degree=2, include_bias=False),
FunctionTransformer(func=func_transform)
]
n = 10000
d_t = 1
T = np.random.normal(size=(n, d_t))
for treatment_featurizer in treatment_featurizers:
# fit a dummy estimator first so the featurizer can be fit to the treatment
dummy_est = LinearDML(treatment_featurizer=treatment_featurizer)
dummy_est.fit(Y=T, T=T, X=T)
expected_jac = calc_expected_jacobian(T)
jac_T = dummy_est.transformer.jac(T)
np.testing.assert_almost_equal(jac_T, expected_jac)
def test_fail_discrete_treatment_and_treatment_featurizer(self):
class OrthoLearner(_OrthoLearner):
def _gen_ortho_learner_model_nuisance(self):
pass
def _gen_ortho_learner_model_final(self):
pass
est_and_params = [
{
'estimator': OrthoLearner,
'params': {
'cv': 2,
'discrete_treatment': False,
'treatment_featurizer': None,
'discrete_instrument': False,
'categories': 'auto',
'random_state': None
}
},
{'estimator': LinearDML, 'params': {}},
{'estimator': CausalForestDML, 'params': {}},
{'estimator': SparseLinearDML, 'params': {}},
{'estimator': KernelDML, 'params': {}},
{'estimator': DMLOrthoForest, 'params': {}}
]
dummy_vec = np.random.normal(size=(100, 1))
for est_and_param in est_and_params:
params = est_and_param['params']
params['discrete_treatment'] = True
params['treatment_featurizer'] = True
est = est_and_param['estimator'](**params)
with self.assertRaises(AssertionError, msg='Estimator fit did not fail when passed '
'both discrete treatment and treatment featurizer'):
est.fit(Y=dummy_vec, T=dummy_vec, X=dummy_vec)
def test_cate_treatment_names_edge_cases(self):
Y = np.random.normal(size=(100, 1))
T = np.random.binomial(n=2, p=0.5, size=(100, 1))
X = np.random.normal(size=(100, 3))
# edge case with transformer that only takes a vector treatment
# so far will always return None for cate_treatment_names
def weird_func(x):
assert np.ndim(x) == 1
return x
est = LinearDML(treatment_featurizer=FunctionTransformer(weird_func)).fit(Y=Y, T=T.squeeze(), X=X)
assert est.cate_treatment_names() is None
assert est.cate_treatment_names(['too', 'many', 'feature_names']) is None
# assert proper handling of improper feature names passed to certain transformers
est = LinearDML(discrete_treatment=True).fit(Y=Y, T=T, X=X)
assert est.cate_treatment_names() == ['T0_1', 'T0_2']
assert est.cate_treatment_names(['too', 'many', 'feature_names']) is None
est = LinearDML(treatment_featurizer=PolynomialFeatures(degree=2, include_bias=False)).fit(Y=Y, T=T, X=X)
assert est.cate_treatment_names() == ['T0', 'T0^2']
# depending on sklearn version, bad feature names either throws error or only uses first relevant name
assert est.cate_treatment_names(['too', 'many', 'feature_names']) in [None, ['too', 'too^2']]
def test_alpha_passthrough(self):
X = np.random.normal(size=(100, 3))
T = np.random.normal(size=(100, 1)) + X[:, [0]]
Y = np.random.normal(size=(100, 1)) + T + X[:, [0]]
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(),
treatment_featurizer=FunctionTransformer())
est.fit(Y=Y, T=T, X=X)
# ensure alpha is passed
lb, ub = est.marginal_effect_interval(T, X, alpha=1)
assert (lb == ub).all()
lb, ub = est.marginal_effect_interval(T, X)
assert (lb != ub).all()
lb1, ub1 = est.marginal_effect_interval(T, X, alpha=0.01)
lb2, ub2 = est.marginal_effect_interval(T, X, alpha=0.1)
assert (lb1 < lb2).all() and (ub1 > ub2).all()
def test_identity_feat_with_cate_api(self):
treatment_featurizations = [FunctionTransformer()]
TestDML()._test_cate_api(treatment_featurizations)

Просмотреть файл

@ -9,6 +9,7 @@ import scipy.sparse
import sparse as sp
import itertools
import inspect
import types
from operator import getitem
from collections import defaultdict, Counter
from sklearn import clone
@ -17,6 +18,7 @@ from sklearn.linear_model import LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLass
from functools import reduce, wraps
from sklearn.utils import check_array, check_X_y
from sklearn.utils.validation import assert_all_finite
from sklearn.preprocessing import PolynomialFeatures
import warnings
from warnings import warn
from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold
@ -46,7 +48,7 @@ class IdentityFeatures(TransformerMixin):
def parse_final_model_params(coef, intercept, d_y, d_t, d_t_in, bias_part_of_coef, fit_cate_intercept):
dt = d_t
if (d_t_in != d_t) and (d_t[0] == 1): # binary treatment
if (d_t_in != d_t) and (d_t and d_t[0] == 1): # binary treatment or single dim featurized treatment
dt = ()
cate_intercept = None
if bias_part_of_coef:
@ -598,12 +600,41 @@ def get_input_columns(X, prefix="X"):
pd.Series: lambda x: [x.name]
}
if type(X) in type_to_func:
return type_to_func[type(X)](X)
column_names = type_to_func[type(X)](X)
# if not all column names are strings
if not all(isinstance(item, str) for item in column_names):
warnings.warn("Not all column names are strings. Coercing to strings for now.", UserWarning)
return [str(item) for item in column_names]
len_X = 1 if np.ndim(X) == 1 else np.asarray(X).shape[1]
return [f"{prefix}{i}" for i in range(len_X)]
def get_feature_names_or_default(featurizer, feature_names):
def get_feature_names_or_default(featurizer, feature_names, prefix="feat(X)"):
"""
Extract feature names from sklearn transformers. Otherwise attempts to assign default feature names.
Designed to be compatible with old and new sklearn versions.
Parameters
----------
featurizer featurizer to extract feature names from
feature_names : input features
prefix : output prefix in the event where we assign default feature names
Returns
----------
feature_names_out : a list of strings (feature names)
"""
# coerce feature names to be strings
if not all(isinstance(item, str) for item in feature_names):
warnings.warn("Not all feature names are strings. Coercing to strings for now.", UserWarning)
feature_names = [str(item) for item in feature_names]
# Prefer sklearn 1.0's get_feature_names_out method to deprecated get_feature_names method
if hasattr(featurizer, "get_feature_names_out"):
try:
@ -612,17 +643,21 @@ def get_feature_names_or_default(featurizer, feature_names):
# Some featurizers will throw, such as a pipeline with a transformer that doesn't itself support names
pass
if hasattr(featurizer, 'get_feature_names'):
# Get number of arguments, some sklearn featurizer don't accept feature_names
arg_no = len(inspect.getfullargspec(featurizer.get_feature_names).args)
if arg_no == 1:
return featurizer.get_feature_names()
elif arg_no == 2:
return featurizer.get_feature_names(feature_names)
try:
# Get number of arguments, some sklearn featurizer don't accept feature_names
arg_no = len(inspect.getfullargspec(featurizer.get_feature_names).args)
if arg_no == 1:
return featurizer.get_feature_names()
elif arg_no == 2:
return featurizer.get_feature_names(feature_names)
except Exception:
# Handles cases where the passed feature names create issues
pass
# Featurizer doesn't have 'get_feature_names' or has atypical 'get_feature_names'
try:
# Get feature names using featurizer
dummy_X = np.ones((1, len(feature_names)))
return get_input_columns(featurizer.transform(dummy_X), prefix="feat(X)")
return get_input_columns(featurizer.transform(dummy_X), prefix=prefix)
except Exception:
# All attempts at retrieving transformed feature names have failed
# Delegate handling to downstream logic
@ -1391,6 +1426,85 @@ class _RegressionWrapper:
return self._clf.predict_proba(X)[:, 1:]
class _TransformerWrapper:
"""Wrapper that takes a featurizer as input and adds jacobian calculation functionality"""
def __init__(self, featurizer):
self.featurizer = featurizer
def fit(self, X):
return self.featurizer.fit(X)
def transform(self, X):
return self.featurizer.transform(X)
def fit_transform(self, X):
return self.featurizer.fit_transform(X)
def get_feature_names_out(self, feature_names):
return get_feature_names_or_default(self.featurizer, feature_names, prefix="feat(T)")
def jac(self, X, epsilon=0.001):
if hasattr(self.featurizer, 'jac'):
return self.featurizer.jac(X)
elif (isinstance(self.featurizer, PolynomialFeatures)):
powers = self.featurizer.powers_
result = np.zeros(X.shape + (self.featurizer.n_output_features_,))
for i in range(X.shape[1]):
p = powers.copy()
c = powers[:, i]
p[:, i] -= 1
M = np.float_power(X[:, np.newaxis, :], p[np.newaxis, :, :])
result[:, i, :] = c[np.newaxis, :] * np.prod(M, axis=-1)
return result
else:
squeeze = []
n = X.shape[0]
d_t = X.shape[-1] if ndim(X) > 1 else 1
X_out = self.transform(X)
d_f_t = X_out.shape[-1] if ndim(X_out) > 1 else 1
jacob = np.zeros((n, d_t, d_f_t))
if ndim(X) == 1:
squeeze.append(1)
X = X[:, np.newaxis]
if ndim(X_out) == 1:
squeeze.append(2)
# for every dimension of the treatment add some epsilon and observe change in featurized treatment
for k in range(d_t):
eps_matrix = np.zeros(shape=X.shape)
eps_matrix[:, k] = epsilon
X_in_plus = X + eps_matrix
X_in_plus = X_in_plus.squeeze(axis=1) if 1 in squeeze else X_in_plus
X_out_plus = self.transform(X_in_plus)
X_out_plus = X_out_plus[:, np.newaxis] if 2 in squeeze else X_out_plus
X_in_minus = X - eps_matrix
X_in_minus = X_in_minus.squeeze(axis=1) if 1 in squeeze else X_in_minus
X_out_minus = self.transform(X_in_minus)
X_out_minus = X_out_minus[:, np.newaxis] if 2 in squeeze else X_out_minus
diff = X_out_plus - X_out_minus
deriv = diff / (2 * epsilon)
jacob[:, k, :] = deriv
return jacob.squeeze(axis=tuple(squeeze))
def jacify_featurizer(featurizer):
"""
Function that takes a featurizer as input and returns a wrapper class that includes
a function for calculating the jacobian
"""
return _TransformerWrapper(featurizer)
@deprecated("This class will be removed from a future version of this package; "
"please use econml.sklearn_extensions.linear_model.WeightedLassoCV instead.")
class LassoCVWrapper:

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -16,5 +16,6 @@ markers = [
"automl",
"dml",
"serial",
"cate_api"
"cate_api",
"treatment_featurization"
]