зеркало из https://github.com/py-why/EconML.git
enable treatment featurization (#615)
Co-authored-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
Родитель
b4191e8735
Коммит
deb564fafa
|
@ -99,8 +99,8 @@ jobs:
|
|||
# Work around https://github.com/pypa/pip/issues/9542
|
||||
- script: 'pip install -U numpy~=1.21.0'
|
||||
displayName: 'Upgrade numpy'
|
||||
|
||||
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm && pip list && python setup.py pytest'
|
||||
|
||||
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm py && pip list && python setup.py pytest'
|
||||
displayName: 'Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "notebook"'
|
||||
|
@ -128,7 +128,7 @@ jobs:
|
|||
- script: 'pip install -U numpy~=1.21.0'
|
||||
displayName: 'Upgrade numpy'
|
||||
|
||||
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm && python setup.py pytest'
|
||||
- script: 'pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm py && python setup.py pytest'
|
||||
displayName: 'Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "notebook"'
|
||||
|
@ -199,10 +199,10 @@ jobs:
|
|||
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
|
||||
displayName: 'Run tests (main)'
|
||||
steps:
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
|
||||
displayName: 'Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "not (notebook or automl or dml or serial or cate_api)" -n 2'
|
||||
PYTEST_ADDOPTS: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization)" -n 2'
|
||||
COVERAGE_PROCESS_START: 'setup.cfg'
|
||||
- task: PublishTestResults@2
|
||||
displayName: 'Publish Test Results **/test-results.xml'
|
||||
|
@ -226,7 +226,7 @@ jobs:
|
|||
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
|
||||
displayName: 'Run tests (DML)'
|
||||
steps:
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
|
||||
displayName: 'Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "dml"'
|
||||
|
@ -255,7 +255,7 @@ jobs:
|
|||
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
|
||||
displayName: 'Run tests (Serial)'
|
||||
steps:
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" && python setup.py pytest'
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py && python setup.py pytest'
|
||||
displayName: 'Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "serial" -n 1'
|
||||
|
@ -282,7 +282,7 @@ jobs:
|
|||
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
|
||||
displayName: 'Run tests (Other)'
|
||||
steps:
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''"'
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py'
|
||||
displayName: 'Install pytest'
|
||||
- script: 'python setup.py pytest'
|
||||
displayName: 'CATE Unit tests'
|
||||
|
@ -296,6 +296,35 @@ jobs:
|
|||
testRunTitle: 'Python $(python.version), image $(imageName)'
|
||||
condition: succeededOrFailed()
|
||||
|
||||
- task: PublishCodeCoverageResults@1
|
||||
displayName: 'Publish Code Coverage Results'
|
||||
inputs:
|
||||
codeCoverageTool: Cobertura
|
||||
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'
|
||||
|
||||
- template: azure-pipelines-steps.yml
|
||||
parameters:
|
||||
package: '-e .[tf,plt]'
|
||||
job:
|
||||
job: Tests_treatment_featurization
|
||||
dependsOn: 'EvalChanges'
|
||||
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
|
||||
displayName: 'Run tests (Treatment Featurization)'
|
||||
steps:
|
||||
- script: 'pip install pytest pytest-runner "coverage<6.4.1;python_version==''3.6''" "coverage;python_version>''3.6''" py'
|
||||
displayName: 'Install pytest'
|
||||
- script: 'python setup.py pytest'
|
||||
displayName: 'Treatment Featurization Unit tests'
|
||||
env:
|
||||
PYTEST_ADDOPTS: '-m "treatment_featurization" -n auto'
|
||||
COVERAGE_PROCESS_START: 'setup.cfg'
|
||||
- task: PublishTestResults@2
|
||||
displayName: 'Publish Test Results **/test-results.xml'
|
||||
inputs:
|
||||
testResultsFiles: '**/test-results.xml'
|
||||
testRunTitle: 'Python $(python.version), image $(imageName)'
|
||||
condition: succeededOrFailed()
|
||||
|
||||
- task: PublishCodeCoverageResults@1
|
||||
displayName: 'Publish Code Coverage Results'
|
||||
inputs:
|
||||
|
|
|
@ -167,6 +167,9 @@ The base class of all the methods in our API has the following signature:
|
|||
Linear in Treatment CATE Estimators
|
||||
-----------------------------------
|
||||
|
||||
.. rubric::
|
||||
Constant Marginal Effects
|
||||
|
||||
In many settings, we might want to make further structural assumptions on the form of the data generating process.
|
||||
One particular prevalent assumption is that the outcome :math:`y` is linear in the treatment vector and therefore that the marginal effect is constant across treatments, i.e.:
|
||||
|
||||
|
@ -188,13 +191,45 @@ Hence, the marginal CATE is independent of :math:`\vec{t}`. In these settings, w
|
|||
.. math ::
|
||||
\theta(\vec{x}) = \E[H(X, W) | X=\vec{x}] \tag{constant marginal CATE}
|
||||
|
||||
.. rubric::
|
||||
Constant Marginal Effects and Marginal Effects Given Treatment Featurization
|
||||
|
||||
Additionally, we may be interested in cases where the outcome depends linearly on a transformation of the treatment vector (via some featurizer :math:`\phi`).
|
||||
Some estimators provide support for passing such a featurizer :math:`\phi` directly to the estimator, in which case the outcome would be modeled as follows:
|
||||
|
||||
.. math ::
|
||||
|
||||
Y = H(X, W) \cdot \phi(T) + g(X, W, \epsilon)
|
||||
|
||||
We can then get constant marginal effects in the featurized treatment space:
|
||||
|
||||
.. math ::
|
||||
|
||||
\tau(\phi(\vec{t_0}), \phi(\vec{t_1}), \vec{x}) =~& \E[H(X, W) | X=\vec{x}] \cdot (\phi(\vec{t_1}) - \phi(\vec{t_0}))
|
||||
|
||||
\partial \tau(\phi(\vec{t}), \vec{x}) =~& \E[H(X, W) | X=\vec{x}]
|
||||
|
||||
\theta(\vec{x}) =~& \E[H(X, W) | X=\vec{x}]
|
||||
|
||||
|
||||
Finally, we can recover the marginal effect with respect to the original treatment space by multiplying the constant marginal effect (which is in featurized treatment space) with the jacobian of the treatment featurizer at :math:`\vec{t}`.
|
||||
|
||||
.. math ::
|
||||
\partial \tau(\vec{t}, \vec{x}) = \theta(\vec{x}) \nabla \phi(\vec{t}) \tag{marginal CATE}
|
||||
|
||||
where :math:`\nabla \phi(\vec{t})` is the :math:`d_{ft} \times d_{t}` jacobian matrix, and :math:`d_{ft}` and :math:`d_{t}` are the dimensions of the featurized treatment and the original treatment, respectively.
|
||||
|
||||
.. rubric::
|
||||
API for Linear in Treatment CATE Estimators
|
||||
|
||||
Given the prevalence of linear treatment effect assumptions, we will create a generic LinearCateEstimator, which will support a method that returns the constant marginal CATE
|
||||
and constant marginal CATE interval at any target feature vector :math:`\vec{x}`.
|
||||
and constant marginal CATE interval at any target feature vector :math:`\vec{x}`, as well as calculating marginal effects in the original treatment space when a treatment featurizer is provided.
|
||||
|
||||
.. code-block:: python3
|
||||
:caption: Linear CATE Estimator Class
|
||||
|
||||
class LinearCateEstimator(BaseCateEstimator):
|
||||
self.treatment_featurizer = None
|
||||
|
||||
def const_marginal_effect(self, X=None):
|
||||
''' Calculates the constant marginal CATE θ(·) conditional on a vector of
|
||||
|
@ -204,8 +239,9 @@ and constant marginal CATE interval at any target feature vector :math:`\vec{x}`
|
|||
X: optional (m × d_x) matrix of features for each sample
|
||||
|
||||
Returns:
|
||||
theta: (m × d_y × d_t) matrix of constant marginal CATE of each treatment
|
||||
on each outcome for each sample
|
||||
theta: (m × d_y × d_f_t) matrix of constant marginal CATE of each treatment on each outcome
|
||||
for each sample, where d_f_t is the dimension of the featurized treatment.
|
||||
If treatment_featurizer is None, d_f_t = d_t
|
||||
'''
|
||||
|
||||
def const_marginal_effect_interval(self, X=None, *, alpha=0.05):
|
||||
|
@ -222,13 +258,25 @@ and constant marginal CATE interval at any target feature vector :math:`\vec{x}`
|
|||
'''
|
||||
|
||||
def effect(self, X=None, *, T0, T1,):
|
||||
return const_marginal_effect(X) * (T1 - T0)
|
||||
if self.treatment_featurizer:
|
||||
return const_marginal_effect(X) * (T1 - T0)
|
||||
else:
|
||||
dt = self.treatment_featurizer.transform(T1) - self.treatment_featurizer.transform(T0)
|
||||
return const_marginal_effect(X) * dt
|
||||
|
||||
def marginal_effect(self, T, X=None)
|
||||
return const_marginal_effect(X)
|
||||
if self.treatment_featurizer is None:
|
||||
return const_marginal_effect(X)
|
||||
else:
|
||||
# for every observation X_i, T_i,
|
||||
# calculate jacobian at T_i and multiply with const_marginal_effect at X_i
|
||||
|
||||
def marginal_effect_interval(self, T, X=None, *, alpha=0.05):
|
||||
return const_marginal_effect_interval(X, alpha=alpha)
|
||||
if self.treatment_featurizer is None:
|
||||
return const_marginal_effect_interval(X, alpha=alpha)
|
||||
else:
|
||||
# perform separate treatment featurization inference logic
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -508,7 +508,33 @@ Usage FAQs
|
|||
The method is going to assume that each of these treatments enters linearly into the model. So it cannot capture complementarities or substitutabilities
|
||||
of the different treatments. For that you can also create composite treatments that look like the product
|
||||
of two base treatments. Then these product will enter in the model and an effect for that product will be estimated.
|
||||
This effect will be the substitute/complement effect of both treatments being present, i.e.:
|
||||
This effect will be the substitute/complement effect of both treatments being present. See below for more examples.
|
||||
|
||||
If you have too many treatments, then you can use the :class:`.SparseLinearDML`. However,
|
||||
this method will essentially impose a regularization that only a small subset of your featurized treatments has any effect.
|
||||
|
||||
- **What if my treatments are continuous and don't have a linear effect on the outcome?**
|
||||
|
||||
You can impose a particular form of non-linearity by specifying a `treatment_featurizer` to the estimator.
|
||||
For example, one can use the sklearn `PolynomialFeatures` transformer as a `treatment_featurizer` in order to learn
|
||||
higher-order polynomial treatment effects.
|
||||
|
||||
Using the `treatment_featurizer` argument additionally has the benefit of calculating marginal effects with respect to the original treatment dimension,
|
||||
as opposed to featurizing the treatment yourself before passing to the estimator.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from econml.dml import LinearDML
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
|
||||
est = LinearDML(treatment_featurizer=poly)
|
||||
est.fit(y, T, X=X, W=W)
|
||||
point = est.const_marginal_effect(X)
|
||||
est.effect(X, T0=T0, T1=T1)
|
||||
est.marginal_effect(T, X)
|
||||
|
||||
|
||||
Alternatively, you can still create composite treatments and add them as extra treatment variables:
|
||||
|
||||
.. testcode::
|
||||
|
||||
|
@ -521,14 +547,6 @@ Usage FAQs
|
|||
point = est.const_marginal_effect(X)
|
||||
est.effect(X, T0=poly.transform(T0), T1=poly.transform(T1))
|
||||
|
||||
If your treatments are too many, then you can use the :class:`.SparseLinearDML`. However,
|
||||
this method will essentially impose a regularization that only a small subset of them has any effect.
|
||||
|
||||
- **What if my treatments are continuous and don't have a linear effect on the outcome?**
|
||||
|
||||
You can create composite treatments and add them as extra treatment variables (see above). This would require
|
||||
imposing a particular form of non-linearity.
|
||||
|
||||
- **What if my treatment is categorical/binary?**
|
||||
|
||||
You can simply set `discrete_treatment=True` in the parameters of the class. Then use any classifier for
|
||||
|
@ -691,14 +709,15 @@ We can even create a Pipeline or Union of featurizers that will apply multiply f
|
|||
|
||||
.. rubric:: Single Outcome, Multiple Treatments
|
||||
|
||||
Suppose that we believed that our treatment was affecting the outcome in a non-linear manner.
|
||||
Then we could expand the treatment vector to contain also polynomial features:
|
||||
Suppose we want to estimate treatment effects for multiple continuous treatments at the same time.
|
||||
Then we can simply concatenate them before passing them to the estimator.
|
||||
|
||||
.. testcode::
|
||||
|
||||
import numpy as np
|
||||
est = LinearDML()
|
||||
est.fit(y, np.concatenate((T, T**2), axis=1), X=X, W=W)
|
||||
est.fit(y, np.concatenate((T0, T1), axis=1), X=X, W=W)
|
||||
|
||||
|
||||
.. rubric:: Multiple Outcome, Multiple Treatments
|
||||
|
||||
|
|
|
@ -511,8 +511,8 @@ Usage FAQs
|
|||
|
||||
|
||||
|
||||
Usage FAQs
|
||||
==========
|
||||
Usage Examples
|
||||
==============
|
||||
|
||||
Check out the following Jupyter notebooks:
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ from functools import wraps
|
|||
from copy import deepcopy
|
||||
from warnings import warn
|
||||
from .inference import BootstrapInference
|
||||
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params,
|
||||
inverse_onehot, Summary, get_input_columns, check_input_arrays)
|
||||
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params, get_feature_names_or_default,
|
||||
inverse_onehot, Summary, get_input_columns, check_input_arrays, jacify_featurizer)
|
||||
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
|
||||
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
|
||||
GenericModelFinalInferenceDiscrete
|
||||
|
@ -320,14 +320,16 @@ class BaseCateEstimator(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
return (X,) + Ts
|
||||
|
||||
def _use_inference_method(self, name, *args, **kwargs):
|
||||
if self._inference is not None:
|
||||
return getattr(self._inference, name)(*args, **kwargs)
|
||||
else:
|
||||
raise AttributeError("Can't call '%s' because 'inference' is None" % name)
|
||||
|
||||
def _defer_to_inference(m):
|
||||
@wraps(m)
|
||||
def call(self, *args, **kwargs):
|
||||
name = m.__name__
|
||||
if self._inference is not None:
|
||||
return getattr(self._inference, name)(*args, **kwargs)
|
||||
else:
|
||||
raise AttributeError("Can't call '%s' because 'inference' is None" % name)
|
||||
return self._use_inference_method(m.__name__, *args, **kwargs)
|
||||
return call
|
||||
|
||||
@_defer_to_inference
|
||||
|
@ -534,7 +536,11 @@ class BaseCateEstimator(metaclass=abc.ABCMeta):
|
|||
|
||||
|
||||
class LinearCateEstimator(BaseCateEstimator):
|
||||
"""Base class for all CATE estimators with linear treatment effects in this package."""
|
||||
"""
|
||||
Base class for all CATE estimators in this package where the outcome is linear given
|
||||
some user-defined treatment featurization.
|
||||
"""
|
||||
_original_treatment_featurizer = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def const_marginal_effect(self, X=None):
|
||||
|
@ -551,10 +557,11 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
|
||||
Returns
|
||||
-------
|
||||
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
|
||||
Constant marginal CATE of each treatment on each outcome for each sample X[i].
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
theta: (m, d_y, d_f_t) matrix or (d_y, d_f_t) matrix if X is None where d_f_t is \
|
||||
the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t.
|
||||
Constant marginal CATE of each featurized treatment on each outcome for each sample X[i].
|
||||
Note that when Y or featurized-T (or T if treatment_featurizer is None) is a vector
|
||||
rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will also be a vector)
|
||||
"""
|
||||
pass
|
||||
|
@ -607,7 +614,8 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
|
||||
The marginal effect is calculated around a base treatment
|
||||
point conditional on a vector of features on a set of m test samples :math:`\\{T_i, X_i\\}`.
|
||||
Since this class assumes a linear model, the base treatment is ignored in this calculation.
|
||||
If treatment_featurizer is None, the base treatment is ignored in this calculation and the result
|
||||
is equivalent to const_marginal_effect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -624,24 +632,49 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will also be a vector)
|
||||
"""
|
||||
X, T = self._expand_treatments(X, T)
|
||||
X, T = self._expand_treatments(X, T, transform=False)
|
||||
eff = self.const_marginal_effect(X)
|
||||
return np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
|
||||
|
||||
if X is None:
|
||||
eff = np.repeat(eff, shape(T)[0], axis=0)
|
||||
|
||||
if self._original_treatment_featurizer:
|
||||
feat_T = self.transformer.transform(T)
|
||||
jac_T = self.transformer.jac(T)
|
||||
|
||||
einsum_str = 'myf, mtf->myt'
|
||||
|
||||
if ndim(T) == 1:
|
||||
einsum_str = einsum_str.replace('t', '')
|
||||
if ndim(feat_T) == 1:
|
||||
einsum_str = einsum_str.replace('f', '')
|
||||
if (ndim(eff) == ndim(feat_T)):
|
||||
einsum_str = einsum_str.replace('y', '')
|
||||
return np.einsum(einsum_str, eff, jac_T)
|
||||
|
||||
else:
|
||||
return eff
|
||||
|
||||
def marginal_effect_interval(self, T, X=None, *, alpha=0.05):
|
||||
X, T = self._expand_treatments(X, T)
|
||||
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
|
||||
if X is None: # need to repeat by the number of rows of T to ensure the right shape
|
||||
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
|
||||
return effs
|
||||
if self._original_treatment_featurizer:
|
||||
return self._use_inference_method('marginal_effect_interval', T, X, alpha=alpha)
|
||||
else:
|
||||
X, T = self._expand_treatments(X, T)
|
||||
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
|
||||
if X is None: # need to repeat by the number of rows of T to ensure the right shape
|
||||
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
|
||||
return effs
|
||||
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__
|
||||
|
||||
def marginal_effect_inference(self, T, X=None):
|
||||
X, T = self._expand_treatments(X, T)
|
||||
cme_inf = self.const_marginal_effect_inference(X=X)
|
||||
if X is None:
|
||||
cme_inf = cme_inf._expand_outputs(shape(T)[0])
|
||||
return cme_inf
|
||||
if self._original_treatment_featurizer:
|
||||
return self._use_inference_method('marginal_effect_inference', T, X)
|
||||
else:
|
||||
X, T = self._expand_treatments(X, T)
|
||||
cme_inf = self.const_marginal_effect_inference(X=X)
|
||||
if X is None:
|
||||
cme_inf = cme_inf._expand_outputs(shape(T)[0])
|
||||
return cme_inf
|
||||
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__
|
||||
|
||||
@BaseCateEstimator._defer_to_inference
|
||||
|
@ -697,11 +730,12 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
|
||||
Returns
|
||||
-------
|
||||
theta: (d_y, d_t) matrix
|
||||
theta: (d_y, d_f_t) matrix where d_f_t is the dimension of the featurized treatment. \
|
||||
If treatment_featurizer is None, d_f_t = d_t.
|
||||
Average constant marginal CATE of each treatment on each outcome.
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will be a scalar)
|
||||
Note that when Y or featurized-T (or T if treatment_featurizer is None) is a vector
|
||||
rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will also be a scalar)
|
||||
"""
|
||||
return np.mean(self.const_marginal_effect(X=X), axis=0)
|
||||
|
||||
|
@ -748,15 +782,17 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
pass
|
||||
|
||||
def marginal_ate(self, T, X=None):
|
||||
return self.const_marginal_ate(X=X)
|
||||
return np.mean(self.marginal_effect(T, X=X), axis=0)
|
||||
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__
|
||||
|
||||
@BaseCateEstimator._defer_to_inference
|
||||
def marginal_ate_interval(self, T, X=None, *, alpha=0.05):
|
||||
return self.const_marginal_ate_interval(X=X, alpha=alpha)
|
||||
pass
|
||||
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__
|
||||
|
||||
@BaseCateEstimator._defer_to_inference
|
||||
def marginal_ate_inference(self, T, X=None):
|
||||
return self.const_marginal_ate_inference(X=X)
|
||||
pass
|
||||
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__
|
||||
|
||||
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
|
||||
|
@ -769,7 +805,7 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
|
||||
The names of input features.
|
||||
treatment_names: optional None or list (Default=None)
|
||||
The name of treatment. In discrete treatment scenario, the name should not include the name of
|
||||
The name of featurized treatment. In discrete treatment scenario, the name should not include the name of
|
||||
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
|
||||
output_names: optional None or list (Default=None)
|
||||
The name of the outcome.
|
||||
|
@ -793,9 +829,13 @@ class LinearCateEstimator(BaseCateEstimator):
|
|||
|
||||
|
||||
class TreatmentExpansionMixin(BaseCateEstimator):
|
||||
"""Mixin which automatically handles promotions of scalar treatments to the appropriate shape."""
|
||||
"""
|
||||
Mixin which automatically handles promotions of scalar treatments to the appropriate shape,
|
||||
as well as treatment featurization for discrete treatments and user-specified treatment transformers
|
||||
"""
|
||||
|
||||
transformer = None
|
||||
_original_treatment_featurizer = None
|
||||
|
||||
def _prefit(self, Y, T, *args, **kwargs):
|
||||
super()._prefit(Y, T, *args, **kwargs)
|
||||
|
@ -808,7 +848,7 @@ class TreatmentExpansionMixin(BaseCateEstimator):
|
|||
if self.transformer:
|
||||
self._set_transformed_treatment_names()
|
||||
|
||||
def _expand_treatments(self, X=None, *Ts):
|
||||
def _expand_treatments(self, X=None, *Ts, transform=True):
|
||||
X, *Ts = check_input_arrays(X, *Ts)
|
||||
n_rows = 1 if X is None else shape(X)[0]
|
||||
outTs = []
|
||||
|
@ -820,23 +860,29 @@ class TreatmentExpansionMixin(BaseCateEstimator):
|
|||
if ndim(T) == 0:
|
||||
T = np.full((n_rows,) + self._d_t_in, T)
|
||||
|
||||
if self.transformer:
|
||||
T = self.transformer.transform(reshape(T, (-1, 1)))
|
||||
if self.transformer and transform:
|
||||
if not self._original_treatment_featurizer:
|
||||
T = T.reshape(-1, 1)
|
||||
T = self.transformer.transform(T)
|
||||
outTs.append(T)
|
||||
|
||||
return (X,) + tuple(outTs)
|
||||
|
||||
def _set_transformed_treatment_names(self):
|
||||
"""Works with sklearn OHEs"""
|
||||
"""
|
||||
Extracts treatment names from sklearn transformers.
|
||||
Or, if transformer does not have a get_feature_names method, sets default treatment names.
|
||||
"""
|
||||
|
||||
if hasattr(self, "_input_names"):
|
||||
self._input_names["treatment_names"] = self.transformer.get_feature_names(
|
||||
self._input_names["treatment_names"]).tolist()
|
||||
ret = get_feature_names_or_default(self.transformer, self._input_names["treatment_names"], prefix='T')
|
||||
self._input_names["treatment_names"] = list(ret) if ret is not None else ret
|
||||
|
||||
def cate_treatment_names(self, treatment_names=None):
|
||||
"""
|
||||
Get treatment names.
|
||||
|
||||
If the treatment is discrete, it will return expanded treatment names.
|
||||
If the treatment is discrete or featurized, it will return expanded treatment names.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -851,7 +897,8 @@ class TreatmentExpansionMixin(BaseCateEstimator):
|
|||
"""
|
||||
if treatment_names is not None:
|
||||
if self.transformer:
|
||||
return self.transformer.get_feature_names(treatment_names).tolist()
|
||||
ret = get_feature_names_or_default(self.transformer, treatment_names)
|
||||
return list(ret) if ret is not None else None
|
||||
return treatment_names
|
||||
# Treatment names is None, default to BaseCateEstimator
|
||||
return super().cate_treatment_names()
|
||||
|
@ -891,6 +938,7 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
|
|||
Whether the CATE model's intercept is contained in the final model's ``coef_`` rather
|
||||
than as a separate ``intercept_``
|
||||
"""
|
||||
featurizer = None
|
||||
|
||||
def _get_inference_options(self):
|
||||
options = super()._get_inference_options()
|
||||
|
@ -1030,14 +1078,31 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
|
|||
output_names = self.cate_output_names(output_names)
|
||||
# Summary
|
||||
smry = Summary()
|
||||
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
|
||||
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
|
||||
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
|
||||
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
|
||||
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
|
||||
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
|
||||
"each outcome $i$ and treatment $j$. "
|
||||
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
|
||||
|
||||
extra_txt = ["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:"]
|
||||
|
||||
if self._original_treatment_featurizer:
|
||||
extra_txt.append("$Y = \\Theta(X)\\cdot \\psi(T) + g(X, W) + \\epsilon$")
|
||||
extra_txt.append("where $\\psi(T)$ is the output of the `treatment_featurizer")
|
||||
extra_txt.append(
|
||||
"and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
|
||||
else:
|
||||
extra_txt.append("$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$")
|
||||
extra_txt.append(
|
||||
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
|
||||
|
||||
if self.featurizer:
|
||||
extra_txt.append("$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$")
|
||||
extra_txt.append("where $\\phi(X)$ is the output of the `featurizer`")
|
||||
else:
|
||||
extra_txt.append("$\\Theta_{ij}(X) = X' coef_{ij} + cate\\_intercept_{ij}$")
|
||||
|
||||
extra_txt.append("Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
|
||||
"each outcome $i$ and treatment $j$. "
|
||||
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>")
|
||||
|
||||
smry.add_extra_txt(extra_txt)
|
||||
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
try:
|
||||
|
|
|
@ -44,7 +44,7 @@ from ._cate_estimator import (BaseCateEstimator, LinearCateEstimator,
|
|||
from .inference import BootstrapInference
|
||||
from .utilities import (_deprecate_positional, check_input_arrays,
|
||||
cross_product, filter_none_kwargs,
|
||||
inverse_onehot, ndim, reshape, shape, transpose)
|
||||
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)
|
||||
|
||||
|
||||
def _crossfit(model, folds, *args, **kwargs):
|
||||
|
@ -256,6 +256,11 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
discrete_treatment: bool
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer` or None
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
discrete_instrument: bool
|
||||
Whether the instrument values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
|
@ -337,8 +342,8 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
np.random.seed(123)
|
||||
X = np.random.normal(size=(100, 3))
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
|
||||
discrete_instrument=False, categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], W=X[:, 1:])
|
||||
|
||||
>>> est.score_
|
||||
|
@ -396,7 +401,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
|
||||
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
treatment_featurizer=None, categories='auto', random_state=None)
|
||||
est.fit(y, T, W=W)
|
||||
|
||||
>>> est.score_
|
||||
|
@ -427,10 +432,12 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
discrete_treatment, discrete_instrument, categories, cv, random_state,
|
||||
discrete_treatment, treatment_featurizer,
|
||||
discrete_instrument, categories, cv, random_state,
|
||||
mc_iters=None, mc_agg='mean'):
|
||||
self.cv = cv
|
||||
self.discrete_treatment = discrete_treatment
|
||||
self.treatment_featurizer = treatment_featurizer
|
||||
self.discrete_instrument = discrete_instrument
|
||||
self.random_state = random_state
|
||||
self.categories = categories
|
||||
|
@ -595,6 +602,8 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
self._random_state = check_random_state(self.random_state)
|
||||
assert (freq_weight is None) == (
|
||||
sample_var is None), "Sample variances and frequency weights must be provided together!"
|
||||
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
|
||||
"is not supported when treatment is discrete"
|
||||
if check_input:
|
||||
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
|
||||
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
|
||||
|
@ -609,6 +618,11 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
|
||||
self.transformer.fit(reshape(T, (-1, 1)))
|
||||
self._d_t = (len(self.transformer.categories_[0]) - 1,)
|
||||
elif self.treatment_featurizer:
|
||||
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
|
||||
self.transformer = jacify_featurizer(self.treatment_featurizer)
|
||||
output_T = self.transformer.fit_transform(T)
|
||||
self._d_t = np.shape(output_T)[1:]
|
||||
else:
|
||||
self.transformer = None
|
||||
|
||||
|
@ -675,10 +689,21 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
# _d_t is altered by fit nuisances to what prefit does. So we need to perform the same
|
||||
# alteration even when we only want to fit_final.
|
||||
if self.transformer is not None:
|
||||
self._d_t = (len(self.transformer.categories_[0]) - 1,)
|
||||
if self.discrete_treatment:
|
||||
self._d_t = (len(self.transformer.categories_[0]) - 1,)
|
||||
else:
|
||||
output_T = self.transformer.fit_transform(T)
|
||||
self._d_t = np.shape(output_T)[1:]
|
||||
|
||||
final_T = T
|
||||
if self.transformer:
|
||||
if (self.discrete_treatment):
|
||||
final_T = self.transformer.transform(final_T.reshape(-1, 1))
|
||||
else: # treatment featurizer case
|
||||
final_T = output_T
|
||||
|
||||
self._fit_final(Y=Y,
|
||||
T=self.transformer.transform(T.reshape((-1, 1))) if self.transformer is not None else T,
|
||||
T=final_T,
|
||||
X=X, W=W, Z=Z,
|
||||
nuisances=nuisances,
|
||||
sample_weight=sample_weight,
|
||||
|
@ -733,8 +758,10 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
if strata is None:
|
||||
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying
|
||||
|
||||
if self.discrete_treatment:
|
||||
T = self.transformer.transform(reshape(T, (-1, 1)))
|
||||
if self.transformer:
|
||||
if self.discrete_treatment:
|
||||
T = reshape(T, (-1, 1))
|
||||
T = self.transformer.transform(T)
|
||||
|
||||
if self.discrete_instrument:
|
||||
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))
|
||||
|
|
|
@ -145,6 +145,11 @@ class _RLearner(_OrthoLearner):
|
|||
discrete_treatment: bool
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer` or None
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -218,7 +223,8 @@ class _RLearner(_OrthoLearner):
|
|||
np.random.seed(123)
|
||||
X = np.random.normal(size=(1000, 3))
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
|
||||
est = RLearner(cv=2, discrete_treatment=False, categories='auto', random_state=None)
|
||||
est = RLearner(cv=2, discrete_treatment=False,
|
||||
treatment_featurizer=None, categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
|
||||
|
||||
>>> est.const_marginal_effect(np.ones((1,1)))
|
||||
|
@ -265,8 +271,10 @@ class _RLearner(_OrthoLearner):
|
|||
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, *, discrete_treatment, categories, cv, random_state, mc_iters=None, mc_agg='mean'):
|
||||
def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
|
||||
cv, random_state, mc_iters=None, mc_agg='mean'):
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
discrete_instrument=False, # no instrument, so doesn't matter
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
|
|
|
@ -197,6 +197,64 @@ class _GenericSingleOutcomeModelFinalWithCovInference(Inference):
|
|||
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
|
||||
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')
|
||||
|
||||
def marginal_effect_interval(self, T, X, alpha=0.05):
|
||||
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
|
||||
|
||||
def marginal_effect_inference(self, T, X):
|
||||
if X is None:
|
||||
raise ValueError("This inference method currently does not support X=None!")
|
||||
if not self._est._original_treatment_featurizer:
|
||||
return self.const_marginal_effect_inference(X)
|
||||
X, T = self._est._expand_treatments(X, T, transform=False)
|
||||
if self.featurizer is not None:
|
||||
X = self.featurizer.transform(X)
|
||||
|
||||
feat_T = self._est.transformer.transform(T)
|
||||
jac_T = self._est.transformer.jac(T)
|
||||
|
||||
d_t_orig = T.shape[1:]
|
||||
d_t_orig = d_t_orig[0] if d_t_orig else 1
|
||||
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
|
||||
output_shape = [X.shape[0]]
|
||||
if self._d_y:
|
||||
output_shape.append(self._d_y[0])
|
||||
if T.shape[1:]:
|
||||
output_shape.append(T.shape[1])
|
||||
me_pred = np.zeros(shape=output_shape)
|
||||
me_stderr = np.zeros(shape=output_shape)
|
||||
|
||||
for i in range(d_t_orig):
|
||||
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
|
||||
jac_index = [slice(None)]
|
||||
me_index = [slice(None)]
|
||||
if self._d_y:
|
||||
me_index.append(slice(None))
|
||||
if T.shape[1:]:
|
||||
jac_index.append(i)
|
||||
me_index.append(i)
|
||||
if feat_T.shape[1:]: # if featurized T is not a vector
|
||||
jac_index.append(slice(None))
|
||||
|
||||
jac_slice = jac_T[tuple(jac_index)]
|
||||
if jac_slice.ndim == 1:
|
||||
jac_slice.reshape((-1, 1))
|
||||
|
||||
e_pred, e_var = self.model_final.predict_projection_and_var(X, jac_slice)
|
||||
e_stderr = np.sqrt(e_var)
|
||||
|
||||
if not self._d_y:
|
||||
e_pred = e_pred.squeeze(axis=1)
|
||||
e_stderr = e_stderr.squeeze(axis=1)
|
||||
|
||||
me_pred[tuple(me_index)] = e_pred
|
||||
me_stderr[tuple(me_index)] = e_stderr
|
||||
|
||||
return NormalInferenceResults(d_t=d_t_orig, d_y=self.d_y, pred=me_pred,
|
||||
pred_stderr=me_stderr, mean_pred_stderr=None, inf_type='effect')
|
||||
|
||||
|
||||
class CausalForestDML(_BaseDML):
|
||||
"""A Causal Forest [cfdml1]_ combined with double machine learning based residualization of the treatment
|
||||
|
@ -227,6 +285,11 @@ class CausalForestDML(_BaseDML):
|
|||
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
|
||||
If featurizer=None, then CATE is trained on X.
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
discrete_treatment: bool, optional (default is ``False``)
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
|
@ -513,6 +576,7 @@ class CausalForestDML(_BaseDML):
|
|||
model_y='auto',
|
||||
model_t='auto',
|
||||
featurizer=None,
|
||||
treatment_featurizer=None,
|
||||
discrete_treatment=False,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
|
@ -567,6 +631,7 @@ class CausalForestDML(_BaseDML):
|
|||
self.n_jobs = n_jobs
|
||||
self.verbose = verbose
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
|
|
@ -360,6 +360,11 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
|
|||
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
|
||||
If featurizer=None, then CATE is trained on X.
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional, default None
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
fit_cate_intercept : bool, optional, default True
|
||||
Whether the linear CATE model should have a constant term.
|
||||
|
||||
|
@ -452,6 +457,7 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
|
|||
def __init__(self, *,
|
||||
model_y, model_t, model_final,
|
||||
featurizer=None,
|
||||
treatment_featurizer=None,
|
||||
fit_cate_intercept=True,
|
||||
linear_first_stages=False,
|
||||
discrete_treatment=False,
|
||||
|
@ -469,6 +475,7 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
|
|||
self.model_t = clone(model_t, safe=False)
|
||||
self.model_final = clone(model_final, safe=False)
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -585,6 +592,11 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
|
|||
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
|
||||
If featurizer=None, then CATE is trained on X.
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
fit_cate_intercept : bool, optional, default True
|
||||
Whether the linear CATE model should have a constant term.
|
||||
|
||||
|
@ -670,6 +682,7 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
|
|||
def __init__(self, *,
|
||||
model_y='auto', model_t='auto',
|
||||
featurizer=None,
|
||||
treatment_featurizer=None,
|
||||
fit_cate_intercept=True,
|
||||
linear_first_stages=True,
|
||||
discrete_treatment=False,
|
||||
|
@ -682,6 +695,7 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
|
|||
model_t=model_t,
|
||||
model_final=None,
|
||||
featurizer=featurizer,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
linear_first_stages=linear_first_stages,
|
||||
discrete_treatment=discrete_treatment,
|
||||
|
@ -811,6 +825,11 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
|
|||
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
|
||||
If featurizer=None, then CATE is trained on X.
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional, default None
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
fit_cate_intercept : bool, optional, default True
|
||||
Whether the linear CATE model should have a constant term.
|
||||
|
||||
|
@ -902,6 +921,7 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
|
|||
tol=1e-4,
|
||||
n_jobs=None,
|
||||
featurizer=None,
|
||||
treatment_featurizer=None,
|
||||
fit_cate_intercept=True,
|
||||
linear_first_stages=True,
|
||||
discrete_treatment=False,
|
||||
|
@ -921,6 +941,7 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
|
|||
model_t=model_t,
|
||||
model_final=None,
|
||||
featurizer=featurizer,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
linear_first_stages=linear_first_stages,
|
||||
discrete_treatment=discrete_treatment,
|
||||
|
@ -1041,6 +1062,11 @@ class KernelDML(DML):
|
|||
discrete_treatment: bool, optional (default is ``False``)
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -1101,7 +1127,9 @@ class KernelDML(DML):
|
|||
"""
|
||||
|
||||
def __init__(self, model_y='auto', model_t='auto',
|
||||
discrete_treatment=False, categories='auto',
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
fit_cate_intercept=True,
|
||||
dim=20,
|
||||
bw=1.0,
|
||||
|
@ -1114,6 +1142,7 @@ class KernelDML(DML):
|
|||
model_t=model_t,
|
||||
model_final=None,
|
||||
featurizer=None,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=discrete_treatment,
|
||||
categories=categories,
|
||||
|
@ -1212,6 +1241,11 @@ class NonParamDML(_BaseDML):
|
|||
discrete_treatment: bool, optional (default is ``False``)
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -1282,6 +1316,7 @@ class NonParamDML(_BaseDML):
|
|||
model_y, model_t, model_final,
|
||||
featurizer=None,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -1295,6 +1330,7 @@ class NonParamDML(_BaseDML):
|
|||
self.featurizer = clone(featurizer, safe=False)
|
||||
self.model_final = clone(model_final, safe=False)
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
|
|
@ -44,7 +44,7 @@ from sklearn.linear_model import (LassoCV, LinearRegression,
|
|||
from sklearn.ensemble import RandomForestRegressor
|
||||
|
||||
from .._ortho_learner import _OrthoLearner
|
||||
from .._cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin,
|
||||
from .._cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin, BaseCateEstimator,
|
||||
ForestModelFinalCateEstimatorDiscreteMixin,
|
||||
StatsModelsCateEstimatorDiscreteMixin, LinearCateEstimator)
|
||||
from ..inference import GenericModelFinalInferenceDiscrete
|
||||
|
@ -420,10 +420,54 @@ class DRLearner(_OrthoLearner):
|
|||
mc_iters=mc_iters,
|
||||
mc_agg=mc_agg,
|
||||
discrete_treatment=True,
|
||||
treatment_featurizer=None, # treatment featurization not supported with discrete treatment
|
||||
discrete_instrument=False, # no instrument, so doesn't matter
|
||||
categories=categories,
|
||||
random_state=random_state)
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_effect(self, X=None):
|
||||
"""
|
||||
Calculate the constant marginal CATE :math:`\\theta(·)`.
|
||||
|
||||
The marginal effect is conditional on a vector of
|
||||
features on a set of m test samples X[i].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X: optional (m, d_x) matrix or None (Default=None)
|
||||
Features for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
|
||||
Constant marginal CATE of each treatment on each outcome for each sample X[i].
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will also be a vector)
|
||||
"""
|
||||
return super().const_marginal_effect(X=X)
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_ate(self, X=None):
|
||||
"""
|
||||
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X: optional (m, d_x) matrix or None (Default=None)
|
||||
Features for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
theta: (d_y, d_t) matrix
|
||||
Average constant marginal CATE of each treatment on each outcome.
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will be a scalar)
|
||||
"""
|
||||
return super().const_marginal_ate(X=X)
|
||||
|
||||
def _get_inference_options(self):
|
||||
options = super()._get_inference_options()
|
||||
if not self.multitask_model_final:
|
||||
|
|
|
@ -469,6 +469,7 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
self.model_y = clone(model_y, safe=False)
|
||||
self.model_t = clone(model_t, safe=False)
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=None,
|
||||
discrete_instrument=False,
|
||||
categories=categories,
|
||||
cv=GroupKFold(cv) if isinstance(cv, int) else cv,
|
||||
|
@ -476,6 +477,49 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
mc_agg=mc_agg,
|
||||
random_state=random_state)
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_effect(self, X=None):
|
||||
"""
|
||||
Calculate the constant marginal CATE :math:`\\theta(·)`.
|
||||
|
||||
The marginal effect is conditional on a vector of
|
||||
features on a set of m test samples X[i].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X: optional (m, d_x) matrix or None (Default=None)
|
||||
Features for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
theta: (m, d_y, d_t) matrix or (d_y, d_t) matrix if X is None
|
||||
Constant marginal CATE of each treatment on each outcome for each sample X[i].
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will also be a vector)
|
||||
"""
|
||||
return super().const_marginal_effect(X=X)
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_ate(self, X=None):
|
||||
"""
|
||||
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X: optional (m, d_x) matrix or None (Default=None)
|
||||
Features for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
theta: (d_y, d_t) matrix
|
||||
Average constant marginal CATE of each treatment on each outcome.
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will be a scalar)
|
||||
"""
|
||||
return super().const_marginal_ate(X=X)
|
||||
|
||||
def _gen_featurizer(self):
|
||||
return clone(self.featurizer, safe=False)
|
||||
|
||||
|
@ -705,13 +749,13 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
return feature_names
|
||||
return get_feature_names_or_default(self.original_featurizer, feature_names)
|
||||
|
||||
def _expand_treatments(self, X, *Ts):
|
||||
def _expand_treatments(self, X, *Ts, transform=True):
|
||||
# Expand treatments for each time period
|
||||
outTs = []
|
||||
base_expand_treatments = super()._expand_treatments
|
||||
for T in Ts:
|
||||
if ndim(T) == 0:
|
||||
one_T = base_expand_treatments(X, T)[1]
|
||||
one_T = base_expand_treatments(X, T, transform=transform)[1]
|
||||
one_T = one_T.reshape(-1, 1) if ndim(one_T) == 1 else one_T
|
||||
T = np.tile(one_T, (1, self._n_periods, ))
|
||||
else:
|
||||
|
@ -720,7 +764,7 @@ class DynamicDML(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
if self.transformer:
|
||||
T = np.hstack([
|
||||
base_expand_treatments(
|
||||
X, T[:, [t]])[1] for t in range(self._n_periods)
|
||||
X, T[:, [t]], transform=transform)[1] for t in range(self._n_periods)
|
||||
])
|
||||
outTs.append(T)
|
||||
return (X,) + tuple(outTs)
|
||||
|
|
|
@ -15,7 +15,7 @@ from ._bootstrap import BootstrapEstimator
|
|||
from ..sklearn_extensions.linear_model import StatsModelsLinearRegression
|
||||
from ..utilities import (Summary, _safe_norm_ppf, broadcast_unit_treatments,
|
||||
cross_product, inverse_onehot, ndim,
|
||||
parse_final_model_params,
|
||||
parse_final_model_params, jacify_featurizer,
|
||||
reshape_treatmentwise_effects, shape, filter_none_kwargs)
|
||||
|
||||
"""Options for performing inference in estimators."""
|
||||
|
@ -201,6 +201,42 @@ class GenericSingleTreatmentModelFinalInference(GenericModelFinalInference):
|
|||
feature_names=self._est.cate_feature_names(),
|
||||
output_names=self._est.cate_output_names())
|
||||
|
||||
def marginal_effect_inference(self, T, X):
|
||||
X, T = self._est._expand_treatments(X, T, transform=False)
|
||||
|
||||
cme_inf = self.const_marginal_effect_inference(X)
|
||||
if not self._est._original_treatment_featurizer:
|
||||
return cme_inf
|
||||
|
||||
feat_T = self._est.transformer.transform(T)
|
||||
|
||||
cme_pred = cme_inf.point_estimate
|
||||
cme_stderr = cme_inf.stderr
|
||||
|
||||
jac_T = self._est.transformer.jac(T)
|
||||
|
||||
einsum_str = 'myf, mtf->myt'
|
||||
if ndim(T) == 1:
|
||||
einsum_str = einsum_str.replace('t', '')
|
||||
if ndim(feat_T) == 1:
|
||||
einsum_str = einsum_str.replace('f', '')
|
||||
# y is a vector, rather than a 2D array
|
||||
if (ndim(cme_pred) == ndim(feat_T)):
|
||||
einsum_str = einsum_str.replace('y', '')
|
||||
e_pred = np.einsum(einsum_str, cme_pred, jac_T)
|
||||
e_stderr = np.einsum(einsum_str, cme_stderr, np.abs(jac_T)) if cme_stderr is not None else None
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
d_t_orig = T.shape[1:][0] if T.shape[1:] else 1
|
||||
|
||||
return NormalInferenceResults(d_t=d_t_orig, d_y=d_y, pred=e_pred,
|
||||
pred_stderr=e_stderr, mean_pred_stderr=None, inf_type='effect',
|
||||
feature_names=self._est.cate_feature_names(),
|
||||
output_names=self._est.cate_output_names())
|
||||
|
||||
def marginal_effect_interval(self, T, X, *, alpha=0.05):
|
||||
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
|
||||
|
||||
|
||||
class LinearModelFinalInference(GenericModelFinalInference):
|
||||
"""
|
||||
|
@ -274,6 +310,68 @@ class LinearModelFinalInference(GenericModelFinalInference):
|
|||
inf_res.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
|
||||
return inf_res
|
||||
|
||||
def marginal_effect_inference(self, T, X):
|
||||
X, T = self._est._expand_treatments(X, T, transform=False)
|
||||
if not self._est._original_treatment_featurizer:
|
||||
return self.const_marginal_effect_inference(X)
|
||||
|
||||
if X is None:
|
||||
X = np.ones((T.shape[0], 1))
|
||||
elif self.featurizer is not None:
|
||||
X = self.featurizer.transform(X)
|
||||
|
||||
feat_T = self._est.transformer.transform(T)
|
||||
|
||||
jac_T = self._est.transformer.jac(T)
|
||||
|
||||
d_t_orig = T.shape[1:]
|
||||
d_t_orig = d_t_orig[0] if d_t_orig else 1
|
||||
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
|
||||
output_shape = [X.shape[0]]
|
||||
if self._d_y:
|
||||
output_shape.append(self._d_y[0])
|
||||
if T.shape[1:]:
|
||||
output_shape.append(T.shape[1])
|
||||
me_pred = np.zeros(shape=output_shape)
|
||||
me_stderr = np.zeros(shape=output_shape)
|
||||
mean_pred_stderr_res = np.zeros(shape=output_shape[1:])
|
||||
for i in range(d_t_orig):
|
||||
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
|
||||
jac_index = [slice(None)]
|
||||
me_index = [slice(None)]
|
||||
if self._d_y:
|
||||
me_index.append(slice(None))
|
||||
if T.shape[1:]:
|
||||
jac_index.append(i)
|
||||
me_index.append(i)
|
||||
if feat_T.shape[1:]: # if featurized T is not a vector
|
||||
jac_index.append(slice(None))
|
||||
|
||||
XT = cross_product(X, jac_T[tuple(jac_index)])
|
||||
e_pred = self._predict(XT).reshape(X.shape[:1] + self._d_y) # enforce output shape
|
||||
e_stderr = self._prediction_stderr(XT).reshape(X.shape[:1] + self._d_y)
|
||||
|
||||
mean_XT = XT.mean(axis=0, keepdims=True)
|
||||
mean_pred_stderr = self._prediction_stderr(mean_XT) # shape[0] will always be 1 here
|
||||
# squeeze the first axis
|
||||
mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0) if mean_pred_stderr is not None else None
|
||||
if mean_pred_stderr is not None:
|
||||
mean_pred_stderr_res[tuple(me_index[1:])] = mean_pred_stderr
|
||||
|
||||
me_pred[tuple(me_index)] = e_pred
|
||||
me_stderr[tuple(me_index)] = e_stderr
|
||||
|
||||
return NormalInferenceResults(d_t=d_t_orig, d_y=d_y, pred=me_pred,
|
||||
pred_stderr=me_stderr, mean_pred_stderr=mean_pred_stderr_res, inf_type='effect',
|
||||
feature_names=self._est.cate_feature_names(),
|
||||
output_names=self._est.cate_output_names())
|
||||
|
||||
def marginal_effect_interval(self, T, X, *, alpha=0.05):
|
||||
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
|
||||
|
||||
def coef__interval(self, *, alpha=0.05):
|
||||
lo, hi = self.model_final.coef__interval(alpha)
|
||||
lo_int, hi_int = self.model_final.intercept__interval(alpha)
|
||||
|
@ -817,7 +915,7 @@ class InferenceResults(metaclass=abc.ABCMeta):
|
|||
The mean value of the metric you'd like to test under null hypothesis.
|
||||
decimals: optinal int (default=3)
|
||||
Number of decimal places to round each column to.
|
||||
tol: optinal float (default=0.001)
|
||||
tol: optional float (default=0.001)
|
||||
The stopping criterion. The iterations will stop when the outcome is less than ``tol``
|
||||
output_names: optional list of strings or None (default is None)
|
||||
The names of the outputs
|
||||
|
|
|
@ -237,6 +237,11 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
discrete_instrument: bool, optional, default False
|
||||
Whether the instrument values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
|
@ -338,6 +343,7 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
featurizer=None,
|
||||
fit_cate_intercept=True,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
discrete_instrument=False,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
|
@ -354,6 +360,7 @@ class OrthoIV(LinearModelFinalCateEstimatorMixin, _OrthoLearner):
|
|||
|
||||
super().__init__(discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -1039,6 +1046,11 @@ class DMLIV(_BaseDMLIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -1129,6 +1141,7 @@ class DMLIV(_BaseDMLIV):
|
|||
featurizer=None,
|
||||
fit_cate_intercept=True,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
discrete_instrument=False,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
|
@ -1142,6 +1155,7 @@ class DMLIV(_BaseDMLIV):
|
|||
self.featurizer = clone(featurizer, safe=False)
|
||||
self.fit_cate_intercept = fit_cate_intercept
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
discrete_instrument=discrete_instrument,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
|
@ -1282,14 +1296,30 @@ class DMLIV(_BaseDMLIV):
|
|||
feature_names = self.cate_feature_names(feature_names)
|
||||
# Summary
|
||||
smry = Summary()
|
||||
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
|
||||
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
|
||||
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
|
||||
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
|
||||
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
|
||||
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
|
||||
"each outcome $i$ and treatment $j$. "
|
||||
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
|
||||
|
||||
extra_txt = ["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:"]
|
||||
|
||||
if self._original_treatment_featurizer:
|
||||
extra_txt.append("$Y = \\Theta(X)\\cdot \\psi(T) + g(X, W) + \\epsilon$")
|
||||
extra_txt.append("where $\\psi(T)$ is the output of the `treatment_featurizer")
|
||||
extra_txt.append(
|
||||
"and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
|
||||
else:
|
||||
extra_txt.append("$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$")
|
||||
extra_txt.append(
|
||||
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:")
|
||||
|
||||
if self.featurizer:
|
||||
extra_txt.append("$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$")
|
||||
extra_txt.append("where $\\phi(X)$ is the output of the `featurizer`")
|
||||
else:
|
||||
extra_txt.append("$\\Theta_{ij}(X) = X' coef_{ij} + cate\\_intercept_{ij}$")
|
||||
|
||||
extra_txt.append("Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
|
||||
"each outcome $i$ and treatment $j$. "
|
||||
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>")
|
||||
|
||||
smry.add_extra_txt(extra_txt)
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
|
||||
|
@ -1403,6 +1433,11 @@ class NonParamDMLIV(_BaseDMLIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
discrete_instrument: bool, optional, default False
|
||||
Whether the instrument values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
|
@ -1495,6 +1530,7 @@ class NonParamDMLIV(_BaseDMLIV):
|
|||
model_t_xwz="auto",
|
||||
model_final,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
discrete_instrument=False,
|
||||
featurizer=None,
|
||||
categories='auto',
|
||||
|
@ -1509,6 +1545,7 @@ class NonParamDMLIV(_BaseDMLIV):
|
|||
self.featurizer = clone(featurizer, safe=False)
|
||||
super().__init__(discrete_treatment=discrete_treatment,
|
||||
discrete_instrument=discrete_instrument,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
|
|
@ -303,6 +303,7 @@ class _BaseDRIV(_OrthoLearner):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -315,6 +316,7 @@ class _BaseDRIV(_OrthoLearner):
|
|||
self.opt_reweighted = opt_reweighted
|
||||
super().__init__(discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -344,6 +346,8 @@ class _BaseDRIV(_OrthoLearner):
|
|||
if len(T1.shape) > 1 and T1.shape[1] > 1:
|
||||
if self.discrete_treatment:
|
||||
raise AttributeError("DRIV only supports binary treatments")
|
||||
elif self.treatment_featurizer: # defer possible failure to downstream logic
|
||||
pass
|
||||
else:
|
||||
raise AttributeError("DRIV only supports single-dimensional continuous treatments")
|
||||
if len(Z1.shape) > 1 and Z1.shape[1] > 1:
|
||||
|
@ -548,6 +552,7 @@ class _DRIV(_BaseDRIV):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -567,6 +572,7 @@ class _DRIV(_BaseDRIV):
|
|||
opt_reweighted=opt_reweighted,
|
||||
discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -733,6 +739,11 @@ class DRIV(_DRIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -828,6 +839,7 @@ class DRIV(_DRIV):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -855,6 +867,7 @@ class DRIV(_DRIV):
|
|||
opt_reweighted=opt_reweighted,
|
||||
discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -1177,6 +1190,11 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -1283,6 +1301,7 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -1305,6 +1324,7 @@ class LinearDRIV(StatsModelsCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=opt_reweighted,
|
||||
discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -1493,6 +1513,11 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -1606,6 +1631,7 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -1635,6 +1661,7 @@ class SparseLinearDRIV(DebiasedLassoCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=opt_reweighted,
|
||||
discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
@ -1897,6 +1924,11 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
|
|||
discrete_treatment: bool, optional, default False
|
||||
Whether the treatment values should be treated as categorical, rather than continuous, quantities
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories: 'auto' or list, default 'auto'
|
||||
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
|
||||
The first category will be treated as the control treatment.
|
||||
|
@ -2006,6 +2038,7 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=False,
|
||||
discrete_instrument=False,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
cv=2,
|
||||
mc_iters=None,
|
||||
|
@ -2041,6 +2074,7 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV):
|
|||
opt_reweighted=opt_reweighted,
|
||||
discrete_instrument=discrete_instrument,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
cv=cv,
|
||||
mc_iters=mc_iters,
|
||||
|
|
|
@ -40,8 +40,8 @@ from ._causal_tree import CausalTree
|
|||
from ..inference import NormalInferenceResults
|
||||
from ..inference._inference import Inference
|
||||
from ..utilities import (reshape, reshape_Y_T, MAX_RAND_SEED, check_inputs, _deprecate_positional,
|
||||
cross_product, inverse_onehot, check_input_arrays,
|
||||
_RegressionWrapper, deprecated)
|
||||
cross_product, inverse_onehot, check_input_arrays, jacify_featurizer,
|
||||
_RegressionWrapper, deprecated, ndim)
|
||||
from sklearn.model_selection import check_cv
|
||||
# TODO: consider working around relying on sklearn implementation details
|
||||
from ..sklearn_extensions.model_selection import _cross_val_predict
|
||||
|
@ -209,6 +209,7 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
second_stage_parameter_estimator,
|
||||
moment_and_mean_gradient_estimator,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
n_trees=500,
|
||||
min_leaf_size=10, max_depth=10,
|
||||
|
@ -244,6 +245,7 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
# Fit check
|
||||
self.model_is_fitted = False
|
||||
self.discrete_treatment = discrete_treatment
|
||||
self.treatment_featurizer = treatment_featurizer
|
||||
self.backend = backend
|
||||
self.verbose = verbose
|
||||
self.batch_size = batch_size
|
||||
|
@ -312,7 +314,8 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
|
|||
|
||||
Returns
|
||||
-------
|
||||
Theta : matrix , shape (n, d_t)
|
||||
Theta : matrix , shape (n, d_f_t) where d_f_t is \
|
||||
the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t
|
||||
Constant marginal CATE of each treatment for each sample.
|
||||
"""
|
||||
# TODO: Check performance
|
||||
|
@ -501,6 +504,11 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
one-hot-encoded and the model_T is treated as a classifier that must have a predict_proba
|
||||
method.
|
||||
|
||||
treatment_featurizer : :term:`transformer`, optional
|
||||
Must support fit_transform and transform. Used to create composite treatment in the final CATE regression.
|
||||
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
|
||||
If featurizer=None, then CATE is trained on T.
|
||||
|
||||
categories : array like or 'auto', optional (default='auto')
|
||||
A list of pre-specified treatment categories. If 'auto' then categories are automatically
|
||||
recognized at fit time.
|
||||
|
@ -540,6 +548,7 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
global_residualization=False,
|
||||
global_res_cv=2,
|
||||
discrete_treatment=False,
|
||||
treatment_featurizer=None,
|
||||
categories='auto',
|
||||
n_jobs=-1,
|
||||
backend='loky',
|
||||
|
@ -567,6 +576,7 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
self.random_state = check_random_state(random_state)
|
||||
self.global_residualization = global_residualization
|
||||
self.global_res_cv = global_res_cv
|
||||
self.treatment_featurizer = treatment_featurizer
|
||||
# Define nuisance estimators
|
||||
nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
|
||||
self.model_T, self.model_Y, self.random_state, second_stage=False,
|
||||
|
@ -580,6 +590,7 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
self.lambda_reg)
|
||||
# Define
|
||||
moment_and_mean_gradient_estimator = _DMLOrthoForest_moment_and_mean_gradient_estimator_func
|
||||
|
||||
super().__init__(
|
||||
nuisance_estimator,
|
||||
second_stage_nuisance_estimator,
|
||||
|
@ -596,6 +607,7 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
verbose=verbose,
|
||||
batch_size=batch_size,
|
||||
discrete_treatment=discrete_treatment,
|
||||
treatment_featurizer=treatment_featurizer,
|
||||
categories=categories,
|
||||
random_state=self.random_state)
|
||||
|
||||
|
@ -635,6 +647,9 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
"""
|
||||
self._set_input_names(Y, T, X, set_flag=True)
|
||||
Y, T, X, W = check_inputs(Y, T, X, W)
|
||||
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
|
||||
"is not supported when treatment is discrete"
|
||||
|
||||
if self.discrete_treatment:
|
||||
categories = self.categories
|
||||
if categories != 'auto':
|
||||
|
@ -643,6 +658,12 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
d_t_in = T.shape[1:]
|
||||
T = self.transformer.fit_transform(T.reshape(-1, 1))
|
||||
self._d_t = T.shape[1:]
|
||||
elif self.treatment_featurizer:
|
||||
self._original_treatment_featurizer = clone(self.treatment_featurizer, safe=False)
|
||||
self.transformer = jacify_featurizer(self.treatment_featurizer)
|
||||
d_t_in = T.shape[1:]
|
||||
T = self.transformer.fit_transform(T)
|
||||
self._d_t = np.shape(T)[1:]
|
||||
|
||||
if self.global_residualization:
|
||||
cv = check_cv(self.global_res_cv, y=T, classifier=self.discrete_treatment)
|
||||
|
@ -654,7 +675,7 @@ class DMLOrthoForest(BaseOrthoForest):
|
|||
|
||||
# weirdness of wrap_fit. We need to store d_t_in. But because wrap_fit decorates the parent
|
||||
# fit, we need to set explicitly d_t_in here after super fit is called.
|
||||
if self.discrete_treatment:
|
||||
if self.discrete_treatment or self.treatment_featurizer:
|
||||
self._d_t_in = d_t_in
|
||||
return self
|
||||
|
||||
|
@ -995,12 +1016,44 @@ class DROrthoForest(BaseOrthoForest):
|
|||
self._d_t_in = d_t_in
|
||||
return self
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_effect(self, X):
|
||||
"""Calculate the constant marginal CATE θ(·) conditional on a vector of features X.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, shape (n, d_x)
|
||||
Feature vector that captures heterogeneity.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Theta : matrix , shape (n, d_t)
|
||||
Constant marginal CATE of each treatment for each sample.
|
||||
"""
|
||||
X = check_array(X)
|
||||
# Override to flatten output if T is flat
|
||||
effects = super().const_marginal_effect(X=X)
|
||||
return effects.reshape((-1,) + self._d_y + self._d_t)
|
||||
const_marginal_effect.__doc__ = BaseOrthoForest.const_marginal_effect.__doc__
|
||||
|
||||
# override only so that we can exclude treatment featurization verbiage in docstring
|
||||
def const_marginal_ate(self, X=None):
|
||||
"""
|
||||
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X: optional (m, d_x) matrix or None (Default=None)
|
||||
Features for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
theta: (d_y, d_t) matrix
|
||||
Average constant marginal CATE of each treatment on each outcome.
|
||||
Note that when Y or T is a vector rather than a 2-dimensional array,
|
||||
the corresponding singleton dimensions in the output will be collapsed
|
||||
(e.g. if both are vectors, then the output of this method will be a scalar)
|
||||
"""
|
||||
return super().const_marginal_ate(X=X)
|
||||
|
||||
@staticmethod
|
||||
def nuisance_estimator_generator(propensity_model, model_Y, random_state=None, second_stage=False):
|
||||
|
@ -1162,6 +1215,10 @@ class BLBInference(Inference):
|
|||
This is called after the estimator's fit.
|
||||
"""
|
||||
self._estimator = estimator
|
||||
self._d_t = estimator._d_t
|
||||
self._d_y = estimator._d_y
|
||||
self.d_t = self._d_t[0] if self._d_t else 1
|
||||
self.d_y = self._d_y[0] if self._d_y else 1
|
||||
# Test whether the input estimator is supported
|
||||
if not hasattr(self._estimator, "_predict"):
|
||||
raise TypeError("Unsupported estimator of type {}.".format(self._estimator.__class__.__name__) +
|
||||
|
@ -1300,6 +1357,81 @@ class BLBInference(Inference):
|
|||
output_names=self._estimator.cate_output_names(),
|
||||
treatment_names=self._estimator.cate_treatment_names())
|
||||
|
||||
def _marginal_effect_inference_helper(self, T, X):
|
||||
if not self._estimator._original_treatment_featurizer:
|
||||
return self.const_marginal_effect_inference(X)
|
||||
|
||||
X, T = check_input_arrays(X, T)
|
||||
X, T = self._estimator._expand_treatments(X, T, transform=False)
|
||||
|
||||
feat_T = self._estimator.transformer.transform(T)
|
||||
|
||||
jac_T = self._estimator.transformer.jac(T)
|
||||
|
||||
params, cov = zip(*(self._predict_wrapper(X)))
|
||||
params = np.array(params)
|
||||
cov = np.array(cov)
|
||||
|
||||
eff_einsum_str = 'mf, mtf-> mt'
|
||||
|
||||
# conditionally expand jacobian dimensions to align with einsum str
|
||||
jac_index = [slice(None), slice(None), slice(None)]
|
||||
if ndim(T) == 1:
|
||||
jac_index[1] = None
|
||||
if ndim(feat_T) == 1:
|
||||
jac_index[2] = None
|
||||
|
||||
# Calculate the effects
|
||||
eff = np.einsum(eff_einsum_str, params, jac_T[tuple(jac_index)])
|
||||
|
||||
# Calculate the standard deviations for the effects
|
||||
d_t_orig = T.shape[1:]
|
||||
d_t_orig = d_t_orig[0] if d_t_orig else 1
|
||||
self.d_t_orig = d_t_orig
|
||||
output_shape = [X.shape[0]]
|
||||
if T.shape[1:]:
|
||||
output_shape.append(T.shape[1])
|
||||
scales = np.zeros(shape=output_shape)
|
||||
|
||||
for i in range(d_t_orig):
|
||||
# conditionally index multiple dimensions depending on shapes of T, Y and feat_T
|
||||
jac_index = [slice(None)]
|
||||
me_index = [slice(None)]
|
||||
if T.shape[1:]:
|
||||
jac_index.append(i)
|
||||
me_index.append(i)
|
||||
|
||||
if feat_T.shape[1:]: # if featurized T is not a vector
|
||||
jac_index.append(slice(None))
|
||||
else:
|
||||
jac_index.append(None)
|
||||
|
||||
jac = jac_T[tuple(jac_index)]
|
||||
final = np.einsum('mj, mjk, mk -> m', jac, cov, jac)
|
||||
scales[tuple(me_index)] = final
|
||||
|
||||
eff = eff.reshape((-1,) + self._d_y + T.shape[1:])
|
||||
scales = scales.reshape((-1,) + self._d_y + T.shape[1:])
|
||||
return eff, scales
|
||||
|
||||
def marginal_effect_inference(self, T, X):
|
||||
if self._estimator._original_treatment_featurizer is None:
|
||||
return self.const_marginal_effect_inference(X)
|
||||
|
||||
eff, scales = self._marginal_effect_inference_helper(T, X)
|
||||
|
||||
d_y = self._d_y[0] if self._d_y else 1
|
||||
d_t = self._d_t[0] if self._d_t else 1
|
||||
|
||||
return NormalInferenceResults(d_t=self.d_t_orig, d_y=d_y,
|
||||
pred=eff, pred_stderr=scales, mean_pred_stderr=None, inf_type='effect',
|
||||
feature_names=self._estimator.cate_feature_names(),
|
||||
output_names=self._estimator.cate_output_names(),
|
||||
treatment_names=self._estimator.cate_treatment_names())
|
||||
|
||||
def marginal_effect_interval(self, T, X, *, alpha=0.05):
|
||||
return self.marginal_effect_inference(T, X).conf_int(alpha=alpha)
|
||||
|
||||
def _predict_wrapper(self, X=None):
|
||||
return self._estimator._predict(X, stderr=True)
|
||||
|
||||
|
|
|
@ -1988,7 +1988,11 @@ class StatsModels2SLS(_StatsModelsWrapper):
|
|||
|
||||
# check dimension of instruments is more than dimension of treatments
|
||||
if Z.shape[1] < T.shape[1]:
|
||||
raise AssertionError("The number of treatments couldn't be larger than the number of instruments!")
|
||||
raise AssertionError("The number of treatments couldn't be larger than the number of instruments!" +
|
||||
" If you are using a treatment featurizer, make sure the number of featurized" +
|
||||
" treatments is less than or equal to the number of instruments. You can either" +
|
||||
" featurize the instrument yourself, or consider using projection = True" +
|
||||
" along with a flexible model_t_xwz.")
|
||||
|
||||
# weight X and y
|
||||
weighted_Z = Z * np.sqrt(sample_weight).reshape(-1, 1)
|
||||
|
|
|
@ -40,6 +40,10 @@ def rand_sol(A, b):
|
|||
class TestDML(unittest.TestCase):
|
||||
|
||||
def test_cate_api(self):
|
||||
treatment_featurizations = [None]
|
||||
self._test_cate_api(treatment_featurizations)
|
||||
|
||||
def _test_cate_api(self, treatment_featurizations):
|
||||
"""Test that we correctly implement the CATE API."""
|
||||
n_c = 20 # number of rows for continuous models
|
||||
n_d = 30 # number of rows for discrete models
|
||||
|
@ -63,284 +67,293 @@ class TestDML(unittest.TestCase):
|
|||
|
||||
for d_t in [2, 1, -1]:
|
||||
for is_discrete in [True, False] if d_t <= 1 else [False]:
|
||||
for d_y in [3, 1, -1]:
|
||||
for d_x in [2, None]:
|
||||
for d_w in [2, None]:
|
||||
n = n_d if is_discrete else n_c
|
||||
W, X, Y, T = [make_random(n, is_discrete, d)
|
||||
for is_discrete, d in [(False, d_w),
|
||||
(False, d_x),
|
||||
(False, d_y),
|
||||
(is_discrete, d_t)]]
|
||||
for treatment_featurizer in treatment_featurizations:
|
||||
for d_y in [3, 1, -1]:
|
||||
for d_x in [2, None]:
|
||||
for d_w in [2, None]:
|
||||
n = n_d if is_discrete else n_c
|
||||
W, X, Y, T = [make_random(n, is_discrete, d)
|
||||
for is_discrete, d in [(False, d_w),
|
||||
(False, d_x),
|
||||
(False, d_y),
|
||||
(is_discrete, d_t)]]
|
||||
|
||||
for featurizer, fit_cate_intercept in\
|
||||
[(None, True),
|
||||
(PolynomialFeatures(degree=2, include_bias=False), True),
|
||||
(PolynomialFeatures(degree=2, include_bias=True), False)]:
|
||||
for featurizer, fit_cate_intercept in\
|
||||
[(None, True),
|
||||
(PolynomialFeatures(degree=2, include_bias=False), True),
|
||||
(PolynomialFeatures(degree=2, include_bias=True), False)]:
|
||||
|
||||
d_t_final = 2 if is_discrete else d_t
|
||||
|
||||
effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
|
||||
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
|
||||
marginal_effect_shape = ((n,) +
|
||||
((d_y,) if d_y > 0 else ()) +
|
||||
((d_t_final,) if d_t_final > 0 else ()))
|
||||
marginal_effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1) *
|
||||
(d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
# since T isn't passed to const_marginal_effect, defaults to one row if X is None
|
||||
const_marginal_effect_shape = ((n if d_x else 1,) +
|
||||
((d_y,) if d_y > 0 else ()) +
|
||||
((d_t_final,) if d_t_final > 0 else ()))
|
||||
const_marginal_effect_summaryframe_shape = (
|
||||
(n if d_x else 1) * (d_y if d_y > 0 else 1) *
|
||||
(d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
fd_x = featurizer.fit_transform(X).shape[1:] if featurizer and d_x\
|
||||
else ((d_x,) if d_x else (0,))
|
||||
coef_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,)) + fd_x
|
||||
|
||||
coef_summaryframe_shape = (
|
||||
(d_y if d_y > 0 else 1) * (fd_x[0] if fd_x[0] >
|
||||
0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
|
||||
intercept_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,))
|
||||
intercept_summaryframe_shape = (
|
||||
(d_y if d_y > 0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
model_t = LogisticRegression() if is_discrete else Lasso()
|
||||
|
||||
all_infs = [None, 'auto', BootstrapInference(2)]
|
||||
|
||||
for est, multi, infs in\
|
||||
[(DML(model_y=Lasso(),
|
||||
model_t=model_t,
|
||||
model_final=Lasso(alpha=0.1, fit_intercept=False),
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
[None] +
|
||||
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
|
||||
(DML(model_y=Lasso(),
|
||||
model_t=model_t,
|
||||
model_final=StatsModelsRLM(fit_intercept=False),
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
['auto']),
|
||||
(LinearDML(model_y=Lasso(),
|
||||
model_t='auto',
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
all_infs),
|
||||
(SparseLinearDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
[None, 'auto'] +
|
||||
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
|
||||
(KernelDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete),
|
||||
False,
|
||||
[None]),
|
||||
(CausalForestDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
featurizer=featurizer,
|
||||
n_estimators=4,
|
||||
n_jobs=1,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
['auto', 'blb'])]:
|
||||
|
||||
if not (multi) and d_y > 1:
|
||||
if is_discrete and treatment_featurizer:
|
||||
continue
|
||||
|
||||
if X is None and isinstance(est, CausalForestDML):
|
||||
continue
|
||||
d_t_final = 2 if is_discrete else d_t
|
||||
|
||||
# ensure we can serialize the unfit estimator
|
||||
pickle.dumps(est)
|
||||
effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
|
||||
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
|
||||
marginal_effect_shape = ((n,) +
|
||||
((d_y,) if d_y > 0 else ()) +
|
||||
((d_t_final,) if d_t_final > 0 else ()))
|
||||
marginal_effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1) *
|
||||
(d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
for inf in infs:
|
||||
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
|
||||
is_discrete=is_discrete, est=est, inf=inf):
|
||||
# since T isn't passed to const_marginal_effect, defaults to one row if X is None
|
||||
const_marginal_effect_shape = ((n if d_x else 1,) +
|
||||
((d_y,) if d_y > 0 else ()) +
|
||||
((d_t_final,) if d_t_final > 0 else ()))
|
||||
const_marginal_effect_summaryframe_shape = (
|
||||
(n if d_x else 1) * (d_y if d_y > 0 else 1) *
|
||||
(d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
if X is None and (not fit_cate_intercept):
|
||||
with pytest.raises(AttributeError):
|
||||
est.fit(Y, T, X=X, W=W, inference=inf)
|
||||
continue
|
||||
fd_x = featurizer.fit_transform(X).shape[1:] if featurizer and d_x\
|
||||
else ((d_x,) if d_x else (0,))
|
||||
coef_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,)) + fd_x
|
||||
|
||||
est.fit(Y, T, X=X, W=W, inference=inf)
|
||||
coef_summaryframe_shape = (
|
||||
(d_y if d_y > 0 else 1) * (fd_x[0] if fd_x[0] >
|
||||
0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
|
||||
intercept_shape = Y.shape[1:] + (T.shape[1:] if not is_discrete else (2,))
|
||||
intercept_summaryframe_shape = (
|
||||
(d_y if d_y > 0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
|
||||
|
||||
# ensure we can pickle the fit estimator
|
||||
pickle.dumps(est)
|
||||
model_t = LogisticRegression() if is_discrete else Lasso()
|
||||
|
||||
# make sure we can call the marginal_effect and effect methods
|
||||
const_marg_eff = est.const_marginal_effect(X)
|
||||
marg_eff = est.marginal_effect(T, X)
|
||||
self.assertEqual(shape(marg_eff), marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_eff), const_marginal_effect_shape)
|
||||
all_infs = [None, 'auto', BootstrapInference(2)]
|
||||
|
||||
np.testing.assert_allclose(
|
||||
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
|
||||
for est, multi, infs in\
|
||||
[(DML(model_y=Lasso(),
|
||||
model_t=model_t,
|
||||
model_final=Lasso(alpha=0.1, fit_intercept=False),
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete,
|
||||
treatment_featurizer=treatment_featurizer),
|
||||
True,
|
||||
[None] +
|
||||
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
|
||||
(DML(model_y=Lasso(),
|
||||
model_t=model_t,
|
||||
model_final=StatsModelsRLM(fit_intercept=False),
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete,
|
||||
treatment_featurizer=treatment_featurizer),
|
||||
True,
|
||||
['auto']),
|
||||
(LinearDML(model_y=Lasso(),
|
||||
model_t='auto',
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete,
|
||||
treatment_featurizer=treatment_featurizer),
|
||||
True,
|
||||
all_infs),
|
||||
(SparseLinearDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
featurizer=featurizer,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete,
|
||||
treatment_featurizer=treatment_featurizer),
|
||||
True,
|
||||
[None, 'auto'] +
|
||||
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
|
||||
(KernelDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
fit_cate_intercept=fit_cate_intercept,
|
||||
discrete_treatment=is_discrete,
|
||||
treatment_featurizer=treatment_featurizer),
|
||||
False,
|
||||
[None]),
|
||||
(CausalForestDML(model_y=WeightedLasso(),
|
||||
model_t=model_t,
|
||||
featurizer=featurizer,
|
||||
n_estimators=4,
|
||||
n_jobs=1,
|
||||
discrete_treatment=is_discrete),
|
||||
True,
|
||||
['auto', 'blb'])]:
|
||||
|
||||
assert isinstance(est.score_, float)
|
||||
for score_list in est.nuisance_scores_y:
|
||||
for score in score_list:
|
||||
assert isinstance(score, float)
|
||||
for score_list in est.nuisance_scores_t:
|
||||
for score in score_list:
|
||||
assert isinstance(score, float)
|
||||
if not (multi) and d_y > 1:
|
||||
continue
|
||||
|
||||
T0 = np.full_like(T, 'a') if is_discrete else np.zeros_like(T)
|
||||
eff = est.effect(X, T0=T0, T1=T)
|
||||
self.assertEqual(shape(eff), effect_shape)
|
||||
if X is None and isinstance(est, CausalForestDML):
|
||||
continue
|
||||
|
||||
if ((not isinstance(est, KernelDML)) and
|
||||
(not isinstance(est, CausalForestDML))):
|
||||
self.assertEqual(shape(est.coef_), coef_shape)
|
||||
if fit_cate_intercept:
|
||||
self.assertEqual(shape(est.intercept_), intercept_shape)
|
||||
else:
|
||||
# ensure we can serialize the unfit estimator
|
||||
pickle.dumps(est)
|
||||
|
||||
for inf in infs:
|
||||
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
|
||||
is_discrete=is_discrete, est=est, inf=inf):
|
||||
|
||||
if X is None and (not fit_cate_intercept):
|
||||
with pytest.raises(AttributeError):
|
||||
self.assertEqual(shape(est.intercept_), intercept_shape)
|
||||
est.fit(Y, T, X=X, W=W, inference=inf)
|
||||
continue
|
||||
|
||||
est.fit(Y, T, X=X, W=W, inference=inf)
|
||||
|
||||
# ensure we can pickle the fit estimator
|
||||
pickle.dumps(est)
|
||||
|
||||
# make sure we can call the marginal_effect and effect methods
|
||||
const_marg_eff = est.const_marginal_effect(X)
|
||||
marg_eff = est.marginal_effect(T, X)
|
||||
self.assertEqual(shape(marg_eff), marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_eff), const_marginal_effect_shape)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
|
||||
|
||||
assert isinstance(est.score_, float)
|
||||
for score_list in est.nuisance_scores_y:
|
||||
for score in score_list:
|
||||
assert isinstance(score, float)
|
||||
for score_list in est.nuisance_scores_t:
|
||||
for score in score_list:
|
||||
assert isinstance(score, float)
|
||||
|
||||
T0 = np.full_like(T, 'a') if is_discrete else np.zeros_like(T)
|
||||
eff = est.effect(X, T0=T0, T1=T)
|
||||
self.assertEqual(shape(eff), effect_shape)
|
||||
|
||||
if inf is not None:
|
||||
const_marg_eff_int = est.const_marginal_effect_interval(X)
|
||||
marg_eff_int = est.marginal_effect_interval(T, X)
|
||||
self.assertEqual(shape(marg_eff_int),
|
||||
(2,) + marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_eff_int),
|
||||
(2,) + const_marginal_effect_shape)
|
||||
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
|
||||
(2,) + effect_shape)
|
||||
if ((not isinstance(est, KernelDML)) and
|
||||
(not isinstance(est, CausalForestDML))):
|
||||
self.assertEqual(shape(est.coef__interval()),
|
||||
(2,) + coef_shape)
|
||||
self.assertEqual(shape(est.coef_), coef_shape)
|
||||
if fit_cate_intercept:
|
||||
self.assertEqual(shape(est.intercept__interval()),
|
||||
(2,) + intercept_shape)
|
||||
self.assertEqual(shape(est.intercept_), intercept_shape)
|
||||
else:
|
||||
with pytest.raises(AttributeError):
|
||||
self.assertEqual(shape(est.intercept_), intercept_shape)
|
||||
|
||||
if inf is not None:
|
||||
const_marg_eff_int = est.const_marginal_effect_interval(X)
|
||||
marg_eff_int = est.marginal_effect_interval(T, X)
|
||||
self.assertEqual(shape(marg_eff_int),
|
||||
(2,) + marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_eff_int),
|
||||
(2,) + const_marginal_effect_shape)
|
||||
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
|
||||
(2,) + effect_shape)
|
||||
if ((not isinstance(est, KernelDML)) and
|
||||
(not isinstance(est, CausalForestDML))):
|
||||
self.assertEqual(shape(est.coef__interval()),
|
||||
(2,) + coef_shape)
|
||||
if fit_cate_intercept:
|
||||
self.assertEqual(shape(est.intercept__interval()),
|
||||
(2,) + intercept_shape)
|
||||
else:
|
||||
with pytest.raises(AttributeError):
|
||||
self.assertEqual(shape(est.intercept__interval()),
|
||||
(2,) + intercept_shape)
|
||||
|
||||
const_marg_effect_inf = est.const_marginal_effect_inference(X)
|
||||
T1 = np.full_like(T, 'b') if is_discrete else T
|
||||
effect_inf = est.effect_inference(X, T0=T0, T1=T1)
|
||||
marg_effect_inf = est.marginal_effect_inference(T, X)
|
||||
# test const marginal inference
|
||||
self.assertEqual(shape(const_marg_effect_inf.summary_frame()),
|
||||
const_marginal_effect_summaryframe_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.point_estimate),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.stderr),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.var),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.pvalue()),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.zstat()),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
|
||||
(2,) + const_marginal_effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
const_marg_effect_inf.conf_int()[0],
|
||||
const_marg_eff_int[0], decimal=5)
|
||||
const_marg_effect_inf.population_summary()._repr_html_()
|
||||
const_marg_effect_inf = est.const_marginal_effect_inference(X)
|
||||
T1 = np.full_like(T, 'b') if is_discrete else T
|
||||
effect_inf = est.effect_inference(X, T0=T0, T1=T1)
|
||||
marg_effect_inf = est.marginal_effect_inference(T, X)
|
||||
# test const marginal inference
|
||||
self.assertEqual(shape(const_marg_effect_inf.summary_frame()),
|
||||
const_marginal_effect_summaryframe_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.point_estimate),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.stderr),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.var),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.pvalue()),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.zstat()),
|
||||
const_marginal_effect_shape)
|
||||
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
|
||||
(2,) + const_marginal_effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
const_marg_effect_inf.conf_int()[0],
|
||||
const_marg_eff_int[0], decimal=5)
|
||||
const_marg_effect_inf.population_summary()._repr_html_()
|
||||
|
||||
# test effect inference
|
||||
self.assertEqual(shape(effect_inf.summary_frame()),
|
||||
effect_summaryframe_shape)
|
||||
self.assertEqual(shape(effect_inf.point_estimate),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.stderr),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.var),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.pvalue()),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.zstat()),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.conf_int()),
|
||||
(2,) + effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
effect_inf.conf_int()[0],
|
||||
est.effect_interval(X, T0=T0, T1=T1)[0], decimal=5)
|
||||
effect_inf.population_summary()._repr_html_()
|
||||
# test effect inference
|
||||
self.assertEqual(shape(effect_inf.summary_frame()),
|
||||
effect_summaryframe_shape)
|
||||
self.assertEqual(shape(effect_inf.point_estimate),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.stderr),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.var),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.pvalue()),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.zstat()),
|
||||
effect_shape)
|
||||
self.assertEqual(shape(effect_inf.conf_int()),
|
||||
(2,) + effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
effect_inf.conf_int()[0],
|
||||
est.effect_interval(X, T0=T0, T1=T1)[0], decimal=5)
|
||||
effect_inf.population_summary()._repr_html_()
|
||||
|
||||
# test marginal effect inference
|
||||
self.assertEqual(shape(marg_effect_inf.summary_frame()),
|
||||
marginal_effect_summaryframe_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.point_estimate),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.stderr),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.var),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.pvalue()),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.zstat()),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.conf_int()),
|
||||
(2,) + marginal_effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
marg_effect_inf.conf_int()[0], marg_eff_int[0], decimal=5)
|
||||
marg_effect_inf.population_summary()._repr_html_()
|
||||
# test marginal effect inference
|
||||
self.assertEqual(shape(marg_effect_inf.summary_frame()),
|
||||
marginal_effect_summaryframe_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.point_estimate),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.stderr),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.var),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.pvalue()),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.zstat()),
|
||||
marginal_effect_shape)
|
||||
self.assertEqual(shape(marg_effect_inf.conf_int()),
|
||||
(2,) + marginal_effect_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
marg_effect_inf.conf_int()[0], marg_eff_int[0], decimal=5)
|
||||
marg_effect_inf.population_summary()._repr_html_()
|
||||
|
||||
# test coef__inference and intercept__inference
|
||||
if ((not isinstance(est, KernelDML)) and
|
||||
(not isinstance(est, CausalForestDML))):
|
||||
if X is not None:
|
||||
self.assertEqual(
|
||||
shape(est.coef__inference().summary_frame()),
|
||||
coef_summaryframe_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
est.coef__inference().conf_int()
|
||||
[0], est.coef__interval()[0], decimal=5)
|
||||
# test coef__inference and intercept__inference
|
||||
if ((not isinstance(est, KernelDML)) and
|
||||
(not isinstance(est, CausalForestDML))):
|
||||
if X is not None:
|
||||
self.assertEqual(
|
||||
shape(est.coef__inference().summary_frame()),
|
||||
coef_summaryframe_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
est.coef__inference().conf_int()
|
||||
[0], est.coef__interval()[0], decimal=5)
|
||||
|
||||
if fit_cate_intercept:
|
||||
cm = ExitStack()
|
||||
# ExitStack can be used as a "do nothing" ContextManager
|
||||
else:
|
||||
cm = pytest.raises(AttributeError)
|
||||
with cm:
|
||||
self.assertEqual(shape(est.intercept__inference().
|
||||
summary_frame()),
|
||||
intercept_summaryframe_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
est.intercept__inference().conf_int()
|
||||
[0], est.intercept__interval()[0], decimal=5)
|
||||
if fit_cate_intercept:
|
||||
cm = ExitStack()
|
||||
# ExitStack can be used as a "do nothing" ContextManager
|
||||
else:
|
||||
cm = pytest.raises(AttributeError)
|
||||
with cm:
|
||||
self.assertEqual(shape(est.intercept__inference().
|
||||
summary_frame()),
|
||||
intercept_summaryframe_shape)
|
||||
np.testing.assert_array_almost_equal(
|
||||
est.intercept__inference().conf_int()
|
||||
[0], est.intercept__interval()[0], decimal=5)
|
||||
|
||||
est.summary()
|
||||
est.summary()
|
||||
|
||||
est.score(Y, T, X, W)
|
||||
est.score(Y, T, X, W)
|
||||
|
||||
if isinstance(est, CausalForestDML):
|
||||
np.testing.assert_array_equal(est.feature_importances_.shape,
|
||||
((d_y,) if d_y > 0 else ()) + fd_x)
|
||||
if isinstance(est, CausalForestDML):
|
||||
np.testing.assert_array_equal(est.feature_importances_.shape,
|
||||
((d_y,) if d_y > 0 else ()) + fd_x)
|
||||
|
||||
# make sure we can call effect with implied scalar treatments,
|
||||
# no matter the dimensions of T, and also that we warn when there
|
||||
# are multiple treatments
|
||||
if d_t > 1:
|
||||
cm = self.assertWarns(Warning)
|
||||
else:
|
||||
# ExitStack can be used as a "do nothing" ContextManager
|
||||
cm = ExitStack()
|
||||
with cm:
|
||||
effect_shape2 = (n if d_x else 1,) + ((d_y,) if d_y > 0 else ())
|
||||
eff = est.effect(X) if not is_discrete else est.effect(
|
||||
X, T0='a', T1='b')
|
||||
self.assertEqual(shape(eff), effect_shape2)
|
||||
# make sure we can call effect with implied scalar treatments,
|
||||
# no matter the dimensions of T, and also that we warn when there
|
||||
# are multiple treatments
|
||||
if d_t > 1:
|
||||
cm = self.assertWarns(Warning)
|
||||
else:
|
||||
# ExitStack can be used as a "do nothing" ContextManager
|
||||
cm = ExitStack()
|
||||
with cm:
|
||||
effect_shape2 = (n if d_x else 1,) + ((d_y,) if d_y > 0 else ())
|
||||
eff = est.effect(X) if not is_discrete else est.effect(
|
||||
X, T0='a', T1='b')
|
||||
self.assertEqual(shape(eff), effect_shape2)
|
||||
|
||||
def test_cate_api_nonparam(self):
|
||||
"""Test that we correctly implement the CATE API."""
|
||||
|
@ -1183,3 +1196,38 @@ class TestDML(unittest.TestCase):
|
|||
est = LinearDML(model_y=LassoCV(cv=5), model_t=LassoCV(cv=5), cv=GroupKFold(2))
|
||||
with pytest.raises(Exception):
|
||||
est.fit(y, t, groups=groups)
|
||||
|
||||
def test_treatment_names(self):
|
||||
Y = np.random.normal(size=(100, 1))
|
||||
T = np.random.binomial(n=1, p=0.5, size=(100, 1))
|
||||
X = np.random.normal(size=(100, 3))
|
||||
|
||||
Ts = [
|
||||
T,
|
||||
pd.DataFrame(T, columns=[0])
|
||||
]
|
||||
|
||||
init_args_list = [
|
||||
{'discrete_treatment': True},
|
||||
{'treatment_featurizer': PolynomialFeatures(degree=2, include_bias=False)},
|
||||
{'treatment_featurizer': FunctionTransformer(lambda x: np.hstack([x, np.sqrt(x)]))},
|
||||
]
|
||||
|
||||
for T in Ts:
|
||||
for init_args in init_args_list:
|
||||
est = LinearDML(**init_args).fit(Y=Y, T=T, X=X)
|
||||
t_name = '0' if isinstance(T, pd.DataFrame) else 'T0' # default treatment name
|
||||
postfixes = ['_1'] if 'discrete_treatment' in init_args else ['', '^2'] # transformer postfixes
|
||||
|
||||
# Try default, integer, and new user-passed treatment name
|
||||
for new_treatment_name in [None, [999], ['NewTreatmentName']]:
|
||||
|
||||
# FunctionTransformers are agnostic to passed treatment names
|
||||
if isinstance(init_args.get('treatment_featurizer'), FunctionTransformer):
|
||||
assert (est.cate_treatment_names(new_treatment_name) == ['feat(T)0', 'feat(T)1'])
|
||||
|
||||
# Expected treatment names are the sums of user-passed prefixes and transformer-specific postfixes
|
||||
else:
|
||||
expected_prefix = str(new_treatment_name[0]) if new_treatment_name is not None else t_name
|
||||
assert (est.cate_treatment_names(new_treatment_name) == [
|
||||
expected_prefix + postfix for postfix in postfixes])
|
||||
|
|
|
@ -220,67 +220,70 @@ class TestOrthoForest(unittest.TestCase):
|
|||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
|
||||
|
||||
from sklearn.preprocessing import FunctionTransformer
|
||||
from sklearn.dummy import DummyClassifier, DummyRegressor
|
||||
for global_residualization in [False, True]:
|
||||
est = DMLOrthoForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
|
||||
model_T=DummyRegressor(strategy='mean'),
|
||||
global_residualization=global_residualization,
|
||||
n_jobs=1)
|
||||
est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 1), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
|
||||
est.fit(y.reshape(-1, 1), T, X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
print(lb.shape)
|
||||
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
|
||||
est.fit(y, T, X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3,), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3,), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3,), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3,), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
|
||||
for treatment_featurization in [None, FunctionTransformer()]:
|
||||
est = DMLOrthoForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
|
||||
model_T=DummyRegressor(strategy='mean'),
|
||||
global_residualization=global_residualization,
|
||||
treatment_featurizer=treatment_featurization,
|
||||
n_jobs=1)
|
||||
est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 1), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
|
||||
est.fit(y.reshape(-1, 1), T, X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
print(lb.shape)
|
||||
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
|
||||
est.fit(y, T, X=X)
|
||||
assert est.const_marginal_effect(X[:3]).shape == (3,), "Const Marginal Effect dimension incorrect"
|
||||
assert est.marginal_effect(1, X[:3]).shape == (3,), "Marginal Effect dimension incorrect"
|
||||
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
|
||||
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
|
||||
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
|
||||
assert lb.shape == (3,), "Effect interval dimension incorrect"
|
||||
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
|
||||
assert lb.shape == (3,), "Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_interval(X[:3])
|
||||
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
|
||||
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_interval(1, X[:3])
|
||||
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
|
||||
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
|
||||
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
|
||||
|
||||
def test_nuisance_model_has_weights(self):
|
||||
"""Test whether the correct exception is being raised if model_final doesn't have weights."""
|
||||
|
|
|
@ -170,8 +170,8 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
X = np.random.normal(size=(10000, 3))
|
||||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False, categories='auto',
|
||||
random_state=None)
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
|
||||
discrete_instrument=False, categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], W=X[:, 1:])
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
|
||||
np.testing.assert_array_almost_equal(est.effect(), np.ones(1), decimal=3)
|
||||
|
@ -187,7 +187,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
X = np.random.normal(size=(10000, 3))
|
||||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
# test non-array inputs
|
||||
est.fit(list(y), list(X[:, 0]), X=None, W=X[:, 1:])
|
||||
|
@ -204,7 +204,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=KFold(n_splits=3),
|
||||
discrete_treatment=False, discrete_instrument=False,
|
||||
discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], X=None, W=X[:, 1:])
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
|
||||
|
@ -220,7 +220,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
folds = [(np.arange(X.shape[0] // 2), np.arange(X.shape[0] // 2, X.shape[0]))]
|
||||
est = OrthoLearner(cv=folds, discrete_treatment=False,
|
||||
est = OrthoLearner(cv=folds, discrete_treatment=False, treatment_featurizer=None,
|
||||
discrete_instrument=False, categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], X=None, W=X[:, 1:])
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=2)
|
||||
|
@ -268,7 +268,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
X = np.random.normal(size=(10000, 3))
|
||||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], W=X[:, 1:])
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
|
||||
|
@ -318,7 +318,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
X = np.random.normal(size=(10000, 3))
|
||||
sigma = 0.1
|
||||
y = X[:, 0] + X[:, 1] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
|
||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
est.fit(y, X[:, 0], W=X[:, 1:])
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
|
||||
|
@ -380,7 +380,7 @@ class TestOrthoLearner(unittest.TestCase):
|
|||
T = np.random.binomial(1, scipy.special.expit(X[:, 0]))
|
||||
sigma = 0.01
|
||||
y = T + X[:, 0] + np.random.normal(0, sigma, size=(10000,))
|
||||
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
|
||||
est = OrthoLearner(cv=2, discrete_treatment=True, treatment_featurizer=None, discrete_instrument=False,
|
||||
categories='auto', random_state=None)
|
||||
est.fit(y, T, W=X)
|
||||
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
|
||||
|
|
|
@ -0,0 +1,580 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import pytest
|
||||
import unittest
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
from sklearn.linear_model import LinearRegression, LogisticRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from econml._ortho_learner import _OrthoLearner
|
||||
from econml.dml import LinearDML, SparseLinearDML, KernelDML, CausalForestDML, NonParamDML
|
||||
from econml.iv.dml import OrthoIV, DMLIV, NonParamDMLIV
|
||||
from econml.iv.dr import DRIV, LinearDRIV, SparseLinearDRIV, ForestDRIV
|
||||
from econml.orf import DMLOrthoForest
|
||||
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
|
||||
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
|
||||
|
||||
from econml.utilities import jacify_featurizer
|
||||
from econml.iv.sieve import DPolynomialFeatures
|
||||
|
||||
from econml.tests.test_dml import TestDML
|
||||
|
||||
from copy import deepcopy
|
||||
from econml.tests.test_dml import TestDML
|
||||
|
||||
|
||||
class DGP():
|
||||
def __init__(self,
|
||||
n=1000,
|
||||
d_t=1,
|
||||
d_y=1,
|
||||
d_x=5,
|
||||
d_z=None,
|
||||
squeeze_T=False,
|
||||
squeeze_Y=False,
|
||||
nuisance_Y=None,
|
||||
nuisance_T=None,
|
||||
nuisance_TZ=None,
|
||||
theta=None,
|
||||
y_of_t=None,
|
||||
x_eps=1,
|
||||
y_eps=1,
|
||||
t_eps=1
|
||||
):
|
||||
self.n = n
|
||||
self.d_t = d_t
|
||||
self.d_y = d_y
|
||||
self.d_x = d_x
|
||||
self.d_z = d_z
|
||||
|
||||
self.squeeze_T = squeeze_T
|
||||
self.squeeze_Y = squeeze_Y
|
||||
|
||||
self.nuisance_Y = nuisance_Y if nuisance_Y else lambda X: 0
|
||||
self.nuisance_T = nuisance_T if nuisance_T else lambda X: 0
|
||||
self.nuisance_TZ = nuisance_TZ if nuisance_TZ else lambda X: 0
|
||||
self.theta = theta if theta else lambda X: 1
|
||||
self.y_of_t = y_of_t if y_of_t else lambda X: 0
|
||||
|
||||
self.x_eps = x_eps
|
||||
self.y_eps = y_eps
|
||||
self.t_eps = t_eps
|
||||
|
||||
def gen_Y(self):
|
||||
noise = np.random.normal(size=(self.n, self.d_y), scale=self.y_eps)
|
||||
self.Y = self.theta(self.X) * self.y_of_t(self.T) + self.nuisance_Y(self.X) + noise
|
||||
return self.Y
|
||||
|
||||
def gen_X(self):
|
||||
self.X = np.random.normal(size=(self.n, self.d_x), scale=self.x_eps)
|
||||
return self.X
|
||||
|
||||
def gen_T(self):
|
||||
noise = np.random.normal(size=(self.n, self.d_t), scale=self.t_eps)
|
||||
self.T_noise = noise
|
||||
self.T = noise + self.nuisance_T(self.X) + self.nuisance_TZ(self.Z)
|
||||
return self.T
|
||||
|
||||
def gen_Z(self):
|
||||
if self.d_z:
|
||||
Z_noise = np.random.normal(size=(self.n, self.d_z), loc=3, scale=3)
|
||||
self.Z = Z_noise
|
||||
return self.Z
|
||||
|
||||
else:
|
||||
self.Z = None
|
||||
return self.Z
|
||||
|
||||
def gen_data(self):
|
||||
X = self.gen_X()
|
||||
Z = self.gen_Z()
|
||||
T = self.gen_T()
|
||||
Y = self.gen_Y()
|
||||
|
||||
if self.squeeze_T:
|
||||
T = T.squeeze()
|
||||
if self.squeeze_Y:
|
||||
Y = Y.squeeze()
|
||||
|
||||
data_dict = {
|
||||
'Y': Y,
|
||||
'T': T,
|
||||
'X': X
|
||||
}
|
||||
|
||||
if self.d_z:
|
||||
data_dict['Z'] = Z
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def actual_effect(y_of_t, T0, T1):
|
||||
return y_of_t(T1) - y_of_t(T0)
|
||||
|
||||
|
||||
def nuisance_T(X):
|
||||
return -0.3 * X[:, [1]]
|
||||
|
||||
|
||||
def nuisance_Y(X):
|
||||
return 0.2 * X[:, [0]]
|
||||
|
||||
|
||||
# identity featurization effect functions
|
||||
def identity_y_of_t(T):
|
||||
return T
|
||||
|
||||
|
||||
def identity_actual_marginal(T):
|
||||
return np.ones(shape=(T.shape))
|
||||
|
||||
|
||||
def identity_actual_cme():
|
||||
return 1
|
||||
|
||||
|
||||
identity_treatment_featurizer = FunctionTransformer()
|
||||
|
||||
|
||||
# polynomial featurization effect functions
|
||||
def poly_y_of_t(T):
|
||||
return 0.5 * T**2
|
||||
|
||||
|
||||
def poly_actual_marginal(t):
|
||||
return t
|
||||
|
||||
|
||||
def poly_actual_cme():
|
||||
return np.array([0, 0.5])
|
||||
|
||||
|
||||
def poly_func_transform(x):
|
||||
x = x.reshape(-1, 1)
|
||||
return np.hstack([x, x**2])
|
||||
|
||||
|
||||
polynomial_treatment_featurizer = FunctionTransformer(func=poly_func_transform)
|
||||
|
||||
|
||||
# 1d polynomial featurization functions
|
||||
def poly_1d_actual_cme():
|
||||
return 0.5
|
||||
|
||||
|
||||
def poly_1d_func_transform(x):
|
||||
return x**2
|
||||
|
||||
|
||||
polynomial_1d_treatment_featurizer = FunctionTransformer(func=poly_1d_func_transform)
|
||||
|
||||
|
||||
# 2d-to-1d featurization functions
|
||||
|
||||
def sum_y_of_t(T):
|
||||
return 0.5 * T.sum(axis=1, keepdims=True)
|
||||
|
||||
|
||||
def sum_actual_cme():
|
||||
return 0.5
|
||||
|
||||
|
||||
def sum_actual_marginal(t):
|
||||
return np.ones(shape=t.shape) * 0.5
|
||||
|
||||
|
||||
def sum_func_transform(x):
|
||||
return x.sum(axis=1, keepdims=True)
|
||||
|
||||
|
||||
sum_treatment_featurizer = FunctionTransformer(func=sum_func_transform)
|
||||
|
||||
|
||||
# 2d-to-1d vector featurization functions
|
||||
def sum_squeeze_func_transform(x):
|
||||
return x.sum(axis=1, keepdims=False)
|
||||
|
||||
|
||||
sum_squeeze_treatment_featurizer = FunctionTransformer(func=sum_squeeze_func_transform)
|
||||
|
||||
|
||||
@pytest.mark.treatment_featurization
|
||||
class TestTreatmentFeaturization(unittest.TestCase):
|
||||
|
||||
def test_featurization(self):
|
||||
identity_config = {
|
||||
'DGP_params': {
|
||||
'n': 2000,
|
||||
'd_t': 1,
|
||||
'd_y': 1,
|
||||
'd_x': 5,
|
||||
'squeeze_T': False,
|
||||
'squeeze_Y': False,
|
||||
'nuisance_Y': nuisance_Y,
|
||||
'nuisance_T': nuisance_T,
|
||||
'theta': None,
|
||||
'y_of_t': identity_y_of_t,
|
||||
'x_eps': 1,
|
||||
'y_eps': 1,
|
||||
't_eps': 1
|
||||
},
|
||||
|
||||
'treatment_featurizer': identity_treatment_featurizer,
|
||||
'actual_marginal': identity_actual_marginal,
|
||||
'actual_cme': identity_actual_cme,
|
||||
'squeeze_Ts': [False, True],
|
||||
'squeeze_Ys': [False, True],
|
||||
'est_dicts': [
|
||||
{'class': LinearDML, 'init_args': {}},
|
||||
{'class': CausalForestDML, 'init_args': {}},
|
||||
{'class': SparseLinearDML, 'init_args': {}},
|
||||
{'class': KernelDML, 'init_args': {}},
|
||||
]
|
||||
}
|
||||
|
||||
poly_config = {
|
||||
'DGP_params': {
|
||||
'n': 2000,
|
||||
'd_t': 1,
|
||||
'd_y': 1,
|
||||
'd_x': 5,
|
||||
'squeeze_T': False,
|
||||
'squeeze_Y': False,
|
||||
'nuisance_Y': nuisance_Y,
|
||||
'nuisance_T': nuisance_T,
|
||||
'theta': None,
|
||||
'y_of_t': poly_y_of_t,
|
||||
'x_eps': 1,
|
||||
'y_eps': 1,
|
||||
't_eps': 1
|
||||
},
|
||||
|
||||
'treatment_featurizer': polynomial_treatment_featurizer,
|
||||
'actual_marginal': poly_actual_marginal,
|
||||
'actual_cme': poly_actual_cme,
|
||||
'squeeze_Ts': [False, True],
|
||||
'squeeze_Ys': [False, True],
|
||||
'est_dicts': [
|
||||
{'class': LinearDML, 'init_args': {}},
|
||||
{'class': CausalForestDML, 'init_args': {}},
|
||||
{'class': SparseLinearDML, 'init_args': {}},
|
||||
{'class': KernelDML, 'init_args': {}},
|
||||
]
|
||||
}
|
||||
|
||||
poly_config_scikit = deepcopy(poly_config)
|
||||
poly_config_scikit['treatment_featurizer'] = PolynomialFeatures(degree=2, include_bias=False)
|
||||
poly_config_scikit['squeeze_Ts'] = [False]
|
||||
|
||||
poly_IV_config = deepcopy(poly_config)
|
||||
poly_IV_config['DGP_params']['d_z'] = 1
|
||||
poly_IV_config['DGP_params']['nuisance_TZ'] = lambda Z: Z
|
||||
poly_IV_config['est_dicts'] = [
|
||||
{'class': OrthoIV, 'init_args': {
|
||||
'model_t_xwz': RandomForestRegressor(random_state=1), 'projection': True}},
|
||||
{'class': DMLIV, 'init_args': {'model_t_xwz': RandomForestRegressor(random_state=1)}},
|
||||
]
|
||||
|
||||
poly_1d_config = deepcopy(poly_config)
|
||||
poly_1d_config['treatment_featurizer'] = polynomial_1d_treatment_featurizer
|
||||
poly_1d_config['actual_cme'] = poly_1d_actual_cme
|
||||
poly_1d_config['est_dicts'].append({
|
||||
'class': NonParamDML,
|
||||
'init_args': {
|
||||
'model_y': LinearRegression(),
|
||||
'model_t': LinearRegression(),
|
||||
'model_final': StatsModelsLinearRegression()}})
|
||||
|
||||
poly_1d_IV_config = deepcopy(poly_IV_config)
|
||||
poly_1d_IV_config['treatment_featurizer'] = polynomial_1d_treatment_featurizer
|
||||
poly_1d_IV_config['actual_cme'] = poly_1d_actual_cme
|
||||
poly_1d_IV_config['est_dicts'] = [
|
||||
{'class': NonParamDMLIV, 'init_args': {'model_final': StatsModelsLinearRegression()}},
|
||||
{'class': DRIV, 'init_args': {'fit_cate_intercept': True}},
|
||||
{'class': LinearDRIV, 'init_args': {}},
|
||||
{'class': SparseLinearDRIV, 'init_args': {}},
|
||||
{'class': ForestDRIV, 'init_args': {}},
|
||||
]
|
||||
|
||||
sum_IV_config = {
|
||||
'DGP_params': {
|
||||
'n': 2000,
|
||||
'd_t': 2,
|
||||
'd_y': 1,
|
||||
'd_x': 5,
|
||||
'd_z': 1,
|
||||
'squeeze_T': False,
|
||||
'squeeze_Y': False,
|
||||
'nuisance_Y': nuisance_Y,
|
||||
'nuisance_T': nuisance_T,
|
||||
'nuisance_TZ': lambda Z: Z,
|
||||
'theta': None,
|
||||
'y_of_t': sum_y_of_t,
|
||||
'x_eps': 1,
|
||||
'y_eps': 1,
|
||||
't_eps': 1
|
||||
},
|
||||
|
||||
'treatment_featurizer': sum_treatment_featurizer,
|
||||
'actual_marginal': sum_actual_marginal,
|
||||
'actual_cme': sum_actual_cme,
|
||||
'squeeze_Ts': [False],
|
||||
'squeeze_Ys': [False, True],
|
||||
'est_dicts': [
|
||||
{'class': NonParamDMLIV, 'init_args': {'model_final': StatsModelsLinearRegression()}},
|
||||
{'class': DRIV, 'init_args': {'fit_cate_intercept': True}},
|
||||
{'class': LinearDRIV, 'init_args': {}},
|
||||
{'class': SparseLinearDRIV, 'init_args': {}},
|
||||
{'class': ForestDRIV, 'init_args': {}},
|
||||
]
|
||||
}
|
||||
|
||||
sum_squeeze_IV_config = deepcopy(sum_IV_config)
|
||||
sum_squeeze_IV_config['treatment_featurizer'] = sum_squeeze_treatment_featurizer
|
||||
|
||||
sum_config = deepcopy(sum_IV_config)
|
||||
sum_config['DGP_params']['d_z'] = None
|
||||
sum_config['DGP_params']['nuisance_TZ'] = None
|
||||
sum_config['est_dicts'] = deepcopy(poly_1d_config['est_dicts'])
|
||||
|
||||
sum_squeeze_config = deepcopy(sum_config)
|
||||
sum_squeeze_config['treatment_featurizer'] = sum_squeeze_treatment_featurizer
|
||||
|
||||
configs = [
|
||||
identity_config,
|
||||
poly_config,
|
||||
poly_config_scikit,
|
||||
poly_IV_config,
|
||||
poly_1d_config,
|
||||
poly_1d_IV_config,
|
||||
sum_IV_config,
|
||||
sum_squeeze_IV_config,
|
||||
sum_config,
|
||||
sum_squeeze_config
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
for squeeze_Y in config['squeeze_Ys']:
|
||||
for squeeze_T in config['squeeze_Ts']:
|
||||
config['DGP_params']['squeeze_Y'] = squeeze_Y
|
||||
config['DGP_params']['squeeze_T'] = squeeze_T
|
||||
dgp = DGP(**config['DGP_params'])
|
||||
data_dict = dgp.gen_data()
|
||||
Y = data_dict['Y']
|
||||
T = data_dict['T']
|
||||
X = data_dict['X']
|
||||
feat_T = config['treatment_featurizer'].fit_transform(T)
|
||||
|
||||
data_dict_outside_feat = deepcopy(data_dict)
|
||||
data_dict_outside_feat['T'] = feat_T
|
||||
|
||||
est_dicts = config['est_dicts']
|
||||
|
||||
for est_dict in est_dicts:
|
||||
estClass = est_dict['class']
|
||||
init_args = deepcopy(est_dict['init_args'])
|
||||
init_args['treatment_featurizer'] = config['treatment_featurizer']
|
||||
init_args['random_state'] = 1
|
||||
|
||||
est = estClass(**init_args)
|
||||
est.fit(**data_dict)
|
||||
|
||||
init_args_outside_feat = deepcopy(est_dict['init_args'])
|
||||
init_args_outside_feat['random_state'] = 1
|
||||
est_outside_feat = estClass(**init_args_outside_feat)
|
||||
est_outside_feat.fit(**data_dict_outside_feat)
|
||||
|
||||
# test that treatment names are assigned for the featurized treatment
|
||||
assert (est.cate_treatment_names() is not None)
|
||||
|
||||
if hasattr(est, 'summary'):
|
||||
est.summary()
|
||||
|
||||
# expected shapes
|
||||
expected_eff_shape = (config['DGP_params']['n'],) + Y.shape[1:]
|
||||
expected_cme_shape = (config['DGP_params']['n'],) + Y.shape[1:] + feat_T.shape[1:]
|
||||
expected_me_shape = (config['DGP_params']['n'],) + Y.shape[1:] + T.shape[1:]
|
||||
expected_marginal_ate_shape = expected_me_shape[1:]
|
||||
|
||||
# check effects
|
||||
T0 = np.ones(shape=T.shape) * 5
|
||||
T1 = np.ones(shape=T.shape) * 10
|
||||
eff = est.effect(X=X, T0=T0, T1=T1)
|
||||
assert (eff.shape == expected_eff_shape)
|
||||
outside_feat = config['treatment_featurizer']
|
||||
eff_outside_feat = est_outside_feat.effect(
|
||||
X=X, T0=outside_feat.fit_transform(T0), T1=outside_feat.fit_transform(T1))
|
||||
np.testing.assert_almost_equal(eff, eff_outside_feat)
|
||||
actual_eff = actual_effect(config['DGP_params']['y_of_t'], T0, T1)
|
||||
|
||||
cme = est.const_marginal_effect(X=X)
|
||||
assert (cme.shape == expected_cme_shape)
|
||||
cme_outside_feat = est_outside_feat.const_marginal_effect(X=X)
|
||||
np.testing.assert_almost_equal(cme, cme_outside_feat)
|
||||
actual_cme = config['actual_cme']()
|
||||
|
||||
me = est.marginal_effect(T=T, X=X)
|
||||
assert (me.shape == expected_me_shape)
|
||||
actual_me = config['actual_marginal'](T).reshape(me.shape)
|
||||
|
||||
# ate
|
||||
m_ate = est.marginal_ate(T, X=X)
|
||||
assert (m_ate.shape == expected_marginal_ate_shape)
|
||||
|
||||
if isinstance(est, (LinearDML, SparseLinearDML, LinearDRIV, SparseLinearDRIV)):
|
||||
d_f_t = feat_T.shape[1] if feat_T.shape[1:] else 1
|
||||
expected_coef_inference_shape = (
|
||||
config['DGP_params']['d_y'] * config['DGP_params']['d_x'] * d_f_t, 6)
|
||||
assert est.coef__inference().summary_frame().shape == expected_coef_inference_shape
|
||||
|
||||
expected_intercept_inf_shape = (
|
||||
config['DGP_params']['d_y'] * d_f_t, 6)
|
||||
assert est.intercept__inference().summary_frame().shape == expected_intercept_inf_shape
|
||||
|
||||
# loose inference checks
|
||||
# temporarily skip LinearDRIV and SparseLinearDRIV for weird effect shape reasons
|
||||
if isinstance(est, (KernelDML, LinearDRIV, SparseLinearDRIV)):
|
||||
continue
|
||||
|
||||
if est._inference is None:
|
||||
continue
|
||||
|
||||
# effect inference
|
||||
eff_inf = est.effect_inference(X=X, T0=T0, T1=T1)
|
||||
eff_lb, eff_ub = eff_inf.conf_int(alpha=0.01)
|
||||
assert (eff.shape == eff_lb.shape)
|
||||
proportion_in_interval = ((eff_lb < actual_eff) & (actual_eff < eff_ub)).mean()
|
||||
np.testing.assert_array_less(0.50, proportion_in_interval)
|
||||
np.testing.assert_almost_equal(eff, eff_inf.point_estimate)
|
||||
|
||||
# marginal effect inference
|
||||
me_inf = est.marginal_effect_inference(T, X=X)
|
||||
me_lb, me_ub = me_inf.conf_int(alpha=0.01)
|
||||
assert (me.shape == me_lb.shape)
|
||||
proportion_in_interval = ((me_lb < actual_me) & (actual_me < me_ub)).mean()
|
||||
np.testing.assert_array_less(0.50, proportion_in_interval)
|
||||
np.testing.assert_almost_equal(me, me_inf.point_estimate)
|
||||
|
||||
# const marginal effect inference
|
||||
cme_inf = est.const_marginal_effect_inference(X=X)
|
||||
cme_lb, cme_ub = cme_inf.conf_int(alpha=0.01)
|
||||
assert (cme.shape == cme_lb.shape)
|
||||
proportion_in_interval = ((cme_lb < actual_cme) & (actual_cme < cme_ub)).mean()
|
||||
np.testing.assert_array_less(0.50, proportion_in_interval)
|
||||
np.testing.assert_almost_equal(cme, cme_inf.point_estimate)
|
||||
|
||||
def test_jac(self):
|
||||
def func_transform(x):
|
||||
x = x.reshape(-1, 1)
|
||||
return np.hstack([x, x**2])
|
||||
|
||||
def calc_expected_jacobian(T):
|
||||
jac = DPolynomialFeatures(degree=2, include_bias=False).fit_transform(T)
|
||||
return jac
|
||||
|
||||
treatment_featurizers = [
|
||||
PolynomialFeatures(degree=2, include_bias=False),
|
||||
FunctionTransformer(func=func_transform)
|
||||
]
|
||||
|
||||
n = 10000
|
||||
d_t = 1
|
||||
T = np.random.normal(size=(n, d_t))
|
||||
|
||||
for treatment_featurizer in treatment_featurizers:
|
||||
# fit a dummy estimator first so the featurizer can be fit to the treatment
|
||||
dummy_est = LinearDML(treatment_featurizer=treatment_featurizer)
|
||||
dummy_est.fit(Y=T, T=T, X=T)
|
||||
expected_jac = calc_expected_jacobian(T)
|
||||
jac_T = dummy_est.transformer.jac(T)
|
||||
np.testing.assert_almost_equal(jac_T, expected_jac)
|
||||
|
||||
def test_fail_discrete_treatment_and_treatment_featurizer(self):
|
||||
class OrthoLearner(_OrthoLearner):
|
||||
def _gen_ortho_learner_model_nuisance(self):
|
||||
pass
|
||||
|
||||
def _gen_ortho_learner_model_final(self):
|
||||
pass
|
||||
|
||||
est_and_params = [
|
||||
{
|
||||
'estimator': OrthoLearner,
|
||||
'params': {
|
||||
'cv': 2,
|
||||
'discrete_treatment': False,
|
||||
'treatment_featurizer': None,
|
||||
'discrete_instrument': False,
|
||||
'categories': 'auto',
|
||||
'random_state': None
|
||||
}
|
||||
},
|
||||
{'estimator': LinearDML, 'params': {}},
|
||||
{'estimator': CausalForestDML, 'params': {}},
|
||||
{'estimator': SparseLinearDML, 'params': {}},
|
||||
{'estimator': KernelDML, 'params': {}},
|
||||
{'estimator': DMLOrthoForest, 'params': {}}
|
||||
|
||||
]
|
||||
|
||||
dummy_vec = np.random.normal(size=(100, 1))
|
||||
|
||||
for est_and_param in est_and_params:
|
||||
params = est_and_param['params']
|
||||
params['discrete_treatment'] = True
|
||||
params['treatment_featurizer'] = True
|
||||
est = est_and_param['estimator'](**params)
|
||||
with self.assertRaises(AssertionError, msg='Estimator fit did not fail when passed '
|
||||
'both discrete treatment and treatment featurizer'):
|
||||
est.fit(Y=dummy_vec, T=dummy_vec, X=dummy_vec)
|
||||
|
||||
def test_cate_treatment_names_edge_cases(self):
|
||||
Y = np.random.normal(size=(100, 1))
|
||||
T = np.random.binomial(n=2, p=0.5, size=(100, 1))
|
||||
X = np.random.normal(size=(100, 3))
|
||||
|
||||
# edge case with transformer that only takes a vector treatment
|
||||
# so far will always return None for cate_treatment_names
|
||||
def weird_func(x):
|
||||
assert np.ndim(x) == 1
|
||||
return x
|
||||
est = LinearDML(treatment_featurizer=FunctionTransformer(weird_func)).fit(Y=Y, T=T.squeeze(), X=X)
|
||||
assert est.cate_treatment_names() is None
|
||||
assert est.cate_treatment_names(['too', 'many', 'feature_names']) is None
|
||||
|
||||
# assert proper handling of improper feature names passed to certain transformers
|
||||
est = LinearDML(discrete_treatment=True).fit(Y=Y, T=T, X=X)
|
||||
assert est.cate_treatment_names() == ['T0_1', 'T0_2']
|
||||
assert est.cate_treatment_names(['too', 'many', 'feature_names']) is None
|
||||
|
||||
est = LinearDML(treatment_featurizer=PolynomialFeatures(degree=2, include_bias=False)).fit(Y=Y, T=T, X=X)
|
||||
assert est.cate_treatment_names() == ['T0', 'T0^2']
|
||||
# depending on sklearn version, bad feature names either throws error or only uses first relevant name
|
||||
assert est.cate_treatment_names(['too', 'many', 'feature_names']) in [None, ['too', 'too^2']]
|
||||
|
||||
def test_alpha_passthrough(self):
|
||||
X = np.random.normal(size=(100, 3))
|
||||
T = np.random.normal(size=(100, 1)) + X[:, [0]]
|
||||
Y = np.random.normal(size=(100, 1)) + T + X[:, [0]]
|
||||
|
||||
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(),
|
||||
treatment_featurizer=FunctionTransformer())
|
||||
est.fit(Y=Y, T=T, X=X)
|
||||
|
||||
# ensure alpha is passed
|
||||
lb, ub = est.marginal_effect_interval(T, X, alpha=1)
|
||||
assert (lb == ub).all()
|
||||
|
||||
lb, ub = est.marginal_effect_interval(T, X)
|
||||
assert (lb != ub).all()
|
||||
|
||||
lb1, ub1 = est.marginal_effect_interval(T, X, alpha=0.01)
|
||||
lb2, ub2 = est.marginal_effect_interval(T, X, alpha=0.1)
|
||||
|
||||
assert (lb1 < lb2).all() and (ub1 > ub2).all()
|
||||
|
||||
def test_identity_feat_with_cate_api(self):
|
||||
treatment_featurizations = [FunctionTransformer()]
|
||||
TestDML()._test_cate_api(treatment_featurizations)
|
|
@ -9,6 +9,7 @@ import scipy.sparse
|
|||
import sparse as sp
|
||||
import itertools
|
||||
import inspect
|
||||
import types
|
||||
from operator import getitem
|
||||
from collections import defaultdict, Counter
|
||||
from sklearn import clone
|
||||
|
@ -17,6 +18,7 @@ from sklearn.linear_model import LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLass
|
|||
from functools import reduce, wraps
|
||||
from sklearn.utils import check_array, check_X_y
|
||||
from sklearn.utils.validation import assert_all_finite
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
import warnings
|
||||
from warnings import warn
|
||||
from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold
|
||||
|
@ -46,7 +48,7 @@ class IdentityFeatures(TransformerMixin):
|
|||
|
||||
def parse_final_model_params(coef, intercept, d_y, d_t, d_t_in, bias_part_of_coef, fit_cate_intercept):
|
||||
dt = d_t
|
||||
if (d_t_in != d_t) and (d_t[0] == 1): # binary treatment
|
||||
if (d_t_in != d_t) and (d_t and d_t[0] == 1): # binary treatment or single dim featurized treatment
|
||||
dt = ()
|
||||
cate_intercept = None
|
||||
if bias_part_of_coef:
|
||||
|
@ -598,12 +600,41 @@ def get_input_columns(X, prefix="X"):
|
|||
pd.Series: lambda x: [x.name]
|
||||
}
|
||||
if type(X) in type_to_func:
|
||||
return type_to_func[type(X)](X)
|
||||
column_names = type_to_func[type(X)](X)
|
||||
|
||||
# if not all column names are strings
|
||||
if not all(isinstance(item, str) for item in column_names):
|
||||
warnings.warn("Not all column names are strings. Coercing to strings for now.", UserWarning)
|
||||
|
||||
return [str(item) for item in column_names]
|
||||
|
||||
len_X = 1 if np.ndim(X) == 1 else np.asarray(X).shape[1]
|
||||
return [f"{prefix}{i}" for i in range(len_X)]
|
||||
|
||||
|
||||
def get_feature_names_or_default(featurizer, feature_names):
|
||||
def get_feature_names_or_default(featurizer, feature_names, prefix="feat(X)"):
|
||||
"""
|
||||
Extract feature names from sklearn transformers. Otherwise attempts to assign default feature names.
|
||||
|
||||
Designed to be compatible with old and new sklearn versions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
featurizer : featurizer to extract feature names from
|
||||
feature_names : input features
|
||||
prefix : output prefix in the event where we assign default feature names
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature_names_out : a list of strings (feature names)
|
||||
|
||||
"""
|
||||
|
||||
# coerce feature names to be strings
|
||||
if not all(isinstance(item, str) for item in feature_names):
|
||||
warnings.warn("Not all feature names are strings. Coercing to strings for now.", UserWarning)
|
||||
feature_names = [str(item) for item in feature_names]
|
||||
|
||||
# Prefer sklearn 1.0's get_feature_names_out method to deprecated get_feature_names method
|
||||
if hasattr(featurizer, "get_feature_names_out"):
|
||||
try:
|
||||
|
@ -612,17 +643,21 @@ def get_feature_names_or_default(featurizer, feature_names):
|
|||
# Some featurizers will throw, such as a pipeline with a transformer that doesn't itself support names
|
||||
pass
|
||||
if hasattr(featurizer, 'get_feature_names'):
|
||||
# Get number of arguments, some sklearn featurizer don't accept feature_names
|
||||
arg_no = len(inspect.getfullargspec(featurizer.get_feature_names).args)
|
||||
if arg_no == 1:
|
||||
return featurizer.get_feature_names()
|
||||
elif arg_no == 2:
|
||||
return featurizer.get_feature_names(feature_names)
|
||||
try:
|
||||
# Get number of arguments, some sklearn featurizer don't accept feature_names
|
||||
arg_no = len(inspect.getfullargspec(featurizer.get_feature_names).args)
|
||||
if arg_no == 1:
|
||||
return featurizer.get_feature_names()
|
||||
elif arg_no == 2:
|
||||
return featurizer.get_feature_names(feature_names)
|
||||
except Exception:
|
||||
# Handles cases where the passed feature names create issues
|
||||
pass
|
||||
# Featurizer doesn't have 'get_feature_names' or has atypical 'get_feature_names'
|
||||
try:
|
||||
# Get feature names using featurizer
|
||||
dummy_X = np.ones((1, len(feature_names)))
|
||||
return get_input_columns(featurizer.transform(dummy_X), prefix="feat(X)")
|
||||
return get_input_columns(featurizer.transform(dummy_X), prefix=prefix)
|
||||
except Exception:
|
||||
# All attempts at retrieving transformed feature names have failed
|
||||
# Delegate handling to downstream logic
|
||||
|
@ -1391,6 +1426,85 @@ class _RegressionWrapper:
|
|||
return self._clf.predict_proba(X)[:, 1:]
|
||||
|
||||
|
||||
class _TransformerWrapper:
|
||||
"""Wrapper that takes a featurizer as input and adds jacobian calculation functionality"""
|
||||
|
||||
def __init__(self, featurizer):
|
||||
self.featurizer = featurizer
|
||||
|
||||
def fit(self, X):
|
||||
return self.featurizer.fit(X)
|
||||
|
||||
def transform(self, X):
|
||||
return self.featurizer.transform(X)
|
||||
|
||||
def fit_transform(self, X):
|
||||
return self.featurizer.fit_transform(X)
|
||||
|
||||
def get_feature_names_out(self, feature_names):
|
||||
return get_feature_names_or_default(self.featurizer, feature_names, prefix="feat(T)")
|
||||
|
||||
def jac(self, X, epsilon=0.001):
|
||||
if hasattr(self.featurizer, 'jac'):
|
||||
return self.featurizer.jac(X)
|
||||
elif (isinstance(self.featurizer, PolynomialFeatures)):
|
||||
powers = self.featurizer.powers_
|
||||
result = np.zeros(X.shape + (self.featurizer.n_output_features_,))
|
||||
for i in range(X.shape[1]):
|
||||
p = powers.copy()
|
||||
c = powers[:, i]
|
||||
p[:, i] -= 1
|
||||
M = np.float_power(X[:, np.newaxis, :], p[np.newaxis, :, :])
|
||||
result[:, i, :] = c[np.newaxis, :] * np.prod(M, axis=-1)
|
||||
return result
|
||||
|
||||
else:
|
||||
squeeze = []
|
||||
|
||||
n = X.shape[0]
|
||||
d_t = X.shape[-1] if ndim(X) > 1 else 1
|
||||
X_out = self.transform(X)
|
||||
d_f_t = X_out.shape[-1] if ndim(X_out) > 1 else 1
|
||||
|
||||
jacob = np.zeros((n, d_t, d_f_t))
|
||||
|
||||
if ndim(X) == 1:
|
||||
squeeze.append(1)
|
||||
X = X[:, np.newaxis]
|
||||
if ndim(X_out) == 1:
|
||||
squeeze.append(2)
|
||||
|
||||
# for every dimension of the treatment add some epsilon and observe change in featurized treatment
|
||||
for k in range(d_t):
|
||||
eps_matrix = np.zeros(shape=X.shape)
|
||||
eps_matrix[:, k] = epsilon
|
||||
|
||||
X_in_plus = X + eps_matrix
|
||||
X_in_plus = X_in_plus.squeeze(axis=1) if 1 in squeeze else X_in_plus
|
||||
X_out_plus = self.transform(X_in_plus)
|
||||
X_out_plus = X_out_plus[:, np.newaxis] if 2 in squeeze else X_out_plus
|
||||
|
||||
X_in_minus = X - eps_matrix
|
||||
X_in_minus = X_in_minus.squeeze(axis=1) if 1 in squeeze else X_in_minus
|
||||
X_out_minus = self.transform(X_in_minus)
|
||||
X_out_minus = X_out_minus[:, np.newaxis] if 2 in squeeze else X_out_minus
|
||||
|
||||
diff = X_out_plus - X_out_minus
|
||||
deriv = diff / (2 * epsilon)
|
||||
|
||||
jacob[:, k, :] = deriv
|
||||
|
||||
return jacob.squeeze(axis=tuple(squeeze))
|
||||
|
||||
|
||||
def jacify_featurizer(featurizer):
|
||||
"""
|
||||
Function that takes a featurizer as input and returns a wrapper class that includes
|
||||
a function for calculating the jacobian
|
||||
"""
|
||||
return _TransformerWrapper(featurizer)
|
||||
|
||||
|
||||
@deprecated("This class will be removed from a future version of this package; "
|
||||
"please use econml.sklearn_extensions.linear_model.WeightedLassoCV instead.")
|
||||
class LassoCVWrapper:
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -16,5 +16,6 @@ markers = [
|
|||
"automl",
|
||||
"dml",
|
||||
"serial",
|
||||
"cate_api"
|
||||
"cate_api",
|
||||
"treatment_featurization"
|
||||
]
|
Загрузка…
Ссылка в новой задаче