Cython implementation of GRF and CausalForestDML (#341)

* added backend option in orf, adding verbosity, restructuring static functions

* added cython grf module that implements generalized random forests

* added cuthon version of causal forest and causal forest dml

* deprecating older CausalForest

* updates to CF and ORF notebook

* restructured dml into folder. Deprecated ForestDML by CausalForestDML. 

* Removed two legacy files in our main folder.

* deprecating ensemble.SubsampledHonestForest

* made drlearner use the non dprecated regression forest. 

* Enable setuptools build process

* fixed flaky random_state test

* fixed tests and api consistency

* updated tables and library flow chart

* enforce sklearn 0.24.

* fixed _cross_val_predict

* added option for max background samples to shap make computation more reasonable

* fixed error_score param in gcvlist due to sklearn upgrade

* added shap cells in DML notebook

* added shap values to GRF notebook

* fixed bug in the way input_feature_names where used in summary. enabled shap to use input featurenames

* updated readme. removed autoreload from noteoboks

* added shap specific notebook

* updated dowhy notebook
This commit is contained in:
vsyrgkanis 2021-01-08 22:29:56 -05:00 коммит произвёл GitHub
Родитель 3df959d120
Коммит bb042d541d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
85 изменённых файлов: 14882 добавлений и 3994 удалений

37
LICENSE
Просмотреть файл

@ -19,3 +19,40 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
Parts of this software, in particular code contained in the modules econml.tree and
econml.grf contain files that are forks from the scikit-learn git repository, or code
snippets from that repository:
https://github.com/scikit-learn/scikit-learn
published under the following License.
BSD 3-Clause License
Copyright (c) 2007-2020 The scikit-learn developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

Просмотреть файл

@ -118,20 +118,7 @@ To install from source, see [For Developers](#for-developers) section below.
treatment_effects = est.effect(X_test)
lb, ub = est.effect_interval(X_test, alpha=0.05) # Confidence intervals via debiased lasso
```
* Forest last stage
```Python
from econml.dml import ForestDML
from sklearn.ensemble import GradientBoostingRegressor
est = ForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
lb, ub = est.effect_interval(X_test, alpha=0.05)
```
* Generic Machine Learning last stage
```Python
@ -152,16 +139,16 @@ To install from source, see [For Developers](#for-developers) section below.
<summary>Causal Forests (click to expand)</summary>
```Python
from econml.causal_forest import CausalForest
from econml.dml import CausalForestDML
from sklearn.linear_model import LassoCV
# Use defaults
est = CausalForest()
est = CausalForestDML()
# Or specify hyperparameters
est = CausalForest(n_trees=500, min_leaf_size=10,
max_depth=10, subsample_ratio=0.7,
lambda_reg=0.01,
discrete_treatment=False,
model_T=LassoCV(), model_Y=LassoCV())
est = CausalForestDML(criterion='het', n_estimators=500,
min_samples_leaf=10,
max_depth=10, max_samples=0.5,
discrete_treatment=False,
model_t=LassoCV(), model_y=LassoCV())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
@ -354,7 +341,7 @@ treatment_effects = est.effect(X_test)
<details>
<summary>Policy Interpreter of the CATE model (click to expand)</summary>
```Python
from econml.cate_interpreter import SingleTreePolicyInterpreter
# We find a tree-based treatment policy based on the CATE model
@ -366,7 +353,21 @@ treatment_effects = est.effect(X_test)
plt.show()
```
![image](notebooks/images/dr_policy_tree.png)
</details>
<details>
<summary>SHAP values for the CATE model (click to expand)</summary>
```Python
import shap
from econml.dml import CausalForestDML
est = CausalForestDML()
est.fit(Y, T, X=X, W=W)
shap_values = est.shap_values(X)
shap.summary_plot(shap_values['Y0']['T0'])
```
</details>
### Inference

Просмотреть файл

@ -5,7 +5,7 @@
parameters:
body: []
package: '.'
package: '-e .'
steps:
- task: UsePythonVersion@0
@ -24,7 +24,7 @@ steps:
condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux'))
# Install the package
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel && pip install ${{ parameters.package }}'
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel Cython && pip install ${{ parameters.package }}'
displayName: 'Install dependencies'
- ${{ parameters.body }}

Просмотреть файл

@ -69,9 +69,6 @@ jobs:
- script: 'pip install --force-reinstall --no-cache-dir shap'
displayName: 'Install public shap'
- script: 'pip install --force-reinstall scikit-learn==0.23.2'
displayName: 'Install public old sklearn'
- script: 'python setup.py build_sphinx -W'
displayName: 'Build documentation'
@ -81,7 +78,7 @@ jobs:
- script: 'python setup.py build_sphinx -b doctest'
displayName: 'Run doctests'
package: '.[automl]'
package: '-e .[automl]'
- job: 'Notebooks'
dependsOn: 'EvalChanges'

Просмотреть файл

@ -211,7 +211,7 @@ epub_exclude_files = ['search.html']
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/0.23/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'matplotlib': ('https://matplotlib.org/', None)}
# -- Options for todo extension ----------------------------------------------

Просмотреть файл

@ -171,8 +171,8 @@
<tspan x="31.92" y="17">Experimentation</tspan>
</text>
<rect x="470" y="461" width="226" height="62" stroke="#FFFFFF" stroke-width="2" stroke-miterlimit="8" fill="#70AD47"/>
<a href="_autosummary/econml.dml.html#econml.dml.SparseLinearDMLCateEstimator" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(489.258 488)">SparseLinearDMLCateEstimator</text>
<a href="_autosummary/econml.dml.html#econml.dml.SparseLinearDML" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(530 488)">SparseLinearDML</text>
</a>
<path d="M489.012 490.461 536.012 490.461 583.012 490.461 630.012 490.461 677.012 490.461 677.012 491.461 630.012 491.461 583.012 491.461 536.012 491.461 489.012 491.461Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.drlearner.html#econml.drlearner.SparseLinearDRLearner" target="_parent">
@ -180,8 +180,8 @@
</a>
<path d="M514.012 507.461 560.345 507.461 606.678 507.461 653.012 507.461 653.012 508.461 606.678 508.461 560.345 508.461 514.012 508.461Z" fill="#FFFFFF" fill-rule="evenodd"/>
<rect x="471" y="549" width="226" height="51" stroke="#FFFFFF" stroke-width="2" stroke-miterlimit="8" fill="#70AD47"/>
<a href="_autosummary/econml.dml.html#econml.dml.LinearDMLCateEstimator" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(509.292 571)">LinearDMLCateEstimator</text>
<a href="_autosummary/econml.dml.html#econml.dml.LinearDML" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(550 571)">LinearDML</text>
</a>
<path d="M509.292 573.095 558.959 573.095 608.626 573.095 658.292 573.095 658.292 574.095 608.626 574.095 558.959 574.095 509.292 574.095Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.drlearner.html#econml.drlearner.LinearDRLearner" target="_parent">
@ -189,20 +189,20 @@
</a>
<path d="M534.292 590.095 583.792 590.095 633.292 590.095 633.292 591.095 583.792 591.095 534.292 591.095Z" fill="#FFFFFF" fill-rule="evenodd"/>
<rect x="471" y="614" width="226" height="90" stroke="#FFFFFF" stroke-width="2" stroke-miterlimit="8" fill="#70AD47"/>
<a href="_autosummary/econml.ortho_forest.html#econml.ortho_forest.ContinuousTreatmentOrthoForest" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(482.126 638)">ContinuousTreatmentOrthoForest</text>
<a href="_autosummary/econml.ortho_forest.html#econml.ortho_forest.DMLOrthoForest" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(535 638)">DMLOrthoForest</text>
</a>
<path d="M482.292 639.866 533.042 639.866 583.792 639.866 634.542 639.866 685.292 639.866 685.292 640.866 634.542 640.866 583.792 640.866 533.042 640.866 482.292 640.866Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.ortho_forest.html#econml.ortho_forest.DiscreteTreatmentOrthoForest" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(491.539 655)">DiscreteTreatmentOrthoForest</text>
<a href="_autosummary/econml.ortho_forest.html#econml.ortho_forest.DROrthoForest" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(540 655)">DROrthoForest</text>
</a>
<path d="M491.292 656.866 537.292 656.866 583.292 656.866 629.292 656.866 675.292 656.866 675.292 657.866 629.292 657.866 583.292 657.866 537.292 657.866 491.292 657.866Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.drlearner.html#econml.drlearner.ForestDRLearner" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(533.539 673)">ForestDRLearner</text>
</a>
<path d="M533.292 674.866 583.292 674.866 633.292 674.866 633.292 675.866 583.292 675.866 533.292 675.866Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.dml.html#econml.dml.ForestDMLCateEstimator" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(508.959 691)">ForestDMLCateEstimator</text>
<a href="_autosummary/econml.dml.html#econml.dml.CausalForestDML" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(531 691)">CausalForestDML</text>
</a>
<path d="M509.292 692.866 558.959 692.866 608.626 692.866 658.292 692.866 658.292 693.866 608.626 693.866 558.959 693.866 509.292 693.866Z" fill="#FFFFFF" fill-rule="evenodd"/>
<rect x="471" y="716" width="226" height="131" stroke="#FFFFFF" stroke-width="2" stroke-miterlimit="8" fill="#70AD47"/>
@ -214,12 +214,12 @@
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(552.209 769)">DRLearner</text>
</a>
<path d="M552.296 770.513 583.796 770.513 615.296 770.513 615.296 771.513 583.796 771.513 552.296 771.513Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.dml.html#econml.dml.DMLCateEstimator" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(527.629 787)">DMLCateEstimator</text>
<a href="_autosummary/econml.dml.html#econml.dml.DML" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(570 787)">DML</text>
</a>
<path d="M527.296 788.513 564.629 788.513 601.962 788.513 639.296 788.513 639.296 789.513 601.962 789.513 564.629 789.513 527.296 789.513Z" fill="#FFFFFF" fill-rule="evenodd"/>
<a href="_autosummary/econml.dml.html#econml.dml.NonParamDMLCateEstimator" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(496.042 805)">NonParamDMLCateEstimator</text>
<a href="_autosummary/econml.dml.html#econml.dml.NonParamDML" target="_parent">
<text fill="#FFFFFF" font-family="Calibri,Calibri_MSFontService,sans-serif" font-weight="400" font-size="15" transform="translate(536 805)">NonParamDML</text>
</a>
<path d="M496.296 806.513 540.046 806.513 583.796 806.513 627.546 806.513 671.296 806.513 671.296 807.513 627.546 807.513 583.796 807.513 540.046 807.513 496.296 807.513Z" fill="#FFFFFF" fill-rule="evenodd"/>
<text fill="#C5E0B4" font-family="Calibri,Calibri_MSFontService,sans-serif" font-style="italic" font-weight="400" font-size="15" transform="translate(499.796 822)">

До

Ширина:  |  Высота:  |  Размер: 56 KiB

После

Ширина:  |  Высота:  |  Размер: 56 KiB

Просмотреть файл

@ -5,14 +5,11 @@ Public Module Reference
:toctree: _autosummary
econml.bootstrap
econml.cate_estimator
econml.cate_interpreter
econml.causal_forest
econml.causal_tree
econml.deepiv
econml.dgp
econml.dml
econml.drlearner
econml.grf
econml.inference
econml.metalearners
econml.ortho_forest
@ -27,7 +24,12 @@ Private Module Reference
:toctree: _autosummary
econml._ortho_learner
econml._rlearner
econml._cate_estimator
econml._causal_tree
econml.dml._rlearner
econml.grf._base_grf
econml.grf._base_grftree
econml.grf._criterion
Scikit-Learn Extensions
=======================
@ -37,4 +39,3 @@ Scikit-Learn Extensions
econml.sklearn_extensions.linear_model
econml.sklearn_extensions.model_selection
econml.sklearn_extensions.ensemble

Просмотреть файл

@ -27,10 +27,6 @@ The latter translates to estimating a local gradient around a treatment vector c
\partial\tau(\vec{t}, \vec{x}) = \E\left[\nabla_{\vec{t}} Y(\vec{t}) | X=\vec{x}\right] \tag{marginal CATE}
We will refer to the latter as the *heterogeneous marginal effect*. [1]_
Finally, we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:
.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}
We assume we have data that are generated from some collection policy. In particular, we assume that we have data of the form:
:math:`\{Y_i(T_i), T_i, X_i, W_i, Z_i\}`, where :math:`Y_i(T_i)` is the observed outcome for the chosen treatment,
@ -43,6 +39,19 @@ The variables :math:`X_i` can also be thought of as *control* variables, but the
they are a subset of the controls with respect to which we want to measure treatment effect heterogeneity.
We will refer to them as *features*.
Finally, some times we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:
.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}
Our package does not offer support for counterfactual prediction. However, for most of our estimators (the ones
assuming a linear-in-treatment model), counterfactual prediction can be easily constructed by combining any baseline predictive model
with our causal effect model, i.e. train any machine learning model :math:`b(\vec{t}, \vec{x})` to solve the regression/classification
problem :math:`\E[Y | T=\vec{t}, X=\vec{x}]`, and then set :math:`\mu(vec{t}, \vec{x}) = \tau(\vec{t}, T, \vec{x}) + b(T, \vec{x})`,
where :math:`T` is either the observed treatment for that sample under the observational policy or the treatment
that the observational policy would have assigned to that sample. These auxiliary ML models can be trained
with any machine learning package outside of EconML.
.. rubric::
Structural Equation Formulation

Просмотреть файл

@ -19,13 +19,13 @@ Detailed estimator comparison
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.LinearDRLearner` | Categorical | | Yes | | Projected | | Yes | |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDML` | 1-d/Binary | | Yes | Yes | | Yes | | Yes |
| :class:`.CausalForestDML` | Any | | Yes | Yes | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDRLearner` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ContinuousTreatmentOrthoForest` | Continuous | | Yes | Yes | | | Yes | Yes |
| :class:`.DMLOrthoForest` | Any | | Yes | Yes | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.DiscreteTreatmentOrthoForest` | Categorical | | Yes | | | | Yes | Yes |
| :class:`.DROrthoForest` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :mod:`~econml.metalearners` | Categorical | | | | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+

Просмотреть файл

@ -34,7 +34,7 @@ What are the relevant estimator classes?
This section describes the methodology implemented in the classes, :class:`._RLearner`,
:class:`.DML`, :class:`.LinearDML`,
:class:`.SparseLinearDML`, :class:`.KernelDML`, :class:`.NonParamDML`,
:class:`.ForestDML`.
:class:`.CausalForestDML`.
Click on each of these links for a detailed module documentation and input parameters of each class.
@ -72,7 +72,7 @@ linear on some pre-defined; potentially high-dimensional; featurization). These
:class:`.DML`, :class:`.LinearDML`,
:class:`.SparseLinearDML`, :class:`.KernelDML`.
For fullly non-parametric heterogeneous treatment effect models, checkout the :class:`.NonParamDML`
and the :class:`.ForestDML`. For more options of non-parametric CATE estimators,
and the :class:`.CausalForestDML`. For more options of non-parametric CATE estimators,
check out the :ref:`Forest Estimators User Guide <orthoforestuserguide>`
and the :ref:`Meta Learners User Guide <metalearnersuserguide>`.
@ -155,10 +155,10 @@ Class Hierarchy Structure
In this library we implement variants of several of the approaches mentioned in the last section. The hierarchy
structure of the implemented CATE estimators is as follows.
.. inheritance-diagram:: econml.dml.LinearDML econml.dml.SparseLinearDML econml.dml.KernelDML econml.dml.NonParamDML econml.dml.ForestDML
.. inheritance-diagram:: econml.dml.LinearDML econml.dml.SparseLinearDML econml.dml.KernelDML econml.dml.NonParamDML econml.dml.CausalForestDML
:parts: 1
:private-bases:
:top-classes: econml._rlearner._RLearner, econml.cate_estimator.StatsModelsCateEstimatorMixin, econml.cate_estimator.DebiasedLassoCateEstimatorMixin
:top-classes: econml._rlearner._RLearner, econml._cate_estimator.StatsModelsCateEstimatorMixin, econml._cate_estimator.DebiasedLassoCateEstimatorMixin
Below we give a brief description of each of these classes:
@ -267,24 +267,24 @@ Below we give a brief description of each of these classes:
estimator is also a *Meta-Learner*, since all steps of the estimation use out-of-the-box ML algorithms. For more information,
check out :ref:`Meta Learners User Guide <metalearnersuserguide>`.
- **ForestDML.** This is a child of the :class:`.NonParamDML` that uses a Subsampled Honest Forest regressor
as a final model (see [Wager2018]_ and [Athey2019]_). The subsampled honest forest is implemented in the library as a scikit-learn extension
of the :class:`~sklearn.ensemble.RandomForestRegressor`, in the class :class:`.SubsampledHonestForest`. This estimator
offers confidence intervals via the Bootstrap-of-Little-Bags as described in [Athey2019]_. Using this functionality we can
also construct confidence intervals for the CATE:
* **CausalForestDML.** This is a child of the :class:`._RLearner` that uses a Causal Forest
as a final model (see [Wager2018]_ and [Athey2019]_). The Causal Forest is implemented in the library as a scikit-learn
predictor, in the class :class:`.CausalForest`. This estimator
offers confidence intervals via the Bootstrap-of-Little-Bags as described in [Athey2019]_.
Using this functionality we can also construct confidence intervals for the CATE:
.. testcode::
from econml.dml import ForestDML
.. testcode::
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor
est = ForestDML(model_y=GradientBoostingRegressor(),
model_t=GradientBoostingRegressor())
est = CausalForestDML(model_y=GradientBoostingRegressor(),
model_t=GradientBoostingRegressor())
est.fit(y, t, X=X, W=W)
point = est.effect(X, T0=t0, T1=t1)
lb, ub = est.effect_interval(X, T0=t0, T1=t1, alpha=0.05)
Check out :ref:`Forest Estimators User Guide <orthoforestuserguide>` for more information on forest based CATE models and other
alternatives to the :class:`.ForestDML`.
Check out :ref:`Forest Estimators User Guide <orthoforestuserguide>` for more information on forest based CATE models and other
alternatives to the :class:`.CausalForestDML`.
* **_RLearner.** The internal private class :class:`._RLearner` is a parent of the :class:`.DML`
and allows the user to specify any way of fitting a final model that takes as input the residual :math:`\tilde{T}`,
@ -320,7 +320,7 @@ Usage FAQs
lb, ub = est.effect_interval(X, T0=T0, T1=T1, alpha=.05)
If you have a single dimensional continuous treatment or a binary treatment, then you can also fit non-linear
models and have confidence intervals by using the :class:`.ForestDML`. This class will also
models and have confidence intervals by using the :class:`.CausalForestDML`. This class will also
perform well with high dimensional features, as long as only few of these features are actually relevant.
- **Why not just run a simple big linear regression with all the treatments, features and controls?**
@ -356,7 +356,7 @@ Usage FAQs
1) If effect heterogeneity does not have a linear form, then this approach is not valid.
One might want to then create more complex featurization, in which case the problem could
become too high-dimensional for OLS. Our :class:`.SparseLinearDML`
can handle such settings via the use of the debiased Lasso. Our :class:`.ForestDML` does not
can handle such settings via the use of the debiased Lasso. Our :class:`.CausalForestDML` does not
even need explicit featurization and learns non-linear forest based CATE models, automatically. Also see the
:ref:`Forest Estimators User Guide <orthoforestuserguide>` and the :ref:`Meta Learners User Guide <metalearnersuserguide>`,
if you want even more flexible CATE models.
@ -378,15 +378,15 @@ Usage FAQs
est.fit(y, T, X=X, W=W)
lb, ub = est.const_marginal_effect_interval(X, alpha=.05)
Alternatively, you can also use a forest based estimator such as :class:`.ForestDML`. This
Alternatively, you can also use a forest based estimator such as :class:`.CausalForestDML`. This
estimator can also handle many features, albeit typically smaller number of features than the sparse linear DML.
Moreover, this estimator essentially performs automatic featurization and can fit non-linear models.
.. testcode::
from econml.dml import ForestDML
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor
est = ForestDML(model_y=GradientBoostingRegressor(),
est = CausalForestDML(model_y=GradientBoostingRegressor(),
model_t=GradientBoostingRegressor())
est.fit(y, t, X=X, W=W)
lb, ub = est.const_marginal_effect_interval(X, alpha=.05)
@ -396,7 +396,7 @@ Usage FAQs
- **What if I have too many features that can create heterogeneity?**
Use the :class:`.SparseLinearDML` or :class:`.ForestDML` (see above).
Use the :class:`.SparseLinearDML` or :class:`.CausalForestDML` (see above).
- **What if I have too many features I want to control for?**

Просмотреть файл

@ -198,7 +198,7 @@ structure of the implemented CATE estimators is as follows.
.. inheritance-diagram:: econml.drlearner.DRLearner econml.drlearner.LinearDRLearner econml.drlearner.SparseLinearDRLearner econml.drlearner.ForestDRLearner
:parts: 1
:private-bases:
:top-classes: econml._ortho_learner._OrthoLearner, econml.cate_estimator.StatsModelsCateEstimatorDiscreteMixin, econml.cate_estimator.DebiasedLassoCateEstimatorDiscreteMixin
:top-classes: econml._ortho_learner._OrthoLearner, econml._cate_estimator.StatsModelsCateEstimatorDiscreteMixin, econml._cate_estimator.DebiasedLassoCateEstimatorDiscreteMixin
Below we give a brief description of each of these classes:

Просмотреть файл

@ -12,15 +12,15 @@ This section describes the different estimation methods provided in the package
to model the treatment effect heterogeneity. We collect these methods in a single user guide to better illustrate
their comparisons and differences. Currently, our package offers three such estimation methods:
* The Orthogonal Random Forest Estimator (see :class:`.ContinuousTreatmentOrthoForest`, :class:`.DiscreteTreatmentOrthoForest`)
* The Forest Double Machine Learning Estimator (aka Causal Forest) (see :class:`.ForestDML`)
* The Orthogonal Random Forest Estimator (see :class:`.DMLOrthoForest`, :class:`.DROrthoForest`)
* The Forest Double Machine Learning Estimator (aka Causal Forest) (see :class:`.CausalForestDML`)
* The Forest Doubly Robust Estimator (see :class:`.ForestDRLearner`).
These estimators, similar to the DML and DR sections require the unconfoundedness assumption, i.e. that all potential
variables that could simultaneously have affected the treatment and the outcome to be observed.
There many commonalities among these estimators. In particular the :class:`.ContinuousTreatmentOrthoForest` shares
many similarities with the :class:`.ForestDML` and the :class:`.DiscreteTreatmentOrthoForest` shares
There many commonalities among these estimators. In particular the :class:`.DMLOrthoForest` shares
many similarities with the :class:`.CausalForestDML` and the :class:`.DROrthoForest` shares
many similarities with the :class:`.ForestDRLearner`. Specifically, the corresponding classes use the same estimating (moment)
equations to identify the heterogeneous treatment effect. However, they differ in a substantial manner in how they
estimate the first stage regression/classification (nuisance) models. In particular, the OrthoForest methods fit
@ -40,9 +40,9 @@ and uses the same metric for the local fitting of the nuisances as well as the f
What are the relevant estimator classes?
========================================
This section describes the methodology implemented in the classes, :class:`.ContinuousTreatmentOrthoForest`,
:class:`.DiscreteTreatmentOrthoForest`,
:class:`.ForestDML`, :class:`.ForestDRLearner`.
This section describes the methodology implemented in the classes, :class:`.DMLOrthoForest`,
:class:`.DROrthoForest`,
:class:`.CausalForestDML`, :class:`.ForestDRLearner`.
Click on each of these links for a detailed module documentation and input parameters of each class.
@ -69,7 +69,7 @@ the heterogeneous treatment effect :math:`\theta(X)`, on a lower dimensional set
Moreover, the estimates are asymptotically normal and hence have theoretical properties
that render bootstrap based confidence intervals asymptotically valid.
In the case of continuous treatments (see :class:`.ContinuousTreatmentOrthoForest`) the method estimates :math:`\theta(x)`
For continuous or discrete treatments (see :class:`.DMLOrthoForest`) the method estimates :math:`\theta(x)`
for some target :math:`x` by solving the same set of moment equations as the ones used in the Double Machine Learning
framework, albeit, it tries to solve them locally for every possible :math:`X=x`. The method makes the following
structural equations assumptions on the data generating process:
@ -165,15 +165,15 @@ some extensions to the scikit-learn library that enable sample weights, such as
.. testcode:: intro
:hide:
from econml.ortho_forest import ContinuousTreatmentOrthoForest
from econml.ortho_forest import DMLOrthoForest
from econml.sklearn_extensions.linear_model import WeightedLasso
.. doctest:: intro
>>> est = ContinuousTreatmentOrthoForest(model_Y=WeightedLasso(), model_T=WeightedLasso())
>>> est = DMLOrthoForest(model_Y=WeightedLasso(), model_T=WeightedLasso())
In the case of discrete treatments (see :class:`.DiscreteTreatmentOrthoForest`) the
In the case of discrete treatments (see :class:`.DROrthoForest`) the
method estimates :math:`\theta(x)` for some target :math:`x` by solving a slightly different
set of equations, similar to the Doubly Robust Learner (see [Oprescu2019]_ for a theoretical exposition of why a different set of
estimating equations is used). In particular, suppose that the treatment :math:`T` takes
@ -213,14 +213,14 @@ a multi-class classification model and should support :code:`predict_proba`.
For more details on the input parameters of the orthogonal forest classes and how to customize
the estimator checkout the two modules:
- :class:`.DiscreteTreatmentOrthoForest`
- :class:`.ContinuousTreatmentOrthoForest`
- :class:`.DROrthoForest`
- :class:`.DMLOrthoForest`
CausalForest (aka Forest Double Machine Learning)
--------------------------------------------------
In this package we implement the double machine learning version of Causal Forests/Generalized Random Forests (see [Wager2018]_, [Athey2019]_)
as for instance described in Section 6.1.1 of [Athey2019]_. This version follows a similar structure to the ContinuousTreatmentOrthoForest approach,
as for instance described in Section 6.1.1 of [Athey2019]_. This version follows a similar structure to the DMLOrthoForest approach,
in that the estimation is based on solving a local residual on residual moment condition:
.. math::
@ -239,23 +239,12 @@ This difference can potentially lead to an improvement in the estimation error o
Causal Forest. However, it does add significant computation cost, as a nuisance function needs to be estimated locally
for each target prediction.
Our implementation of a Causal Forest is restricted to binary treatment or single-dimensional continuous treatment
and is based on an extra observation that for such settings, we can view the local square loss above as a normal regression
square loss with sample weights, i.e.:
Our implementation of a Causal Forest allows for any number of continuous treatments or a multi-valued discrete
treatment. The causal forest is implemented in :class:`.CausalForest` in a high-performance Cython implementation
as a scikit-learn predictor.
.. math::
\hat{\theta}(x) = \argmin_{\theta} \sum_{i=1}^n K_x(X_i)\cdot \tilde{T}_i^2 \cdot \left( \tilde{Y}_i/\tilde{T}_i - \theta\right)^2
where :math:`\tilde{T}_i = T_i - \hat{f}(X_i, W_i)` and :math:`\tilde{Y}_i = Y_i - \hat{q}(X_i, W_i)`. Thus we can apply
a normal regression forest to estimate the :math:`\theta`. Albeit for valid confidence intervals we need a forest
that is based on subsampling and uses honesty to define the leaf estimates. Thus we can re-use the splitting machinery
of a scikit-learn regressor and augment it with honesty and subsampling capabilities. We implement this in our
:class:`.SubsampledHonestForest` scikit-learn extension.
The causal criterion that is implicit in the above reduction approach is slightly different than the one
proposed in [Athey2019]_. However, the exact criterion is not crucial for the theoretical developments and the
validity of the confidence intervals is maintained. The difference can potentially lead to small finite sample
Apart from the criterion proposed in [Athey2019]_ we also implemented an MSE criterion that penalizes splits
with low variance in the treatment. The difference can potentially lead to small finite sample
differences. In particular, suppose that we want to decide how to split a node in two subsets of samples :math:`S_1`
and :math:`S_2` and let :math:`\theta_1` and :math:`\theta_2` be the estimates on each of these partitions.
Then the criterion implicit in the reduction is the weighted mean squared error, which boils down to
@ -268,30 +257,13 @@ Then the criterion implicit in the reduction is the weighted mean squared error,
where :math:`Var_n`, denotes the empirical variance. Essentially, this criterion tries to maximize heterogeneity
(as captured by maximizing the sum of squares of the two estimates), while penalizing splits that create nodes
with small variation in the treatment. On the contrary the criterion proposed in [Athey2019]_ ignores the within
child variation of the treatment and solely maximizes the hetergoeneity, i.e. :math:`\max_{S_1, S_2} \theta_1^2 + \theta_2^2`.
Moreover, a subtle point is that in order to mirror the Genearlized Random Forest algorithm, our final prediction is not just
the average of the tree estimates. Instead we use the tree to define sample weights as describe in [Athey2019]_ and then
calculate the solution to the weighted moment equation or equivalently the minimizer of the square loss, which boils down to:
child variation of the treatment and solely maximizes the hetergoeneity, i.e.
.. math::
\hat{\theta}(x) = \frac{\sum_{i=1}^{n} K_x(X_i) \cdot \tilde{Y}_i \cdot \tilde{T}_i}{\sum_{i=1}^n K_x(X_i) \cdot \tilde{T}_i^2}
\max_{S_1, S_2} \theta_1^2 + \theta_2^2
From our reduction prespective, this is equivalent to saying that we will train a regression forest with sample weights
:math:`k_i`, features :math:`X_i` and labels :math:`Y_i` and then in the end, we will define the overall estimate at some target :math:`x`, as:
.. math::
\hat{\theta}(x) =~& \frac{\sum_{b=1}^B \sum_{i=1}^n w_{bi}\cdot Y_i}{\sum_{b=1}^B \sum_{i=1}^n w_{bi}}\\
w_{bi} =~& \frac{k_i\cdot 1\{i \in L_{b}(x)\}}{|L_b(x)|}
where :math:`L_b(x)` is the leaf the sample :math:`x` falls into in the :math:`b`-th tree of the forest.
This is exactly what is implemented in the SubsampledHonestForest (see :class:`.SubsampledHonestForest`). Combining
these ideas leads to a "reduction-based" approach implementation of the Causal Forest, that re-uses and only slightly modifies
existing impementations of regression forests.
For more details on Double Machine Learning and how the :class:`.ForestDML` fits into our overall
For more details on Double Machine Learning and how the :class:`.CausalForestDML` fits into our overall
set of DML based CATE estimators, check out the :ref:`Double Machine Learning User Guide <dmluserguide>`.
Forest Doubly Robust Learner
@ -302,7 +274,7 @@ The Forest Doubly Robust Learner is a variant of the Generalized Random Forest a
to the double machine learning moments (see the :ref:`Doubly Robust Learning User Guide <druserguide>`).
The method only applies for categorical treatments.
Essentially, it is an analogue of the :class:`.DiscreteTreatmentOrthoForest`, that instead of local nuisance estimation
Essentially, it is an analogue of the :class:`.DROrthoForest`, that instead of local nuisance estimation
it conducts global nuisance estimation and does not couple the implicit similarity metric used for the nuisance
estimates, with the final stage similarity metric.
@ -325,22 +297,22 @@ manner (see e.g. :class:`._OrthoLearner` for more details on cross fitting).
The similarity metric :math:`K_x(X_i)` is trained in a data-adaptive manner by constructing a Subsampled Honest Random Regression Forest
where the target label is :math:`Y_{i, t}^{DR} - Y_{i, 0}^{DR}` and the features are :math:`X` and roughly calculating
how frequently sample :math:`x` falls in the same leaf as
sample :math:`X_i`. This is implemented in the SubsampledHonestForest (see :class:`.SubsampledHonestForest`).
sample :math:`X_i`. This is implemented in the RegressionForest (see :class:`.RegressionForest`).
Class Hierarchy Structure
=========================
.. inheritance-diagram:: econml.ortho_forest.ContinuousTreatmentOrthoForest econml.ortho_forest.DiscreteTreatmentOrthoForest econml.drlearner.ForestDRLearner econml.dml.ForestDML
.. inheritance-diagram:: econml.ortho_forest.DMLOrthoForest econml.ortho_forest.DROrthoForest econml.drlearner.ForestDRLearner econml.dml.CausalForestDML
:parts: 1
:private-bases:
:top-classes: econml._ortho_learner._OrthoLearner, econml.ortho_forest.BaseOrthoForest, econml.cate_estimator.LinearCateEstimator
:top-classes: econml._ortho_learner._OrthoLearner, econml.ortho_forest.BaseOrthoForest, econml._cate_estimator.LinearCateEstimator
Usage Examples
==================================
Here is a simple example of how to call :class:`.ContinuousTreatmentOrthoForest`
Here is a simple example of how to call :class:`.DMLOrthoForest`
and what the returned values correspond to in a simple data generating process.
For more examples check out our
`OrthoForest Jupyter notebook <https://github.com/Microsoft/EconML/blob/master/notebooks/Orthogonal%20Random%20Forest%20Examples.ipynb>`_
@ -351,30 +323,30 @@ and the `ForestLearners Jupyter notebook <https://github.com/microsoft/EconML/bl
import numpy as np
import sklearn
from econml.ortho_forest import ContinuousTreatmentOrthoForest, DiscreteTreatmentOrthoForest
from econml.ortho_forest import DMLOrthoForest, DROrthoForest
np.random.seed(123)
>>> T = np.array([0, 1]*60)
>>> W = np.array([0, 1, 1, 0]*30).reshape(-1, 1)
>>> Y = (.2 * W[:, 0] + 1) * T + .5
>>> est = ContinuousTreatmentOrthoForest(n_trees=1, max_depth=1, subsample_ratio=1,
... model_T=sklearn.linear_model.LinearRegression(),
... model_Y=sklearn.linear_model.LinearRegression())
>>> est = DMLOrthoForest(n_trees=1, max_depth=1, subsample_ratio=1,
... model_T=sklearn.linear_model.LinearRegression(),
... model_Y=sklearn.linear_model.LinearRegression())
>>> est.fit(Y, T, X=W, W=W)
<econml.ortho_forest.ContinuousTreatmentOrthoForest object at 0x...>
<econml.ortho_forest.DMLOrthoForest object at 0x...>
>>> print(est.effect(W[:2]))
[1.00... 1.19...]
Similarly, we can call :class:`.DiscreteTreatmentOrthoForest`:
Similarly, we can call :class:`.DROrthoForest`:
>>> T = np.array([0, 1]*60)
>>> W = np.array([0, 1, 1, 0]*30).reshape(-1, 1)
>>> Y = (.2 * W[:, 0] + 1) * T + .5
>>> est = DiscreteTreatmentOrthoForest(n_trees=1, max_depth=1, subsample_ratio=1,
... propensity_model=sklearn.linear_model.LogisticRegression(),
... model_Y=sklearn.linear_model.LinearRegression())
>>> est = DROrthoForest(n_trees=1, max_depth=1, subsample_ratio=1,
... propensity_model=sklearn.linear_model.LogisticRegression(),
... model_Y=sklearn.linear_model.LinearRegression())
>>> est.fit(Y, T, X=W, W=W)
<econml.ortho_forest.DiscreteTreatmentOrthoForest object at 0x...>
<econml.ortho_forest.DROrthoForest object at 0x...>
>>> print(est.effect(W[:2]))
[0.99... 1.35...]
@ -383,8 +355,8 @@ and with more realistic noisy data. In this case we can just use the default par
of the class, which specify the use of the :class:`~sklearn.linear_model.LassoCV` for
both the treatment and the outcome regressions, in the case of continuous treatments.
>>> from econml.ortho_forest import ContinuousTreatmentOrthoForest
>>> from econml.ortho_forest import ContinuousTreatmentOrthoForest
>>> from econml.ortho_forest import DMLOrthoForest
>>> from econml.ortho_forest import DMLOrthoForest
>>> from econml.sklearn_extensions.linear_model import WeightedLasso
>>> import matplotlib.pyplot as plt
>>> np.random.seed(123)
@ -393,12 +365,12 @@ both the treatment and the outcome regressions, in the case of continuous treatm
>>> support = np.random.choice(50, 4, replace=False)
>>> T = np.dot(W[:, support], np.random.normal(size=4)) + np.random.normal(size=4000)
>>> Y = np.exp(2*X[:, 0]) * T + np.dot(W[:, support], np.random.normal(size=4)) + .5
>>> est = ContinuousTreatmentOrthoForest(n_trees=100,
... max_depth=5,
... model_Y=WeightedLasso(alpha=0.01),
... model_T=WeightedLasso(alpha=0.01))
>>> est = DMLOrthoForest(n_trees=100,
... max_depth=5,
... model_Y=WeightedLasso(alpha=0.01),
... model_T=WeightedLasso(alpha=0.01))
>>> est.fit(Y, T, X=X, W=W)
<econml.ortho_forest.ContinuousTreatmentOrthoForest object at 0x...>
<econml.ortho_forest.DMLOrthoForest object at 0x...>
>>> X_test = np.linspace(-1, 1, 30).reshape(-1, 1)
>>> treatment_effects = est.effect(X_test)
>>> plt.plot(X_test[:, 0], treatment_effects, label='ORF estimate')

Просмотреть файл

@ -138,7 +138,7 @@ Class Hierarchy Structure
.. inheritance-diagram:: econml.metalearners.SLearner econml.metalearners.TLearner econml.metalearners.XLearner econml.metalearners.DomainAdaptationLearner econml.drlearner.DRLearner econml.dml.DML
:parts: 1
:private-bases:
:top-classes: econml._ortho_learner._OrthoLearner, econml.cate_estimator.LinearCateEstimator, econml.cate_estimator.TreatmentExpansionMixin
:top-classes: econml._ortho_learner._OrthoLearner, econml._cate_estimator.LinearCateEstimator, econml._cate_estimator.TreatmentExpansionMixin
Usage Examples

Просмотреть файл

@ -110,16 +110,16 @@ Subsampled Honest Forest Inference
For estimators where the final stage CATE estimate is a non-parametric model based on a Random Forest, we offer
confidence intervals via the bootstrap-of-little-bags approach (see [Athey2019]_) for estimating the uncertainty of
an Honest Random Forest. This for instance holds for the :class:`.ForestDML`
an Honest Random Forest. This for instance holds for the :class:`.CausalForestDML`
and the :class:`.ForestDRLearner`. Such intervals are enabled by leaving inference at its default setting of ``'auto'``
or by explicitly setting ``inference='blb'``, e.g.:
.. testcode::
from econml.dml import ForestDML
from econml.dml import CausalForestDML
from sklearn.ensemble import RandomForestRegressor
est = ForestDML(model_y=RandomForestRegressor(n_estimators=10, min_samples_leaf=10),
model_t=RandomForestRegressor(n_estimators=10, min_samples_leaf=10))
est = CausalForestDML(model_y=RandomForestRegressor(n_estimators=10, min_samples_leaf=10),
model_t=RandomForestRegressor(n_estimators=10, min_samples_leaf=10))
est.fit(y, t, X=X, W=W)
point = est.const_marginal_effect(X)
lb, ub = est.const_marginal_effect_interval(X, alpha=0.05)
@ -141,19 +141,19 @@ This inference is enabled by our implementation of the :class:`.SubsampledHonest
OrthoForest Bootstrap of Little Bags Inference
==============================================
For the Orthogonal Random Forest estimators (see :class:`.ContinuousTreatmentOrthoForest`, :class:`.DiscreteTreatmentOrthoForest`),
For the Orthogonal Random Forest estimators (see :class:`.DMLOrthoForest`, :class:`.DROrthoForest`),
we provide confidence intervals built via the bootstrap-of-little-bags approach ([Athey2019]_). This technique is well suited for
estimating the uncertainty of the honest causal forests underlying the OrthoForest estimators. Such intervals are enabled by leaving
inference at its default setting of ``'auto'`` or by explicitly setting ``inference='blb'``, e.g.:
.. testcode::
from econml.ortho_forest import ContinuousTreatmentOrthoForest
from econml.ortho_forest import DMLOrthoForest
from econml.sklearn_extensions.linear_model import WeightedLasso
est = ContinuousTreatmentOrthoForest(n_trees=10,
min_leaf_size=3,
model_T=WeightedLasso(alpha=0.01),
model_Y=WeightedLasso(alpha=0.01))
est = DMLOrthoForest(n_trees=10,
min_leaf_size=3,
model_T=WeightedLasso(alpha=0.01),
model_Y=WeightedLasso(alpha=0.01))
est.fit(y, t, X=X, W=W)
point = est.const_marginal_effect(X)
lb, ub = est.const_marginal_effect_interval(X, alpha=0.05)

7
econml/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,7 @@
__all__ = ['automated_ml', 'bootstrap',
'cate_interpreter', 'causal_forest',
'data', 'deepiv', 'dml', 'drlearner', 'inference',
'metalearners', 'ortho_forest', 'ortho_iv',
'sklearn_extensions', 'tree',
'two_stage_least_squares', 'utilities']

Просмотреть файл

@ -0,0 +1,37 @@
"""
Utilities useful during the build.
"""
import os
import sklearn
import contextlib
from distutils.version import LooseVersion
CYTHON_MIN_VERSION = '0.28.5'
def _check_cython_version():
message = ('Please install Cython with a version >= {0} in order '
'to build a scikit-learn from source.').format(
CYTHON_MIN_VERSION)
try:
import Cython
except ModuleNotFoundError:
# Re-raise with more informative error message instead:
raise ModuleNotFoundError(message)
if LooseVersion(Cython.__version__) < CYTHON_MIN_VERSION:
message += (' The current version of Cython is {} installed in {}.'
.format(Cython.__version__, Cython.__path__))
raise ValueError(message)
def cythonize_extensions(top_path, config):
"""Check that a recent Cython is available and cythonize extensions"""
_check_cython_version()
from Cython.Build import cythonize
config.ext_modules = cythonize(
config.ext_modules,
compiler_directives={'language_level': 3})

Просмотреть файл

@ -10,12 +10,11 @@ from copy import deepcopy
from warnings import warn
from .inference import BootstrapInference
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params,
inverse_onehot, Summary, get_input_columns, broadcast_unit_treatments,
cross_product)
inverse_onehot, Summary, get_input_columns)
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
GenericModelFinalInferenceDiscrete
from .shap import _shap_explain_cme, _define_names, _shap_explain_joint_linear_model_cate
from ._shap import _shap_explain_cme, _shap_explain_joint_linear_model_cate
class BaseCateEstimator(metaclass=abc.ABCMeta):
@ -458,7 +457,7 @@ class LinearCateEstimator(BaseCateEstimator):
"""
pass
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
""" Shap value for the final stage models (const_marginal_effect)
Parameters
@ -472,18 +471,23 @@ class LinearCateEstimator(BaseCateEstimator):
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
background_samples: int or None, (Default=100)
How many samples to use to compute the baseline effect. If None then all samples are used.
Returns
-------
shap_outs: nested dictionary of Explanation object
A nested dictionary by using each output name (e.g. "Y0" when `output_names=None`) and
each treatment name (e.g. "T0" when `treatment_names=None`) as key
and the shap_values explanation object as value.
A nested dictionary by using each output name (e.g. 'Y0', 'Y1', ... when `output_names=None`) and
each treatment name (e.g. 'T0', 'T1', ... when `treatment_names=None`) as key
and the shap_values explanation object as value. If the input data at fit time also contain metadata,
(e.g. are pandas DataFrames), then the column metatdata for the treatments, outcomes and features
are used instead of the above defaults (unless the user overrides with explicitly passing the
corresponding names).
"""
return _shap_explain_cme(self.const_marginal_effect, X, self._d_t, self._d_y, feature_names, treatment_names,
output_names)
return _shap_explain_cme(self.const_marginal_effect, X, self._d_t, self._d_y,
feature_names=feature_names, treatment_names=treatment_names,
output_names=output_names, input_names=self._input_names,
background_samples=background_samples)
class TreatmentExpansionMixin(BaseCateEstimator):
@ -667,7 +671,6 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
converted to various output formats.
"""
# Get input names
feature_names = self.cate_feature_names() if feature_names is None else feature_names
treatment_names = self._input_names["treatment_names"] if treatment_names is None else treatment_names
output_names = self._input_names["output_names"] if output_names is None else output_names
# Summary
@ -714,17 +717,16 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
if len(smry.tables) > 0:
return smry
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
(dt, dy, treatment_names, output_names) = _define_names(self._d_t, self._d_y, treatment_names, output_names)
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
if hasattr(self, "featurizer") and self.featurizer is not None:
X = self.featurizer.transform(X)
X, T = broadcast_unit_treatments(X, dt)
d_x = X.shape[1]
X_new = cross_product(X, T)
feature_names = self.cate_feature_names(feature_names)
return _shap_explain_joint_linear_model_cate(self.model_final, X_new, T, dt, dy, self.fit_cate_intercept,
return _shap_explain_joint_linear_model_cate(self.model_final, X, self._d_t, self._d_y,
self.fit_cate_intercept,
feature_names=feature_names, treatment_names=treatment_names,
output_names=output_names)
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__

Просмотреть файл

@ -8,10 +8,7 @@ class :class:`Node` represents the core unit of the :class:`CausalTree` class.
"""
import numpy as np
import warnings
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
import scipy.special
class Node:
@ -96,18 +93,11 @@ class CausalTree:
"""
def __init__(self,
nuisance_estimator,
parameter_estimator,
moment_and_mean_gradient_estimator,
min_leaf_size=10,
max_depth=10,
n_proposals=1000,
balancedness_tol=.3,
random_state=None):
# Estimators
self.nuisance_estimator = nuisance_estimator
self.parameter_estimator = parameter_estimator
self.moment_and_mean_gradient_estimator = moment_and_mean_gradient_estimator
# Causal tree parameters
self.min_leaf_size = min_leaf_size
self.max_depth = max_depth
@ -117,7 +107,8 @@ class CausalTree:
# Tree structure
self.tree = None
def create_splits(self, Y, T, X, W):
def create_splits(self, Y, T, X, W,
nuisance_estimator, parameter_estimator, moment_and_mean_gradient_estimator):
"""
Recursively build a causal tree.
@ -158,17 +149,17 @@ class CausalTree:
node_size_est = node_X_estimate.shape[0]
# Compute nuisance estimates for the current node
nuisance_estimates = self.nuisance_estimator(node_Y, node_T, node_X, node_W)
nuisance_estimates = nuisance_estimator(node_Y, node_T, node_X, node_W)
if nuisance_estimates is None:
# Nuisance estimate cannot be calculated
continue
# Estimate parameter for current node
node_estimate = self.parameter_estimator(node_Y, node_T, node_X, nuisance_estimates)
node_estimate = parameter_estimator(node_Y, node_T, node_X, nuisance_estimates)
if node_estimate is None:
# Node estimate cannot be calculated
continue
# Calculate moments and gradient of moments for current data
moments, mean_grad = self.moment_and_mean_gradient_estimator(
moments, mean_grad = moment_and_mean_gradient_estimator(
node_Y, node_T, node_X, node_W,
nuisance_estimates,
node_estimate)

Просмотреть файл

@ -34,8 +34,8 @@ from sklearn.preprocessing import (FunctionTransformer, LabelEncoder,
OneHotEncoder)
from sklearn.utils import check_random_state
from .cate_estimator import (BaseCateEstimator, LinearCateEstimator,
TreatmentExpansionMixin)
from ._cate_estimator import (BaseCateEstimator, LinearCateEstimator,
TreatmentExpansionMixin)
from .utilities import (_deprecate_positional, _EncoderWrapper, check_input_arrays,
cross_product, filter_none_kwargs,
inverse_onehot, ndim, reshape, shape, transpose)
@ -291,7 +291,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete

Просмотреть файл

@ -4,9 +4,12 @@
import shap
from collections import defaultdict
import numpy as np
from .utilities import broadcast_unit_treatments, cross_product
def _shap_explain_cme(cme_model, X, d_t, d_y, feature_names=None, treatment_names=None, output_names=None):
def _shap_explain_cme(cme_model, X, d_t, d_y,
feature_names=None, treatment_names=None, output_names=None,
input_names=None, background_samples=100):
"""
Method to explain `const_marginal_effect` function using shap Explainer().
@ -27,6 +30,10 @@ def _shap_explain_cme(cme_model, X, d_t, d_y, feature_names=None, treatment_name
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
input_names: dictionary or None
The parsed names of variables at fit input time of cate estimators
background_samples: int or None, (Default=100)
How many samples to use to compute the baseline effect. If None then all samples are used.
Returns
-------
@ -36,9 +43,11 @@ def _shap_explain_cme(cme_model, X, d_t, d_y, feature_names=None, treatment_name
and the shap_values explanation object as value.
"""
(dt, dy, treatment_names, output_names) = _define_names(d_t, d_y, treatment_names, output_names)
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
feature_names, input_names)
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
background = shap.maskers.Independent(X, max_samples=X.shape[0])
bg_samples = X.shape[0] if background_samples is None else min(background_samples, X.shape[0])
background = shap.maskers.Independent(X, max_samples=bg_samples)
shap_outs = defaultdict(dict)
for i in range(dy):
def cmd_func(X):
@ -65,10 +74,13 @@ def _shap_explain_cme(cme_model, X, d_t, d_y, feature_names=None, treatment_name
def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, feature_names=None,
treatment_names=None, output_names=None):
treatment_names=None, output_names=None,
input_names=None, background_samples=100):
"""
Method to explain `model_cate` using shap Explainer(), will instead explain `const_marignal_effect`
if `model_cate` can't be parsed.
if `model_cate` can't be parsed. Models should be a list of length d_t. Each element in the list of
models represents the const_marginal_effect associated with each treatments and for all outcomes, i.e.
the outcome of the predict method of each model should be of length d_y.
Parameters
----------
@ -89,6 +101,10 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, feature_names=None,
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
input_names: dictionary or None
The parsed names of variables at fit input time of cate estimators
background_samples: int or None, (Default=100)
How many samples to use to compute the baseline effect. If None then all samples are used.
Returns
-------
@ -97,13 +113,17 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, feature_names=None,
each treatment name (e.g. "T0" when `treatment_names=None`) as key
and the shap_values explanation object as value.
"""
(dt, dy, treatment_names, output_names) = _define_names(d_t, d_y, treatment_names, output_names)
d_t_, d_y_ = d_t, d_y
feature_names_, treatment_names_ = feature_names, treatment_names,
output_names_, input_names_ = output_names, input_names
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
feature_names, input_names)
if not isinstance(models, list):
models = [models]
assert len(models) == dt, "Number of final stage models don't equals to number of treatments!"
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
background = shap.maskers.Independent(X, max_samples=X.shape[0])
bg_samples = X.shape[0] if background_samples is None else min(background_samples, X.shape[0])
background = shap.maskers.Independent(X, max_samples=bg_samples)
shap_outs = defaultdict(dict)
for i in range(dt):
@ -111,9 +131,13 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, feature_names=None,
explainer = shap.Explainer(models[i], background,
feature_names=feature_names)
except Exception as e:
print("Final model can't be parsed, explain const_marginal_effect() instead!")
return _shap_explain_cme(cme_model, X, d_t, d_y, feature_names, treatment_names,
output_names)
print("Final model can't be parsed, explain const_marginal_effect() instead!", repr(e))
return _shap_explain_cme(cme_model, X, d_t_, d_y_,
feature_names=feature_names_,
treatment_names=treatment_names_,
output_names=output_names_,
input_names=input_names_,
background_samples=background_samples)
shap_out = explainer(X)
if dy > 1:
for j in range(dy):
@ -130,8 +154,9 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, feature_names=None,
return shap_outs
def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_intercept,
feature_names=None, treatment_names=None, output_names=None):
def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_intercept,
feature_names=None, treatment_names=None, output_names=None,
input_names=None, background_samples=100):
"""
Method to explain `model_cate` of parametric final stage that was fitted on the cross product of
`featurizer(X)` and T.
@ -141,13 +166,14 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
model_final: a single estimator
the model's final stage model.
X: matrix
Intermediate X
T: matrix
Intermediate T
Featurized X
d_t: tuple of int
Tuple of number of treatment (exclude control in discrete treatment scenario).
d_y: tuple of int
Tuple of number of outcome.
fit_cate_intercept: bool
Whether the first entry of the coefficient of the joint linear model associated with
each treatment, is an intercept.
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
The names of input features.
treatment_names: optional None or list (Default=None)
@ -155,6 +181,10 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
input_names: dictionary or None
The parsed names of variables at fit input time of cate estimators
background_samples: int or None, (Default=100)
How many samples to use to compute the baseline effect. If None then all samples are used.
Returns
-------
@ -163,7 +193,10 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
each treatment name (e.g. "T0" when `treatment_names=None`) as key
and the shap_values explanation object as value.
"""
(d_t, d_y, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
feature_names, input_names)
X, T = broadcast_unit_treatments(X, d_t)
X = cross_product(X, T)
d_x = X.shape[1]
# define the index of d_x to filter for each given T
ind_x = np.arange(d_x).reshape(d_t, -1)
@ -174,7 +207,8 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
# filter X after broadcast with T for each given T
X_sub = X[T[:, i] == 1]
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
background = shap.maskers.Independent(X_sub, max_samples=X_sub.shape[0])
bg_samples = X_sub.shape[0] if background_samples is None else min(background_samples, X_sub.shape[0])
background = shap.maskers.Independent(X_sub, max_samples=bg_samples)
explainer = shap.Explainer(model_final, background)
shap_out = explainer(X_sub)
@ -182,7 +216,7 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
if d_y > 1:
for j in range(d_y):
base_values = shap_out.base_values[..., j]
main_effects = shap_out.main_effects[..., ind_x[i], j]
main_effects = None if shap_out.main_effects is None else shap_out.main_effects[..., ind_x[i], j]
values = shap_out.values[..., ind_x[i], j]
shap_out_new = shap.Explanation(values, base_values=base_values, data=data, main_effects=main_effects,
feature_names=feature_names)
@ -199,16 +233,20 @@ def _shap_explain_joint_linear_model_cate(model_final, X, T, d_t, d_y, fit_cate_
def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t, d_y, feature_names=None,
treatment_names=None, output_names=None):
treatment_names=None, output_names=None,
input_names=None, background_samples=100):
"""
Method to explain `multitask_model_cate` for DRLearner
Method to explain a final cate model that is represented in a multi-task manner, i.e. the prediction
of the method is of dimension equal to the number of treatments and represents the const_marginal_effect
vector for all treatments.
Parameters
----------
cme_model: function
const_marginal_effect function.
multitask_model_cate: a single estimator
the model's final stage model.
multitask_model_cate: a single estimator or a list of estimators of length d_y if d_y > 1
the model's final stage model whose predict represents the const_marginal_effect for
all treatments (or list of models, one for each outcome)
X: (m, d_x) matrix
Features for each sample. Should be in the same shape of fitted X in final stage.
d_t: tuple of int
@ -222,6 +260,10 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.
input_names: dictionary or None
The parsed names of variables at fit input time of cate estimators
background_samples: int or None, (Default=100)
How many samples to use to compute the baseline effect. If None then all samples are used.
Returns
-------
@ -230,35 +272,47 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
each treatment name (e.g. "T0" when `treatment_names=None`) as key
and the shap_values explanation object as value.
"""
(dt, dy, treatment_names, output_names) = _define_names(d_t, d_y, treatment_names, output_names)
d_t_, d_y_ = d_t, d_y
feature_names_, treatment_names_ = feature_names, treatment_names,
output_names_, input_names_ = output_names, input_names
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
feature_names, input_names)
if dy == 1 and (not isinstance(multitask_model_cate, list)):
multitask_model_cate = [multitask_model_cate]
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
background = shap.maskers.Independent(X, max_samples=X.shape[0])
bg_samples = X.shape[0] if background_samples is None else min(background_samples, X.shape[0])
background = shap.maskers.Independent(X, max_samples=bg_samples)
shap_outs = defaultdict(dict)
try:
explainer = shap.Explainer(multitask_model_cate, background,
feature_names=feature_names)
except Exception as e:
print("Final model can't be parsed, explain const_marginal_effect() instead!")
return _shap_explain_cme(cme_model, X, d_t, d_y, feature_names, treatment_names,
output_names)
for j in range(dy):
try:
explainer = shap.Explainer(multitask_model_cate[j], background,
feature_names=feature_names)
except Exception as e:
print("Final model can't be parsed, explain const_marginal_effect() instead!", repr(e))
return _shap_explain_cme(cme_model, X, d_t_, d_y_,
feature_names=feature_names_,
treatment_names=treatment_names_,
output_names=output_names_,
input_names=input_names_,
background_samples=background_samples)
shap_out = explainer(X)
if dt > 1:
for i in range(dt):
base_values = shap_out.base_values[..., i]
values = shap_out.values[..., i]
main_effects = None if shap_out.main_effects is not None else shap_out.main_effects[..., i]
shap_out_new = shap.Explanation(values, base_values=base_values,
data=shap_out.data, main_effects=main_effects,
feature_names=shap_out.feature_names)
shap_outs[output_names[0]][treatment_names[i]] = shap_out_new
else:
shap_outs[output_names[0]][treatment_names[0]] = shap_out
shap_out = explainer(X)
if dt > 1:
for i in range(dt):
base_values = shap_out.base_values[..., i]
values = shap_out.values[..., i]
main_effects = None if shap_out.main_effects is None else shap_out.main_effects[..., i]
shap_out_new = shap.Explanation(values, base_values=base_values,
data=shap_out.data, main_effects=main_effects,
feature_names=shap_out.feature_names)
shap_outs[output_names[j]][treatment_names[i]] = shap_out_new
else:
shap_outs[output_names[j]][treatment_names[0]] = shap_out
return shap_outs
def _define_names(d_t, d_y, treatment_names, output_names):
def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_names):
"""
Helper function to get treatment and output names
@ -268,11 +322,15 @@ def _define_names(d_t, d_y, treatment_names, output_names):
Tuple of number of treatment (exclude control in discrete treatment scenario).
d_y: tuple of int
Tuple of number of outcome.
treatment_names: optional None or list (Default=None)
treatment_names: None or list
The name of treatment. In discrete treatment scenario, the name should not include the name of
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
output_names: None or list
The name of the outcome.
feature_names: None or list
The user provided names of the features
input_names: dicitionary
The names of the features, outputs and treatments parsed from the fit input at fit time.
Returns
-------
@ -280,12 +338,21 @@ def _define_names(d_t, d_y, treatment_names, output_names):
d_y: int
treament_names: List
output_names: List
feature_names: List or None
"""
d_t = d_t[0] if d_t else 1
d_y = d_y[0] if d_y else 1
if treatment_names is None:
treatment_names = [f"T{i}" for i in range(d_t)]
if (input_names is None) or (input_names['treatment_names'] is None):
treatment_names = [f"T{i}" for i in range(d_t)]
else:
treatment_names = input_names['treatment_names']
if output_names is None:
output_names = [f"Y{i}" for i in range(d_y)]
return (d_t, d_y, treatment_names, output_names)
if (input_names is None) or (input_names['output_names'] is None):
output_names = [f"Y{i}" for i in range(d_y)]
else:
output_names = input_names['output_names']
if (feature_names is None) and (input_names is not None):
feature_names = input_names['feature_names']
return (d_t, d_y, treatment_names, output_names, feature_names)

Просмотреть файл

@ -76,7 +76,7 @@ class BootstrapEstimator:
The full signature of this method is the same as that of the wrapped object's `fit` method.
"""
from .cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import
from ._cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import
index_chunks = None
if isinstance(self._instances[0], BaseCateEstimator):
@ -182,7 +182,7 @@ class BootstrapEstimator:
def get_inference():
# can't import from econml.inference at top level without creating cyclical dependencies
from .inference import EmpiricalInferenceResults, NormalInferenceResults
from .cate_estimator import LinearModelFinalCateEstimatorDiscreteMixin
from ._cate_estimator import LinearModelFinalCateEstimatorDiscreteMixin
prefix = name[: - len("_inference")]

Просмотреть файл

@ -1,9 +1,26 @@
from .ortho_forest import DMLOrthoForest
from .utilities import LassoCVWrapper
from .utilities import LassoCVWrapper, deprecated
from sklearn.linear_model import LogisticRegressionCV
from .dml import CausalForestDML
class CausalForest(DMLOrthoForest):
@deprecated("The CausalForest class has been deprecated by the CausalForestDML; "
"an upcoming release will remove support for the old class")
def CausalForest(n_trees=500,
min_leaf_size=10,
max_depth=10,
subsample_ratio=0.7,
lambda_reg=0.01,
model_T='auto',
model_Y=LassoCVWrapper(cv=3),
cv=2,
discrete_treatment=False,
categories='auto',
n_jobs=-1,
backend='threading',
verbose=0,
batch_size='auto',
random_state=None):
"""CausalForest for continuous treatments. To apply to discrete
treatments, first one-hot-encode your treatments and then pass the one-hot-encoding.
@ -36,7 +53,7 @@ class CausalForest(DMLOrthoForest):
`fit` and `predict` methods.
cv : int, cross-validation generator or an iterable, optional (default=2)
The specification of the cv splitter to be used for cross-fitting, when constructing
The specification of the CV splitter to be used for cross-fitting, when constructing
the global residuals of Y and T.
discrete_treatment : bool, optional (default=False)
@ -53,6 +70,9 @@ class CausalForest(DMLOrthoForest):
``-1`` means using all processors. Since OrthoForest methods are
computationally heavy, it is recommended to set `n_jobs` to -1.
backend : 'threading' or 'multiprocessing'
What backend should be used for parallelization with the joblib library.
random_state : int, :class:`~numpy.random.mtrand.RandomState` instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator;
@ -61,32 +81,19 @@ class CausalForest(DMLOrthoForest):
"""
def __init__(self,
n_trees=500,
min_leaf_size=10,
max_depth=10,
subsample_ratio=0.7,
lambda_reg=0.01,
model_T='auto',
model_Y=LassoCVWrapper(cv=3),
cv=2,
discrete_treatment=False,
categories='auto',
n_jobs=-1,
random_state=None):
super().__init__(n_trees=n_trees,
min_leaf_size=min_leaf_size,
max_depth=max_depth,
subsample_ratio=subsample_ratio,
bootstrap=False,
lambda_reg=lambda_reg,
model_T=model_T,
model_Y=model_Y,
model_T_final=None,
model_Y_final=None,
global_residualization=True,
global_res_cv=cv,
discrete_treatment=discrete_treatment,
categories=categories,
n_jobs=n_jobs,
random_state=random_state)
return CausalForestDML(
model_t=model_T,
model_y=model_Y,
n_crossfit_splits=cv,
discrete_treatment=discrete_treatment,
categories=categories,
n_estimators=n_trees,
criterion='het',
min_samples_leaf=min_leaf_size,
max_depth=max_depth,
max_samples=subsample_ratio / 2,
min_balancedness_tol=.3,
n_jobs=n_jobs,
verbose=verbose,
random_state=random_state
)

0
econml/data/__init__.py Normal file
Просмотреть файл

Просмотреть файл

@ -5,7 +5,7 @@
import numpy as np
import keras
from .cate_estimator import BaseCateEstimator
from ._cate_estimator import BaseCateEstimator
from .utilities import deprecated
from keras import backend as K
import keras.layers as L

Просмотреть файл

@ -1,99 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Data generating processes for correctness testing."""
import numpy as np
from itertools import product
from econml.utilities import cross_product
########################################################
# Perfect Data DGPs for Testing Correctness of Code
########################################################
def dgp_perfect_data_multiple_treatments(n_samples, n_cov, n_treatments, Alpha, beta, effect):
"""Generate data with carefully crafted controls and noise for perfect estimation."""
# Generate random control co-variates
X = np.random.randint(low=-2, high=2, size=(n_samples, n_cov))
# Create epsilon residual treatments that deterministically sum up to
# zero
epsilon = np.random.normal(size=(n_samples, n_treatments))
# Re-calibrate epsilon to make sure that empirical distribution of epsilon
# conditional on each co-variate vector is equal to zero
unique_X = np.unique(X, axis=0)
for u_row in unique_X:
# We simply subtract the conditional mean from the epsilons
epsilon[np.all(X == u_row, axis=1),
:] -= np.mean(epsilon[np.all(X == u_row, axis=1), :])
# Construct treatments as T = X*A + epsilon
T = np.dot(X, Alpha) + epsilon
# Construct outcomes as y = X*beta + T*effect
y = np.dot(X, beta) + np.dot(T, effect)
return y, T, X, epsilon
def dgp_perfect_data_multiple_treatments_and_features(n_samples, n_cov, feat_sizes, n_treatments, Alpha, beta, effect):
"""Generate data with carefully crafted controls and noise for perfect estimation."""
# Generate random control co-variates
X = np.random.randint(low=-2, high=2, size=(n_samples, n_cov))
X_f = [
c
for c in (np.arange(s - 1).reshape((1, s - 1)) ==
np.random.randint(s, size=(X.shape[0], 1))).astype(np.int)
for s in feat_sizes]
# Create epsilon residual treatments that deterministically sum up to
# zero
epsilon = np.random.normal(size=(n_samples, n_treatments))
# Re-calibrate epsilon to make sure that empirical distribution of epsilon
# conditional on each co-variate vector is equal to zero
unique_X = np.unique(X, axis=0)
for u_row in unique_X:
# We simply subtract the conditional mean from the epsilons
epsilon[np.all(X == u_row, axis=1),
:] -= np.mean(epsilon[np.all(X == u_row, axis=1), :])
# Construct treatments as T = X*A + epsilon
T = np.dot(X, Alpha) + epsilon
# Construct outcomes as y = X*beta + T*effect
y = np.dot(X, beta) + np.dot(T, effect)
return y, T, X, epsilon
def dgp_perfect_counterfactual_data_multiple_treatments(n_samples, n_cov, beta, effect, treatment_vector):
"""Generate data with carefully crafted controls and noise for perfect estimation."""
# Generate random control co-variates
X = np.random.randint(low=-2, high=2, size=(n_samples, n_cov))
# Construct treatments as T = X*A + epsilon
T = np.repeat(treatment_vector.reshape(1, -1), n_samples, axis=0)
# Construct outcomes as y = X*beta + T*effect
y = np.dot(X, beta) + np.dot(T, effect)
return y, T, X
def dgp_data_multiple_treatments(n_samples, n_cov, n_treatments, Alpha, beta, effect):
"""Generate data from a linear model using covariates drawn from gaussians and with gaussian noise."""
# Generate random control co-variates
X = np.random.normal(size=(n_samples, n_cov))
# Create epsilon residual treatments
epsilon = np.random.normal(size=(n_samples, n_treatments))
# Construct treatments as T = X*A + epsilon
T = np.dot(X, Alpha) + epsilon
# Construct outcomes as y = X*beta + T*effect + eta
y = np.dot(X, beta) + np.dot(T, effect) + np.random.normal(size=n_samples)
return y, T, X, epsilon
def dgp_counterfactual_data_multiple_treatments(n_samples, n_cov, beta, effect, treatment_vector):
"""Generate data with carefully crafted controls and noise for perfect estimation."""
# Generate random control co-variates
X = np.random.normal(size=(n_samples, n_cov))
# Use the same treatment vector for each row
T = np.repeat(treatment_vector.reshape(1, -1), n_samples, axis=0)
# Construct outcomes as y = X*beta + T*effect
y = np.dot(X, beta) + np.dot(T, effect) + np.random.normal(size=n_samples)
return y, T, X

11
econml/dml/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,11 @@
from .dml import (DML, LinearDML, SparseLinearDML,
KernelDML, NonParamDML, ForestDML)
from .causal_forest import CausalForestDML
__all__ = ["DML",
"LinearDML",
"SparseLinearDML",
"KernelDML",
"NonParamDML",
"ForestDML",
"CausalForestDML", ]

Просмотреть файл

@ -28,10 +28,10 @@ Chernozhukov et al. (2017). Double/debiased machine learning for treatment and s
import numpy as np
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, filter_none_kwargs, _deprecate_positional)
from ..utilities import (shape, reshape, ndim, hstack, filter_none_kwargs, _deprecate_positional)
from sklearn.linear_model import LinearRegression
from sklearn.base import clone
from ._ortho_learner import _OrthoLearner
from .._ortho_learner import _OrthoLearner
class _ModelNuisance:
@ -177,7 +177,7 @@ class _RLearner(_OrthoLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -205,7 +205,7 @@ class _RLearner(_OrthoLearner):
import numpy as np
from sklearn.linear_model import LinearRegression
from econml._rlearner import _RLearner
from econml.dml._rlearner import _RLearner
from sklearn.base import clone
class ModelFirst:
def __init__(self, model):

512
econml/dml/causal_forest.py Normal file
Просмотреть файл

@ -0,0 +1,512 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
from .dml import _BaseDML
from .dml import _FirstStageWrapper, _FinalWrapper
from ..sklearn_extensions.linear_model import WeightedLassoCVWrapper
from ..sklearn_extensions.model_selection import WeightedStratifiedKFold
from ..inference import Inference, NormalInferenceResults
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import clone, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from ..utilities import add_intercept, shape, check_inputs, _deprecate_positional
from ..grf import CausalForest, MultiOutputGRF
from .._cate_estimator import LinearCateEstimator
from .._shap import _shap_explain_multitask_model_cate
class _CausalForestFinalWrapper(_FinalWrapper):
def _combine(self, X, fitting=True):
if X is not None:
if self._featurizer is not None:
F = self._featurizer.fit_transform(X) if fitting else self._featurizer.transform(X)
else:
F = X
else:
raise AttributeError("Cannot use this method with X=None. Consider "
"using the LinearDML estimator.")
return F
def fit(self, X, T_res, Y_res, sample_weight=None, sample_var=None):
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
self._d_t = shape(T_res)[1:]
self._d_y = shape(Y_res)[1:]
fts = self._combine(X)
if sample_var is not None:
raise ValueError("This estimator does not support sample_var!")
if T_res.ndim == 1:
T_res = T_res.reshape((-1, 1))
if Y_res.ndim == 1:
Y_res = Y_res.reshape((-1, 1))
self._model.fit(fts, T_res, Y_res, sample_weight=sample_weight)
return self
def predict(self, X):
return self._model.predict(self._combine(X, fitting=False)).reshape((-1,) + self._d_y + self._d_t)
class _GenericSingleOutcomeModelFinalWithCovInference(Inference):
def prefit(self, estimator, *args, **kwargs):
self.model_final = estimator.model_final
self.featurizer = estimator.featurizer if hasattr(estimator, 'featurizer') else None
def fit(self, estimator, *args, **kwargs):
# once the estimator has been fit, it's kosher to store d_t here
# (which needs to have been expanded if there's a discrete treatment)
self._est = estimator
self._d_t = estimator._d_t
self._d_y = estimator._d_y
self.d_t = self._d_t[0] if self._d_t else 1
self.d_y = self._d_y[0] if self._d_y else 1
def const_marginal_effect_interval(self, X, *, alpha=0.1):
return self.const_marginal_effect_inference(X).conf_int(alpha=alpha)
def const_marginal_effect_inference(self, X):
if X is None:
raise ValueError("This inference method currently does not support X=None!")
if self.featurizer is not None:
X = self.featurizer.transform(X)
pred, pred_var = self.model_final.predict_and_var(X)
pred = pred.reshape((-1,) + self._d_y + self._d_t)
pred_stderr = np.sqrt(np.diagonal(pred_var, axis1=2, axis2=3).reshape((-1,) + self._d_y + self._d_t))
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect')
def effect_interval(self, X, *, T0, T1, alpha=0.1):
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)
def effect_inference(self, X, *, T0, T1):
if X is None:
raise ValueError("This inference method currently does not support X=None!")
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
if self.featurizer is not None:
X = self.featurizer.transform(X)
dT = T1 - T0
if dT.ndim == 1:
dT = dT.reshape((-1, 1))
pred, pred_var = self.model_final.predict_projection_and_var(X, dT)
pred = pred.reshape((-1,) + self._d_y)
pred_stderr = np.sqrt(pred_var.reshape((-1,) + self._d_y))
return NormalInferenceResults(d_t=1, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect')
class CausalForestDML(_BaseDML):
"""A Causal Forest [1]_ combined with double machine learning based residualization of the treatment
and outcome variables. It fits a forest that solves the local moment equation problem:
.. code-block::
E[ (Y - E[Y|X, W] - <theta(x), T - E[T|X, W]> - beta(x)) (T;1) | X=x] = 0
where E[Y|X, W] and E[T|X, W] are fitted in a first stage in a cross-fitting manner.
Parameters
----------
model_y: estimator or 'auto', optional (default is 'auto')
The estimator for fitting the response to the features. Must implement
`fit` and `predict` methods.
If 'auto' :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV` will be chosen.
model_t: estimator or 'auto', optional (default is 'auto')
The estimator for fitting the treatment to the features.
If estimator, it must implement `fit` and `predict` methods;
If 'auto', :class:`~sklearn.linear_model.LogisticRegressionCV` will be applied for discrete treatment,
and :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV`
will be applied for continuous treatment.
featurizer : :term:`transformer`, optional, default None
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
If featurizer=None, then CATE is trained on X.
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_crossfit_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
:class:`~sklearn.model_selection.StratifiedKFold` is used, else,
:class:`~sklearn.model_selection.KFold` is used
(with a random shuffle in either case).
Unless an iterable is used, we call `split(X,T)` to generate the splits.
n_estimators : int, default=100
Number of trees
criterion : {``"mse"``, ``"het"``}, default="mse"
The function to measure the quality of a split. Supported criteria
are ``"mse"`` for the mean squared error in a linear moment estimation tree and ``"het"`` for
heterogeneity score.
- The ``"mse"`` criterion finds splits that minimize the score
.. code-block::
sum_{child} E[(Y - <theta(child), T> - beta(child))^2 | X=child] weight(child)
Internally, for the case of more than two treatments or for the case of one treatment with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by:
.. code-block::
sum_{child} weight(child) * rho(child).T @ E[(T;1) @ (T;1).T | X in child] @ rho(child)
where:
.. code-block::
rho(child) := E[(T;1) @ (T;1).T | X in parent]^{-1}
* E[(Y - <theta(x), T> - beta(x)) (T;1) | X in child]
This can be thought as a heterogeneity inducing score, but putting more weight on scores
with a large minimum eigenvalue of the child jacobian ``E[(T;1) @ (T;1).T | X in child]``,
which leads to smaller variance of the estimate and stronger identification of the parameters.
- The "het" criterion finds splits that maximize the pure parameter heterogeneity score
.. code-block::
sum_{child} weight(child) * rho(child)[:n_T].T @ rho(child)[:n_T]
This can be thought as an approximation to the ideal heterogeneity score:
.. code-block::
weight(left) * weight(right) || theta(left) - theta(right)||_2^2 / weight(parent)^2
as outlined in [1]_
max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
min_samples_split samples.
min_samples_split : int or float, default=10
The minimum number of samples required to split an internal node:
- If int, then consider `min_samples_split` as the minimum number.
- If float, then `min_samples_split` is a fraction and `ceil(min_samples_split * n_samples)` are the minimum
number of samples for each split.
min_samples_leaf : int or float, default=5
The minimum number of samples required to be at a leaf node.
A split point at any depth will only be considered if it leaves at
least ``min_samples_leaf`` training samples in each of the left and
right branches. This may have the effect of smoothing the model,
especially in regression.
- If int, then consider `min_samples_leaf` as the minimum number.
- If float, then `min_samples_leaf` is a fraction and `ceil(min_samples_leaf * n_samples)` are the minimum
number of samples for each node.
min_weight_fraction_leaf : float, default=0.0
The minimum weighted fraction of the sum total of weights (of all
the input samples) required to be at a leaf node. Samples have
equal weight when sample_weight is not provided.
min_var_fraction_leaf : None or float in (0, 1], default=None
A constraint on some proxy of the variation of the treatment vector that should be contained within each
leaf as a percentage of the total variance of the treatment vector on the whole sample. This avoids
performing splits where either the variance of the treatment is small and hence the local parameter
is not well identified and has high variance. The proxy of variance is different for different criterion,
primarily for computational efficiency reasons. If ``criterion='het'``, then this constraint translates to::
for all i in {1, ..., T.shape[1]}:
Var(T[i] | X in leaf) > `min_var_fraction_leaf` * Var(T[i])
If ``criterion='mse'``, because the criterion stores more information about the leaf for
every candidate split, then this constraint imposes further constraints on the pairwise correlations
of different coordinates of each treatment, i.e.::
for all i neq j:
sqrt( Var(T[i]|X in leaf) * Var(T[j]|X in leaf)
* ( 1 - rho(T[i], T[j]| in leaf)^2 ) )
> `min_var_fraction_leaf` sqrt( Var(T[i]) * Var(T[j]) * (1 - rho(T[i], T[j])^2 ) )
where rho(X, Y) is the Pearson correlation coefficient of two random variables X, Y. Thus this
constraint also enforces that no two pairs of treatments be very co-linear within a leaf. This
extra constraint primarily has bite in the case of more than two input treatments and also avoids
leafs where the parameter estimate has large variance due to local co-linearities of the treatments.
min_var_leaf_on_val : bool, default=False
Whether the `min_var_fraction_leaf` constraint should also be enforced to hold on the validation set of the
honest split too. If ``min_var_leaf=None`` then this flag does nothing. Setting this to True should
be done with caution, as this partially violates the honesty structure, since the treatment variable
of the validation set is used to inform the split structure of the tree. However, this is a benign
dependence as it only uses local correlation structure of the treatment T to decide whether
a split is feasible.
max_features : int, float or {"auto", "sqrt", "log2"}, default=None
The number of features to consider when looking for the best split:
- If int, then consider `max_features` features at each split.
- If float, then `max_features` is a fraction and `int(max_features * n_features)` features
are considered at each split.
- If "auto", then `max_features=n_features`.
- If "sqrt", then `max_features=sqrt(n_features)`.
- If "log2", then `max_features=log2(n_features)`.
- If None, then `max_features=n_features`.
Note: the search for a split does not stop until at least one
valid partition of the node samples is found, even if it requires to
effectively inspect more than ``max_features`` features.
min_impurity_decrease : float, default=0.0
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.
The weighted impurity decrease equation is the following::
N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)
where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.
max_samples : int or float in (0, 1], default=.45,
The number of samples to use for each subsample that is used to train each tree:
- If int, then train each tree on `max_samples` samples, sampled without replacement from all the samples
- If float, then train each tree on `ceil(`max_samples` * `n_samples`)`, sampled without replacement
from all the samples.
If ``inference=True``, then `max_samples` must either be an integer smaller than `n_samples//2` or a float
less than or equal to .5.
min_balancedness_tol: float in [0, .5], default=.45
How imbalanced a split we can tolerate. This enforces that each split leaves at least
(.5 - min_balancedness_tol) fraction of samples on each side of the split; or fraction
of the total weight of samples, when sample_weight is not None. Default value, ensures
that at least 5% of the parent node weight falls in each side of the split. Set it to 0.0 for no
balancedness and to .5 for perfectly balanced splits. For the formal inference theory
to be valid, this has to be any positive constant bounded away from zero.
honest : bool, default=True
Whether each tree should be trained in an honest manner, i.e. the training set is split into two equal
sized subsets, the train and the val set. All samples in train are used to create the split structure
and all samples in val are used to calculate the value of each node in the tree.
inference : bool, default=True
Whether inference (i.e. confidence interval construction and uncertainty quantification of the estimates)
should be enabled. If ``inference=True``, then the estimator uses a bootstrap-of-little-bags approach
to calculate the covariance of the parameter vector, with am objective Bayesian debiasing correction
to ensure that variance quantities are positive.
fit_intercept : bool, default=True
Whether we should fit an intercept nuisance parameter beta(x).
subforest_size : int, default=4,
The number of trees in each sub-forest that is used in the bootstrap-of-little-bags calculation.
The parameter `n_estimators` must be divisible by `subforest_size`. Should typically be a small constant.
n_jobs : int or None, default=-1
The number of parallel jobs to be used for parallelism; follows joblib semantics.
`n_jobs=-1` means all available cpu cores. `n_jobs=None` means no parallelism.
random_state : int, RandomState instance or None, default=None
Controls the randomness of the estimator. The features are always
randomly permuted at each split. When ``max_features < n_features``, the algorithm will
select ``max_features`` at random at each split before finding the best
split among them. But the best found split may vary across different
runs, even if ``max_features=n_features``. That is the case, if the
improvement of the criterion is identical for several splits and one
split has to be selected at random. To obtain a deterministic behaviour
during fitting, ``random_state`` has to be fixed to an integer.
verbose : int, default=0
Controls the verbosity when fitting and predicting.
warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.
Attributes
----------
feature_importances_ : ndarray of shape (n_features,)
The feature importances based on the amount of parameter heterogeneity they create.
The higher, the more important the feature.
The importance of a feature is computed as the (normalized) total heterogeneity that the feature
creates. Each split that the feature was chosen adds::
parent_weight * (left_weight * right_weight)
* mean((value_left[k] - value_right[k])**2) / parent_weight**2
to the importance of the feature. Each such quantity is also weighted by the depth of the split.
By default splits below `max_depth=4` are not used in this calculation and also each split
at depth `depth`, is re-weighted by 1 / (1 + `depth`)**2.0. See the method ``feature_importances``
for a method that allows one to change these defaults.
References
----------
.. [1] Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests."
The Annals of Statistics 47.2 (2019): 1148-1178
https://arxiv.org/pdf/1610.01271.pdf
"""
def __init__(self, *,
model_y='auto',
model_t='auto',
featurizer=None,
discrete_treatment=False,
categories='auto',
n_crossfit_splits=2,
n_estimators=100,
criterion="mse",
max_depth=None,
min_samples_split=10,
min_samples_leaf=5,
min_weight_fraction_leaf=0.,
min_var_fraction_leaf=None,
min_var_leaf_on_val=True,
max_features="auto",
min_impurity_decrease=0.,
max_samples=.45,
min_balancedness_tol=.45,
honest=True,
inference=True,
fit_intercept=True,
subforest_size=4,
n_jobs=-1,
random_state=None,
verbose=0,
warm_start=False):
# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
if model_y == 'auto':
model_y = WeightedLassoCVWrapper(random_state=random_state)
if model_t == 'auto':
if discrete_treatment:
model_t = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=random_state),
random_state=random_state)
else:
model_t = WeightedLassoCVWrapper(random_state=random_state)
self.bias_part_of_coef = False
self.fit_cate_intercept = False
model_final = MultiOutputGRF(CausalForest(n_estimators=n_estimators,
criterion=criterion,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
min_var_fraction_leaf=min_var_fraction_leaf,
min_var_leaf_on_val=min_var_leaf_on_val,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
max_samples=max_samples,
min_balancedness_tol=min_balancedness_tol,
honest=honest,
inference=inference,
fit_intercept=fit_intercept,
subforest_size=subforest_size,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start))
super().__init__(model_y=_FirstStageWrapper(model_y, True,
featurizer, False, discrete_treatment),
model_t=_FirstStageWrapper(model_t, False,
featurizer, False, discrete_treatment),
model_final=_CausalForestFinalWrapper(model_final, False, featurizer, False),
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_crossfit_splits,
random_state=random_state)
def _get_inference_options(self):
options = super()._get_inference_options()
options.update(blb=_GenericSingleOutcomeModelFinalWithCovInference)
options.update(auto=_GenericSingleOutcomeModelFinalWithCovInference)
return options
# override only so that we can update the docstring to indicate support for `blb`
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), τ(·,·).
Parameters
----------
Y: (n × d_y) matrix or vector of length n
Outcomes for each sample
T: (n × dₜ) matrix or vector of length n
Treatments for each sample
X: (n × dₓ) matrix
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`), 'blb' or 'auto'
Returns
-------
self
"""
if sample_var is not None:
raise ValueError("This estimator does not support sample_var!")
if X is None:
raise ValueError("This estimator does not support X=None!")
Y, T, X, W = check_inputs(Y, T, X, W=W, multi_output_T=True, multi_output_Y=True)
return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, groups=groups,
inference=inference)
def feature_importances(self, max_depth=4, depth_decay_exponent=2.0):
imps = self.model_final.feature_importances(max_depth=max_depth, depth_decay_exponent=depth_decay_exponent)
return imps.reshape(self._d_y + (-1,))
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
if self.featurizer is not None:
F = self.featurizer.transform(X)
else:
F = X
feature_names = self.cate_feature_names(feature_names)
return _shap_explain_multitask_model_cate(self.const_marginal_effect, self.model_cate.estimators_, F,
self._d_t, self._d_y, feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
@property
def feature_importances_(self):
return self.feature_importances()
def __len__(self):
"""Return the number of estimators in the ensemble."""
return self.model_cate.__len__()
def __getitem__(self, index):
"""Return the index'th estimator in the ensemble."""
return self.model_cate.__getitem__(index)
def __iter__(self):
"""Return iterator over estimators in the ensemble."""
return self.model_cate.__iter__()

Просмотреть файл

@ -49,23 +49,23 @@ from sklearn.utils import check_random_state
import copy
from ._rlearner import _RLearner
from .cate_estimator import (DebiasedLassoCateEstimatorMixin,
ForestModelFinalCateEstimatorMixin,
LinearModelFinalCateEstimatorMixin,
StatsModelsCateEstimatorMixin,
LinearCateEstimator)
from .inference import StatsModelsInference
from .sklearn_extensions.ensemble import SubsampledHonestForest
from .sklearn_extensions.linear_model import (MultiOutputDebiasedLasso,
StatsModelsLinearRegression,
WeightedLassoCVWrapper)
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
from .utilities import (_deprecate_positional, add_intercept,
broadcast_unit_treatments, check_high_dimensional,
cross_product, deprecated, fit_with_groups,
hstack, inverse_onehot, ndim, reshape,
reshape_treatmentwise_effects, shape, transpose)
from .shap import _shap_explain_model_cate
from .._cate_estimator import (DebiasedLassoCateEstimatorMixin,
ForestModelFinalCateEstimatorMixin,
LinearModelFinalCateEstimatorMixin,
StatsModelsCateEstimatorMixin,
LinearCateEstimator)
from ..inference import StatsModelsInference
from ..sklearn_extensions.ensemble import SubsampledHonestForest
from ..sklearn_extensions.linear_model import (MultiOutputDebiasedLasso,
StatsModelsLinearRegression,
WeightedLassoCVWrapper)
from ..sklearn_extensions.model_selection import WeightedStratifiedKFold
from ..utilities import (_deprecate_positional, add_intercept,
broadcast_unit_treatments, check_high_dimensional,
cross_product, deprecated, fit_with_groups,
hstack, inverse_onehot, ndim, reshape,
reshape_treatmentwise_effects, shape, transpose)
from .._shap import _shap_explain_model_cate
class _FirstStageWrapper:
@ -105,6 +105,7 @@ class _FirstStageWrapper:
sample_weight=sample_weight)
else:
fit_with_groups(self._model, self._combine(X, W, Target.shape[0]), Target, groups=groups)
return self
def predict(self, X, W):
n_samples = X.shape[0] if X is not None else (W.shape[0] if W is not None else 1)
@ -213,6 +214,7 @@ class _FinalWrapper:
self._model.fit(F, target, sample_weight=T_res.flatten()**2)
else:
raise AttributeError("This combination is not a feasible one!")
return self
def predict(self, X):
X2, T = broadcast_unit_treatments(X if X is not None else np.empty((1, 0)),
@ -418,7 +420,7 @@ class DML(LinearModelFinalCateEstimatorMixin, _BaseDML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -543,7 +545,7 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -685,7 +687,7 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -836,7 +838,7 @@ class KernelDML(DML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -903,7 +905,7 @@ class NonParamDML(_BaseDML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -942,7 +944,7 @@ class NonParamDML(_BaseDML):
n_splits=n_splits,
random_state=random_state)
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
if self.featurizer is not None:
F = self.featurizer.transform(X)
else:
@ -951,11 +953,33 @@ class NonParamDML(_BaseDML):
return _shap_explain_model_cate(self.const_marginal_effect, self.model_cate, F, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
class ForestDML(ForestModelFinalCateEstimatorMixin, NonParamDML):
@deprecated("The ForestDML class has been deprecated by the CausalForestDML with parameter "
"`criterion='mse'`; an upcoming release will remove support for the old class")
def ForestDML(model_y, model_t,
discrete_treatment=False,
categories='auto',
n_crossfit_splits=2,
n_estimators=100,
criterion="mse",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_decrease=0.,
subsample_fr='auto',
honest=True,
n_jobs=None,
verbose=0,
random_state=None):
""" Instance of NonParamDML with a
:class:`~econml.sklearn_extensions.ensemble.SubsampledHonestForest`
as a final model, so as to enable non-parametric inference.
@ -983,7 +1007,7 @@ class ForestDML(ForestModelFinalCateEstimatorMixin, NonParamDML):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -1107,126 +1131,22 @@ class ForestDML(ForestModelFinalCateEstimatorMixin, NonParamDML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
"""
def __init__(self,
model_y, model_t,
discrete_treatment=False,
categories='auto',
n_crossfit_splits=2,
n_estimators=100,
criterion="mse",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_decrease=0.,
subsample_fr='auto',
honest=True,
n_jobs=None,
verbose=0,
random_state=None):
model_final = SubsampledHonestForest(n_estimators=n_estimators,
criterion=criterion,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
subsample_fr=subsample_fr,
honest=honest,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
super().__init__(model_y=model_y, model_t=model_t,
model_final=model_final, featurizer=None,
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_crossfit_splits, random_state=random_state)
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), τ(·,·).
Parameters
----------
Y: (n × d_y) matrix or vector of length n
Outcomes for each sample
T: (n × dₜ) matrix or vector of length n
Treatments for each sample
X: optional (n × dₓ) matrix
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
sample_var: optional (n, n_y) vector
Variance of sample, in case it corresponds to summary of many samples. Currently
not in use by this method (as inference method does not require sample variance info).
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the n_splits argument passed to this class's initializer
must support a 'groups' argument to its split method.
inference: string, `Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'blb'
(for Bootstrap-of-Little-Bags based inference)
Returns
-------
self
"""
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=None, groups=groups,
inference=inference)
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
# SubsampleHonestForest can't be recognized by SHAP, but the tree entries are consistent with a tree in
# a RandomForestRegressor, modify the class name in order to be identified as tree models.
model = copy.deepcopy(self.model_cate)
model.__class__ = RandomForestRegressor
return _shap_explain_model_cate(self.const_marginal_effect, model, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
@deprecated("The DMLCateEstimator class has been renamed to DML; "
"an upcoming release will remove support for the old name")
class DMLCateEstimator(DML):
pass
@deprecated("The LinearDMLCateEstimator class has been renamed to LinearDML; "
"an upcoming release will remove support for the old name")
class LinearDMLCateEstimator(LinearDML):
pass
@deprecated("The SparseLinearDMLCateEstimator class has been renamed to SparseLinearDML; "
"an upcoming release will remove support for the old name")
class SparseLinearDMLCateEstimator(SparseLinearDML):
pass
@deprecated("The KernelDMLCateEstimator class has been renamed to KernelDML; "
"an upcoming release will remove support for the old name")
class KernelDMLCateEstimator(KernelDML):
pass
@deprecated("The NonParamDMLCateEstimator class has been renamed to NonParamDML; "
"an upcoming release will remove support for the old name")
class NonParamDMLCateEstimator(NonParamDML):
pass
@deprecated("The ForestDMLCateEstimator class has been renamed to ForestDML; "
"an upcoming release will remove support for the old name")
class ForestDMLCateEstimator(ForestDML):
pass
from . import CausalForestDML
return CausalForestDML(model_y=model_y,
model_t=model_t,
discrete_treatment=discrete_treatment,
categories=categories,
n_crossfit_splits=n_crossfit_splits,
n_estimators=n_estimators,
criterion="mse",
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
max_samples=.45 if subsample_fr == 'auto' else subsample_fr / 2,
honest=honest,
n_jobs=n_jobs,
verbose=verbose,
random_state=random_state)

Просмотреть файл

@ -25,6 +25,13 @@ Tsiatis AA (2006).
Semiparametric Theory and Missing Data.
New York: Springer; 2006.
.. testcode::
:hide:
import numpy as np
import scipy.special
np.set_printoptions(suppress=True)
"""
from warnings import warn
@ -37,16 +44,16 @@ from sklearn.linear_model import (LassoCV, LinearRegression,
from sklearn.ensemble import RandomForestRegressor
from ._ortho_learner import _OrthoLearner
from .cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin,
ForestModelFinalCateEstimatorDiscreteMixin,
StatsModelsCateEstimatorDiscreteMixin, LinearCateEstimator)
from ._cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin,
ForestModelFinalCateEstimatorDiscreteMixin,
StatsModelsCateEstimatorDiscreteMixin, LinearCateEstimator)
from .inference import GenericModelFinalInferenceDiscrete
from .sklearn_extensions.ensemble import SubsampledHonestForest
from .grf import RegressionForest
from .sklearn_extensions.linear_model import (
DebiasedLasso, StatsModelsLinearRegression, WeightedLassoCVWrapper)
from .utilities import (_deprecate_positional, check_high_dimensional,
filter_none_kwargs, fit_with_groups, inverse_onehot)
from .shap import _shap_explain_multitask_model_cate, _shap_explain_model_cate
from ._shap import _shap_explain_multitask_model_cate, _shap_explain_model_cate
class _ModelNuisance:
@ -143,7 +150,7 @@ class _ModelFinal:
return pred[:, np.newaxis, :]
return pred
else:
preds = np.array([mdl.predict(X) for mdl in self.models_cate])
preds = np.array([mdl.predict(X).reshape((-1,) + self.d_y) for mdl in self.models_cate])
return np.moveaxis(preds, 0, -1) # move treatment dim to end
def score(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, sample_var=None):
@ -250,7 +257,7 @@ class DRLearner(_OrthoLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -273,8 +280,6 @@ class DRLearner(_OrthoLearner):
.. testcode::
import numpy as np
import scipy.special
from econml.drlearner import DRLearner
np.random.seed(123)
@ -301,22 +306,20 @@ class DRLearner(_OrthoLearner):
>>> est.cate_feature_names()
<BLANKLINE>
>>> [mdl.coef_ for mdl in est.models_regression]
[array([ 1.472104...e+00, 1.984419...e-03, -1.103451...e-02, 6.984376...e-01,
2.049695...e+00]), array([ 1.455654..., -0.002110..., 0.005488..., 0.677090..., 1.998648...])]
[array([ 1.472..., 0.001..., -0.011..., 0.698..., 2.049...]),
array([ 1.455..., -0.002..., 0.005..., 0.677..., 1.998...])]
>>> [mdl.coef_ for mdl in est.models_propensity]
[array([[-0.747137..., 0.153419..., -0.018412...],
[ 0.083807..., -0.110360..., -0.076003...],
[ 0.663330..., -0.043058... , 0.094416...]]),
array([[-1.048348...e+00, 2.248997...e-04, 3.228087...e-02],
[ 1.911900...e-02, 1.241337...e-01, -8.196211...e-02],
[ 1.029229...e+00, -1.243586...e-01, 4.968123...e-02]])]
[array([[-0.747..., 0.153..., -0.018...],
[ 0.083..., -0.110..., -0.076...],
[ 0.663..., -0.043... , 0.094...]]),
array([[-1.048..., 0.000..., 0.032...],
[ 0.019..., 0.124..., -0.081...],
[ 1.029..., -0.124..., 0.049...]])]
Beyond default models:
.. testcode::
import scipy.special
import numpy as np
from sklearn.linear_model import LassoCV
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from econml.drlearner import DRLearner
@ -567,7 +570,7 @@ class DRLearner(_OrthoLearner):
else:
raise AttributeError("Featurizer does not have a method: get_feature_names!")
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
if self.featurizer is not None:
F = self.featurizer.transform(X)
else:
@ -576,12 +579,20 @@ class DRLearner(_OrthoLearner):
if self._multitask_model_final:
return _shap_explain_multitask_model_cate(self.const_marginal_effect, self.multitask_model_cate, F,
self._d_t, self._d_y, feature_names,
treatment_names, output_names)
self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
else:
return _shap_explain_model_cate(self.const_marginal_effect, super().model_final.models_cate,
F, self._d_t, self._d_y, feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
F, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
@ -652,7 +663,7 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -678,6 +689,7 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
import scipy.special
from econml.drlearner import DRLearner, LinearDRLearner
np.set_printoptions(suppress=True)
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
T = np.random.binomial(2, scipy.special.expit(X[:, 0]))
@ -863,7 +875,7 @@ class SparseLinearDRLearner(DebiasedLassoCateEstimatorDiscreteMixin, DRLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -889,6 +901,7 @@ class SparseLinearDRLearner(DebiasedLassoCateEstimatorDiscreteMixin, DRLearner):
import scipy.special
from econml.drlearner import DRLearner, SparseLinearDRLearner
np.set_printoptions(suppress=True)
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
T = np.random.binomial(2, scipy.special.expit(X[:, 0]))
@ -1041,7 +1054,7 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -1172,33 +1185,48 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
categories='auto',
n_crossfit_splits=2,
n_estimators=1000,
criterion="mse",
criterion='deprecated',
max_depth=None,
min_samples_split=5,
min_samples_leaf=5,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
max_leaf_nodes='deprecated',
min_impurity_decrease=0.,
subsample_fr='auto',
subsample_fr='deprecated',
max_samples=.45,
min_balancedness_tol=.45,
honest=True,
n_jobs=None,
subforest_size=4,
n_jobs=-1,
verbose=0,
random_state=None):
model_final = SubsampledHonestForest(n_estimators=n_estimators,
criterion=criterion,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
subsample_fr=subsample_fr,
honest=honest,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
if criterion != 'deprecated':
warn("The parameter 'criterion' has been deprecated and will be removed in the next version. "
"Only the 'mse' criterion is supported.")
if max_leaf_nodes != 'deprecated':
warn("The parameter 'max_leaf_nodes' has been deprecated and will be removed in the next version.")
if subsample_fr != 'deprecated':
warn("The parameter 'subsample_fr' has been deprecated and will be removed in the next version. "
"Use 'max_samples' instead, with the convention that "
"'subsample_fr=x' is equivalent to 'max_samples=x/2'.")
max_samples = .45 if subsample_fr == 'auto' else subsample_fr / 2
model_final = RegressionForest(n_estimators=n_estimators,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
max_samples=max_samples,
min_balancedness_tol=min_balancedness_tol,
honest=honest,
inference=True,
subforest_size=subforest_size,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=False)
super().__init__(model_regression=model_regression, model_propensity=model_propensity,
model_final=model_final, featurizer=None,
multitask_model_final=False,
@ -1255,16 +1283,3 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
@property
def fitted_models_final(self):
return super().model_final.models_cate
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
models = []
for fitted_model in self.fitted_models_final:
# SubsampleHonestForest can't be recognized by SHAP, but the tree entries are consistent with a tree in
# a RandomForestRegressor, modify the class name in order to be identified as tree models.
model = deepcopy(fitted_model)
model.__class__ = RandomForestRegressor
models.append(model)
return _shap_explain_model_cate(self.const_marginal_effect, models, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__

12
econml/grf/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from ._criterion import LinearMomentGRFCriterion, LinearMomentGRFCriterionMSE
from .classes import CausalForest, CausalIVForest, RegressionForest, MultiOutputGRF
__all__ = ["CausalForest",
"CausalIVForest",
"RegressionForest",
"MultiOutputGRF",
"LinearMomentGRFCriterion",
"LinearMomentGRFCriterionMSE"]

1101
econml/grf/_base_grf.py Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

777
econml/grf/_base_grftree.py Normal file
Просмотреть файл

@ -0,0 +1,777 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code contains snippets of code from:
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_classes.py
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
import numpy as np
import numbers
from math import ceil
from ..tree import Tree
from ._criterion import LinearMomentGRFCriterionMSE, LinearMomentGRFCriterion
from ..tree._criterion import Criterion
from ..tree._splitter import Splitter, BestSplitter
from ..tree import DepthFirstTreeBuilder
from . import _criterion
from ..tree import _tree
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.utils.validation import _check_sample_weight
from sklearn.utils.validation import check_is_fitted
import copy
# =============================================================================
# Types and constants
# =============================================================================
DTYPE = _tree.DTYPE
DOUBLE = _tree.DOUBLE
CRITERIA_GRF = {"het": LinearMomentGRFCriterion,
"mse": LinearMomentGRFCriterionMSE}
SPLITTERS = {"best": BestSplitter, }
MAX_INT = np.iinfo(np.int32).max
# =============================================================================
# Base GRF tree
# =============================================================================
class GRFTree(BaseEstimator):
"""A tree of a Generalized Random Forest [grftree1]. This method should be used primarily
through the BaseGRF forest class and its derivatives and not as a standalone
estimator. It fits a tree that solves the local moment equation problem::
E[ m(Z; theta(x)) | X=x] = 0
For some moment vector function m, that takes as input random samples of a random variable Z
and is parameterized by some unknown parameter theta(x). Each node in the tree
contains a local estimate of the parameter theta(x), for every region of X that
falls within that leaf.
Parameters
----------
criterion : {``'mse'``, ``'het'``}, default='mse'
The function to measure the quality of a split. Supported criteria
are ``'mse'`` for the mean squared error in a linear moment estimation tree and ``'het'`` for
heterogeneity score. These criteria solve any linear moment problem of the form::
E[J * theta(x) - A | X = x] = 0
- The ``'mse'`` criterion finds splits that maximize the score:
.. code-block::
sum_{child} weight(child) * theta(child).T @ E[J | X in child] @ theta(child)
- In the case of a causal tree, this coincides with minimizing the MSE:
.. code-block::
sum_{child} E[(Y - <theta(child), T>)^2 | X=child] weight(child)
- In the case of an IV tree, this roughly coincides with minimize the projected MSE::
.. code-block::
sum_{child} E[(Y - <theta(child), E[T|Z]>)^2 | X=child] weight(child)
Internally, for the case of more than two treatments or for the case of one treatment with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by::
sum_{child} weight(child) * rho(child).T @ E[J | X in child] @ rho(child)
where:
.. code-block::
rho(child) := J(parent)^{-1} E[A - J * theta(parent) | X in child]
This can be thought as a heterogeneity inducing score, but putting more weight on scores
with a large minimum eigenvalue of the child jacobian ``E[J | X in child]``, which leads to smaller
variance of the estimate and stronger identification of the parameters.
- The ``'het'`` criterion finds splits that maximize the pure parameter heterogeneity score:
.. code-block::
sum_{child} weight(child) * rho(child).T @ rho(child)
This can be thought as an approximation to the ideal heterogeneity score:
.. code-block::
weight(left) * weight(right) || theta(left) - theta(right)||_2^2 / weight(parent)^2
as outlined in [grftree1]_
splitter : {"best"}, default="best"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split.
max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
min_samples_split samples.
min_samples_split : int or float, default=10
The minimum number of samples required to split an internal node:
- If int, then consider `min_samples_split` as the minimum number.
- If float, then `min_samples_split` is a fraction and
`ceil(min_samples_split * n_samples)` are the minimum
number of samples for each split.
min_samples_leaf : int or float, default=5
The minimum number of samples required to be at a leaf node.
A split point at any depth will only be considered if it leaves at
least ``min_samples_leaf`` training samples in each of the left and
right branches. This may have the effect of smoothing the model,
especially in regression.
- If int, then consider `min_samples_leaf` as the minimum number.
- If float, then `min_samples_leaf` is a fraction and
`ceil(min_samples_leaf * n_samples)` are the minimum
number of samples for each node.
min_weight_fraction_leaf : float, default=0.0
The minimum weighted fraction of the sum total of weights (of all
the input samples) required to be at a leaf node. Samples have
equal weight when sample_weight is not provided.
min_var_leaf : None or double in (0, infinity), default=None
A constraint on the minimum degree of identification of the parameter of interest. This avoids performing
splits where either the variance of the treatment is small or the correlation of the instrument with the
treatment is small, or the variance of the instrument is small. Generically for any linear moment problem
this translates to conditions on the leaf jacobian matrix J(leaf) that are proxies for a well-conditioned
matrix, which leads to smaller variance of the local estimate. The proxy of the well-conditioning is
different for different criterion, primarily for computational efficiency reasons.
- If ``criterion='het'``, then the diagonal entries of J(leaf) are constraint to have absolute
value at least `min_var_leaf`:
.. code-block::
for all i in {1, ..., n_outputs}: abs(J(leaf)[i, i]) > `min_var_leaf`
In the context of a causal tree, when residual treatment is passed
at fit time, then, this translates to a requirement on Var(T[i]) for every treatment coordinate i.
In the context of an IV tree, with residual instruments and residual treatments passed at fit time
this translates to ``Cov(T[i], Z[i]) > min_var_leaf`` for each coordinate i of the instrument and the
treatment.
- If ``criterion='mse'``, because the criterion stores more information about the leaf jacobian for
every candidate split, then we impose further constraints on the pairwise determininants of the
leaf jacobian, as they come at small extra computational cost, i.e.::
for all i neq j:
sqrt(abs(J(leaf)[i, i] * J(leaf)[j, j] - J(leaf)[i, j] * J(leaf)[j, i])) > `min_var_leaf`
In the context of a causal tree, when residual treatment is passed at fit time, then this
translates to a constraint on the pearson correlation coefficient on any two coordinates
of the treatment within the leaf, i.e.::
for all i neq j:
sqrt( Var(T[i]) * Var(T[j]) * (1 - rho(T[i], T[j])^2) ) ) > `min_var_leaf`
where rho(X, Y) is the Pearson correlation coefficient of two random variables X, Y. Thus this
constraint also enforces that no two pairs of treatments be very co-linear within a leaf. This
extra constraint primarily has bite in the case of more than two input treatments.
min_var_leaf_on_val : bool, default=False
Whether the `min_var_leaf` constraint should also be enforced to hold on the validation set of the
honest split too. If `min_var_leaf=None` then this flag does nothing. Setting this to True should
be done with caution, as this partially violates the honesty structure, since parts of the variables
other than the X variable (e.g. the variables that go into the jacobian J of the linear model) are
used to inform the split structure of the tree. However, this is a benign dependence and for instance
in a causal tree or an IV tree does not use the label y. It only uses the treatment T and the instrument
Z and their local correlation structures to decide whether a split is feasible.
max_features : int, float or {"auto", "sqrt", "log2"}, default=None
The number of features to consider when looking for the best split:
- If int, then consider `max_features` features at each split.
- If float, then `max_features` is a fraction and
`int(max_features * n_features)` features are considered at each
split.
- If "auto", then `max_features=n_features`.
- If "sqrt", then `max_features=sqrt(n_features)`.
- If "log2", then `max_features=log2(n_features)`.
- If None, then `max_features=n_features`.
Note: the search for a split does not stop until at least one
valid partition of the node samples is found, even if it requires to
effectively inspect more than ``max_features`` features.
random_state : int, RandomState instance or None, default=None
Controls the randomness of the estimator. The features are always
randomly permuted at each split, even if ``splitter`` is set to
``"best"``. When ``max_features < n_features``, the algorithm will
select ``max_features`` at random at each split before finding the best
split among them. But the best found split may vary across different
runs, even if ``max_features=n_features``. That is the case, if the
improvement of the criterion is identical for several splits and one
split has to be selected at random. To obtain a deterministic behaviour
during fitting, ``random_state`` has to be fixed to an integer.
min_impurity_decrease : float, default=0.0
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.
The weighted impurity decrease equation is the following::
N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)
where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.
min_balancedness_tol: float in [0, .5], default=.45
How imbalanced a split we can tolerate. This enforces that each split leaves at least
(.5 - min_balancedness_tol) fraction of samples on each side of the split; or fraction
of the total weight of samples, when sample_weight is not None. Default value, ensures
that at least 5% of the parent node weight falls in each side of the split. Set it to 0.0 for no
balancedness and to .5 for perfectly balanced splits. For the formal inference theory
to be valid, this has to be any positive constant bounded away from zero.
honest: bool, default=True
Whether the data should be split in two equally sized samples, such that the one half-sample
is used to determine the optimal split at each node and the other sample is used to determine
the value of every node.
Attributes
----------
feature_importances_ : ndarray of shape (n_features,)
The feature importances based on the amount of parameter heterogeneity they create.
The higher, the more important the feature.
The importance of a feature is computed as the (normalized) total heterogeneity that the feature
creates. Each split that the feature was chosen adds::
parent_weight * (left_weight * right_weight)
* mean((value_left[k] - value_right[k])**2) / parent_weight**2
to the importance of the feature. Each such quantity is also weighted by the depth of the split.
By default splits below `max_depth=4` are not used in this calculation and also each split
at depth `depth`, is re-weighted by 1 / (1 + `depth`)**2.0. See the method ``feature_importances``
for a method that allows one to change these defaults.
max_features_ : int
The inferred value of max_features.
n_features_ : int
The number of features when ``fit`` is performed.
n_outputs_ : int
The number of outputs when ``fit`` is performed.
n_relevant_outputs_ : int
The first `n_relevant_outputs_` where the ones we cared about when ``fit`` was performed.
n_y_ : int
The raw label dimension when ``fit`` is performed.
n_samples_ : int
The number of training samples when ``fit`` is performed.
honest_ : int
Whether honesty was enabled when ``fit`` was performed
tree_ : Tree instance
The underlying Tree object. Please refer to
``help(econml.tree._tree.Tree)`` for attributes of Tree object.
References
----------
.. [grftree1] Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests."
The Annals of Statistics 47.2 (2019): 1148-1178
https://arxiv.org/pdf/1610.01271.pdf
"""
def __init__(self, *,
criterion="mse",
splitter="best",
max_depth=None,
min_samples_split=10,
min_samples_leaf=5,
min_weight_fraction_leaf=0.,
min_var_leaf=None,
min_var_leaf_on_val=False,
max_features=None,
random_state=None,
min_impurity_decrease=0.,
min_balancedness_tol=0.45,
honest=True):
self.criterion = criterion
self.splitter = splitter
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.min_var_leaf = min_var_leaf
self.min_var_leaf_on_val = min_var_leaf_on_val
self.max_features = max_features
self.random_state = random_state
self.min_impurity_decrease = min_impurity_decrease
self.min_balancedness_tol = min_balancedness_tol
self.honest = honest
def get_depth(self):
"""Return the depth of the decision tree.
The depth of a tree is the maximum distance between the root
and any leaf.
Returns
-------
self.tree_.max_depth : int
The maximum depth of the tree.
"""
check_is_fitted(self)
return self.tree_.max_depth
def get_n_leaves(self):
"""Return the number of leaves of the decision tree.
Returns
-------
self.tree_.n_leaves : int
Number of leaves.
"""
check_is_fitted(self)
return self.tree_.n_leaves
def init(self,):
""" This method should be called before fit. We added this pre-fit step so that this step
can be executed without parallelism as it contains code that holds the gil and can hinder
parallel execution. We also did not merge this step to ``__init__`` as we want ``__init__`` to just
be storing the parameters for easy cloning. We also don't want to directly pass a RandomState
object as random_state, as we want to keep the starting seed to be able to replicate the
randomness of the object outside the object.
"""
self.random_seed_ = self.random_state
self.random_state_ = check_random_state(self.random_seed_)
return self
def fit(self, X, y, n_y, n_outputs, n_relevant_outputs, sample_weight=None, check_input=True):
""" Fit the tree from the data
Parameters
----------
X : (n, d) array
The features to split on
y : (n, m) array
All the variables required to calculate the criterion function, evaluate splits and
estimate local values, i.e. all the values that go into the moment function except X.
n_y, n_outputs, n_relevant_outputs : auxiliary info passed to the criterion objects that
help the object parse the variable y into each separate variable components.
- In the case when `isinstance(criterion, LinearMomentGRFCriterion)`, then the first
n_y columns of y are the raw outputs, the next n_outputs columns contain the A part
of the moment and the next n_outputs * n_outputs columnts contain the J part of the moment
in row contiguous format. The first n_relevant_outputs parameters of the linear moment
are the ones that we care about. The rest are nuisance parameters.
sample_weight : (n,) array, default=None
The sample weights
check_input : bool, defaul=True
Whether to check the input parameters for validity. Should be set to False to improve
running time in parallel execution, if the variables have already been checked by the
forest class that spawned this tree.
"""
random_state = self.random_state_
# Determine output settings
n_samples, self.n_features_ = X.shape
self.n_outputs_ = n_outputs
self.n_relevant_outputs_ = n_relevant_outputs
self.n_y_ = n_y
self.n_samples_ = n_samples
self.honest_ = self.honest
# Important: This must be the first invocation of the random state at fit time, so that
# train/test splits are re-generatable from an external object simply by knowing the
# random_state parameter of the tree. Can be useful in the future if one wants to create local
# linear predictions. Currently is also useful for testing.
inds = np.arange(n_samples, dtype=np.intp)
if self.honest:
random_state.shuffle(inds)
samples_train, samples_val = inds[:n_samples // 2], inds[n_samples // 2:]
else:
samples_train, samples_val = inds, inds
if check_input:
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
y = np.atleast_1d(y)
if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = np.reshape(y, (-1, 1))
if len(y) != n_samples:
raise ValueError("Number of labels=%d does not match "
"number of samples=%d" % (len(y), n_samples))
if (sample_weight is not None):
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
# Check parameters
max_depth = (np.iinfo(np.int32).max if self.max_depth is None
else self.max_depth)
if isinstance(self.min_samples_leaf, numbers.Integral):
if not 1 <= self.min_samples_leaf:
raise ValueError("min_samples_leaf must be at least 1 "
"or in (0, 0.5], got %s"
% self.min_samples_leaf)
min_samples_leaf = self.min_samples_leaf
else: # float
if not 0. < self.min_samples_leaf <= 0.5:
raise ValueError("min_samples_leaf must be at least 1 "
"or in (0, 0.5], got %s"
% self.min_samples_leaf)
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
if isinstance(self.min_samples_split, numbers.Integral):
if not 2 <= self.min_samples_split:
raise ValueError("min_samples_split must be an integer "
"greater than 1 or a float in (0.0, 1.0]; "
"got the integer %s"
% self.min_samples_split)
min_samples_split = self.min_samples_split
else: # float
if not 0. < self.min_samples_split <= 1.:
raise ValueError("min_samples_split must be an integer "
"greater than 1 or a float in (0.0, 1.0]; "
"got the float %s"
% self.min_samples_split)
min_samples_split = int(ceil(self.min_samples_split * n_samples))
min_samples_split = max(2, min_samples_split)
min_samples_split = max(min_samples_split, 2 * min_samples_leaf)
if isinstance(self.max_features, str):
if self.max_features == "auto":
max_features = self.n_features_
elif self.max_features == "sqrt":
max_features = max(1, int(np.sqrt(self.n_features_)))
elif self.max_features == "log2":
max_features = max(1, int(np.log2(self.n_features_)))
else:
raise ValueError("Invalid value for max_features. "
"Allowed string values are 'auto', "
"'sqrt' or 'log2'.")
elif self.max_features is None:
max_features = self.n_features_
elif isinstance(self.max_features, numbers.Integral):
max_features = self.max_features
else: # float
if self.max_features > 0.0:
max_features = max(1,
int(self.max_features * self.n_features_))
else:
max_features = 0
self.max_features_ = max_features
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
if max_depth < 0:
raise ValueError("max_depth must be greater than or equal to zero. ")
if not (0 < max_features <= self.n_features_):
raise ValueError("max_features must be in (0, n_features]")
if not 0 <= self.min_balancedness_tol <= 0.5:
raise ValueError("min_balancedness_tol must be in [0, 0.5]")
if self.min_var_leaf is None:
min_var_leaf = -1.0
elif isinstance(self.min_var_leaf, numbers.Real) and (self.min_var_leaf >= 0.0):
min_var_leaf = self.min_var_leaf
else:
raise ValueError("min_var_leaf must be either None or a real in [0, infinity). "
"Got {}".format(self.min_var_leaf))
if not isinstance(self.min_var_leaf_on_val, bool):
raise ValueError("min_var_leaf_on_val must be either True or False. "
"Got {}".format(self.min_var_leaf_on_val))
# Set min_weight_leaf from min_weight_fraction_leaf
if sample_weight is None:
min_weight_leaf = (self.min_weight_fraction_leaf *
n_samples)
else:
min_weight_leaf = (self.min_weight_fraction_leaf *
np.sum(sample_weight))
# Build tree
# We calculate the maximum number of samples from each half-split that any node in the tree can
# hold. Used by criterion for memory space savings.
max_train = len(samples_train) if sample_weight is None else np.count_nonzero(sample_weight[samples_train])
if self.honest:
max_val = len(samples_val) if sample_weight is None else np.count_nonzero(sample_weight[samples_val])
# Initialize the criterion object and the criterion_val object if honest.
if callable(self.criterion):
criterion = self.criterion(self.n_outputs_, self.n_relevant_outputs_, self.n_features_, self.n_y_,
n_samples, max_train,
random_state.randint(np.iinfo(np.int32).max))
if not isinstance(criterion, Criterion):
raise ValueError("Input criterion is not a valid criterion")
if self.honest:
criterion_val = self.criterion(self.n_outputs_, self.n_relevant_outputs_, self.n_features_, self.n_y_,
n_samples, max_val,
random_state.randint(np.iinfo(np.int32).max))
else:
criterion_val = criterion
else:
criterion = CRITERIA_GRF[self.criterion](
self.n_outputs_, self.n_relevant_outputs_, self.n_features_, self.n_y_, n_samples, max_train,
random_state.randint(np.iinfo(np.int32).max))
if self.honest:
criterion_val = CRITERIA_GRF[self.criterion](
self.n_outputs_, self.n_relevant_outputs_, self.n_features_, self.n_y_, n_samples, max_val,
random_state.randint(np.iinfo(np.int32).max))
else:
criterion_val = criterion
if (min_var_leaf >= 0.0 and (not isinstance(criterion, LinearMomentGRFCriterion)) and
(not isinstance(criterion_val, LinearMomentGRFCriterion))):
raise ValueError("This criterion does not support min_var_leaf constraint!")
splitter = self.splitter
if not isinstance(self.splitter, Splitter):
splitter = SPLITTERS[self.splitter](criterion, criterion_val,
self.max_features_,
min_samples_leaf,
min_weight_leaf,
self.min_balancedness_tol,
self.honest,
min_var_leaf,
self.min_var_leaf_on_val,
random_state.randint(np.iinfo(np.int32).max))
self.tree_ = Tree(self.n_features_, self.n_outputs_, self.n_relevant_outputs_, store_jac=True)
builder = DepthFirstTreeBuilder(splitter, min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
self.min_impurity_decrease)
builder.build(self.tree_, X, y, samples_train, samples_val,
sample_weight=sample_weight,
store_jac=True)
return self
def _validate_X_predict(self, X, check_input):
"""Validate X whenever one tries to predict, apply, or any other of the prediction
related methods. """
if check_input:
X = check_array(X, dtype=DTYPE, accept_sparse=False)
n_features = X.shape[1]
if self.n_features_ != n_features:
raise ValueError("Number of features of the model must "
"match the input. Model n_features is %s and "
"input n_features is %s "
% (self.n_features_, n_features))
return X
def get_train_test_split_inds(self,):
""" Regenerate the train_test_split of input sample indices that was used for the training
and the evaluation split of the honest tree construction structure. Uses the same random seed
that was used at ``fit`` time and re-generates the indices.
"""
check_is_fitted(self)
random_state = check_random_state(self.random_seed_)
inds = np.arange(self.n_samples_, dtype=np.intp)
if self.honest_:
random_state.shuffle(inds)
return inds[:self.n_samples_ // 2], inds[self.n_samples_ // 2:]
else:
return inds, inds
def predict(self, X, check_input=True):
"""Return the prefix of relevant fitted local parameters for each X, i.e. theta(X).
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``.
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
theta(X)[:n_relevant_outputs] : array-like of shape (n_samples, n_relevant_outputs)
The estimated relevant parameters for each row of X
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
pred = self.tree_.predict(X)
return pred
def predict_full(self, X, check_input=True):
"""Return the fitted local parameters for each X, i.e. theta(X).
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``.
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
theta(X) : array-like of shape (n_samples, n_outputs)
All the estimated parameters for each row of X
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
pred = self.tree_.predict_full(X)
return pred
def predict_alpha_and_jac(self, X, check_input=True):
"""Predict the local jacobian ``E[J | X=x]`` and the local alpha ``E[A | X=x]`` of
a linear moment equation.
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
alpha : array-like of shape (n_samples, n_outputs)
The local alpha E[A | X=x] for each sample x
jac : array-like of shape (n_samples, n_outputs * n_outputs)
The local jacobian E[J | X=x] flattened in a C contiguous format
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
return self.tree_.predict_precond_and_jac(X)
def predict_moment(self, X, parameter, check_input=True):
"""
Predict the local moment value for each sample and at the given parameter::
E[J | X=x] theta(x) - E[A | X=x]
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``
parameter : {array-like} of shape (n_samples, n_outputs)
A parameter estimate for each sample
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
moment : array-like of shape (n_samples, n_outputs)
The local moment E[J | X=x] theta(x) - E[A | X=x] for each sample x
"""
alpha, jac = self.predict_alpha_and_jac(X)
return alpha - np.einsum('ijk,ik->ij', jac.reshape((-1, self.n_outputs_, self.n_outputs_)), parameter)
def apply(self, X, check_input=True):
"""Return the index of the leaf that each sample is predicted as.
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
X_leaves : array-like of shape (n_samples,)
For each datapoint x in X, return the index of the leaf x
ends up in. Leaves are numbered within
``[0; self.tree_.node_count)``, possibly with gaps in the
numbering.
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
return self.tree_.apply(X)
def decision_path(self, X, check_input=True):
"""Return the decision path in the tree.
Parameters
----------
X : {array-like} of shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float64``
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
indicator : sparse matrix of shape (n_samples, n_nodes)
Return a node indicator CSR matrix where non zero elements
indicates that the samples goes through the nodes.
"""
X = self._validate_X_predict(X, check_input)
return self.tree_.decision_path(X)
def feature_importances(self, max_depth=4, depth_decay_exponent=2.0):
"""The feature importances based on the amount of parameter heterogeneity they create.
The higher, the more important the feature.
The importance of a feature is computed as the (normalized) total heterogeneity that the feature
creates. Each split that the feature was chosen adds::
parent_weight * (left_weight * right_weight)
* mean((value_left[k] - value_right[k])**2) / parent_weight**2
to the importance of the feature. Each such quantity is also weighted by the depth of the split.
Parameters
----------
max_depth : int, default=4
Splits of depth larger than `max_depth` are not used in this calculation
depth_decay_exponent: double, default=2.0
The contribution of each split to the total score is re-weighted by ``1 / (1 + `depth`)**2.0``.
Returns
-------
feature_importances_ : ndarray of shape (n_features,)
Normalized total parameter heterogeneity inducing importance of each feature
"""
check_is_fitted(self)
return self.tree_.compute_feature_heterogeneity_importances(normalize=True, max_depth=max_depth,
depth_decay=depth_decay_exponent)
@property
def feature_importances_(self):
return self.feature_importances()

82
econml/grf/_criterion.pxd Normal file
Просмотреть файл

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
# See _criterion.pyx for implementation details.
import numpy as np
cimport numpy as np
from ..tree._tree cimport DTYPE_t # Type of X
from ..tree._tree cimport DOUBLE_t # Type of y, sample_weight
from ..tree._tree cimport SIZE_t # Type for indices and counters
from ..tree._tree cimport INT32_t # Signed 32 bit integer
from ..tree._tree cimport UINT32_t # Unsigned 32 bit integer
from ..tree._criterion cimport Criterion, RegressionCriterion
cdef class LinearMomentGRFCriterion(RegressionCriterion):
""" A criterion class that estimates local parameters defined via linear moment equations
of the form:
E[ m(J, A; theta(x)) | X=x] = E[ J * theta(x) - A | X=x] = 0
Calculates impurity based on heterogeneity induced on the estimated parameters, based on the proxy score
defined in the Generalized Random Forest paper:
Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests."
The Annals of Statistics 47.2 (2019): 1148-1178
https://arxiv.org/pdf/1610.01271.pdf
"""
cdef const DOUBLE_t[:, ::1] alpha # The A random vector of the linear moment equation for each sample
cdef const DOUBLE_t[:, ::1] pointJ # The J random vector of the linear moment equation for each sample
cdef DOUBLE_t* rho # Proxy heterogeneity label: rho = E[J | X in Node]^{-1} m(J, A; theta(Node))
cdef DOUBLE_t* moment # Moment for each sample: m(J, A; theta(Node))
cdef DOUBLE_t* parameter # Estimated node parameter: theta(Node) = E[J|X in Node]^{-1} E[A|X in Node]
cdef DOUBLE_t* parameter_pre # Preconditioned node parameter: theta_pre(Node) = E[A | X in Node]
cdef DOUBLE_t* J # Node average jacobian: J(Node) = E[J | X in Node]
cdef DOUBLE_t* invJ # Inverse of node average jacobian: J(Node)^{-1}
cdef DOUBLE_t* var_total # The diagonal elements of J(Node) (used for proxy of min eigenvalue)
cdef DOUBLE_t* var_left # The diagonal elements of J(Left) = E[J | X in Left-Child]
cdef DOUBLE_t* var_right # The diagonal elements of J(Right) = E[J | X in Right-Child]
cdef SIZE_t* node_index_mapping # Used internally to map between sample index in y, with sample index in
# internal memory space that stores rho and moment for each sample
cdef DOUBLE_t y_sq_sum_total # The sum of the raw labels y: \sum_i sum_k w_i y_{ik}^2
cdef int node_reset_jacobian(self, DOUBLE_t* J, DOUBLE_t* invJ, double* weighted_n_node_samples,
const DOUBLE_t[:, ::1] pointJ,
DOUBLE_t* sample_weight,
SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1
cdef int node_reset_parameter(self, DOUBLE_t* parameter, DOUBLE_t* parameter_pre,
DOUBLE_t* invJ,
const DOUBLE_t[:, ::1] alpha,
DOUBLE_t* sample_weight, double weighted_n_node_samples,
SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1
cdef int node_reset_rho(self, DOUBLE_t* rho, DOUBLE_t* moment, SIZE_t* node_index_mapping,
DOUBLE_t* parameter, DOUBLE_t* invJ, double weighted_n_node_samples,
const DOUBLE_t[:, ::1] pointJ, const DOUBLE_t[:, ::1] alpha,
DOUBLE_t* sample_weight, SIZE_t* samples,
SIZE_t start, SIZE_t end) nogil except -1
cdef int node_reset_sums(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* rho,
DOUBLE_t* J,
DOUBLE_t* sample_weight, SIZE_t* samples,
DOUBLE_t* sum_total, DOUBLE_t* var_total,
DOUBLE_t* sq_sum_total, DOUBLE_t* y_sq_sum_total,
SIZE_t start, SIZE_t end) nogil except -1
cdef class LinearMomentGRFCriterionMSE(LinearMomentGRFCriterion):
cdef DOUBLE_t* J_left # The jacobian of the left child: J(Left) = E[J | X in Left-Child]
cdef DOUBLE_t* J_right # The jacobian of the right child: J(Right) = E[J | X in Right-Child]
cdef DOUBLE_t* invJ_left # The jacobian of the left child: J(Left) = E[J | X in Left-Child]
cdef DOUBLE_t* invJ_right # The jacobian of the right child: J(Right) = E[J | X in Right-Child]
cdef DOUBLE_t* parameter_pre_left
cdef DOUBLE_t* parameter_pre_right
cdef DOUBLE_t* parameter_left
cdef DOUBLE_t* parameter_right
cdef double _get_min_eigv(self, DOUBLE_t* J_child, DOUBLE_t* var_child,
double weighted_n_child) nogil except -1

1144
econml/grf/_criterion.pyx Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

165
econml/grf/_ensemble.py Normal file
Просмотреть файл

@ -0,0 +1,165 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from:
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/_base.py
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
import numbers
import numpy as np
from abc import ABCMeta, abstractmethod
from sklearn.base import BaseEstimator, clone
from sklearn.utils import _print_elapsed_time
from sklearn.utils import check_random_state
from joblib import effective_n_jobs
def _fit_single_estimator(estimator, X, y, sample_weight=None,
message_clsname=None, message=None):
"""Private function used to fit an estimator within a job."""
if sample_weight is not None:
try:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y, sample_weight=sample_weight)
except TypeError as exc:
if "unexpected keyword argument 'sample_weight'" in str(exc):
raise TypeError(
"Underlying estimator {} does not support sample weights."
.format(estimator.__class__.__name__)
) from exc
raise
else:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y)
return estimator
def _set_random_states(estimator, random_state):
"""Set fixed random_state parameters for an estimator.
Finds all parameters ending ``random_state`` and sets them to integers
derived from ``random_state``.
Parameters
----------
estimator : estimator supporting get/set_params
Estimator with potential randomness managed by random_state
parameters.
random_state : np.RandomState object
Pseudo-random number generator to control the generation of the random
integers.
Notes
-----
This does not necessarily set *all* ``random_state`` attributes that
control an estimator's randomness, only those accessible through
``estimator.get_params()``. ``random_state``s not controlled include
those belonging to:
* cross-validation splitters
* ``scipy.stats`` rvs
"""
to_set = {}
for key in sorted(estimator.get_params(deep=True)):
if key == 'random_state' or key.endswith('__random_state'):
to_set[key] = random_state.randint(np.iinfo(np.int32).max)
if to_set:
estimator.set_params(**to_set)
class BaseEnsemble(BaseEstimator, metaclass=ABCMeta):
"""Base class for all ensemble classes.
Warning: This class should not be used directly. Use derived classes
instead.
Parameters
----------
base_estimator : object
The base estimator from which the ensemble is built.
n_estimators : int, default=10
The number of estimators in the ensemble.
estimator_params : list of str, default=tuple()
The list of attributes to use as parameters when instantiating a
new base estimator. If none are given, default parameters are used.
Attributes
----------
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
estimators_ : list of estimators
The collection of fitted base estimators.
"""
@abstractmethod
def __init__(self, base_estimator, *, n_estimators=10,
estimator_params=tuple()):
# Set parameters
self.base_estimator = base_estimator
self.n_estimators = n_estimators
self.estimator_params = estimator_params
# Don't instantiate estimators now! Parameters of base_estimator might
# still change. Eg., when grid-searching with the nested object syntax.
# self.estimators_ needs to be filled by the derived classes in fit.
def _validate_estimator(self, default=None):
"""Check the estimator and the n_estimator attribute.
Sets the base_estimator_` attributes.
"""
if not isinstance(self.n_estimators, numbers.Integral):
raise ValueError("n_estimators must be an integer, "
"got {0}.".format(type(self.n_estimators)))
if self.n_estimators <= 0:
raise ValueError("n_estimators must be greater than zero, "
"got {0}.".format(self.n_estimators))
if self.base_estimator is not None:
self.base_estimator_ = self.base_estimator
else:
self.base_estimator_ = default
if self.base_estimator_ is None:
raise ValueError("base_estimator cannot be None")
def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
estimator = clone(self.base_estimator_)
estimator.set_params(**{p: getattr(self, p)
for p in self.estimator_params})
if random_state is not None:
_set_random_states(estimator, random_state)
if append:
self.estimators_.append(estimator)
return estimator
def __len__(self):
"""Return the number of estimators in the ensemble."""
return len(self.estimators_)
def __getitem__(self, index):
"""Return the index'th estimator in the ensemble."""
return self.estimators_[index]
def __iter__(self):
"""Return iterator over estimators in the ensemble."""
return iter(self.estimators_)
def _partition_estimators(n_estimators, n_jobs):
"""Private function used to partition estimators between jobs."""
# Compute the number of jobs
n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
# Partition estimators between jobs
n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs,
dtype=np.int)
n_estimators_per_job[:n_estimators % n_jobs] += 1
starts = np.cumsum(n_estimators_per_job)
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()

31
econml/grf/_utils.pxd Normal file
Просмотреть файл

@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
cimport numpy as np
ctypedef np.npy_float64 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
cpdef bint matinv(DOUBLE_t[::1, :] a, DOUBLE_t[::1, :] inv_a) nogil
cdef bint matinv_(DOUBLE_t* a, DOUBLE_t* inv_a, int m) nogil
cpdef void lstsq(DOUBLE_t[::1,:] a, DOUBLE_t[::1,:] b, DOUBLE_t[::1, :] sol, bint copy_b=*) nogil
cdef void lstsq_(DOUBLE_t* a, DOUBLE_t* b, DOUBLE_t* sol, int m, int n, int ldb, int nrhs, bint copy_b=*) nogil
cpdef void pinv(DOUBLE_t[::1,:] a, DOUBLE_t[::1, :] sol) nogil
cdef void pinv_(DOUBLE_t* a, DOUBLE_t* sol, int m, int n) nogil
cpdef double fast_max_eigv(DOUBLE_t[::1, :] A, int reps, UINT32_t random_state) nogil
cdef double fast_max_eigv_(DOUBLE_t* A, int n, int reps, UINT32_t* random_state) nogil
cpdef double fast_min_eigv(DOUBLE_t[::1, :] A, int reps, UINT32_t random_state) nogil
cdef double fast_min_eigv_(DOUBLE_t* A, int n, int reps, UINT32_t* random_state) nogil

280
econml/grf/_utils.pyx Normal file
Просмотреть файл

@ -0,0 +1,280 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from libc.stdlib cimport free
from libc.stdlib cimport malloc
from libc.stdlib cimport calloc
from libc.stdlib cimport realloc
from libc.string cimport memcpy
from libc.math cimport log as ln
from libc.stdlib cimport abort
from scipy.linalg.cython_lapack cimport dgelsy, dgetrf, dgetri, dgecon, dlacpy, dlange
import numpy as np
cimport numpy as np
np.import_array()
from ..tree._utils cimport rand_int
rcond_ = np.finfo(np.float64).eps
cdef inline double RCOND = rcond_
# =============================================================================
# Linear Algebra Functions
# =============================================================================
cpdef bint matinv(DOUBLE_t[::1, :] a, DOUBLE_t[::1, :] inv_a) nogil:
""" Compute matrix inverse and store it in inv_a.
"""
cdef int m, n
m = a.shape[0]
if not (m == a.shape[1]):
raise ValueError("Can only invert square matrices!")
return matinv_(&a[0, 0], &inv_a[0, 0], m)
cdef bint matinv_(DOUBLE_t* a, DOUBLE_t* inv_a, int m) nogil:
""" Compute matrix inverse of matrix a of size (m, m) and store it in inv_a.
"""
cdef:
int* pivot
DOUBLE_t* work
int lda, INFO, Lwork
bint failed
lda = m
Lwork = m**2
pivot = <int*> malloc(m * sizeof(int))
work = <DOUBLE_t*> malloc(Lwork * sizeof(DOUBLE_t))
failed = False
if (pivot==NULL or work==NULL):
with gil:
raise MemoryError()
try:
memcpy(inv_a, a, m * m * sizeof(DOUBLE_t))
#Conduct the LU factorization of the array a
dgetrf(&m, &m, inv_a, &lda, pivot, &INFO)
if not (INFO == 0):
failed = True
else:
#Now use the LU factorization and the pivot information to invert
dgetri(&m, inv_a, &lda, pivot, work, &Lwork, &INFO)
if not (INFO == 0):
failed = True
finally:
free(pivot)
free(work)
return (not failed)
cpdef void lstsq(DOUBLE_t[::1, :] a, DOUBLE_t[::1, :] b, DOUBLE_t[::1, :] sol, bint copy_b=True) nogil:
""" Compute solution to least squares problem min ||b - a sol||_2^2,
where a is a matrix of size (m, n), b is (m, nrhs). Store (n, nrhs) solution in sol.
The memory view b, must have at least max(m, n) rows. If m < n, then pad remainder with zeros.
If copy_b=True, then b is left unaltered on output. Otherwise b is altered by this call.
"""
cdef int m, n, nrhs
m = a.shape[0]
n = a.shape[1]
nrhs = b.shape[1]
ldb = b.shape[0]
if ldb < max(m, n):
with gil:
raise ValueError("Matrix b must have first dimension at least max(a.shape[0], a.shape[1]). "
"Please pad with zeros.")
if (sol.shape[0] != n) or (sol.shape[1] != nrhs):
with gil:
raise ValueError("Matrix sol must have dimensions (a.shape[1], b.shape[1]).")
lstsq_(&a[0, 0], &b[0, 0], &sol[0, 0], m, n, ldb, nrhs, copy_b)
cdef void lstsq_(DOUBLE_t* a, DOUBLE_t* b, DOUBLE_t* sol, int m, int n, int ldb, int nrhs, bint copy_b=True) nogil:
""" Compute solution to least squares problem min ||b - a sol||_2^2,
where a is a matrix of size (m, n), b is (m, nrhs). Store (n, nrhs) solution in sol.
The leading (row) dimension b, must be at least max(m, n). If m < n, then pad remainder with zeros.
If copy_b=True, then b is left unaltered on output. Otherwise b is altered by this call.
"""
cdef:
int lda, rank, info, lwork, n_out
double rcond
Py_ssize_t i, j
#array pointers
int* jpvt
double* work
double* b_copy
char* UPLO = 'O' #Any letter other then 'U' or 'L' will copy entire array
lda = m
if ldb < max(m, n):
with gil:
raise ValueError("Matrix b must have dimension at least max(a.shape[0], a.shape[1]). "
"Please pad with zeros.")
rcond = max(m, n) * RCOND
jpvt = <int*> calloc(n, sizeof(int))
lwork = max(min(n, m) + 3 * n + 1, 2 * min(n, m) + nrhs)
work = <DOUBLE_t*> malloc(lwork * sizeof(DOUBLE_t))
# TODO. can we avoid all this malloc and copying in our context?
a_copy = <DOUBLE_t*> calloc(lda * n, sizeof(DOUBLE_t))
if copy_b:
b_copy = <DOUBLE_t*> calloc(ldb * nrhs, sizeof(DOUBLE_t))
else:
b_copy = b
try:
dlacpy(UPLO, &lda, &n, a, &lda, a_copy, &lda)
if copy_b:
dlacpy(UPLO, &ldb, &nrhs, b, &ldb, b_copy, &ldb)
dgelsy(&m, &n, &nrhs, a_copy, &lda, b_copy, &ldb,
&jpvt[0], &rcond, &rank, &work[0], &lwork, &info)
for i in range(n):
for j in range(nrhs):
sol[i + j * n] = b_copy[i + j * ldb]
finally:
free(jpvt)
free(work)
free(a_copy)
if copy_b:
free(b_copy)
cpdef void pinv(DOUBLE_t[::1,:] a, DOUBLE_t[::1, :] sol) nogil:
""" Compute pseudo-inverse of (m, n) matrix a and store it in (n, m) matrix sol.
Matrix a is left un-altered by this call.
"""
cdef int m = a.shape[0]
cdef int n = a.shape[1]
pinv_(&a[0, 0], &sol[0, 0], m, n)
cdef void pinv_(DOUBLE_t* a, DOUBLE_t* sol, int m, int n) nogil:
""" Compute pseudo-inverse of (m, n) matrix a and store it in (n, m) matrix sol.
Matrix a is left un-altered by this call.
"""
# TODO. can we avoid this mallon in our context. Maybe create some fixed memory allocations?
cdef int ldb = max(m, n)
cdef double* b = <DOUBLE_t*> calloc(ldb * m, sizeof(double))
cdef Py_ssize_t i
for i in range(m):
b[i + i * ldb] = 1.0
try:
lstsq_(a, b, sol, m, n, ldb, m, copy_b=False)
finally:
free(b)
cpdef double fast_max_eigv(DOUBLE_t[::1, :] A, int reps, UINT32_t random_state) nogil:
""" Calculate approximation of maximum eigenvalue via randomized power iteration algorithm.
See e.g.: http://theory.stanford.edu/~trevisan/expander-online/lecture03.pdf
Use reps repetition and random seed based on random_state
"""
return fast_max_eigv_(&A[0, 0], A.shape[0], reps, &random_state)
cdef double fast_max_eigv_(DOUBLE_t* A, int n, int reps, UINT32_t* random_state) nogil:
""" Calculate approximation of maximum eigenvalue via randomized power iteration algorithm.
See e.g.: http://theory.stanford.edu/~trevisan/expander-online/lecture03.pdf
Use reps repetition and random seed based on random_state
"""
cdef int t, i, j
cdef double normx, Anormx
cdef double* xnew
cdef double* xold
cdef double* temp
xnew = NULL
xold = NULL
try:
xnew = <double*> calloc(n, sizeof(double))
xold = <double*> calloc(n, sizeof(double))
if xnew == NULL or xold == NULL:
with gil:
raise MemoryError()
for i in range(n):
xold[i] = (1 - 2*rand_int(0, 2, random_state))
for t in range(reps):
for i in range(n):
xnew[i] = 0
for j in range(n):
xnew[i] += A[i + j * n] * xold[j]
temp = xold
xold = xnew
xnew = temp
normx = 0
Anormx = 0
for i in range(n):
normx += xnew[i] * xnew[i]
for j in range(n):
Anormx += xnew[i] * A[i + j * n] * xnew[j]
return Anormx / normx
finally:
free(xnew)
free(xold)
cpdef double fast_min_eigv(DOUBLE_t[::1, :] A, int reps, UINT32_t random_state) nogil:
""" Calculate approximation of minimum eigenvalue via randomized power iteration algorithm.
See e.g.: http://theory.stanford.edu/~trevisan/expander-online/lecture03.pdf
Use reps repetition and random seed based on random_state
"""
return fast_min_eigv_(&A[0, 0], A.shape[0], reps, &random_state)
cdef double fast_min_eigv_(DOUBLE_t* A, int n, int reps, UINT32_t* random_state) nogil:
""" Calculate approximation of minimum eigenvalue via randomized power iteration algorithm.
See e.g.: http://theory.stanford.edu/~trevisan/expander-online/lecture03.pdf
Use reps repetition and random seed based on random_state.
"""
cdef int t, i, j
cdef double normx, Anormx
cdef double* xnew
cdef double* xold
cdef double* temp
cdef double* update
xnew = NULL
xold = NULL
try:
xnew = <double*> calloc(n, sizeof(double))
xold = <double*> calloc(n, sizeof(double))
update = <double*> calloc(n, sizeof(double))
if xnew == NULL or xold == NULL or update == NULL:
with gil:
raise MemoryError()
for i in range(n):
xold[i] = (1 - 2*rand_int(0, 2, random_state))
for t in range(reps):
lstsq_(A, xold, update, n, n, n, 1, copy_b=False)
for i in range(n):
xnew[i] = 0
for j in range(n):
xnew[i] += update[i]
temp = xold
xold = xnew
xnew = temp
normx = 0
Anormx = 0
for i in range(n):
normx += xnew[i] * xnew[i]
for j in range(n):
Anormx += xnew[i] * A[i + j * n] * xnew[j]
return Anormx / normx
finally:
free(xnew)
free(xold)
free(update)

1027
econml/grf/classes.py Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -345,17 +345,19 @@ class GenericModelFinalInferenceDiscrete(Inference):
def const_marginal_effect_interval(self, X, *, alpha=0.1):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
preds = np.array([mdl.predict_interval(X, alpha=alpha) for mdl in self.fitted_models_final])
preds = np.array([tuple(map(lambda x: x.reshape((-1,) + self._d_y), mdl.predict_interval(X, alpha=alpha)))
for mdl in self.fitted_models_final])
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
def const_marginal_effect_inference(self, X):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
pred = np.array([mdl.predict(X) for mdl in self.fitted_models_final])
pred = np.array([mdl.predict(X).reshape((-1,) + self._d_y) for mdl in self.fitted_models_final])
if not hasattr(self.fitted_models_final[0], 'prediction_stderr'):
raise AttributeError("Final model doesn't support prediction standard eror, "
"please call const_marginal_effect_interval to get confidence interval.")
pred_stderr = np.array([mdl.prediction_stderr(X) for mdl in self.fitted_models_final])
pred_stderr = np.array([mdl.prediction_stderr(X).reshape((-1,) + self._d_y)
for mdl in self.fitted_models_final])
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=np.moveaxis(pred, 0, -1),
# send treatment to the end, pull bounds to the front
pred_stderr=np.moveaxis(pred_stderr, 0, -1), inf_type='effect',

Просмотреть файл

@ -9,7 +9,7 @@ For more details on these CATE methods, see <https://arxiv.org/abs/1706.03461>
import numpy as np
import warnings
from .cate_estimator import BaseCateEstimator, LinearCateEstimator, TreatmentExpansionMixin
from ._cate_estimator import BaseCateEstimator, LinearCateEstimator, TreatmentExpansionMixin
from sklearn import clone
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
@ -17,7 +17,7 @@ from sklearn.utils import check_array, check_X_y
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
from .utilities import (check_inputs, check_models, broadcast_unit_treatments, reshape_treatmentwise_effects,
inverse_onehot, transpose, _EncoderWrapper, _deprecate_positional)
from .shap import _shap_explain_model_cate
from ._shap import _shap_explain_model_cate
class TLearner(TreatmentExpansionMixin, LinearCateEstimator):
@ -460,8 +460,11 @@ class DomainAdaptationLearner(TreatmentExpansionMixin, LinearCateEstimator):
last_step_name = model_instance.steps[-1][0]
model_instance.fit(X, y, **{"{0}__sample_weight".format(last_step_name): sample_weight})
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
return _shap_explain_model_cate(self.const_marginal_effect, self.final_models, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
background_samples=background_samples)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__

Просмотреть файл

@ -35,8 +35,8 @@ from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, PolynomialFeatures, FunctionTransformer
from sklearn.utils import check_random_state, check_array, column_or_1d
from .sklearn_extensions.linear_model import WeightedLassoCVWrapper
from .cate_estimator import BaseCateEstimator, LinearCateEstimator, TreatmentExpansionMixin
from .causal_tree import CausalTree
from ._cate_estimator import BaseCateEstimator, LinearCateEstimator, TreatmentExpansionMixin
from ._causal_tree import CausalTree
from .inference import Inference, NormalInferenceResults
from .utilities import (reshape, reshape_Y_T, MAX_RAND_SEED, check_inputs, _deprecate_positional,
cross_product, inverse_onehot, _EncoderWrapper, check_input_arrays,
@ -46,19 +46,10 @@ from sklearn.model_selection import check_cv
from .sklearn_extensions.model_selection import _cross_val_predict
def _build_tree_in_parallel(Y, T, X, W,
nuisance_estimator,
parameter_estimator,
moment_and_mean_gradient_estimator,
min_leaf_size, max_depth, random_state):
tree = CausalTree(nuisance_estimator=nuisance_estimator,
parameter_estimator=parameter_estimator,
moment_and_mean_gradient_estimator=moment_and_mean_gradient_estimator,
min_leaf_size=min_leaf_size,
max_depth=max_depth,
random_state=random_state)
def _build_tree_in_parallel(tree, Y, T, X, W,
nuisance_estimator, parameter_estimator, moment_and_mean_gradient_estimator):
# Create splits of causal tree
tree.create_splits(Y, T, X, W)
tree.create_splits(Y, T, X, W, nuisance_estimator, parameter_estimator, moment_and_mean_gradient_estimator)
return tree
@ -223,6 +214,9 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
subsample_ratio=0.25,
bootstrap=False,
n_jobs=-1,
backend='loky',
verbose=3,
batch_size='auto',
random_state=None):
# Estimators
self.nuisance_estimator = nuisance_estimator
@ -249,6 +243,9 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
# Fit check
self.model_is_fitted = False
self.discrete_treatment = discrete_treatment
self.backend = backend
self.verbose = verbose
self.batch_size = batch_size
super().__init__()
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
@ -325,7 +322,8 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
if not self.model_is_fitted:
raise NotFittedError('This {0} instance is not fitted yet.'.format(self.__class__.__name__))
X = check_array(X)
results = Parallel(n_jobs=self.n_jobs, verbose=3)(
results = Parallel(n_jobs=self.n_jobs, backend=self.backend,
batch_size=self.batch_size, verbose=self.verbose)(
delayed(_pointwise_effect)(X_single, *self._pw_effect_inputs(X_single, stderr=stderr),
self.second_stage_nuisance_estimator, self.second_stage_parameter_estimator,
self.moment_and_mean_gradient_estimator, self.slice_len, self.n_slices,
@ -371,14 +369,17 @@ class BaseOrthoForest(TreatmentExpansionMixin, LinearCateEstimator):
# Generate subsample indices
subsample_ind = self._get_blb_indices(X)
# Build trees in parallel
return subsample_ind, Parallel(n_jobs=self.n_jobs, verbose=3, max_nbytes=None)(
delayed(_build_tree_in_parallel)(
Y[s], T[s], X[s], W[s] if W is not None else None,
self.nuisance_estimator,
self.parameter_estimator,
self.moment_and_mean_gradient_estimator,
self.min_leaf_size, self.max_depth,
self.random_state.randint(MAX_RAND_SEED)) for s in subsample_ind)
trees = [CausalTree(self.min_leaf_size, self.max_depth, 1000, .4,
check_random_state(self.random_state.randint(MAX_RAND_SEED)))
for _ in range(len(subsample_ind))]
return subsample_ind, Parallel(n_jobs=self.n_jobs, backend=self.backend,
batch_size=self.batch_size, verbose=self.verbose, max_nbytes=None)(
delayed(_build_tree_in_parallel)(tree,
Y[s], T[s], X[s], W[s] if W is not None else None,
self.nuisance_estimator,
self.parameter_estimator,
self.moment_and_mean_gradient_estimator)
for s, tree in zip(subsample_ind, trees))
def _get_weights(self, X_single, tree_slice=None):
"""Calculate weights for a single input feature vector over a subset of trees.
@ -492,7 +493,7 @@ class DMLOrthoForest(BaseOrthoForest):
power, especially when W is not None.
global_res_cv : int, cross-validation generator or an iterable, optional (default=2)
The specification of the cv splitter to be used for cross-fitting, when constructing
The specification of the CV splitter to be used for cross-fitting, when constructing
the global residuals of Y and T.
discrete_treatment : bool, optional (default=False)
@ -509,6 +510,15 @@ class DMLOrthoForest(BaseOrthoForest):
``-1`` means using all processors. Since OrthoForest methods are
computationally heavy, it is recommended to set `n_jobs` to -1.
backend : 'threading' or 'loky', optional (default='loky')
What backend should be used for parallelization with the joblib library.
verbose : int, optional (default=3)
Verbosity level
batch_size : int or 'auto', optional (default='auto')
Batch_size of jobs for parallelism
random_state : int, :class:`~numpy.random.mtrand.RandomState` instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator;
@ -532,6 +542,9 @@ class DMLOrthoForest(BaseOrthoForest):
discrete_treatment=False,
categories='auto',
n_jobs=-1,
backend='loky',
verbose=3,
batch_size='auto',
random_state=None):
# Copy and/or define models
self.lambda_reg = lambda_reg
@ -555,18 +568,18 @@ class DMLOrthoForest(BaseOrthoForest):
self.global_residualization = global_residualization
self.global_res_cv = global_res_cv
# Define nuisance estimators
nuisance_estimator = DMLOrthoForest.nuisance_estimator_generator(
nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
self.model_T, self.model_Y, self.random_state, second_stage=False,
global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
second_stage_nuisance_estimator = DMLOrthoForest.nuisance_estimator_generator(
second_stage_nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
self.model_T_final, self.model_Y_final, self.random_state, second_stage=True,
global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
# Define parameter estimators
parameter_estimator = DMLOrthoForest.parameter_estimator_func
second_stage_parameter_estimator = DMLOrthoForest.second_stage_parameter_estimator_gen(
parameter_estimator = _DMLOrthoForest_parameter_estimator_func
second_stage_parameter_estimator = _DMLOrthoForest_second_stage_parameter_estimator_gen(
self.lambda_reg)
# Define
moment_and_mean_gradient_estimator = DMLOrthoForest.moment_and_mean_gradient_estimator_func
moment_and_mean_gradient_estimator = _DMLOrthoForest_moment_and_mean_gradient_estimator_func
if discrete_treatment:
if categories != 'auto':
categories = [categories] # OneHotEncoder expects a 2D array with features per column
@ -583,6 +596,9 @@ class DMLOrthoForest(BaseOrthoForest):
subsample_ratio=subsample_ratio,
bootstrap=bootstrap,
n_jobs=n_jobs,
backend=backend,
verbose=verbose,
batch_size=batch_size,
discrete_treatment=discrete_treatment,
categories=categories,
random_state=self.random_state)
@ -654,135 +670,144 @@ class DMLOrthoForest(BaseOrthoForest):
return effects.reshape((-1,) + self._d_y + self._d_t)
const_marginal_effect.__doc__ = BaseOrthoForest.const_marginal_effect.__doc__
@staticmethod
def nuisance_estimator_generator(model_T, model_Y, random_state=None, second_stage=True,
global_residualization=False, discrete_treatment=False):
"""Generate nuissance estimator given model inputs from the class."""
def nuisance_estimator(Y, T, X, W, sample_weight=None, split_indices=None):
if global_residualization:
return 0
if discrete_treatment:
# Check that all discrete treatments are represented
if len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1:
return None
# Nuissance estimates evaluated with cross-fitting
this_random_state = check_random_state(random_state)
if (split_indices is None) and second_stage:
if discrete_treatment:
# Define 2-fold iterator
kfold_it = StratifiedKFold(n_splits=2, shuffle=True, random_state=this_random_state).split(X, T)
# Check if there is only one example of some class
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
split_indices = list(kfold_it)[0]
except Warning as warn:
msg = str(warn)
if "The least populated class in y has only 1 members" in msg:
return None
else:
# Define 2-fold iterator
kfold_it = KFold(n_splits=2, shuffle=True, random_state=this_random_state).split(X)
split_indices = list(kfold_it)[0]
if W is not None:
X_tilde = np.concatenate((X, W), axis=1)
class _DMLOrthoForest_nuisance_estimator_generator:
"""Generate nuissance estimator given model inputs from the class."""
def __init__(self, model_T, model_Y, random_state=None, second_stage=True,
global_residualization=False, discrete_treatment=False):
self.model_T = model_T
self.model_Y = model_Y
self.random_state = random_state
self.second_stage = second_stage
self.global_residualization = global_residualization
self.discrete_treatment = discrete_treatment
def __call__(self, Y, T, X, W, sample_weight=None, split_indices=None):
if self.global_residualization:
return 0
if self.discrete_treatment:
# Check that all discrete treatments are represented
if len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1:
return None
# Nuissance estimates evaluated with cross-fitting
this_random_state = check_random_state(self.random_state)
if (split_indices is None) and self.second_stage:
if self.discrete_treatment:
# Define 2-fold iterator
kfold_it = StratifiedKFold(n_splits=2, shuffle=True, random_state=this_random_state).split(X, T)
# Check if there is only one example of some class
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
split_indices = list(kfold_it)[0]
except Warning as warn:
msg = str(warn)
if "The least populated class in y has only 1 members" in msg:
return None
else:
X_tilde = X
# Define 2-fold iterator
kfold_it = KFold(n_splits=2, shuffle=True, random_state=this_random_state).split(X)
split_indices = list(kfold_it)[0]
if W is not None:
X_tilde = np.concatenate((X, W), axis=1)
else:
X_tilde = X
try:
if second_stage:
T_hat = _cross_fit(model_T, X_tilde, T, split_indices, sample_weight=sample_weight)
Y_hat = _cross_fit(model_Y, X_tilde, Y, split_indices, sample_weight=sample_weight)
else:
# need safe=False when cloning for WeightedModelWrapper
T_hat = clone(model_T, safe=False).fit(X_tilde, T).predict(X_tilde)
Y_hat = clone(model_Y, safe=False).fit(X_tilde, Y).predict(X_tilde)
except ValueError as exc:
raise ValueError("The original error: {0}".format(str(exc)) +
" This might be caused by too few sample in the tree leafs." +
" Try increasing the min_leaf_size.")
return Y_hat, T_hat
try:
if self.second_stage:
T_hat = _cross_fit(self.model_T, X_tilde, T, split_indices, sample_weight=sample_weight)
Y_hat = _cross_fit(self.model_Y, X_tilde, Y, split_indices, sample_weight=sample_weight)
else:
# need safe=False when cloning for WeightedModelWrapper
T_hat = clone(self.model_T, safe=False).fit(X_tilde, T).predict(X_tilde)
Y_hat = clone(self.model_Y, safe=False).fit(X_tilde, Y).predict(X_tilde)
except ValueError as exc:
raise ValueError("The original error: {0}".format(str(exc)) +
" This might be caused by too few sample in the tree leafs." +
" Try increasing the min_leaf_size.")
return Y_hat, T_hat
return nuisance_estimator
@staticmethod
def parameter_estimator_func(Y, T, X,
nuisance_estimates,
sample_weight=None):
"""Calculate the parameter of interest for points given by (Y, T) and corresponding nuisance estimates."""
def _DMLOrthoForest_parameter_estimator_func(Y, T, X,
nuisance_estimates,
sample_weight=None):
"""Calculate the parameter of interest for points given by (Y, T) and corresponding nuisance estimates."""
# Compute residuals
Y_res, T_res = _DMLOrthoForest_get_conforming_residuals(Y, T, nuisance_estimates)
# Compute coefficient by OLS on residuals
param_estimate = LinearRegression(fit_intercept=False).fit(
T_res, Y_res, sample_weight=sample_weight
).coef_
# Parameter returned by LinearRegression is (d_T, )
return param_estimate
class _DMLOrthoForest_second_stage_parameter_estimator_gen:
"""
For the second stage parameter estimation we add a local linear correction. So
we fit a local linear function as opposed to a local constant function. We also penalize
the linear part to reduce variance.
"""
def __init__(self, lambda_reg):
self.lambda_reg = lambda_reg
def __call__(self, Y, T, X,
nuisance_estimates,
sample_weight,
X_single):
"""Calculate the parameter of interest for points given by (Y, T) and corresponding nuisance estimates.
The parameter is calculated around the feature vector given by `X_single`. `X_single` can be used to do
local corrections on a preliminary parameter estimate.
"""
# Compute residuals
Y_res, T_res = DMLOrthoForest._get_conforming_residuals(Y, T, nuisance_estimates)
Y_res, T_res = _DMLOrthoForest_get_conforming_residuals(Y, T, nuisance_estimates)
X_aug = np.hstack([np.ones((X.shape[0], 1)), X])
XT_res = cross_product(T_res, X_aug)
# Compute coefficient by OLS on residuals
param_estimate = LinearRegression(fit_intercept=False).fit(
T_res, Y_res, sample_weight=sample_weight
).coef_
# Parameter returned by LinearRegression is (d_T, )
return param_estimate
if sample_weight is not None:
weighted_XT_res = sample_weight.reshape(-1, 1) * XT_res
else:
weighted_XT_res = XT_res / XT_res.shape[0]
# ell_2 regularization
diagonal = np.ones(XT_res.shape[1])
diagonal[:T_res.shape[1]] = 0
reg = self.lambda_reg * np.diag(diagonal)
# Ridge regression estimate
linear_coef_estimate = np.linalg.lstsq(np.matmul(weighted_XT_res.T, XT_res) + reg,
np.matmul(weighted_XT_res.T, Y_res.reshape(-1, 1)),
rcond=None)[0].flatten()
X_aug = np.append([1], X_single)
linear_coef_estimate = linear_coef_estimate.reshape((X_aug.shape[0], -1)).T
# Parameter returned is of shape (d_T, )
return np.dot(linear_coef_estimate, X_aug)
@staticmethod
def second_stage_parameter_estimator_gen(lambda_reg):
"""
For the second stage parameter estimation we add a local linear correction. So
we fit a local linear function as opposed to a local constant function. We also penalize
the linear part to reduce variance.
"""
def parameter_estimator_func(Y, T, X,
nuisance_estimates,
sample_weight,
X_single):
"""Calculate the parameter of interest for points given by (Y, T) and corresponding nuisance estimates.
The parameter is calculated around the feature vector given by `X_single`. `X_single` can be used to do
local corrections on a preliminary parameter estimate.
"""
# Compute residuals
Y_res, T_res = DMLOrthoForest._get_conforming_residuals(Y, T, nuisance_estimates)
X_aug = np.hstack([np.ones((X.shape[0], 1)), X])
XT_res = cross_product(T_res, X_aug)
# Compute coefficient by OLS on residuals
if sample_weight is not None:
weighted_XT_res = sample_weight.reshape(-1, 1) * XT_res
else:
weighted_XT_res = XT_res / XT_res.shape[0]
# ell_2 regularization
diagonal = np.ones(XT_res.shape[1])
diagonal[:T_res.shape[1]] = 0
reg = lambda_reg * np.diag(diagonal)
# Ridge regression estimate
linear_coef_estimate = np.linalg.lstsq(np.matmul(weighted_XT_res.T, XT_res) + reg,
np.matmul(weighted_XT_res.T, Y_res.reshape(-1, 1)),
rcond=None)[0].flatten()
X_aug = np.append([1], X_single)
linear_coef_estimate = linear_coef_estimate.reshape((X_aug.shape[0], -1)).T
# Parameter returned is of shape (d_T, )
return np.dot(linear_coef_estimate, X_aug)
def _DMLOrthoForest_moment_and_mean_gradient_estimator_func(Y, T, X, W,
nuisance_estimates,
parameter_estimate):
"""Calculate the moments and mean gradient at points given by (Y, T, X, W)."""
# Return moments and gradients
# Compute residuals
Y_res, T_res = _DMLOrthoForest_get_conforming_residuals(Y, T, nuisance_estimates)
# Compute moments
# Moments shape is (n, d_T)
moments = (Y_res - np.matmul(T_res, parameter_estimate)).reshape(-1, 1) * T_res
# Compute moment gradients
mean_gradient = - np.matmul(T_res.T, T_res) / T_res.shape[0]
return moments, mean_gradient
return parameter_estimator_func
@staticmethod
def moment_and_mean_gradient_estimator_func(Y, T, X, W,
nuisance_estimates,
parameter_estimate):
"""Calculate the moments and mean gradient at points given by (Y, T, X, W)."""
# Return moments and gradients
# Compute residuals
Y_res, T_res = DMLOrthoForest._get_conforming_residuals(Y, T, nuisance_estimates)
# Compute moments
# Moments shape is (n, d_T)
moments = (Y_res - np.matmul(T_res, parameter_estimate)).reshape(-1, 1) * T_res
# Compute moment gradients
mean_gradient = - np.matmul(T_res.T, T_res) / T_res.shape[0]
return moments, mean_gradient
@staticmethod
def _get_conforming_residuals(Y, T, nuisance_estimates):
if nuisance_estimates == 0:
return reshape_Y_T(Y, T)
# returns shape-conforming residuals
Y_hat, T_hat = reshape_Y_T(*nuisance_estimates)
Y, T = reshape_Y_T(Y, T)
Y_res, T_res = Y - Y_hat, T - T_hat
return Y_res, T_res
def _DMLOrthoForest_get_conforming_residuals(Y, T, nuisance_estimates):
if nuisance_estimates == 0:
return reshape_Y_T(Y, T)
# returns shape-conforming residuals
Y_hat, T_hat = reshape_Y_T(*nuisance_estimates)
Y, T = reshape_Y_T(Y, T)
Y_res, T_res = Y - Y_hat, T - T_hat
return Y_res, T_res
class DROrthoForest(BaseOrthoForest):
@ -849,6 +874,15 @@ class DROrthoForest(BaseOrthoForest):
``-1`` means using all processors. Since OrthoForest methods are
computationally heavy, it is recommended to set `n_jobs` to -1.
backend : 'threading' or 'loky', optional (default='loky')
What backend should be used for parallelization with the joblib library.
verbose : int, optional (default=3)
Verbosity level
batch_size : int or 'auto', optional (default='auto')
Batch_size of jobs for parallelism
random_state : int, :class:`~numpy.random.mtrand.RandomState` instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator;
@ -871,6 +905,9 @@ class DROrthoForest(BaseOrthoForest):
model_Y_final=None,
categories='auto',
n_jobs=-1,
backend='loky',
verbose=3,
batch_size='auto',
random_state=None):
# Copy and/or define models
self.propensity_model = clone(propensity_model, safe=False)
@ -911,6 +948,9 @@ class DROrthoForest(BaseOrthoForest):
subsample_ratio=subsample_ratio,
bootstrap=bootstrap,
n_jobs=n_jobs,
backend=backend,
verbose=verbose,
batch_size=batch_size,
random_state=self.random_state)
@_deprecate_positional("X and W should be passed by keyword only. In a future release "

Просмотреть файл

@ -20,8 +20,8 @@ from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from ._ortho_learner import _OrthoLearner
from .cate_estimator import StatsModelsCateEstimatorMixin
from .dml import _FinalWrapper
from ._cate_estimator import StatsModelsCateEstimatorMixin
from .dml.dml import _FinalWrapper
from .inference import StatsModelsInference
from .sklearn_extensions.linear_model import StatsModelsLinearRegression
from .utilities import (_deprecate_positional, add_intercept, fit_with_groups, filter_none_kwargs,
@ -483,7 +483,7 @@ class _BaseDMLIV(_OrthoLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -735,7 +735,7 @@ class DMLIV(_BaseDMLIV):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -830,7 +830,7 @@ class NonParamDMLIV(_BaseDMLIV):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -1025,7 +1025,7 @@ class _BaseDRIV(_OrthoLearner):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -1321,7 +1321,7 @@ class IntentToTreatDRIV(_IntentToTreatDRIV):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
@ -1435,7 +1435,7 @@ class LinearIntentToTreatDRIV(StatsModelsCateEstimatorMixin, IntentToTreatDRIV):
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete

Просмотреть файл

Просмотреть файл

@ -4,188 +4,33 @@
""" Subsampled honest forest extension to scikit-learn's forest methods. Contains pieces of code from
scikit-learn's random forest implementation.
"""
import numpy as np
import scipy.sparse
import threading
import sparse as sp
import itertools
from joblib import effective_n_jobs, Parallel, delayed
from sklearn.utils import check_array, check_X_y, issparse
from sklearn.ensemble.forest import ForestRegressor, _accumulate_prediction
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.base import RegressorMixin
from warnings import catch_warnings, simplefilter, warn
from sklearn.exceptions import DataConversionWarning, NotFittedError
from sklearn.utils import check_random_state, check_array, compute_sample_weight
from sklearn.utils.validation import check_is_fitted
MAX_INT = np.iinfo(np.int32).max
from ..grf import RegressionForest
from ..utilities import deprecated
def _parallel_add_trees(tree, forest, X, y, sample_weight, s_inds, tree_idx, n_trees, verbose=0):
"""Private function used to fit a single subsampled honest tree in parallel."""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))
# Construct the subsample of data
X, y = X[s_inds], y[s_inds]
if sample_weight is None:
sample_weight = np.ones((X.shape[0],), dtype=np.float64)
else:
sample_weight = sample_weight[s_inds]
# Split into estimation and splitting sample set
if forest.honest:
X_split, X_est, y_split, y_est,\
sample_weight_split, sample_weight_est = train_test_split(
X, y, sample_weight, test_size=.5, shuffle=True, random_state=tree.random_state)
else:
X_split, X_est, y_split, y_est, sample_weight_split, sample_weight_est =\
X, X, y, y, sample_weight, sample_weight
# Fit the tree on the splitting sample
tree.fit(X_split, y_split, sample_weight=sample_weight_split,
check_input=False)
# Set the estimation values based on the estimation split
total_weight_est = np.sum(sample_weight_est)
# Apply the trained tree on the estimation sample to get the path for every estimation sample
path_est = tree.decision_path(X_est)
# Calculate the total weight of estimation samples on each tree node:
# \sum_i sample_weight[i] * 1{i \\in node}
weight_est = sample_weight_est.reshape(1, -1) @ path_est
# Calculate the total number of estimation samples on each tree node:
# |node| = \sum_{i} 1{i \\in node}
count_est = path_est.sum(axis=0)
# Calculate the weighted sum of responses on the estimation sample on each node:
# \sum_{i} sample_weight[i] 1{i \\in node} Y_i
num_est = (sample_weight_est.reshape(-1, 1) * y_est).T @ path_est
# Calculate the predicted value on each node based on the estimation sample:
# weighted sum of responses / total weight
value_est = num_est / weight_est
# Calculate the criterion on each node based on the estimation sample and for each output dimension,
# summing the impurity across dimensions.
# First we calculate the difference of observed label y of each node and predicted value for each
# node that the sample falls in: y[i] - value_est[node]
impurity_est = np.zeros((1, path_est.shape[1]))
for i in range(tree.n_outputs_):
diff = path_est.multiply(y_est[:, [i]]) - path_est.multiply(value_est[[i], :])
if tree.criterion == 'mse':
# If criterion is mse then calculate weighted sum of squared differences for each node
impurity_est_i = sample_weight_est.reshape(1, -1) @ diff.power(2)
elif tree.criterion == 'mae':
# If criterion is mae then calculate weighted sum of absolute differences for each node
impurity_est_i = sample_weight_est.reshape(1, -1) @ np.abs(diff)
else:
raise AttributeError("Criterion {} not yet supported by SubsampledHonestForest!".format(tree.criterion))
# Normalize each weighted sum of criterion for each node by the total weight of each node
impurity_est += impurity_est_i / weight_est
# Prune tree to remove leafs that don't satisfy the leaf requirements on the estimation sample
# and for each un-pruned tree set the value and the weight appropriately.
children_left = tree.tree_.children_left
children_right = tree.tree_.children_right
stack = [(0, -1)] # seed is the root node id and its parent depth
numerator = np.empty_like(tree.tree_.value)
denominator = np.empty_like(tree.tree_.weighted_n_node_samples)
while len(stack) > 0:
node_id, parent_id = stack.pop()
# If minimum weight requirement or minimum leaf size requirement is not satisfied on estimation
# sample, then prune the whole sub-tree
if weight_est[0, node_id] / total_weight_est < forest.min_weight_fraction_leaf\
or count_est[0, node_id] < forest.min_samples_leaf:
tree.tree_.children_left[parent_id] = -1
tree.tree_.children_right[parent_id] = -1
else:
for i in range(tree.n_outputs_):
# Set the numerator of the node to: \sum_{i} sample_weight[i] 1{i \\in node} Y_i / |node|
numerator[node_id, i] = num_est[i, node_id] / count_est[0, node_id]
# Set the value of the node to:
# \sum_{i} sample_weight[i] 1{i \\in node} Y_i / \sum_{i} sample_weight[i] 1{i \\in node}
tree.tree_.value[node_id, i] = value_est[i, node_id]
# Set the denominator of the node to: \sum_{i} sample_weight[i] 1{i \\in node} / |node|
denominator[node_id] = weight_est[0, node_id] / count_est[0, node_id]
# Set the weight of the node to: \sum_{i} sample_weight[i] 1{i \\in node}
tree.tree_.weighted_n_node_samples[node_id] = weight_est[0, node_id]
# Set the count to the estimation split count
tree.tree_.n_node_samples[node_id] = count_est[0, node_id]
# Set the node impurity to the estimation split impurity
tree.tree_.impurity[node_id] = impurity_est[0, node_id]
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], node_id))
stack.append((children_right[node_id], node_id))
return tree, numerator, denominator
class SubsampledHonestForest(ForestRegressor, RegressorMixin):
@deprecated("The SubsampledHonestForest class has been deprecated by the grf.RegressionForest class; "
"an upcoming release will remove support for the this class.")
def SubsampledHonestForest(n_estimators=100,
criterion="mse",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_decrease=0.,
subsample_fr='auto',
honest=True,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False):
"""
An implementation of a subsampled honest random forest regressor on top of an sklearn
regression tree. Implements subsampling and honesty as described in [3]_,
but uses a scikit-learn regression tree as a base. It provides confidence intervals based on ideas
described in [3]_ and [4]_
A random forest is a meta estimator that fits a number of classifying
decision trees on various sub-samples of the dataset and uses averaging
to improve the predictive accuracy and control over-fitting.
The sub-sample size is smaller than the original size and subsampling is
performed without replacement. Each decision tree is built in an honest
manner: half of the sub-sampled data are used for creating the tree structure
(referred to as the splitting sample) and the other half for calculating the
constant regression estimate at each leaf of the tree (referred to as the estimation sample).
One difference with the algorithm proposed in [3]_ is that we do not ensure balancedness
and we do not consider poisson sampling of the features, so that we guarantee
that each feature has a positive probability of being selected on each split.
Rather we use the original algorithm of Breiman [1]_, which selects the best split
among a collection of candidate splits, as long as the max_depth is not reached
and as long as there are not more than max_leafs and each child contains
at least min_samples_leaf samples and total weight fraction of
min_weight_fraction_leaf. Moreover, it allows the use of both mean squared error (MSE)
and mean absoulte error (MAE) as the splitting criterion. Finally, we allow
for early stopping of the splits if the criterion is not improved by more than
min_impurity_decrease. These techniques that date back to the work of [1]_,
should lead to finite sample performance improvements, especially for
high dimensional features.
The implementation also provides confidence intervals
for each prediction using a bootstrap of little bags approach described in [3]_:
subsampling is performed at hierarchical level by first drawing a set of half-samples
at random and then sub-sampling from each half-sample to build a forest
of forests. All the trees are used for the point prediction and the distribution
of predictions returned by each of the sub-forests is used to calculate the standard error
of the point prediction.
In particular we use a variant of the standard error estimation approach proposed in [4]_,
where, if :math:`\\theta(X)` is the point prediction at X, then the variance of :math:`\\theta(X)`
is computed as:
.. math ::
Var(\\theta(X)) = \\frac{\\hat{V}}{\\left(\\frac{1}{B} \\sum_{b \\in [B], i\\in [n]} w_{b, i}(x)\\right)^2}
where B is the number of trees, n the number of training points, and:
.. math ::
w_{b, i}(x) = \\text{sample\\_weight}[i] \\cdot \\frac{1\\{i \\in \\text{leaf}(x; b)\\}}{|\\text{leaf}(x; b)|}
.. math ::
\\hat{V} = \\text{Var}_{\\text{random half-samples } S}\\left[ \\frac{1}{B_S}\
\\sum_{b\\in S, i\\in [n]} w_{b, i}(x) (Y_i - \\theta(X)) \\right]
where :math:`B_S` is the number of trees in half sample S. The latter variance is approximated by:
.. math ::
\\hat{V} = \\frac{1}{|\\text{drawn half samples } S|} \\sum_{S} \\left( \\frac{1}{B_S}\
\\sum_{b\\in S, i\\in [n]} w_{b, i}(x) (Y_i - \\theta(X)) \\right)^2
This variance calculation does not contain the correction due to finite number of monte carlo half-samples
used (as proposed in [4]_), hence can be a bit conservative when a small number of half samples is used.
However, it is on the conservative side. We use ceil(sqrt(n_estimators)) half samples, and the forest associated
with each such half-sample contains roughly sqrt(n_estimators) trees, amounting to a total of n_estimator trees
overall.
Parameters
----------
n_estimators : integer, optional (default=100)
@ -196,8 +41,7 @@ class SubsampledHonestForest(ForestRegressor, RegressorMixin):
criterion : string, optional (default="mse")
The function to measure the quality of a split. Supported criteria
are "mse" for the mean squared error, which is equal to variance
reduction as feature selection criterion, and "mae" for the mean
absolute error.
reduction as feature selection criterion.
max_depth : integer or None, optional (default=None)
The maximum depth of the tree. If None, then nodes are expanded until
@ -321,62 +165,9 @@ class SubsampledHonestForest(ForestRegressor, RegressorMixin):
The chosen subsample ratio. Eache tree was trained on ``subsample_fr_ * n_samples / 2``
data points.
Examples
--------
.. testcode::
import numpy as np
from econml.sklearn_extensions.ensemble import SubsampledHonestForest
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
np.random.seed(123)
X, y = make_regression(n_samples=1000, n_features=4, n_informative=2,
random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5)
regr = SubsampledHonestForest(max_depth=None, random_state=0,
n_estimators=1000)
>>> regr.fit(X_train, y_train)
SubsampledHonestForest(n_estimators=1000, random_state=0)
>>> regr.feature_importances_
array([0.64..., 0.33..., 0.01..., 0.01...])
>>> regr.predict(np.ones((1, 4)))
array([112.9...])
>>> regr.predict_interval(np.ones((1, 4)), alpha=.05)
(array([94.9...]), array([130.9...]))
>>> regr.score(X_test, y_test)
0.94...
Notes
-----
The default values for the parameters controlling the size of the trees
(e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
unpruned trees which can potentially be very large on some data sets. To
reduce memory consumption, the complexity and size of the trees should be
controlled by setting those parameter values. For valid inference, the trees
are recommended to be fully grown.
The features are always randomly permuted at each split. Therefore,
the best found split may vary, even with the same training data,
``max_features=n_features``, if the improvement
of the criterion is identical for several splits enumerated during the
search of the best split. To obtain a deterministic behaviour during
fitting, ``random_state`` has to be fixed.
The default value ``max_features="auto"`` uses ``n_features``
rather than ``n_features / 3``. The latter was originally suggested in
[1]_, whereas the former was more recently justified empirically in [2]_.
References
----------
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
.. [2] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized
trees", Machine Learning, 63(1), 3-42, 2006.
.. [3] S. Athey, S. Wager, "Estimation and Inference of Heterogeneous Treatment Effects using Random Forests",
Journal of the American Statistical Association 113.523 (2018): 1228-1242.
@ -384,391 +175,17 @@ class SubsampledHonestForest(ForestRegressor, RegressorMixin):
The Annals of Statistics, 47(2), 1148-1178, 2019.
"""
def __init__(self,
n_estimators=100,
criterion="mse",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_decrease=0.,
subsample_fr='auto',
honest=True,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False):
super().__init__(
base_estimator=DecisionTreeRegressor(),
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"min_impurity_decrease",
"random_state"),
bootstrap=False,
oob_score=False,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
self.n_estimators = n_estimators
self.criterion = criterion
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.subsample_fr = subsample_fr
self.honest = honest
self.random_state = random_state
self.estimators_ = None
self.vars_ = None
self.subsample_fr_ = None
return
def fit(self, X, y, sample_weight=None, sample_var=None):
"""
Fit the forest.
Parameters
----------
X : ndarray or scipy.sparse matrix, (n_samples, n_features)
Input data.
y : array, shape (n_samples, n_outputs)
Target. Will be cast to X's dtype if necessary
sample_weight : numpy array of shape [n_samples]
Individual weights for each sample. Weights will not be normalized. The weighted square loss
will be minimized by the forest.
sample_var : numpy array of shape [n_samples, n_outputs]
Variance of composite samples (not used here. Exists for API compatibility)
Returns
-------
self
"""
# Validate or convert input data
X = check_array(X, accept_sparse="csc", dtype=np.float32)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
X.sort_indices()
# Remap output
self.n_features_ = X.shape[1]
y = np.atleast_1d(y)
if y.ndim == 2 and y.shape[1] == 1:
warn("A column-vector y was passed when a 1d array was"
" expected. Please change the shape of y to "
"(n_samples,), for example using ravel().",
DataConversionWarning, stacklevel=2)
if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = np.reshape(y, (-1, 1))
self.n_outputs_ = y.shape[1]
y, expanded_class_weight = self._validate_y_class_weight(y)
if getattr(y, "dtype", None) != np.float64 or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=np.float64)
if expanded_class_weight is not None:
if sample_weight is not None:
sample_weight = sample_weight * expanded_class_weight
else:
sample_weight = expanded_class_weight
if self.subsample_fr == 'auto':
self.subsample_fr_ = (
X.shape[0] / 2)**(1 - 1 / (2 * X.shape[1] + 2)) / (X.shape[0] / 2)
else:
self.subsample_fr_ = self.subsample_fr
# Check parameters
self._validate_estimator()
random_state = check_random_state(self.random_state)
if not self.warm_start or not hasattr(self, "estimators_"):
# Free allocated memory, if any
self.estimators_ = []
self.numerators_ = []
self.denominators_ = []
n_more_estimators = self.n_estimators - len(self.estimators_)
if n_more_estimators < 0:
raise ValueError('n_estimators=%d must be larger or equal to '
'len(estimators_)=%d when warm_start==True'
% (self.n_estimators, len(self.estimators_)))
elif n_more_estimators == 0:
warn("Warm-start fitting without increasing n_estimators does not "
"fit new trees.")
else:
if self.warm_start and len(self.estimators_) > 0:
# We draw from the random state to get the random state we
# would have got if we hadn't used a warm_start.
random_state.randint(MAX_INT, size=len(self.estimators_))
trees = [self._make_estimator(append=False,
random_state=random_state)
for i in range(n_more_estimators)]
# Parallel loop: we prefer the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading more efficient than multiprocessing in
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.
self.n_slices = int(np.ceil((self.n_estimators)**(1 / 2)))
self.slice_len = int(np.ceil(self.n_estimators / self.n_slices))
s_inds = []
# TODO. This slicing should ultimately be done inside the parallel function
# so that we don't need to create a matrix of size roughly n_samples * n_estimators
for it in range(self.n_slices):
half_sample_inds = random_state.choice(
X.shape[0], X.shape[0] // 2, replace=False)
for _ in np.arange(it * self.slice_len, min((it + 1) * self.slice_len, self.n_estimators)):
s_inds.append(half_sample_inds[random_state.choice(X.shape[0] // 2,
int(np.ceil(self.subsample_fr_ *
(X.shape[0] // 2))),
replace=False)])
res = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer='threads')(
delayed(_parallel_add_trees)(
t, self, X, y, sample_weight, s_inds[i], i, len(trees),
verbose=self.verbose)
for i, t in enumerate(trees))
trees, numerators, denominators = zip(*res)
# Collect newly grown trees
self.estimators_.extend(trees)
self.numerators_.extend(numerators)
self.denominators_.extend(denominators)
return self
def _mean_fn(self, X, fn, acc, slice=None):
# Helper class that accumulates an arbitrary function in parallel on the accumulator acc
# and calls the function fn on each tree e and returns the mean output. The function fn
# should take as input a tree e and associated numerator n and denominator d structures and
# return another function g_e, which takes as input X, check_input
# If slice is not None, but rather a tuple (start, end), then a subset of the trees from
# index start to index end will be used. The returned result is essentially:
# (mean over e in slice)(g_e(X)).
check_is_fitted(self, 'estimators_')
# Check data
X = self._validate_X_predict(X)
if slice is None:
estimator_slice = zip(self.estimators_, self.numerators_, self.denominators_)
n_estimators = len(self.estimators_)
else:
estimator_slice = zip(self.estimators_[slice[0]:slice[1]], self.numerators_[slice[0]:slice[1]],
self.denominators_[slice[0]:slice[1]])
n_estimators = slice[1] - slice[0]
# Assign chunk of trees to jobs
n_jobs = min(effective_n_jobs(self.n_jobs), n_estimators)
lock = threading.Lock()
Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
delayed(_accumulate_prediction)(fn(e, n, d), X, [acc], lock)
for e, n, d in estimator_slice)
acc /= n_estimators
return acc
def _weight(self, X, slice=None):
"""
Returns the cummulative sum of training data weights for each of a target set of samples
Parameters
----------
X : (n, d_x) array
The target samples
slice : tuple(int, int) or None
(start, end) tree index of the slice to be used
Returns
-------
W : (n,) array
For each sample x in X, it returns the quantity:
1/B_S \\sum_{b \\in S} \\sum_{i\\in [n]} sample\\_weight[i] * 1{i \\in leaf(x; b)} / |leaf(x; b)|.
where S is the slice of estimators chosen. If slice is None, then all estimators are used else
the slice start:end is used.
"""
# Check data
X = self._validate_X_predict(X)
weight_hat = np.zeros((X.shape[0]), dtype=np.float64)
return self._mean_fn(X, lambda e, n, d: (lambda x, check_input: d[e.apply(x)]),
weight_hat, slice=slice)
def _predict(self, X, slice=None):
"""Construct un-normalized numerator of the prediction for taret X, which when divided by weights
creates the point prediction. Allows for subselecting the set of trees to use.
The predicted regression unnormalized target of an input sample is computed as the
mean predicted regression unnormalized targets of the trees in the forest.
Parameters
----------
X : array-like or sparse matrix of shape = [n_samples, n_features]
The input samples. Internally, its dtype will be converted to
``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csr_matrix``.
slice : tuple(int, int) or None
(start, end) tree index of the slice to be used
Returns
-------
y : array of shape = [n_samples] or [n_samples, n_outputs]
The predicted values based on the subset of estimators in the slice start:end
(all estimators, if slice=None). This is equivalent to:
1/B_S \\sum_{b\\in S} \\sum_{i\\in [n]} sample_weight[i] * 1{i \\in leaf(x; b)} * Y_i / |leaf(x; b)|
"""
# Check data
X = self._validate_X_predict(X)
# avoid storing the output of every estimator by summing them here
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
y_hat = self._mean_fn(X, lambda e, n, d: (lambda x, check_input: n[e.apply(x), :, 0]), y_hat, slice=slice)
if self.n_outputs_ == 1:
y_hat = y_hat.flatten()
return y_hat
def _inference(self, X, stderr=False):
"""
Returns the point prediction for a set of samples X and if stderr=True, also returns stderr of the prediction.
For standard error calculation it implements the bootstrap of little bags approach proposed
in the GRF paper for estimating variance of estimate, specialized to this setup.
.. math ::
Var(\\theta(X)) = \\frac{V_hat}{(1/B \\sum_{b \\in [B], i\\in [n]} w_{b, i}(x))^2}
where B is the number of trees, n the number of training points,
.. math ::
w_{b, i}(x) = sample\\_weight[i] \\cdot 1{i \\in leaf(x; b)} / |leaf(x; b)|
.. math ::
V_hat = Var_{random half-samples S}[ 1/B_S \\sum_{b\\in S, i\\in [n]} w_{b, i}(x) (Y_i - \\theta(X)) ]
= E_S[(1/B_S \\sum_{b\\in S, i\\in [n]} w_{b, i}(x) (Y_i - \\theta(X)))^2]
where B_S is the number of trees in half sample S. This variance calculation does not contain the
correction due to finite number of monte carlo half-samples used, hence can be a bit conservative
when a small number of half samples is used. However, it is on the conservative side.
Parameters
----------
X : (n, d_x) array
The target samples
stderr : bool, optional (default=2)
Whether to return stderr information for each prediction
Returns
-------
Y_pred : (n,) or (n, d_y) array
For each sample x in X, it returns the point prediction
stderr : (n,) or (n, d_y) array
The standard error for each prediction. Returned only if stderr=True.
"""
y_pred = self._predict(X) # get 1/B \sum_{b, i} w_{b, i}(x) Y_i
weight_hat = self._weight(X) # get 1/B \sum_{b, i} w_{b, i}(x)
if len(y_pred.shape) > 1:
weight_hat = weight_hat[:, np.newaxis]
y_point_pred = y_pred / weight_hat # point prediction: \sum_{b, i} w_{b, i} Y_i / \sum_{b, i} w_{b, i}
if stderr:
def slice_inds(it):
return (it * self.slice_len, min((it + 1) * self.slice_len, self.n_estimators))
# Calculate for each slice S: 1/B_S \sum_{b\in S, i\in [n]} w_{b, i}(x) Y_i
y_bags_pred = np.array([self._predict(X, slice=slice_inds(it))
for it in range(self.n_slices)])
# Calculate for each slice S: 1/B_S \sum_{b\in S, i\in [n]} w_{b, i}(x)
weight_hat_bags = np.array([self._weight(X, slice=slice_inds(it))
for it in range(self.n_slices)])
if np.ndim(y_bags_pred) > 2:
weight_hat_bags = weight_hat_bags[:, :, np.newaxis]
# Calculate for each slice S: Q(S) = 1/B_S \sum_{b\in S, i\in [n]} w_{b, i}(x) (Y_i - \theta(X))
# where \theta(X) is the point estimate using the whole forest
bag_res = y_bags_pred - weight_hat_bags * \
np.expand_dims(y_point_pred, axis=0)
# Calculate the variance of the latter as E[Q(S)^2]
std_pred = np.sqrt(np.nanmean(bag_res**2, axis=0)) / weight_hat
return y_point_pred, std_pred
return y_point_pred
def predict(self, X):
"""
Returns point prediction.
Parameters
----------
X : (n, d_x) array
Features
Returns
-------
y_pred : (n,) or (n, d_y) array
Point predictions
"""
return self._inference(X)
def prediction_stderr(self, X):
"""
Returns the standard deviation of the point prediction.
Parameters
----------
X : (n, d_x) array
Features
Returns
-------
pred_stderr : (n,) or (n, d_y) array
The standard error for each point prediction
"""
_, pred_stderr = self._inference(X, stderr=True)
return pred_stderr
def predict_interval(self, X, alpha=.1, normal=True):
"""
Return the confidence interval of the prediction.
Parameters
----------
X : (n, d_x) array
Features
alpha : float
The significance level of the interval
Returns
-------
lb, ub : tuple(shape of :meth:`predict(X)<predict>`, shape of :meth:`predict(X)<predict>`)
The lower and upper bound of an alpha-confidence interval for each prediction
"""
y_point_pred, pred_stderr = self._inference(X, stderr=True)
upper_pred = scipy.stats.norm.ppf(
1 - alpha / 2, loc=y_point_pred, scale=pred_stderr)
lower_pred = scipy.stats.norm.ppf(
alpha / 2, loc=y_point_pred, scale=pred_stderr)
return lower_pred, upper_pred
return RegressionForest(n_estimators=n_estimators,
criterion=criterion,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
max_samples=.45 if subsample_fr == 'auto' else subsample_fr / 2,
honest=honest,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)

Просмотреть файл

@ -365,7 +365,7 @@ class WeightedLassoCV(WeightedModelMixin, LassoCV):
Possible inputs for cv are:
- None, to use the default 3-fold weighted cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`,
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, :class:`WeightedKFold` is used.
@ -472,7 +472,7 @@ class WeightedMultiTaskLassoCV(WeightedModelMixin, MultiTaskLassoCV):
Possible inputs for cv are:
- None, to use the default 3-fold weighted cross-validation,
- integer, to specify the number of folds.
- :term:`cv splitter`,
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, :class:`WeightedKFold` is used.

Просмотреть файл

@ -21,12 +21,13 @@ from sklearn.utils.validation import _num_samples
def _split_weighted_sample(self, X, y, sample_weight, is_stratified=False):
random_state = self.random_state if self.shuffle else None
if is_stratified:
kfold_model = StratifiedKFold(n_splits=self.n_splits, shuffle=self.shuffle,
random_state=self.random_state)
random_state=random_state)
else:
kfold_model = KFold(n_splits=self.n_splits, shuffle=self.shuffle,
random_state=self.random_state)
random_state=random_state)
if sample_weight is None:
return kfold_model.split(X, y)
weights_sum = np.sum(sample_weight)
@ -44,7 +45,10 @@ def _split_weighted_sample(self, X, y, sample_weight, is_stratified=False):
max_deviations.append(max_deviation)
# Reseed random generator and try again
kfold_model.shuffle = True
kfold_model.random_state = None
if isinstance(kfold_model.random_state, numbers.Integral):
kfold_model.random_state = kfold_model.random_state + 1
elif kfold_model.random_state is not None:
kfold_model.random_state = np.random.RandomState(kfold_model.random_state.randint(np.iinfo(np.int32).max))
# If KFold fails after n_trials, we try the next best thing: stratifying by weight groups
warnings.warn("The KFold algorithm failed to find a weight-balanced partition after " +
@ -227,7 +231,7 @@ class GridSearchCVList(BaseEstimator):
def __init__(self, estimator_list, param_grid_list, scoring=None,
n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise-deprecating', return_train_score=False):
error_score=np.nan, return_train_score=False):
self.estimator_list = estimator_list
self.param_grid_list = param_grid_list
self.scoring = scoring
@ -240,7 +244,7 @@ class GridSearchCVList(BaseEstimator):
self.return_train_score = return_train_score
return
def fit(self, X, y, **fit_params):
def fit(self, X, y=None, **fit_params):
self._gcv_list = [GridSearchCV(estimator, param_grid, scoring=self.scoring,
n_jobs=self.n_jobs, refit=self.refit, cv=self.cv, verbose=self.verbose,
pre_dispatch=self.pre_dispatch, error_score=self.error_score,
@ -347,6 +351,11 @@ def _cross_val_predict(estimator, X, y=None, *, groups=None, cv=None,
X, y, groups = indexable(X, y, groups)
cv = check_cv(cv, y, classifier=is_classifier(estimator))
splits = list(cv.split(X, y, groups))
test_indices = np.concatenate([test for _, test in splits])
if not _check_is_permutation(test_indices, _num_samples(X)):
raise ValueError('cross_val_predict only works for partitions')
# If classification methods produce multiple columns of output,
# we need to manually encode classes to ensure consistent column ordering.
@ -358,7 +367,7 @@ def _cross_val_predict(estimator, X, y=None, *, groups=None, cv=None,
le = LabelEncoder()
y = le.fit_transform(y)
elif y.ndim == 2:
y_enc = np.zeros_like(y, dtype=np.int)
y_enc = np.zeros_like(y, dtype=int)
for i_label in range(y.shape[1]):
y_enc[:, i_label] = LabelEncoder().fit_transform(y[:, i_label])
y = y_enc
@ -367,17 +376,12 @@ def _cross_val_predict(estimator, X, y=None, *, groups=None, cv=None,
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch=pre_dispatch)
prediction_blocks = parallel(delayed(_fit_and_predict)(
# TODO. The API of the private scikit-learn `_fit_and_predict` has changed
# between 0.23.2 and 0.24. For this to work with <0.24, we need to add a
# case analysis based on sklearn version.
predictions = parallel(delayed(_fit_and_predict)(
clone(estimator, safe=safe), X, y, train, test, verbose, fit_params, method)
for train, test in cv.split(X, y, groups))
# Concatenate the predictions
predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
test_indices = np.concatenate([indices_i
for _, indices_i in prediction_blocks])
if not _check_is_permutation(test_indices, _num_samples(X)):
raise ValueError('cross_val_predict only works for partitions')
for train, test in splits)
inv_test_indices = np.empty(len(test_indices), dtype=int)
inv_test_indices[test_indices] = np.arange(len(test_indices))

0
econml/tests/__init__.py Normal file
Просмотреть файл

Просмотреть файл

@ -27,7 +27,7 @@ class TestCateInterpreter(unittest.TestCase):
for y_shape in [(n,), (n, 1)]:
X = np.random.normal(size=(n, 4))
T = np.random.binomial(1, 0.5, size=t_shape)
Y = np.random.normal(size=y_shape)
Y = T.flatten() * (X[:, 0] > 0)
est = LinearDML(discrete_treatment=True)
est.fit(Y, T, X=X)
for intrp in [SingleTreeCateInterpreter(), SingleTreePolicyInterpreter()]:

Просмотреть файл

@ -8,8 +8,8 @@ from sklearn.linear_model import LinearRegression, Lasso, LassoCV, LogisticRegre
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, PolynomialFeatures
from sklearn.model_selection import KFold, GroupKFold
from econml.dml import DML, LinearDML, SparseLinearDML, KernelDML
from econml.dml import NonParamDML, ForestDML
from econml.dml import DML, LinearDML, SparseLinearDML, KernelDML, CausalForestDML
from econml.dml import NonParamDML
import numpy as np
from econml.utilities import shape, hstack, vstack, reshape, cross_product
from econml.inference import BootstrapInference
@ -145,11 +145,22 @@ class TestDML(unittest.TestCase):
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=is_discrete),
False,
[None])]:
[None]),
(CausalForestDML(model_y=WeightedLasso(),
model_t=model_t,
featurizer=featurizer,
n_estimators=4,
n_jobs=1,
discrete_treatment=is_discrete),
True,
['auto', 'blb'])]:
if not(multi) and d_y > 1:
continue
if X is None and isinstance(est, CausalForestDML):
continue
# ensure we can serialize the unfit estimator
pickle.dumps(est)
@ -173,7 +184,7 @@ class TestDML(unittest.TestCase):
self.assertEqual(shape(marg_eff), marginal_effect_shape)
self.assertEqual(shape(const_marg_eff), const_marginal_effect_shape)
np.testing.assert_array_equal(
np.testing.assert_allclose(
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
assert isinstance(est.score_, float)
@ -186,7 +197,8 @@ class TestDML(unittest.TestCase):
eff = est.effect(X, T0=T0, T1=T)
self.assertEqual(shape(eff), effect_shape)
if not isinstance(est, KernelDML):
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
self.assertEqual(shape(est.coef_), coef_shape)
if fit_cate_intercept:
self.assertEqual(shape(est.intercept_), intercept_shape)
@ -203,7 +215,8 @@ class TestDML(unittest.TestCase):
(2,) + const_marginal_effect_shape)
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
(2,) + effect_shape)
if not isinstance(est, KernelDML):
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
self.assertEqual(shape(est.coef__interval()),
(2,) + coef_shape)
if fit_cate_intercept:
@ -278,7 +291,8 @@ class TestDML(unittest.TestCase):
marg_effect_inf.population_summary()._repr_html_()
# test coef__inference and intercept__inference
if not isinstance(est, KernelDML):
if ((not isinstance(est, KernelDML)) and
(not isinstance(est, CausalForestDML))):
if X is not None:
self.assertEqual(
shape(est.coef__inference().summary_frame()),
@ -304,6 +318,10 @@ class TestDML(unittest.TestCase):
est.score(Y, T, X, W)
if isinstance(est, CausalForestDML):
np.testing.assert_array_equal(est.feature_importances_.shape,
((d_y,) if d_y > 0 else()) + fd_x)
# make sure we can call effect with implied scalar treatments,
# no matter the dimensions of T, and also that we warn when there
# are multiple treatments
@ -380,12 +398,7 @@ class TestDML(unittest.TestCase):
featurizer=FunctionTransformer(),
discrete_treatment=is_discrete),
True,
base_infs),
(ForestDML(model_y=WeightedLasso(),
model_t=model_t,
discrete_treatment=is_discrete),
True,
base_infs + ['auto', 'blb'])]:
base_infs), ]:
if not(multi) and d_y > 1:
continue
@ -412,10 +425,6 @@ class TestDML(unittest.TestCase):
eff = est.effect(X, T0=T0, T1=T)
self.assertEqual(shape(eff), effect_shape)
if isinstance(est, ForestDML):
np.testing.assert_array_equal(est.feature_importances_.shape,
[X.shape[1]])
if inf is not None:
const_marg_eff_int = est.const_marginal_effect_interval(X)
marg_eff_int = est.marginal_effect_interval(T, X)
@ -601,24 +610,21 @@ class TestDML(unittest.TestCase):
y_sum = np.concatenate((y1_sum, y2_sum)) # outcome
n_sum = np.concatenate((n1_sum, n2_sum)) # number of summarized points
var_sum = np.concatenate((var1_sum, var2_sum)) # variance of the summarized points
for summarized, min_samples_leaf, sample_var in [(False, 20, False), (True, 1, True), (True, 1, False)]:
est = ForestDML(model_y=GradientBoostingRegressor(n_estimators=30, min_samples_leaf=30),
model_t=GradientBoostingClassifier(n_estimators=30, min_samples_leaf=30),
discrete_treatment=True,
n_crossfit_splits=2,
n_estimators=1000,
subsample_fr=.8,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=0.001,
verbose=0, min_weight_fraction_leaf=.03,
random_state=12345)
for summarized, min_samples_leaf in [(False, 20), (True, 1)]:
est = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=30, min_samples_leaf=30),
model_t=GradientBoostingClassifier(n_estimators=30, min_samples_leaf=30),
discrete_treatment=True,
n_crossfit_splits=2,
n_estimators=1000,
max_samples=.4,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=0.001,
verbose=0, min_var_fraction_leaf=.1,
fit_intercept=False,
random_state=12345)
if summarized:
if sample_var:
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum, sample_var=var_sum)
else:
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
@ -629,23 +635,20 @@ class TestDML(unittest.TestCase):
np.testing.assert_array_less(lb - .01, truth)
np.testing.assert_array_less(truth, ub + .01)
est = ForestDML(model_y=GradientBoostingRegressor(n_estimators=50, min_samples_leaf=100),
model_t=GradientBoostingRegressor(n_estimators=50, min_samples_leaf=100),
discrete_treatment=False,
n_crossfit_splits=2,
n_estimators=1000,
subsample_fr=.8,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=0.001,
verbose=0, min_weight_fraction_leaf=.03,
random_state=12345)
est = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=50, min_samples_leaf=100),
model_t=GradientBoostingRegressor(n_estimators=50, min_samples_leaf=100),
discrete_treatment=False,
n_crossfit_splits=2,
n_estimators=1000,
max_samples=.4,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=0.001,
verbose=0, min_var_fraction_leaf=.1,
fit_intercept=False,
random_state=12345)
if summarized:
if sample_var:
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum, sample_var=var_sum)
else:
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
@ -1049,19 +1052,3 @@ class TestDML(unittest.TestCase):
est = LinearDML(n_splits=GroupKFold(2))
with pytest.raises(Exception):
est.fit(y, t, groups=groups)
def test_deprecation(self):
from econml.dml import LinearDMLCateEstimator
# make sure we warn when using old aliases
with self.assertWarns(FutureWarning):
est = LinearDMLCateEstimator()
# make sure we can use the old alias as a type
self.assertIsInstance(est, LinearDMLCateEstimator)
# make sure that we can still pickle the old aliases
import pickle
d = pickle.dumps(LinearDMLCateEstimator())
e = pickle.loads(d)

Просмотреть файл

@ -514,178 +514,181 @@ class TestDRLearner(unittest.TestCase):
for W in [None, controls]:
for sample_weight, sample_var in [(None, None), (np.ones(T.shape[0]), np.zeros(T.shape[0]))]:
for featurizer in [None, PolynomialFeatures(degree=2, include_bias=False)]:
for models in [(GradientBoostingClassifier(), GradientBoostingRegressor()),
(LogisticRegression(solver='lbfgs', multi_class='auto'),
LinearRegression())]:
for est_class,\
inference in [(ForestDRLearner, 'auto'),
(LinearDRLearner, 'auto'),
(LinearDRLearner, StatsModelsInferenceDiscrete(
for model_t, model_y, est_class,\
inference in [(GradientBoostingClassifier(), GradientBoostingRegressor(),
ForestDRLearner, 'auto'),
(LogisticRegression(solver='lbfgs', multi_class='auto'),
LinearRegression(), LinearDRLearner, 'auto'),
(LogisticRegression(solver='lbfgs', multi_class='auto'),
LinearRegression(), LinearDRLearner, StatsModelsInferenceDiscrete(
cov_type='nonrobust')),
(SparseLinearDRLearner, 'auto')]:
with self.subTest(X=X, W=W, sample_weight=sample_weight, sample_var=sample_var,
featurizer=featurizer, models=models,
est_class=est_class, inference=inference):
if (X is None) and (est_class == SparseLinearDRLearner):
continue
if (X is None) and (est_class == ForestDRLearner):
continue
if (featurizer is not None) and (est_class == ForestDRLearner):
continue
(LogisticRegression(solver='lbfgs', multi_class='auto'),
LinearRegression(), SparseLinearDRLearner, 'auto')
]:
with self.subTest(X=X, W=W, sample_weight=sample_weight, sample_var=sample_var,
featurizer=featurizer, model_y=model_y, model_t=model_t,
est_class=est_class, inference=inference):
if (X is None) and (est_class == SparseLinearDRLearner):
continue
if (X is None) and (est_class == ForestDRLearner):
continue
if (featurizer is not None) and (est_class == ForestDRLearner):
continue
if est_class == ForestDRLearner:
est = est_class(model_propensity=models[0],
model_regression=models[1])
else:
est = est_class(model_propensity=models[0],
model_regression=models[1],
featurizer=featurizer)
if est_class == ForestDRLearner:
est = est_class(model_propensity=model_t,
model_regression=model_y,
n_estimators=1000)
else:
est = est_class(model_propensity=model_t,
model_regression=model_y,
featurizer=featurizer)
if (X is None) and (W is None):
with pytest.raises(AttributeError) as e_info:
est.fit(
y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var)
continue
est.fit(y, T, X=X, W=W, sample_weight=sample_weight,
sample_var=sample_var, inference=inference)
if (X is None) and (W is None):
with pytest.raises(AttributeError) as e_info:
est.fit(
y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var)
continue
est.fit(y, T, X=X, W=W, sample_weight=sample_weight,
sample_var=sample_var, inference=inference)
if X is not None:
lower, upper = est.effect_interval(
X[:3], T0=0, T1=1)
point = est.effect(X[:3], T0=0, T1=1)
truth = 1 + .5 * X[:3, 0]
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.const_marginal_effect_interval(
X[:3])
point = est.const_marginal_effect(
X[:3])
truth = np.hstack(
[1 + .5 * X[:3, [0]], 2 * (1 + .5 * X[:3, [0]])])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
else:
lower, upper = est.effect_interval(
T0=0, T1=1)
point = est.effect(T0=0, T1=1)
truth = np.array([1])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.const_marginal_effect_interval()
point = est.const_marginal_effect()
truth = np.array([[1, 2]])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
for t in [1, 2]:
if X is not None:
lower, upper = est.effect_interval(
X[:3], T0=0, T1=1)
point = est.effect(X[:3], T0=0, T1=1)
truth = 1 + .5 * X[:3, 0]
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.const_marginal_effect_interval(
X[:3])
point = est.const_marginal_effect(
X[:3])
lower, upper = est.marginal_effect_interval(
t, X[:3])
point = est.marginal_effect(
t, X[:3])
truth = np.hstack(
[1 + .5 * X[:3, [0]], 2 * (1 + .5 * X[:3, [0]])])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
else:
lower, upper = est.effect_interval(
T0=0, T1=1)
point = est.effect(T0=0, T1=1)
truth = np.array([1])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.const_marginal_effect_interval()
point = est.const_marginal_effect()
lower, upper = est.marginal_effect_interval(
t)
point = est.marginal_effect(t)
truth = np.array([[1, 2]])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
assert isinstance(est.score_, float)
assert isinstance(
est.score(y, T, X=X, W=W), float)
for t in [1, 2]:
if X is not None:
lower, upper = est.marginal_effect_interval(
t, X[:3])
point = est.marginal_effect(
t, X[:3])
truth = np.hstack(
[1 + .5 * X[:3, [0]], 2 * (1 + .5 * X[:3, [0]])])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
else:
lower, upper = est.marginal_effect_interval(
t)
point = est.marginal_effect(t)
truth = np.array([[1, 2]])
TestDRLearner._check_with_interval(
truth, point, lower, upper)
assert isinstance(est.score_, float)
assert isinstance(
est.score(y, T, X=X, W=W), float)
if X is not None:
feature_names = ['A', 'B']
else:
feature_names = []
out_feat_names = feature_names
if X is not None:
if (featurizer is not None):
out_feat_names = featurizer.fit(
X).get_feature_names(feature_names)
np.testing.assert_array_equal(
est.featurizer.n_input_features_, 2)
np.testing.assert_array_equal(est.cate_feature_names(feature_names),
out_feat_names)
if isinstance(model_t, GradientBoostingClassifier):
np.testing.assert_array_equal(np.array([mdl.feature_importances_
for mdl
in est.models_regression]).shape,
[2, 2 + len(feature_names) +
(W.shape[1] if W is not None else 0)])
np.testing.assert_array_equal(np.array([mdl.feature_importances_
for mdl
in est.models_propensity]).shape,
[2, len(feature_names) +
(W.shape[1] if W is not None else 0)])
else:
np.testing.assert_array_equal(np.array([mdl.coef_
for mdl
in est.models_regression]).shape,
[2, 2 + len(feature_names) +
(W.shape[1] if W is not None else 0)])
np.testing.assert_array_equal(np.array([mdl.coef_
for mdl
in est.models_propensity]).shape,
[2, 3, len(feature_names) +
(W.shape[1] if W is not None else 0)])
if isinstance(est, LinearDRLearner) or isinstance(est, SparseLinearDRLearner):
if X is not None:
feature_names = ['A', 'B']
else:
feature_names = []
out_feat_names = feature_names
if X is not None:
if (featurizer is not None):
out_feat_names = featurizer.fit(
X).get_feature_names(feature_names)
np.testing.assert_array_equal(
est.featurizer.n_input_features_, 2)
np.testing.assert_array_equal(est.cate_feature_names(feature_names),
out_feat_names)
if isinstance(models[0], GradientBoostingClassifier):
np.testing.assert_array_equal(np.array([mdl.feature_importances_
for mdl
in est.models_regression]).shape,
[2, 2 + len(feature_names) +
(W.shape[1] if W is not None else 0)])
np.testing.assert_array_equal(np.array([mdl.feature_importances_
for mdl
in est.models_propensity]).shape,
[2, len(feature_names) +
(W.shape[1] if W is not None else 0)])
else:
np.testing.assert_array_equal(np.array([mdl.coef_
for mdl
in est.models_regression]).shape,
[2, 2 + len(feature_names) +
(W.shape[1] if W is not None else 0)])
np.testing.assert_array_equal(np.array([mdl.coef_
for mdl
in est.models_propensity]).shape,
[2, 3, len(feature_names) +
(W.shape[1] if W is not None else 0)])
if isinstance(est, LinearDRLearner) or isinstance(est, SparseLinearDRLearner):
if X is not None:
for t in [1, 2]:
true_coef = np.zeros(
len(out_feat_names))
true_coef[0] = .5 * t
lower, upper = est.model_cate(
T=t).coef__interval()
point = est.model_cate(
T=t).coef_
truth = true_coef
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.coef__interval(
t)
point = est.coef_(t)
truth = true_coef
TestDRLearner._check_with_interval(
truth, point, lower, upper)
# test coef__inference function works
est.coef__inference(
t).summary_frame()
np.testing.assert_array_almost_equal(
est.coef__inference(t).conf_int()[0], lower, decimal=5)
for t in [1, 2]:
true_coef = np.zeros(
len(out_feat_names))
true_coef[0] = .5 * t
lower, upper = est.model_cate(
T=t).intercept__interval()
T=t).coef__interval()
point = est.model_cate(
T=t).intercept_
truth = t
T=t).coef_
truth = true_coef
TestDRLearner._check_with_interval(
truth, point, lower, upper)
lower, upper = est.intercept__interval(
lower, upper = est.coef__interval(
t)
point = est.intercept_(t)
truth = t
point = est.coef_(t)
truth = true_coef
TestDRLearner._check_with_interval(
truth, point, lower, upper)
# test intercept__inference function works
est.intercept__inference(
# test coef__inference function works
est.coef__inference(
t).summary_frame()
np.testing.assert_array_almost_equal(
est.intercept__inference(t).conf_int()[0], lower, decimal=5)
# test summary function works
est.summary(t)
est.coef__inference(t).conf_int()[0], lower, decimal=5)
for t in [1, 2]:
lower, upper = est.model_cate(
T=t).intercept__interval()
point = est.model_cate(
T=t).intercept_
truth = t
TestDRLearner._check_with_interval(
truth, point, lower, upper)
if isinstance(est, ForestDRLearner):
for t in [1, 2]:
np.testing.assert_array_equal(est.feature_importances_(t).shape,
[X.shape[1]])
lower, upper = est.intercept__interval(
t)
point = est.intercept_(t)
truth = t
TestDRLearner._check_with_interval(
truth, point, lower, upper)
# test intercept__inference function works
est.intercept__inference(
t).summary_frame()
np.testing.assert_array_almost_equal(
est.intercept__inference(t).conf_int()[0], lower, decimal=5)
# test summary function works
est.summary(t)
if isinstance(est, ForestDRLearner):
for t in [1, 2]:
np.testing.assert_array_equal(est.feature_importances_(t).shape,
[X.shape[1]])
@staticmethod
def _check_with_interval(truth, point, lower, upper):

Просмотреть файл

@ -1,122 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Tests for linear_model extensions."""
import numpy as np
import pytest
import unittest
import warnings
from econml.sklearn_extensions.ensemble import SubsampledHonestForest
class TestSubsampledHonestForest(unittest.TestCase):
"""Test SubsampledHonestForest."""
def test_y1d(self):
np.random.seed(123)
n = 5000
d = 5
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
for _ in range(3):
for criterion in ['mse', 'mae']:
X = np.random.normal(0, 1, size=(n, d))
y = X[:, 0] + np.random.normal(0, .1, size=(n,))
est = SubsampledHonestForest(n_estimators=100, max_depth=5, criterion=criterion,
min_samples_leaf=10, verbose=0)
est.fit(X, y)
point = est.predict(X_test)
lb, ub = est.predict_interval(X_test, alpha=0.01)
np.testing.assert_allclose(point, X_test[:, 0], rtol=0, atol=.2)
np.testing.assert_array_less(lb, X_test[:, 0] + .05)
np.testing.assert_array_less(X_test[:, 0], ub + .05)
def test_nonauto_subsample_fr(self):
np.random.seed(123)
n = 5000
d = 5
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
X = np.random.normal(0, 1, size=(n, d))
y = X[:, 0] + np.random.normal(0, .1, size=(n,))
est = SubsampledHonestForest(n_estimators=100, subsample_fr=.8, max_depth=5, min_samples_leaf=10, verbose=0)
est.fit(X, y)
point = est.predict(X_test)
lb, ub = est.predict_interval(X_test, alpha=0.01)
np.testing.assert_allclose(point, X_test[:, 0], rtol=0, atol=.2)
np.testing.assert_array_less(lb, X_test[:, 0] + .05)
np.testing.assert_array_less(X_test[:, 0], ub + .05)
def test_y2d(self):
np.random.seed(123)
n = 5000
d = 5
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
for _ in range(3):
for criterion in ['mse', 'mae']:
X = np.random.normal(0, 1, size=(n, d))
y = X[:, [0, 0]] + np.random.normal(0, .1, size=(n, 2))
est = SubsampledHonestForest(n_estimators=100, max_depth=5, criterion=criterion,
min_samples_leaf=10, verbose=0)
est.fit(X, y)
point = est.predict(X_test)
lb, ub = est.predict_interval(X_test, alpha=0.01)
np.testing.assert_allclose(point, X_test[:, [0, 0]], rtol=0, atol=.2)
np.testing.assert_array_less(lb, X_test[:, [0, 0]] + .05)
np.testing.assert_array_less(X_test[:, [0, 0]], ub + .05)
def test_dishonest_y1d(self):
np.random.seed(123)
n = 5000
d = 1
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
for _ in range(3):
X = np.random.normal(0, 1, size=(n, d))
y = 1. * (X[:, 0] > 0) + np.random.normal(0, .1, size=(n,))
est = SubsampledHonestForest(n_estimators=100, honest=False, max_depth=3,
min_samples_leaf=10, verbose=0)
est.fit(X, y)
point = est.predict(X_test)
lb, ub = est.predict_interval(X_test, alpha=0.01)
np.testing.assert_allclose(point, 1 * (X_test[:, 0] > 0), rtol=0, atol=.2)
np.testing.assert_array_less(lb, 1 * (X_test[:, 0] > 0) + .05)
np.testing.assert_array_less(1 * (X_test[:, 0] > 0), ub + .05)
def test_dishonest_y2d(self):
np.random.seed(123)
n = 5000
d = 1
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
for _ in range(3):
X = np.random.normal(0, 1, size=(n, d))
y = 1. * (X[:, [0, 0]] > 0) + np.random.normal(0, .1, size=(n, 2))
est = SubsampledHonestForest(n_estimators=100, honest=False, max_depth=3,
min_samples_leaf=10, verbose=0)
est.fit(X, y)
point = est.predict(X_test)
lb, ub = est.predict_interval(X_test, alpha=0.01)
np.testing.assert_allclose(point, 1. * (X_test[:, [0, 0]] > 0), rtol=0, atol=.2)
np.testing.assert_array_less(lb, 1. * (X_test[:, [0, 0]] > 0) + .05)
np.testing.assert_array_less(1. * (X_test[:, [0, 0]] > 0), ub + .05)
def test_random_state(self):
np.random.seed(123)
n = 5000
d = 5
x_grid = np.linspace(-1, 1, 10)
X_test = np.hstack([x_grid.reshape(-1, 1), np.random.normal(size=(10, d - 1))])
X = np.random.normal(0, 1, size=(n, d))
y = X[:, 0] + np.random.normal(0, .1, size=(n,))
est = SubsampledHonestForest(n_estimators=100, max_depth=5, min_samples_leaf=10, verbose=0, random_state=12345)
est.fit(X, y)
point1 = est.predict(X_test)
est = SubsampledHonestForest(n_estimators=100, max_depth=5,
min_samples_leaf=10, verbose=0, random_state=12345)
est.fit(X, y)
point2 = est.predict(X_test)
# Check that the point estimates are the same
np.testing.assert_equal(point1, point2)

Просмотреть файл

@ -0,0 +1,298 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import logging
import time
import random
import numpy as np
import sparse as sp
import pytest
from econml.tree import DepthFirstTreeBuilder, BestSplitter, Tree, MSE
from econml.grf import LinearMomentGRFCriterion, LinearMomentGRFCriterionMSE
from econml.grf._utils import matinv, lstsq, pinv, fast_max_eigv, fast_min_eigv
from econml.utilities import cross_product
class TestGRFCython(unittest.TestCase):
def _get_base_config(self, n_features=2, n_t=2, n_samples_train=1000):
n_y = 1
return {'criterion': 'het',
'n_features': n_features,
'n_y': n_y,
'n_outputs': n_t + 1,
'n_relevant_outputs': n_t,
'store_jac': True,
'n_samples': n_samples_train,
'n_samples_train': n_samples_train,
'max_features': n_features,
'min_samples_split': 2,
'min_samples_leaf': 10,
'min_weight_leaf': 1,
'min_eig_leaf': -1,
'min_eig_leaf_on_val': False,
'min_balancedness_tol': .3,
'max_depth': 2,
'min_impurity_decrease': 0.0,
'honest': False,
'random_state': 1234,
'max_node_samples': n_samples_train,
'samples_train': np.arange(n_samples_train, dtype=np.intp),
'samples_val': np.arange(n_samples_train, dtype=np.intp)
}
def _get_base_honest_config(self, n_features=2, n_t=2, n_samples_train=1000):
n_y = 1
return {'criterion': 'het',
'n_features': n_features,
'n_y': n_y,
'n_outputs': n_t + 1,
'n_relevant_outputs': n_t,
'store_jac': True,
'n_samples': 2 * n_samples_train,
'n_samples_train': n_samples_train,
'max_features': n_features,
'min_samples_split': 2,
'min_samples_leaf': 10,
'min_weight_leaf': 1,
'min_eig_leaf': -1,
'min_eig_leaf_on_val': False,
'min_balancedness_tol': .3,
'max_depth': 2,
'min_impurity_decrease': 0.0,
'honest': True,
'random_state': 1234,
'max_node_samples': n_samples_train,
'samples_train': np.arange(n_samples_train, dtype=np.intp),
'samples_val': np.arange(n_samples_train, 2 * n_samples_train, dtype=np.intp)
}
def _get_cython_objects(self, *, criterion, n_features, n_y, n_outputs, n_relevant_outputs,
store_jac, n_samples, n_samples_train, max_features,
min_samples_split, min_samples_leaf, min_weight_leaf,
min_eig_leaf, min_eig_leaf_on_val, min_balancedness_tol, max_depth, min_impurity_decrease,
honest, random_state, max_node_samples, samples_train,
samples_val):
tree = Tree(n_features, n_outputs, n_relevant_outputs, store_jac)
if criterion == 'het':
criterion = LinearMomentGRFCriterion(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
criterion_val = LinearMomentGRFCriterion(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
else:
criterion = LinearMomentGRFCriterionMSE(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
criterion_val = LinearMomentGRFCriterionMSE(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
splitter = BestSplitter(criterion, criterion_val,
max_features, min_samples_leaf, min_weight_leaf,
min_balancedness_tol, honest, min_eig_leaf, min_eig_leaf_on_val, random_state)
builder = DepthFirstTreeBuilder(splitter, min_samples_split,
min_samples_leaf, min_weight_leaf,
max_depth, min_impurity_decrease)
return tree, criterion, criterion_val, splitter, builder
def _get_continuous_data(self, config):
random_state = np.random.RandomState(config['random_state'])
X = random_state.normal(size=(config['n_samples_train'], config['n_features']))
T = np.zeros((config['n_samples_train'], config['n_relevant_outputs']))
for t in range(T.shape[1]):
T[:, t] = random_state.binomial(1, .5, size=(T.shape[0],))
Taug = np.hstack([T, np.ones((T.shape[0], 1))])
y = ((X[:, [0]] > 0.0) + .5) * np.sum(T, axis=1, keepdims=True) + .5
yaug = np.hstack([y, y * Taug, cross_product(Taug, Taug)])
X = np.vstack([X, X])
yaug = np.vstack([yaug, yaug])
return X, yaug, np.hstack([(X[:, [0]] > 0.0) + .5, (X[:, [0]] > 0.0) + .5])
def _train_tree(self, config, X, y):
tree, criterion, criterion_val, splitter, builder = self._get_cython_objects(**config)
builder.build(tree, X, y,
config['samples_train'],
config['samples_val'],
store_jac=config['store_jac'])
return tree
def _get_true_quantities(self, config, X, y, mask, criterion):
alpha = y[mask, config['n_y']:config['n_y'] + config['n_outputs']]
pointJ = y[mask, config['n_y'] + config['n_outputs']:
config['n_y'] + (config['n_outputs'] + 1) * config['n_outputs']]
jac = np.mean(pointJ, axis=0)
precond = np.mean(alpha, axis=0)
invJ = np.linalg.inv(jac.reshape((alpha.shape[1], alpha.shape[1])))
param = invJ @ precond
moment = alpha - pointJ.reshape((-1, alpha.shape[1], alpha.shape[1])) @ param
rho = ((invJ @ moment.T).T)[:, :config['n_relevant_outputs']]
if criterion == 'het':
impurity = np.mean(rho**2) - np.mean(np.mean(rho, axis=0)**2)
else:
impurity = np.mean(y[mask, :config['n_y']]**2)
impurity -= (param.reshape(1, -1) @ jac.reshape((alpha.shape[1], alpha.shape[1])) @ param)[0]
return jac, precond, param, impurity
def _get_node_quantities(self, tree, node_id):
return (tree.jac[node_id, :], tree.precond[node_id, :],
tree.full_value[node_id, :, 0], tree.impurity[node_id])
def _test_tree_quantities(self, base_config_gen, criterion):
config = base_config_gen()
config['criterion'] = criterion
config['max_depth'] = 1
X, y, truth = self._get_continuous_data(config)
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(X[:config['n_samples_train']], X[config['n_samples_train']:])
np.testing.assert_array_equal(y[:config['n_samples_train']], y[config['n_samples_train']:])
np.testing.assert_array_equal(config['samples_train'], np.arange(config['n_samples_train']))
if config['honest']:
np.testing.assert_array_equal(config['samples_val'],
np.arange(config['n_samples_train'], 2 * config['n_samples_train']))
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2]))
np.testing.assert_allclose(tree.threshold, np.array([0, -2, -2]), atol=.1, rtol=0)
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, np.ones(X.shape[0]) > 0, criterion),
self._get_node_quantities(tree, 0))]
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y,
X[:, tree.feature[0]] < tree.threshold[0], criterion),
self._get_node_quantities(tree, 1))]
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y,
X[:, tree.feature[0]] >= tree.threshold[0], criterion),
self._get_node_quantities(tree, 2))]
mask = np.abs(X[:, 0]) > .05
np.testing.assert_allclose(tree.predict(X[mask]), truth[mask], atol=.05)
config = base_config_gen()
config['criterion'] = criterion
config['max_depth'] = 2
X, y, truth = self._get_continuous_data(config)
tree = self._train_tree(config, X, y)
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, np.ones(X.shape[0]) > 0, criterion),
self._get_node_quantities(tree, 0))]
mask0 = X[:, tree.feature[0]] < tree.threshold[0]
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, mask0, criterion),
self._get_node_quantities(tree, 1))]
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, ~mask0, criterion),
self._get_node_quantities(tree, 4))]
mask1a = mask0 & (X[:, tree.feature[1]] < tree.threshold[1])
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, mask1a, criterion),
self._get_node_quantities(tree, 2))]
mask1b = mask0 & (X[:, tree.feature[1]] >= tree.threshold[1])
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, mask1b, criterion),
self._get_node_quantities(tree, 3))]
mask1c = (~mask0) & (X[:, tree.feature[4]] < tree.threshold[4])
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, mask1c, criterion),
self._get_node_quantities(tree, 5))]
mask1d = (~mask0) & (X[:, tree.feature[4]] >= tree.threshold[4])
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(config, X, y, mask1d, criterion),
self._get_node_quantities(tree, 6))]
mask = np.abs(X[:, 0]) > .05
np.testing.assert_allclose(tree.predict(X[mask]), truth[mask], atol=.05)
def test_dishonest_tree(self):
self._test_tree_quantities(self._get_base_config, criterion='het')
self._test_tree_quantities(self._get_base_config, criterion='mse')
def test_honest_tree(self):
self._test_tree_quantities(self._get_base_honest_config, criterion='het')
self._test_tree_quantities(self._get_base_honest_config, criterion='mse')
def test_honest_dishonest_equivalency(self):
for criterion in ['het', 'mse']:
config = self._get_base_config()
config['criterion'] = criterion
config['max_depth'] = 4
X, y, _ = self._get_continuous_data(config)
tree = self._train_tree(config, X, y)
config = self._get_base_honest_config()
config['criterion'] = criterion
config['max_depth'] = 4
X, y, _ = self._get_continuous_data(config)
honest_tree = self._train_tree(config, X, y)
np.testing.assert_equal(tree.feature, honest_tree.feature)
np.testing.assert_equal(tree.threshold, honest_tree.threshold)
np.testing.assert_equal(tree.value, honest_tree.value)
np.testing.assert_equal(tree.full_value, honest_tree.full_value)
np.testing.assert_equal(tree.impurity, honest_tree.impurity)
np.testing.assert_equal(tree.impurity, honest_tree.impurity_train)
np.testing.assert_equal(tree.n_node_samples, honest_tree.n_node_samples)
np.testing.assert_equal(tree.weighted_n_node_samples, honest_tree.weighted_n_node_samples_train)
np.testing.assert_equal(tree.n_node_samples, honest_tree.n_node_samples_train)
np.testing.assert_equal(tree.jac, honest_tree.jac)
np.testing.assert_equal(tree.precond, honest_tree.precond)
np.testing.assert_equal(tree.predict(X), honest_tree.predict(X))
np.testing.assert_equal(tree.predict_full(X), honest_tree.predict_full(X))
np.testing.assert_equal(tree.compute_feature_importances(), honest_tree.compute_feature_importances())
np.testing.assert_equal(tree.compute_feature_heterogeneity_importances(),
honest_tree.compute_feature_heterogeneity_importances())
def test_min_var_leaf(self):
n_samples_train = 10
for criterion in ['het', 'mse']:
config = self._get_base_config(n_samples_train=n_samples_train, n_t=1, n_features=1)
config['max_depth'] = 1
config['min_samples_leaf'] = 1
config['min_eig_leaf'] = .2
config['criterion'] = criterion
X = np.arange(n_samples_train).reshape(-1, 1)
T = np.random.binomial(1, .5, size=(n_samples_train, 1))
T[X[:, 0] < n_samples_train // 2] = 0
T[X[:, 0] >= n_samples_train // 2] = 1
Taug = np.hstack([T, np.ones((T.shape[0], 1))])
y = np.zeros((n_samples_train, 1))
yaug = np.hstack([y, y * Taug, cross_product(Taug, Taug)])
tree = self._train_tree(config, X, yaug)
if criterion == 'het':
np.testing.assert_array_less(config['min_eig_leaf'], np.mean(T[X[:, 0] > tree.threshold[0]]**2))
np.testing.assert_array_less(config['min_eig_leaf'], np.mean(T[X[:, 0] <= tree.threshold[0]]**2))
else:
np.testing.assert_array_equal(tree.feature, np.array([-2]))
def test_fast_eigv(self):
n = 4
np.random.seed(123)
for _ in range(10):
A = np.random.normal(0, 1, size=(n, n))
A = np.asfortranarray(A @ A.T)
apx = fast_min_eigv(A, 5, 123)
opt = np.min(np.linalg.eig(A)[0])
np.testing.assert_allclose(apx, opt, atol=.01, rtol=.3)
apx = fast_max_eigv(A, 10, 123)
opt = np.max(np.linalg.eig(A)[0])
np.testing.assert_allclose(apx, opt, atol=.5, rtol=.2)
def test_linalg(self):
np.random.seed(1235)
for n, m, nrhs in [(3, 3, 3), (3, 2, 1), (3, 1, 2), (1, 4, 2), (3, 4, 5)]:
for _ in range(100):
A = np.random.normal(0, 1, size=(n, m))
y = np.random.normal(0, 1, size=(n, nrhs))
yf = y
if m > n:
yf = np.zeros((m, nrhs))
yf[:n] = y
ours = np.asfortranarray(np.zeros((m, nrhs)))
lstsq(np.asfortranarray(A), np.asfortranarray(yf.copy()), ours, copy_b=True)
true = np.linalg.lstsq(A, y, rcond=np.finfo(np.float64).eps * max(n, m))[0]
np.testing.assert_allclose(ours, true, atol=.00001, rtol=.0)
ours = np.asfortranarray(np.zeros(A.T.shape, dtype=np.float64))
pinv(np.asfortranarray(A), ours)
true = np.linalg.pinv(A)
np.testing.assert_allclose(ours, true, atol=.00001, rtol=.0)
if n == m:
ours = np.asfortranarray(np.zeros(A.T.shape, dtype=np.float64))
matinv(np.asfortranarray(A), ours)
true = np.linalg.inv(A)
np.testing.assert_allclose(ours, true, atol=.00001, rtol=.0)

Просмотреть файл

@ -0,0 +1,690 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import logging
import time
import random
import numpy as np
import pandas as pd
import pytest
from econml.grf import RegressionForest, CausalForest, CausalIVForest, MultiOutputGRF
from econml.utilities import cross_product
from copy import deepcopy
from sklearn.utils import check_random_state
import scipy.stats
class TestGRFPython(unittest.TestCase):
def _get_base_config(self):
return {'n_estimators': 1, 'subforest_size': 1, 'max_depth': 2,
'min_samples_split': 2, 'min_samples_leaf': 1,
'inference': False, 'max_samples': 1.0, 'honest': False,
'n_jobs': None, 'random_state': 123}
def _get_regression_data(self, n, n_features, random_state):
X = np.zeros((n, n_features))
X[:, 0] = np.arange(X.shape[0])
X[:, 1] = np.random.RandomState(random_state).normal(0, 1, size=(X.shape[0]))
y = 1.0 * (X[:, 0] >= n / 2).reshape(-1, 1)
y += 1.0 * (X[:, 0] >= n / 4).reshape(-1, 1)
y += 1.0 * (X[:, 0] >= 3 * n / 4).reshape(-1, 1)
return X, y, y
def test_regression_tree_internals(self):
base_config = self._get_base_config()
n, n_features = 10, 2
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(**base_config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, 0, -2, -2, 0, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, 2.5, - 2, -2, 7.5, -2, -2]))
np.testing.assert_array_almost_equal(tree.value.flatten()[:3],
np.array([np.mean(y),
np.mean(y[X[:, tree.feature[0]] < tree.threshold[0]]),
np.mean(y[(X[:, tree.feature[0]] < tree.threshold[0]) &
(X[:, tree.feature[1]] < tree.threshold[1])])]),
decimal=5)
np.testing.assert_array_almost_equal(tree.predict(X), y, decimal=5)
tree.predict_precond(X)
tree.predict_jac(X)
tree.predict_precond_and_jac(X)
less = X[:, tree.feature[0]] < tree.threshold[0]
# testing importances
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y)
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False),
feature_importances, decimal=5)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y) - np.var(y[less])
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False, max_depth=0),
feature_importances, decimal=5)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y) - np.var(y[less]) + .5 * (np.var(y[less]))
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False,
max_depth=1, depth_decay=1.0),
feature_importances, decimal=5)
# testing heterogeneity importances
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = 5 * 5 * (np.mean(y[less]) - np.mean(y[~less]))**2 / 100
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False,
max_depth=0),
feature_importances, decimal=5)
feature_importances[0] += .5 * (2 * 2 * 3 * (1)**2 / 5) / 10
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False,
max_depth=1,
depth_decay=1.0),
feature_importances, decimal=5)
feature_importances[0] += .5 * (2 * 2 * 3 * (1)**2 / 5) / 10
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False),
feature_importances, decimal=5)
# Testing that all parameters do what they are supposed to
config = deepcopy(base_config)
config['min_samples_leaf'] = 5
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = deepcopy(base_config)
config['min_samples_split'] = 11
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([-2]))
np.testing.assert_array_equal(tree.threshold, np.array([-2]))
np.testing.assert_array_almost_equal(tree.predict(X), np.mean(y), decimal=5)
np.testing.assert_array_almost_equal(tree.predict_full(X), np.mean(y), decimal=5)
config = deepcopy(base_config)
config['min_weight_fraction_leaf'] = .5
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
# testing predict, apply and decision path
less = X[:, tree.feature[0]] < tree.threshold[0]
y_pred = np.zeros((X.shape[0], 1))
y_pred[less] = np.mean(y[less])
y_pred[~less] = np.mean(y[~less])
np.testing.assert_array_almost_equal(tree.predict(X), y_pred, decimal=5)
np.testing.assert_array_almost_equal(tree.predict_full(X), y_pred, decimal=5)
decision_path = np.zeros((X.shape[0], len(tree.feature)))
decision_path[less, :] = np.array([1, 1, 0])
decision_path[~less, :] = np.array([1, 0, 1])
np.testing.assert_array_equal(tree.decision_path(X).todense(), decision_path)
apply = np.zeros(X.shape[0])
apply[less] = 1
apply[~less] = 2
np.testing.assert_array_equal(tree.apply(X), apply)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = 1
np.testing.assert_array_equal(tree.compute_feature_importances(),
feature_importances)
config = deepcopy(base_config)
config['min_balancedness_tol'] = 0.
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = deepcopy(base_config)
config['min_balancedness_tol'] = 0.1
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, 0, -2, -2, 0, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, 2.5, - 2, -2, 7.5, -2, -2]))
config = deepcopy(base_config)
config['max_depth'] = 1
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = deepcopy(base_config)
config['min_impurity_decrease'] = 0.9999
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = deepcopy(base_config)
config['min_impurity_decrease'] = 1.0001
forest = RegressionForest(**config).fit(X, y)
tree = forest[0].tree_
np.testing.assert_array_equal(tree.feature, np.array([-2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([-2, ]))
def _get_causal_data(self, n, n_features, n_treatments, random_state):
random_state = np.random.RandomState(random_state)
X = random_state.normal(size=(n, n_features))
T = np.zeros((n, n_treatments))
for t in range(T.shape[1]):
T[:, t] = random_state.binomial(1, .5, size=(T.shape[0],))
y = ((X[:, [0]] > 0.0) + .5) * np.sum(T, axis=1, keepdims=True) + .5
return (X, T, y, np.hstack([(X[:, [0]] > 0.0) + .5, (X[:, [0]] > 0.0) + .5]),
np.hstack([(X[:, [0]] > 0.0) + .5, (X[:, [0]] > 0.0) + .5, .5 * np.ones((X.shape[0], 1))]))
def _get_true_quantities(self, X, T, y, mask, criterion, fit_intercept, sample_weight=None):
if sample_weight is None:
sample_weight = np.ones(X.shape[0])
X, T, y, sample_weight = X[mask], T[mask], y[mask], sample_weight[mask]
n_relevant_outputs = T.shape[1]
if fit_intercept:
T = np.hstack([T, np.ones((T.shape[0], 1))])
alpha = y * T
pointJ = cross_product(T, T)
node_weight = np.sum(sample_weight)
jac = node_weight * np.average(pointJ, axis=0, weights=sample_weight)
precond = node_weight * np.average(alpha, axis=0, weights=sample_weight)
if jac.shape[0] == 1:
invJ = np.array([[1 / jac[0]]])
elif jac.shape[0] == 4:
det = jac[0] * jac[3] - jac[1] * jac[2]
if abs(det) < 1e-6:
det = 1e-6
invJ = np.array([[jac[3], -jac[1]], [-jac[2], jac[0]]]) / det
else:
invJ = np.linalg.inv(jac.reshape((alpha.shape[1], alpha.shape[1])) + 1e-6 * np.eye(T.shape[1]))
param = invJ @ precond
jac = jac / node_weight
precond = precond / node_weight
if criterion == 'het':
moment = alpha - pointJ.reshape((-1, alpha.shape[1], alpha.shape[1])) @ param
rho = ((invJ @ moment.T).T)[:, :n_relevant_outputs] * node_weight
impurity = np.mean(np.average(rho**2, axis=0, weights=sample_weight))
impurity -= np.mean(np.average(rho, axis=0, weights=sample_weight)**2)
else:
impurity = np.mean(np.average(y**2, axis=0, weights=sample_weight))
impurity -= (param.reshape(1, -1) @ jac.reshape((alpha.shape[1], alpha.shape[1])) @ param)[0]
return jac, precond, param, impurity
def _get_node_quantities(self, tree, node_id):
return (tree.jac[node_id, :], tree.precond[node_id, :],
tree.full_value[node_id, :, 0], tree.impurity[node_id])
def _train_causal_forest(self, X, T, y, config, sample_weight=None):
return CausalForest(**config).fit(X, T, y, sample_weight=sample_weight)
def _train_iv_forest(self, X, T, y, config, sample_weight=None):
return CausalIVForest(**config).fit(X, T, y, Z=T, sample_weight=sample_weight)
def _test_causal_tree_internals(self, trainer):
config = self._get_base_config()
for criterion in ['het', 'mse']:
for fit_intercept in [False, True]:
for min_var_fraction_leaf in [None, .4]:
config['criterion'] = criterion
config['fit_intercept'] = fit_intercept
config['max_depth'] = 2
config['min_samples_leaf'] = 5
config['min_var_fraction_leaf'] = min_var_fraction_leaf
n, n_features, n_treatments = 100, 2, 2
random_state = 123
X, T, y, truth, truth_full = self._get_causal_data(n, n_features, n_treatments, random_state)
forest = trainer(X, T, y, config)
tree = forest[0].tree_
paths = np.array(forest[0].decision_path(X).todense())
for node_id in range(len(tree.feature)):
mask = paths[:, node_id] > 0
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(X, T, y, mask, criterion, fit_intercept),
self._get_node_quantities(tree, node_id))]
if fit_intercept and (min_var_fraction_leaf is not None):
mask = np.abs(X[:, 0]) > .3
np.testing.assert_allclose(tree.predict(X[mask]), truth[mask], atol=.05)
np.testing.assert_allclose(tree.predict_full(X[mask]), truth_full[mask], atol=.05)
def _test_causal_honesty(self, trainer):
for criterion in ['het', 'mse']:
for fit_intercept in [False, True]:
for min_var_fraction_leaf, min_var_leaf_on_val in [(None, False), (.4, False), (.4, True)]:
for min_impurity_decrease in [0.0, 0.07]:
for inference in [False, True]:
for sample_weight in [None, 'rand']:
config = self._get_base_config()
config['honest'] = True
config['criterion'] = criterion
config['fit_intercept'] = fit_intercept
config['max_depth'] = 2
config['min_samples_leaf'] = 5
config['min_var_fraction_leaf'] = min_var_fraction_leaf
config['min_var_leaf_on_val'] = min_var_leaf_on_val
config['min_impurity_decrease'] = min_impurity_decrease
config['inference'] = inference
n, n_features, n_treatments = 400, 2, 2
if inference:
config['n_estimators'] = 4
config['subforest_size'] = 2
config['max_samples'] = .4
config['n_jobs'] = 1
n = 800
random_state = 123
if sample_weight is not None:
sample_weight = check_random_state(random_state).randint(0, 4, size=n)
X, T, y, truth, truth_full = self._get_causal_data(n, n_features,
n_treatments, random_state)
forest = trainer(X, T, y, config, sample_weight=sample_weight)
subinds = forest.get_subsample_inds()
if (sample_weight is None) and fit_intercept and (min_var_fraction_leaf is not None):
mask = np.abs(X[:, 0]) > .5
np.testing.assert_allclose(forest.predict(X[mask]),
truth[mask], atol=.07)
np.testing.assert_allclose(forest.predict_full(X[mask]),
truth_full[mask], atol=.07)
np.testing.assert_allclose(forest.predict_tree_average(X[mask]),
truth[mask], atol=.07)
np.testing.assert_allclose(forest.predict_tree_average_full(X[mask]),
truth_full[mask], atol=.07)
forest_paths, ptr = forest.decision_path(X)
forest_paths = np.array(forest_paths.todense())
forest_apply = forest.apply(X)
for it, tree in enumerate(forest):
tree_paths = np.array(tree.decision_path(X).todense())
np.testing.assert_array_equal(tree_paths, forest_paths[:, ptr[it]:ptr[it + 1]])
tree_apply = tree.apply(X)
np.testing.assert_array_equal(tree_apply, forest_apply[:, it])
_, samples_val = tree.get_train_test_split_inds()
inds_val = subinds[it][samples_val]
Xval, Tval, yval, truthval = X[inds_val], T[inds_val], y[inds_val], truth[inds_val]
sample_weightval = sample_weight[inds_val] if sample_weight is not None else None
paths = np.array(tree.decision_path(Xval).todense())
for node_id in range(len(tree.tree_.feature)):
mask = paths[:, node_id] > 0
[np.testing.assert_allclose(a, b, atol=1e-4)
for a, b in zip(self._get_true_quantities(Xval, Tval, yval, mask,
criterion, fit_intercept,
sample_weight=sample_weightval),
self._get_node_quantities(tree.tree_, node_id))]
if ((sample_weight is None) and
fit_intercept and (min_var_fraction_leaf is not None)):
mask = np.abs(Xval[:, 0]) > .5
np.testing.assert_allclose(tree.tree_.predict(Xval[mask]),
truthval[mask], atol=.07)
if (sample_weight is None) and min_impurity_decrease > 0.0005:
assert np.all((tree.tree_.feature == 0) | (tree.tree_.feature == -2))
def test_causal_tree(self,):
self._test_causal_tree_internals(self._train_causal_forest)
self._test_causal_honesty(self._train_causal_forest)
def test_iv_tree(self,):
self._test_causal_tree_internals(self._train_iv_forest)
self._test_causal_honesty(self._train_iv_forest)
def test_min_var_leaf(self,):
random_state = np.random.RandomState(123)
n, n_features, n_treatments = 200, 2, 1
X = random_state.normal(size=(n, n_features))
T = np.zeros((n, n_treatments))
for t in range(T.shape[1]):
T[:, t] = random_state.binomial(1, .5 + .2 * np.clip(X[:, 0], -1, 1), size=(T.shape[0],))
y = ((X[:, [0]] > 0.0) + .5) * np.sum(T, axis=1, keepdims=True) + .5
total_std = np.std(T)
min_var = .7 * total_std
for honest, min_var_fraction_leaf, min_var_leaf_on_val in [(False, None, False), (False, .8, False),
(True, None, True), (True, .8, True)]:
config = self._get_base_config()
config['criterion'] = 'mse'
config['n_estimators'] = 4
config['max_samples'] = 1.0
config['max_depth'] = None
config['min_var_fraction_leaf'] = min_var_fraction_leaf
config['fit_intercept'] = True
config['honest'] = honest
config['min_var_leaf_on_val'] = min_var_leaf_on_val
forest = self._train_causal_forest(X, T, y, config)
subinds = forest.get_subsample_inds()
for it, tree in enumerate(forest):
_, samples_val = tree.get_train_test_split_inds()
inds_val = subinds[it][samples_val]
Xval, Tval, _ = X[inds_val], T[inds_val], y[inds_val]
paths = np.array(tree.decision_path(Xval).todense())
if min_var_fraction_leaf is None:
with np.testing.assert_raises(AssertionError):
for node_id in range(len(tree.tree_.feature)):
mask = paths[:, node_id] > 0
np.testing.assert_array_less(min_var - 1e-7, np.std(Tval[mask]))
else:
for node_id in range(len(tree.tree_.feature)):
mask = paths[:, node_id] > 0
np.testing.assert_array_less(min_var - 1e-7, np.std(Tval[mask]))
def test_subsampling(self,):
# test that the subsampling scheme past to the trees is correct
random_state = 123
n, n_features, n_treatments = 10, 2, 2
n_estimators = 600
config = self._get_base_config()
config['n_estimators'] = n_estimators
config['max_samples'] = .7
config['max_depth'] = 1
X, T, y, _, _ = self._get_causal_data(n, n_features, n_treatments, random_state)
forest = self._train_causal_forest(X, T, y, config)
subinds = forest.get_subsample_inds()
inds, counts = np.unique(subinds, return_counts=True)
np.testing.assert_allclose(counts / n_estimators, .7, atol=.06)
counts = np.zeros(n)
for it, tree in enumerate(forest):
samples_train, samples_val = tree.get_train_test_split_inds()
np.testing.assert_equal(samples_train, samples_val)
config = self._get_base_config()
config['n_estimators'] = n_estimators
config['max_samples'] = 7
config['max_depth'] = 1
X, T, y, _, _ = self._get_causal_data(n, n_features, n_treatments, random_state)
forest = self._train_causal_forest(X, T, y, config)
subinds = forest.get_subsample_inds()
inds, counts = np.unique(subinds, return_counts=True)
np.testing.assert_allclose(counts / n_estimators, .7, atol=.06)
config = self._get_base_config()
config['n_estimators'] = n_estimators
config['inference'] = True
config['subforest_size'] = 2
config['max_samples'] = .4
config['max_depth'] = 1
config['honest'] = True
X, T, y, _, _ = self._get_causal_data(n, n_features, n_treatments, random_state)
forest = self._train_causal_forest(X, T, y, config)
subinds = forest.get_subsample_inds()
inds, counts = np.unique(subinds, return_counts=True)
np.testing.assert_allclose(counts / n_estimators, .4, atol=.06)
counts = np.zeros(n)
for it, tree in enumerate(forest):
_, samples_val = tree.get_train_test_split_inds()
inds_val = subinds[it][samples_val]
counts[inds_val] += 1
np.testing.assert_allclose(counts / n_estimators, .2, atol=.05)
return
def _get_step_regression_data(self, n, n_features, random_state):
rnd = np.random.RandomState(random_state)
X = rnd.uniform(-1, 1, size=(n, n_features))
y = 1.0 * (X[:, 0] >= 0.0).reshape(-1, 1) + rnd.normal(0, 1, size=(n, 1))
return X, y, y
def test_var(self,):
# test that the estimator calcualtes var correctly
config = self._get_base_config()
config['honest'] = True
config['max_depth'] = 0
config['inference'] = True
config['n_estimators'] = 1000
config['subforest_size'] = 2
config['max_samples'] = .5
config['n_jobs'] = 1
n_features = 2
# test api
n = 100
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(**config).fit(X, y)
alpha = .1
mean, var = forest.predict_and_var(X)
lb = scipy.stats.norm.ppf(alpha / 2, loc=mean[:, 0], scale=np.sqrt(var[:, 0, 0])).reshape(-1, 1)
ub = scipy.stats.norm.ppf(1 - alpha / 2, loc=mean[:, 0], scale=np.sqrt(var[:, 0, 0])).reshape(-1, 1)
np.testing.assert_allclose(var, forest.predict_var(X))
lbtest, ubtest = forest.predict_interval(X, alpha=alpha)
np.testing.assert_allclose(lb, lbtest)
np.testing.assert_allclose(ub, ubtest)
meantest, lbtest, ubtest = forest.predict(X, interval=True, alpha=alpha)
np.testing.assert_allclose(mean, meantest)
np.testing.assert_allclose(lb, lbtest)
np.testing.assert_allclose(ub, ubtest)
np.testing.assert_allclose(np.sqrt(var[:, 0, 0]), forest.prediction_stderr(X)[:, 0])
# test accuracy
for n in [10, 100, 1000, 10000]:
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(**config).fit(X, y)
our_mean, our_var = forest.predict_and_var(X[:1])
true_mean, true_var = np.mean(y), np.var(y) / y.shape[0]
np.testing.assert_allclose(our_mean, true_mean, atol=0.05)
np.testing.assert_allclose(our_var, true_var, atol=0.05, rtol=.1)
for n, our_thr, true_thr in [(1000, .5, .25), (10000, .05, .05)]:
random_state = 123
config['max_depth'] = 1
X, y, truth = self._get_step_regression_data(n, n_features, random_state)
forest = RegressionForest(**config).fit(X, y)
posX = X[X[:, 0] > our_thr]
negX = X[X[:, 0] < -our_thr]
our_pos_mean, our_pos_var = forest.predict_and_var(posX)
our_neg_mean, our_neg_var = forest.predict_and_var(negX)
pos = X[:, 0] > true_thr
true_pos_mean, true_pos_var = np.mean(y[pos]), np.var(y[pos]) / y[pos].shape[0]
neg = X[:, 0] < -true_thr
true_neg_mean, true_neg_var = np.mean(y[neg]), np.var(y[neg]) / y[neg].shape[0]
np.testing.assert_allclose(our_pos_mean, true_pos_mean, atol=0.07)
np.testing.assert_allclose(our_pos_var, true_pos_var, atol=0.0, rtol=.25)
np.testing.assert_allclose(our_neg_mean, true_neg_mean, atol=0.07)
np.testing.assert_allclose(our_neg_var, true_neg_var, atol=0.0, rtol=.25)
return
def test_projection(self,):
# test the projection functionality of forests
# test that the estimator calcualtes var correctly
np.set_printoptions(precision=10, suppress=True)
config = self._get_base_config()
config['honest'] = True
config['max_depth'] = 0
config['inference'] = True
config['n_estimators'] = 100
config['subforest_size'] = 2
config['max_samples'] = .5
config['n_jobs'] = 1
n_features = 2
# test api
n = 100
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(**config).fit(X, y)
mean, var = forest.predict_and_var(X)
mean = mean.flatten()
var = var.flatten()
y = np.hstack([y, y])
truth = np.hstack([truth, truth])
forest = RegressionForest(**config).fit(X, y)
projector = np.ones((X.shape[0], 2)) / 2.0
mean_proj, var_proj = forest.predict_projection_and_var(X, projector)
np.testing.assert_array_equal(mean_proj, mean)
np.testing.assert_array_equal(var_proj, var)
np.testing.assert_array_equal(var_proj, forest.predict_projection_var(X, projector))
np.testing.assert_array_equal(mean_proj, forest.predict_projection(X, projector))
return
def test_feature_importances(self,):
# test that the estimator calcualtes var correctly
for trainer in [self._train_causal_forest, self._train_iv_forest]:
for criterion in ['het', 'mse']:
for sample_weight in [None, 'rand']:
config = self._get_base_config()
config['honest'] = True
config['criterion'] = criterion
config['fit_intercept'] = True
config['max_depth'] = 2
config['min_samples_leaf'] = 5
config['min_var_fraction_leaf'] = None
config['min_impurity_decrease'] = 0.0
config['inference'] = True
config['n_estimators'] = 4
config['subforest_size'] = 2
config['max_samples'] = .4
config['n_jobs'] = 1
n, n_features, n_treatments = 800, 2, 2
random_state = 123
if sample_weight is not None:
sample_weight = check_random_state(random_state).randint(0, 4, size=n)
X, T, y, truth, truth_full = self._get_causal_data(n, n_features,
n_treatments, random_state)
forest = trainer(X, T, y, config, sample_weight=sample_weight)
forest_het_importances = np.zeros(n_features)
for it, tree in enumerate(forest):
tree_ = tree.tree_
tfeature = tree_.feature
timpurity = tree_.impurity
tdepth = tree_.depth
tleft = tree_.children_left
tright = tree_.children_right
tw = tree_.weighted_n_node_samples
tvalue = tree_.value
for max_depth in [0, 2]:
feature_importances = np.zeros(n_features)
for it, (feat, impurity, depth, left, right, w) in\
enumerate(zip(tfeature, timpurity, tdepth, tleft, tright, tw)):
if (left != -1) and (depth <= max_depth):
gain = w * impurity - tw[left] * timpurity[left] - tw[right] * timpurity[right]
feature_importances[feat] += gain / (depth + 1)**2.0
feature_importances /= tw[0]
totest = tree.tree_.compute_feature_importances(normalize=False,
max_depth=max_depth, depth_decay=2.0)
np.testing.assert_array_equal(feature_importances, totest)
het_importances = np.zeros(n_features)
for it, (feat, depth, left, right, w) in\
enumerate(zip(tfeature, tdepth, tleft, tright, tw)):
if (left != -1) and (depth <= max_depth):
gain = tw[left] * tw[right] * np.mean((tvalue[left] - tvalue[right])**2) / w
het_importances[feat] += gain / (depth + 1)**2.0
het_importances /= tw[0]
totest = tree.tree_.compute_feature_heterogeneity_importances(normalize=False,
max_depth=max_depth,
depth_decay=2.0)
np.testing.assert_allclose(het_importances, totest)
het_importances /= np.sum(het_importances)
forest_het_importances += het_importances / len(forest)
np.testing.assert_allclose(forest_het_importances,
forest.feature_importances(max_depth=2, depth_decay_exponent=2.0))
np.testing.assert_allclose(forest_het_importances, forest.feature_importances_)
return
def test_non_standard_input(self,):
# test that the estimator accepts lists, tuples and pandas data frames
n_features = 2
n = 100
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(n_estimators=20, n_jobs=1, random_state=123).fit(X, y)
pred = forest.predict(X)
forest = RegressionForest(n_estimators=20, n_jobs=1, random_state=123).fit(tuple(X), tuple(y))
np.testing.assert_allclose(pred, forest.predict(tuple(X)))
forest = RegressionForest(n_estimators=20, n_jobs=1, random_state=123).fit(list(X), list(y))
np.testing.assert_allclose(pred, forest.predict(list(X)))
forest = RegressionForest(n_estimators=20, n_jobs=1, random_state=123).fit(pd.DataFrame(X), pd.DataFrame(y))
np.testing.assert_allclose(pred, forest.predict(pd.DataFrame(X)))
forest = RegressionForest(n_estimators=20, n_jobs=1, random_state=123).fit(
pd.DataFrame(X), pd.Series(y.ravel()))
np.testing.assert_allclose(pred, forest.predict(pd.DataFrame(X)))
return
def test_raise_exceptions(self,):
# test that we raise errors in mishandled situations.
n_features = 2
n = 10
random_state = 123
X, y, truth = self._get_regression_data(n, n_features, random_state)
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=20).fit(X, y[:4])
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=20, subforest_size=3).fit(X, y)
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=20, inference=True, max_samples=.6).fit(X, y)
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=20, max_samples=20).fit(X, y)
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=20, max_samples=1.2).fit(X, y)
with np.testing.assert_raises(ValueError):
forest = RegressionForest(n_estimators=4, warm_start=True, inference=True).fit(X, y)
forest.inference = False
forest.n_estimators = 8
forest.fit(X, y)
with np.testing.assert_raises(KeyError):
forest = CausalForest(n_estimators=4, criterion='peculiar').fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, max_depth=-1).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, min_samples_split=-1).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, min_samples_leaf=-1).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, min_weight_fraction_leaf=-1.0).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, min_var_fraction_leaf=-1.0).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, max_features=10).fit(X, y, y)
with np.testing.assert_raises(ValueError):
forest = CausalForest(n_estimators=4, min_balancedness_tol=.55).fit(X, y, y)
return
def test_warm_start(self,):
n_features = 2
n = 10
random_state = 123
X, y, _ = self._get_regression_data(n, n_features, random_state)
forest = RegressionForest(n_estimators=4, warm_start=True, random_state=123).fit(X, y)
forest.n_estimators = 8
forest.fit(X, y)
pred1 = forest.predict(X)
inds1 = forest.get_subsample_inds()
tree_states1 = [t.random_state for t in forest]
forest = RegressionForest(n_estimators=8, warm_start=True, random_state=123).fit(X, y)
pred2 = forest.predict(X)
inds2 = forest.get_subsample_inds()
tree_states2 = [t.random_state for t in forest]
np.testing.assert_allclose(pred1, pred2)
np.testing.assert_allclose(inds1, inds2)
np.testing.assert_allclose(tree_states1, tree_states2)
return
def test_multioutput(self,):
# test that the subsampling scheme past to the trees is correct
random_state = 123
n, n_features, n_treatments = 10, 2, 2
X, T, y, _, _ = self._get_causal_data(n, n_features, n_treatments, random_state)
y = np.hstack([y, y])
for est in [CausalForest(n_estimators=4, random_state=123),
CausalIVForest(n_estimators=4, random_state=123)]:
forest = MultiOutputGRF(est)
if isinstance(est, CausalForest):
forest.fit(X, T, y)
else:
forest.fit(X, T, y, Z=T)
pred, lb, ub = forest.predict(X, interval=True, alpha=.05)
np.testing.assert_array_equal(pred.shape, (X.shape[0], 2, n_treatments))
np.testing.assert_allclose(pred[:, 0, :], pred[:, 1, :])
np.testing.assert_allclose(lb[:, 0, :], lb[:, 1, :])
np.testing.assert_allclose(ub[:, 0, :], ub[:, 1, :])
pred, var = forest.predict_and_var(X)
np.testing.assert_array_equal(pred.shape, (X.shape[0], 2, n_treatments))
np.testing.assert_array_equal(var.shape, (X.shape[0], 2, n_treatments, n_treatments))
np.testing.assert_allclose(pred[:, 0, :], pred[:, 1, :])
np.testing.assert_allclose(var[:, 0, :, :], var[:, 1, :, :])
pred, var = forest.predict_projection_and_var(X, np.ones((X.shape[0], n_treatments)))
np.testing.assert_array_equal(pred.shape, (X.shape[0], 2))
np.testing.assert_array_equal(var.shape, (X.shape[0], 2))
np.testing.assert_allclose(pred[:, 0], pred[:, 1])
np.testing.assert_allclose(var[:, 0], var[:, 1])
imps = forest.feature_importances(max_depth=3, depth_decay_exponent=1.0)
np.testing.assert_array_equal(imps.shape, (X.shape[1], 2))
np.testing.assert_allclose(imps[:, 0], imps[:, 1])
imps = forest.feature_importances_
np.testing.assert_array_equal(imps.shape, (2, X.shape[1]))
np.testing.assert_allclose(imps[0, :], imps[1, :])
return

Просмотреть файл

@ -12,7 +12,6 @@ from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import Pipeline
from econml.ortho_forest import DMLOrthoForest, DROrthoForest
from econml.sklearn_extensions.linear_model import WeightedLassoCVWrapper
from econml.causal_forest import CausalForest
class TestOrthoForest(unittest.TestCase):
@ -86,41 +85,6 @@ class TestOrthoForest(unittest.TestCase):
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5)
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5)
# Test Causal Forest API
# Generate data with continuous treatments
T = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
TestOrthoForest.eta_sample(TestOrthoForest.n)
TE = np.array([self._exp_te(x) for x in TestOrthoForest.X])
Y = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_Y) + \
T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
# Instantiate model with most of the default parameters.
est = CausalForest(n_jobs=1, n_trees=10,
model_T=WeightedLassoCVWrapper(),
model_Y=WeightedLassoCVWrapper())
# Test inputs for continuous treatments
# --> Check that one can pass in regular lists
est.fit(list(Y), list(T), X=list(TestOrthoForest.X), W=list(TestOrthoForest.W))
# --> Check that it fails correctly if lists of different shape are passed in
self.assertRaises(ValueError, est.fit, Y[:TestOrthoForest.n // 2], T[:TestOrthoForest.n // 2],
TestOrthoForest.X, TestOrthoForest.W)
# Check that outputs have the correct shape
out_te = est.const_marginal_effect(TestOrthoForest.x_test)
self.assertEqual(TestOrthoForest.x_test.shape[0], out_te.shape[0])
# Test continuous treatments with controls
est = CausalForest(n_jobs=1, n_trees=100, min_leaf_size=10,
max_depth=50, subsample_ratio=0.50,
model_T=WeightedLassoCVWrapper(),
model_Y=WeightedLassoCVWrapper(), cv=5)
est.fit(Y, T, X=TestOrthoForest.X, W=TestOrthoForest.W, inference="blb")
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5)
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5)
# Test continuous treatments without controls
T = TestOrthoForest.eta_sample(TestOrthoForest.n)
Y = T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
est.fit(Y, T, X=TestOrthoForest.X, inference="blb")
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5)
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5)
def test_binary_treatments(self):
np.random.seed(123)
# Generate data with binary treatments
@ -174,57 +138,6 @@ class TestOrthoForest(unittest.TestCase):
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5, treatment_type='discrete')
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5, treatment_type='discrete')
# Test CausalForest API
np.random.seed(123)
# Generate data with binary treatments
log_odds = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
TestOrthoForest.eta_sample(TestOrthoForest.n)
T_sigmoid = 1 / (1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
TE = np.array([self._exp_te(x) for x in TestOrthoForest.X])
Y = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_Y) + \
T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
# Instantiate model with default params. Using n_jobs=1 since code coverage
# does not work well with parallelism.
est = CausalForest(n_trees=10, n_jobs=1,
model_Y=Lasso(),
model_T=LogisticRegressionCV(penalty='l1', solver='saga'))
# Test inputs for binary treatments
# --> Check that one can pass in regular lists
est.fit(list(Y), list(T), X=list(TestOrthoForest.X), W=list(TestOrthoForest.W))
# --> Check that it fails correctly if lists of different shape are passed in
self.assertRaises(ValueError, est.fit, Y[:TestOrthoForest.n // 2], T[:TestOrthoForest.n // 2],
TestOrthoForest.X, TestOrthoForest.W)
# --> Check that it works when T, Y have shape (n, 1)
est.fit(Y.reshape(-1, 1), T.reshape(-1, 1), X=TestOrthoForest.X, W=TestOrthoForest.W)
# --> Check that it fails correctly when T has shape (n, 2)
self.assertRaises(ValueError, est.fit, Y, np.ones((TestOrthoForest.n, 2)),
TestOrthoForest.X, TestOrthoForest.W)
# --> Check that it fails correctly when the treatments are not numeric
self.assertRaises(ValueError, est.fit, Y, np.array(["a"] * TestOrthoForest.n),
TestOrthoForest.X, TestOrthoForest.W)
# Check that outputs have the correct shape
out_te = est.const_marginal_effect(TestOrthoForest.x_test)
self.assertSequenceEqual((TestOrthoForest.x_test.shape[0], 1, 1), out_te.shape)
# Test binary treatments with controls
est = CausalForest(n_trees=100, min_leaf_size=10,
max_depth=30, subsample_ratio=0.30, n_jobs=1,
model_Y=Lasso(),
model_T=LogisticRegressionCV(penalty='l1', solver='saga'),
discrete_treatment=True,
cv=5)
est.fit(Y, T, X=TestOrthoForest.X, W=TestOrthoForest.W, inference="blb")
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.7, treatment_type='discrete')
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5, treatment_type='discrete')
# Test binary treatments without controls
log_odds = TestOrthoForest.eta_sample(TestOrthoForest.n)
T_sigmoid = 1 / (1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
Y = T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
est.fit(Y, T, X=TestOrthoForest.X, inference="blb")
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5, treatment_type='discrete')
self._test_ci(est, TestOrthoForest.expected_exp_te, tol=1.5, treatment_type='discrete')
def test_multiple_treatments(self):
np.random.seed(123)
# Only applicable to continuous treatments
@ -251,17 +164,6 @@ class TestOrthoForest(unittest.TestCase):
self._test_te(est, expected_te, tol=0.5, treatment_type='multi')
self._test_ci(est, expected_te, tol=2.0, treatment_type='multi')
# Test CausalForest API
est = CausalForest(n_trees=100, min_leaf_size=10,
max_depth=50, subsample_ratio=0.50, n_jobs=1,
model_T=WeightedLassoCVWrapper(cv=5),
model_Y=WeightedLassoCVWrapper(cv=5),
cv=5)
est.fit(Y, T, X=TestOrthoForest.X, W=TestOrthoForest.W, inference="blb")
expected_te = np.array([TestOrthoForest.expected_exp_te, TestOrthoForest.expected_const_te]).T
self._test_te(est, expected_te, tol=0.5, treatment_type='multi')
self._test_ci(est, expected_te, tol=2.0, treatment_type='multi')
def test_effect_shape(self):
import scipy.special
np.random.seed(123)
@ -317,46 +219,6 @@ class TestOrthoForest(unittest.TestCase):
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
# Test causal foret API
est = CausalForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
model_T=DummyClassifier(strategy='prior'), discrete_treatment=True,
n_jobs=1)
est.fit(y, T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 2), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 2), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 2), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 2), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 2), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 2), "Marginal Effect interval dimension incorrect"
est.fit(y.reshape(-1, 1), T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 2), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 2), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 1, 2), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1, 2), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 2), "Marginal Effect interval dimension incorrect"
from sklearn.dummy import DummyClassifier, DummyRegressor
for global_residualization in [False, True]:
@ -420,65 +282,6 @@ class TestOrthoForest(unittest.TestCase):
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
# Test Causal Forest API
est = CausalForest(n_trees=10, model_Y=DummyRegressor(strategy='mean'),
model_T=DummyRegressor(strategy='mean'),
n_jobs=1)
est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1, 1), "Marginal Effect interval dimension incorrect"
est.fit(y.reshape(-1, 1), T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3, 1), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3, 1), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3, 1), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3, 1), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3, 1), "Marginal Effect interval dimension incorrect"
est.fit(y, T, X=X)
assert est.const_marginal_effect(X[:3]).shape == (3,), "Const Marginal Effect dimension incorrect"
assert est.marginal_effect(1, X[:3]).shape == (3,), "Marginal Effect dimension incorrect"
assert est.effect(X[:3]).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=0, T1=2).shape == (3,), "Effect dimension incorrect"
assert est.effect(X[:3], T0=1, T1=2).shape == (3,), "Effect dimension incorrect"
lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
assert lb.shape == (3,), "Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_interval(X[:3])
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
assert lb.shape == (3,), "Const Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_interval(1, X[:3])
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
assert lb.shape == (3,), "Marginal Effect interval dimension incorrect"
def test_nuisance_model_has_weights(self):
"""Test whether the correct exception is being raised if model_final doesn't have weights."""

Просмотреть файл

@ -6,7 +6,7 @@ from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, PolynomialFeatures
from sklearn.model_selection import KFold, GroupKFold
from econml.dml import DML, LinearDML, SparseLinearDML, KernelDML
from econml.dml import NonParamDML, ForestDML
from econml.dml import NonParamDML, CausalForestDML
from econml.drlearner import DRLearner, SparseLinearDRLearner, LinearDRLearner, ForestDRLearner
from econml.ortho_iv import DMLATEIV, ProjectedDMLATEIV, DMLIV, NonParamDMLIV,\
IntentToTreatDRIV, LinearIntentToTreatDRIV
@ -53,8 +53,8 @@ class TestRandomState(unittest.TestCase):
te2 = est.effect(X_test)
est.fit(Y, T, **kwargs)
te3 = est.effect(X_test)
np.testing.assert_array_equal(te1, te2, err_msg='random state fixing does not work')
np.testing.assert_array_equal(te1, te3, err_msg='random state fixing does not work')
np.testing.assert_allclose(te1, te2, err_msg='random state fixing does not work')
np.testing.assert_allclose(te1, te3, err_msg='random state fixing does not work')
def test_dml_random_state(self):
Y, T, X, W, X_test = TestRandomState._make_data(500, 2)
@ -64,10 +64,10 @@ class TestRandomState(unittest.TestCase):
model_final=RandomForestRegressor(max_depth=3, n_estimators=10, min_samples_leaf=100,
bootstrap=True, random_state=123),
discrete_treatment=True, n_splits=2, random_state=123),
ForestDML(model_y=RandomForestRegressor(n_estimators=10, max_depth=4, random_state=123),
model_t=RandomForestClassifier(n_estimators=10, max_depth=4, random_state=123),
n_estimators=10,
discrete_treatment=True, n_crossfit_splits=2, random_state=123),
CausalForestDML(model_y=RandomForestRegressor(n_estimators=10, max_depth=4, random_state=123),
model_t=RandomForestClassifier(n_estimators=10, max_depth=4, random_state=123),
n_estimators=8,
discrete_treatment=True, n_crossfit_splits=2, random_state=123),
LinearDML(model_y=RandomForestRegressor(n_estimators=10, max_depth=4, random_state=123),
model_t=RandomForestClassifier(n_estimators=10, max_depth=4, random_state=123),
discrete_treatment=True, n_splits=2, random_state=123),

Просмотреть файл

@ -30,25 +30,27 @@ class TestShap(unittest.TestCase):
for featurizer in [None, PolynomialFeatures(degree=2, include_bias=False)]:
est_list = [LinearDML(model_y=LinearRegression(),
model_t=LinearRegression(), featurizer=featurizer)]
model_t=LinearRegression(), featurizer=featurizer),
CausalForestDML(model_y=LinearRegression(), model_t=LinearRegression())]
if d_t == 1:
est_list += [
NonParamDML(model_y=LinearRegression(
), model_t=LinearRegression(), model_final=RandomForestRegressor(), featurizer=featurizer),
ForestDML(model_y=LinearRegression(), model_t=LinearRegression())]
]
for est in est_list:
with self.subTest(est=est, featurizer=featurizer, d_y=d_y, d_t=d_t):
fd_x = featurizer.fit_transform(X).shape[1] if featurizer is not None else d_x
est.fit(Y, T, X, W)
shap_values = est.shap_values(X[:10], feature_names=["a", "b", "c"])
shap_values = est.shap_values(X[:10], feature_names=["a", "b", "c"],
background_samples=None)
# test base values equals to mean of constant marginal effect
if not isinstance(est, (ForestDML, DMLOrthoForest)):
if not isinstance(est, (CausalForestDML, DMLOrthoForest)):
mean_cate = est.const_marginal_effect(X[:10]).mean(axis=0)
mean_cate = mean_cate.flatten()[0] if not np.isscalar(mean_cate) else mean_cate
self.assertAlmostEqual(shap_values["Y0"]["T0"].base_values[0], mean_cate, delta=1e-2)
if isinstance(est, (ForestDML, DMLOrthoForest)):
if isinstance(est, (CausalForestDML, DMLOrthoForest)):
fd_x = d_x
# test shape of shap values output is as expected
@ -80,15 +82,15 @@ class TestShap(unittest.TestCase):
SLearner(overall_model=RandomForestRegressor()),
XLearner(models=RandomForestRegressor()),
DomainAdaptationLearner(models=RandomForestRegressor(),
final_models=RandomForestRegressor())
final_models=RandomForestRegressor()),
CausalForestDML(model_y=LinearRegression(), model_t=LogisticRegression(),
discrete_treatment=True)
]
if d_t == 2:
est_list += [
NonParamDML(model_y=LinearRegression(
), model_t=LogisticRegression(), model_final=RandomForestRegressor(),
featurizer=featurizer, discrete_treatment=True),
ForestDML(model_y=LinearRegression(), model_t=LogisticRegression(),
discrete_treatment=True)]
featurizer=featurizer, discrete_treatment=True)]
if d_y == 1:
est_list += [DRLearner(multitask_model_final=True, featurizer=featurizer),
DRLearner(multitask_model_final=False, featurizer=featurizer),
@ -100,15 +102,16 @@ class TestShap(unittest.TestCase):
est.fit(Y, T, X)
else:
est.fit(Y, T, X, W)
shap_values = est.shap_values(X[:10], feature_names=["a", "b", "c"])
shap_values = est.shap_values(X[:10], feature_names=["a", "b", "c"],
background_samples=None)
# test base values equals to mean of constant marginal effect
if not isinstance(est, (ForestDML, ForestDRLearner, DROrthoForest)):
if not isinstance(est, (CausalForestDML, ForestDRLearner, DROrthoForest)):
mean_cate = est.const_marginal_effect(X[:10]).mean(axis=0)
mean_cate = mean_cate.flatten()[0] if not np.isscalar(mean_cate) else mean_cate
self.assertAlmostEqual(shap_values["Y0"]["T0"].base_values[0], mean_cate, delta=1e-2)
if isinstance(est, (TLearner, SLearner, XLearner, DomainAdaptationLearner, ForestDML,
if isinstance(est, (TLearner, SLearner, XLearner, DomainAdaptationLearner, CausalForestDML,
ForestDRLearner, DROrthoForest)):
fd_x = d_x
# test shape of shap values output is as expected
@ -156,14 +159,16 @@ class TestShap(unittest.TestCase):
fit_cate_intercept=True,
featurizer=PolynomialFeatures(degree=2, include_bias=False))
est.fit(Y, T, X=X, W=W)
shap_values1 = est.shap_values(X[:10], feature_names=["A", "B"], treatment_names=["orange"])
shap_values1 = est.shap_values(X[:10], feature_names=["A", "B"], treatment_names=["orange"],
background_samples=None)
est = LinearDML(model_y=Lasso(),
model_t=Lasso(),
random_state=123,
fit_cate_intercept=True,
featurizer=PolynomialFeatures(degree=2, include_bias=False))
est.fit(Y[:, 0], T, X=X, W=W)
shap_values2 = est.shap_values(X[:10], feature_names=["A", "B"], treatment_names=["orange"])
shap_values2 = est.shap_values(X[:10], feature_names=["A", "B"], treatment_names=["orange"],
background_samples=None)
np.testing.assert_allclose(shap_values1["Y0"]["orange"].data,
shap_values2["Y0"]["orange"].data)
np.testing.assert_allclose(shap_values1["Y0"]["orange"].values,

297
econml/tests/test_tree.py Normal file
Просмотреть файл

@ -0,0 +1,297 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import logging
import time
import random
import numpy as np
import sparse as sp
import pytest
from econml.tree import DepthFirstTreeBuilder, BestSplitter, Tree, MSE
class TestTree(unittest.TestCase):
def _get_base_config(self):
n_features = 2
n_samples_train = 10
n_y = 1
return {'n_features': n_features,
'n_y': n_y,
'n_outputs': n_y,
'n_relevant_outputs': n_y,
'store_jac': False,
'n_samples': n_samples_train,
'n_samples_train': n_samples_train,
'max_features': n_features,
'min_samples_split': 2,
'min_samples_leaf': 1,
'min_weight_leaf': 1,
'min_eig_leaf': -1,
'min_eig_leaf_on_val': False,
'min_balancedness_tol': .3,
'max_depth': 2,
'min_impurity_decrease': 0.0,
'honest': False,
'random_state': 123,
'max_node_samples': n_samples_train,
'samples_train': np.arange(n_samples_train, dtype=np.intp),
'samples_val': np.arange(n_samples_train, dtype=np.intp)
}
def _get_base_honest_config(self):
n_features = 2
n_samples_train = 10
n_y = 1
return {'n_features': n_features,
'n_y': n_y,
'n_outputs': n_y,
'n_relevant_outputs': n_y,
'store_jac': False,
'n_samples': 2 * n_samples_train,
'n_samples_train': n_samples_train,
'max_features': n_features,
'min_samples_split': 2,
'min_samples_leaf': 1,
'min_weight_leaf': 1,
'min_eig_leaf': -1,
'min_eig_leaf_on_val': False,
'min_balancedness_tol': .3,
'max_depth': 2,
'min_impurity_decrease': 0.0,
'honest': True,
'random_state': 123,
'max_node_samples': n_samples_train,
'samples_train': np.arange(n_samples_train, dtype=np.intp),
'samples_val': np.arange(n_samples_train, 2 * n_samples_train, dtype=np.intp)
}
def _get_cython_objects(self, *, n_features, n_y, n_outputs, n_relevant_outputs,
store_jac, n_samples, n_samples_train, max_features,
min_samples_split, min_samples_leaf, min_weight_leaf,
min_eig_leaf, min_eig_leaf_on_val, min_balancedness_tol, max_depth, min_impurity_decrease,
honest, random_state, max_node_samples, samples_train,
samples_val):
tree = Tree(n_features, n_outputs, n_relevant_outputs, store_jac)
criterion = MSE(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
criterion_val = MSE(n_outputs, n_relevant_outputs, n_features, n_y,
n_samples, max_node_samples, random_state)
splitter = BestSplitter(criterion, criterion_val,
max_features, min_samples_leaf, min_weight_leaf,
min_balancedness_tol, honest, min_eig_leaf, min_eig_leaf_on_val, random_state)
builder = DepthFirstTreeBuilder(splitter, min_samples_split,
min_samples_leaf, min_weight_leaf,
max_depth, min_impurity_decrease)
return tree, criterion, criterion_val, splitter, builder
def _get_continuous_data(self, config):
X = np.zeros((config['n_samples_train'], config['n_features']))
X[:, 0] = np.arange(X.shape[0])
X[:, 1] = np.random.RandomState(config['random_state']).normal(0, 1, size=(X.shape[0]))
y = 1.0 * (X[:, 0] >= config['n_samples_train'] / 2).reshape(-1, 1)
y += 1.0 * (X[:, 0] >= config['n_samples_train'] / 4).reshape(-1, 1)
y += 1.0 * (X[:, 0] >= 3 * config['n_samples_train'] / 4).reshape(-1, 1)
X = np.vstack([X, X])
y = np.vstack([y, y])
return X, y
def _get_binary_data(self, config):
n_samples_train = config['n_samples_train']
X = np.zeros((n_samples_train, config['n_features']))
X[:n_samples_train // 2, 0] = 1
X[:n_samples_train // 4, 1] = 1
X[3 * n_samples_train // 4:, 1] = 1
y = 1.0 * (X[:, 0] + X[:, 1]).reshape(-1, 1)
X = np.vstack([X, X])
y = np.vstack([y, y])
return X, y
def _train_tree(self, config, X, y):
tree, criterion, criterion_val, splitter, builder = self._get_cython_objects(**config)
builder.build(tree, X, y,
config['samples_train'],
config['samples_val'],
store_jac=config['store_jac'])
return tree
def _test_tree_continuous(self, base_config_gen):
config = base_config_gen()
X, y = self._get_continuous_data(config)
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, 0, -2, -2, 0, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, 2.5, - 2, -2, 7.5, -2, -2]))
np.testing.assert_array_equal(tree.value.flatten()[:3],
np.array([np.mean(y),
np.mean(y[X[:, tree.feature[0]] < tree.threshold[0]]),
np.mean(y[(X[:, tree.feature[0]] < tree.threshold[0]) &
(X[:, tree.feature[1]] < tree.threshold[1])])]))
np.testing.assert_array_almost_equal(tree.predict(X), y, decimal=10)
with np.testing.assert_raises(AttributeError):
tree.predict_precond(X)
with np.testing.assert_raises(AttributeError):
tree.predict_jac(X)
with np.testing.assert_raises(AttributeError):
tree.predict_precond_and_jac(X)
less = X[:, tree.feature[0]] < tree.threshold[0]
# testing importances
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y)
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False),
feature_importances, decimal=10)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y) - np.var(y[less])
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False, max_depth=0),
feature_importances, decimal=10)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = np.var(y) - np.var(y[less]) + .5 * (np.var(y[less]))
np.testing.assert_array_almost_equal(tree.compute_feature_importances(normalize=False,
max_depth=1, depth_decay=1.0),
feature_importances, decimal=10)
# testing heterogeneity importances
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = 5 * 5 * (np.mean(y[less]) - np.mean(y[~less]))**2 / 100
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False,
max_depth=0),
feature_importances, decimal=10)
feature_importances[0] += .5 * (2 * 2 * 3 * (1)**2 / 5) / 10
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False,
max_depth=1,
depth_decay=1.0),
feature_importances, decimal=10)
feature_importances[0] += .5 * (2 * 2 * 3 * (1)**2 / 5) / 10
np.testing.assert_array_almost_equal(tree.compute_feature_heterogeneity_importances(normalize=False),
feature_importances, decimal=10)
# Testing that all parameters do what they are supposed to
config = base_config_gen()
config['min_samples_leaf'] = 5
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = base_config_gen()
config['min_samples_split'] = 11
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([-2]))
np.testing.assert_array_equal(tree.threshold, np.array([-2]))
np.testing.assert_array_almost_equal(tree.predict(X), np.mean(y), decimal=10)
np.testing.assert_array_almost_equal(tree.predict_full(X), np.mean(y), decimal=10)
config = base_config_gen()
config['min_weight_leaf'] = 5
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
# testing predict, apply and decision path
less = X[:, tree.feature[0]] < tree.threshold[0]
y_pred = np.zeros((X.shape[0], 1))
y_pred[less] = np.mean(y[less])
y_pred[~less] = np.mean(y[~less])
np.testing.assert_array_almost_equal(tree.predict(X), y_pred, decimal=10)
np.testing.assert_array_almost_equal(tree.predict_full(X), y_pred, decimal=10)
decision_path = np.zeros((X.shape[0], len(tree.feature)))
decision_path[less, :] = np.array([1, 1, 0])
decision_path[~less, :] = np.array([1, 0, 1])
np.testing.assert_array_equal(tree.decision_path(X).todense(), decision_path)
apply = np.zeros(X.shape[0])
apply[less] = 1
apply[~less] = 2
np.testing.assert_array_equal(tree.apply(X), apply)
feature_importances = np.zeros(X.shape[1])
feature_importances[0] = 1
np.testing.assert_array_equal(tree.compute_feature_importances(),
feature_importances)
config = base_config_gen()
config['min_balancedness_tol'] = .0
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = base_config_gen()
config['min_balancedness_tol'] = .1
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, 0, -2, -2, 0, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, 2.5, - 2, -2, 7.5, -2, -2]))
config = base_config_gen()
config['max_depth'] = 1
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = base_config_gen()
config['min_impurity_decrease'] = .99999
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, -2, -2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([4.5, -2, -2]))
config = base_config_gen()
config['min_impurity_decrease'] = 1.00001
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([-2, ]))
np.testing.assert_array_equal(tree.threshold, np.array([-2, ]))
def test_dishonest_tree(self):
self._test_tree_continuous(self._get_base_config)
def test_honest_tree(self):
self._test_tree_continuous(self._get_base_honest_config)
def test_multivariable_split(self):
config = self._get_base_config()
X, y = self._get_binary_data(config)
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, 1, -2, -2, 1, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([0.5, 0.5, - 2, -2, 0.5, -2, -2]))
def test_honest_values(self):
config = self._get_base_honest_config()
X, y = self._get_binary_data(config)
y[config['n_samples_train']:] = .4
tree = self._train_tree(config, X, y)
np.testing.assert_array_equal(tree.feature, np.array([0, 1, -2, -2, 1, -2, -2]))
np.testing.assert_array_equal(tree.threshold, np.array([0.5, 0.5, - 2, -2, 0.5, -2, -2]))
np.testing.assert_array_almost_equal(tree.value.flatten(), .4 * np.ones(len(tree.value)))
def test_noisy_instance(self):
n_samples = 5000
X = np.random.normal(0, 1, size=(n_samples, 1))
y_base = 1.0 * X[:, [0]] * (X[:, [0]] > 0)
y = y_base + np.random.normal(0, .1, size=(n_samples, 1))
config = self._get_base_config()
config['n_features'] = 1
config['max_features'] = 1
config['max_depth'] = 10
config['min_samples_leaf'] = 20
config['n_samples'] = X.shape[0]
config['min_balancedness_tol'] = .5
config['n_samples_train'] = X.shape[0]
config['max_node_samples'] = X.shape[0]
config['samples_train'] = np.arange(X.shape[0], dtype=np.intp)
config['samples_val'] = np.arange(X.shape[0], dtype=np.intp)
tree = self._train_tree(config, X, y)
X_test = np.zeros((100, 1))
X_test[:, 0] = np.linspace(np.percentile(X, 10), np.percentile(X, 90), 100)
y_test = 1.0 * X_test[:, [0]] * (X_test[:, [0]] > 0)
np.testing.assert_array_almost_equal(tree.predict(X_test), y_test, decimal=1)
config = self._get_base_honest_config()
config['n_features'] = 1
config['max_features'] = 1
config['max_depth'] = 10
config['min_samples_leaf'] = 20
config['n_samples'] = X.shape[0]
config['min_balancedness_tol'] = .5
config['n_samples_train'] = X.shape[0] // 2
config['max_node_samples'] = X.shape[0] // 2
config['samples_train'] = np.arange(X.shape[0] // 2, dtype=np.intp)
config['samples_val'] = np.arange(X.shape[0] // 2, X.shape[0], dtype=np.intp)
tree = self._train_tree(config, X, y)
X_test = np.zeros((100, 1))
X_test[:, 0] = np.linspace(np.percentile(X, 10), np.percentile(X, 90), 100)
y_test = 1.0 * X_test[:, [0]] * (X_test[:, [0]] > 0)
np.testing.assert_array_almost_equal(tree.predict(X_test), y_test, decimal=1)

15
econml/tree/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from ._criterion import Criterion, RegressionCriterion, MSE
from ._splitter import Splitter, BestSplitter
from ._tree import DepthFirstTreeBuilder
from ._tree import Tree
__all__ = ["Tree",
"Splitter",
"BestSplitter",
"DepthFirstTreeBuilder",
"Criterion",
"RegressionCriterion",
"MSE"]

Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_criterion.pxd
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
# See _criterion.pyx for implementation details.
import numpy as np
cimport numpy as np
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef class Criterion:
# The criterion computes the impurity of a node and the reduction of
# impurity of a split on that node. It also computes the output statistics
# such as the mean in regression and class probabilities in classification
# and parameter estimates in a tree that solves a moment equation.
# Internal structures
cdef bint proxy_children_impurity # Whether the value returned by children_impurity is only an approximation
cdef const DOUBLE_t[:, ::1] y # Values of y (y contains all the variables for node parameter estimation)
cdef DOUBLE_t* sample_weight # Sample weights
cdef SIZE_t n_outputs # Number of outputs
cdef SIZE_t n_relevant_outputs # The first n_relevant_outputs are the ones we care about
cdef SIZE_t n_features # Number of features
cdef SIZE_t n_y # The first n_y columns of the y matrix correspond to raw labels.
# The remainder are auxiliary variables required for parameter estimation
cdef UINT32_t random_state # A random seed for any internal randomness
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t start # samples[start:pos] are the samples in the left node
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
cdef SIZE_t end
cdef SIZE_t n_samples # Number of all samples in y (i.e. rows of y)
cdef SIZE_t max_node_samples # The maximum number of samples that can ever be contained in a node
# Used for memory space saving, as we need to allocate memory space for
# internal quantities that will store as many values as the number of samples
# in the current node under consideration. Providing this can save space
# allocation time.
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
cdef double weighted_n_samples # Weighted number of samples (in total)
cdef double weighted_n_node_samples # Weighted number of samples in the node
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node
cdef double* sum_total # For classification criteria, the sum of the
# weighted count of each label. For regression,
# the sum of w*y. sum_total[k] is equal to
# sum_{i=start}^{end-1} w[samples[i]]*y[samples[i], k],
# where k is output index.
cdef double* sum_left # Same as above, but for the left side of the split
cdef double* sum_right # same as above, but for the right side of the split
# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].
# Methods
cdef int init(self, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight, double weighted_n_samples,
SIZE_t* samples) nogil except -1
cdef int node_reset(self, SIZE_t start, SIZE_t end) nogil except -1
cdef int reset(self) nogil except -1
cdef int reverse_reset(self) nogil except -1
cdef int update(self, SIZE_t new_pos) nogil except -1
cdef double node_impurity(self) nogil
cdef double proxy_node_impurity(self) nogil
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil
cdef void node_value(self, double* dest) nogil
cdef void node_jacobian(self, double* dest) nogil
cdef void node_precond(self, double* dest) nogil
cdef double impurity_improvement(self, double impurity) nogil
cdef double proxy_impurity_improvement(self) nogil
cdef double min_eig_left(self) nogil
cdef double min_eig_right(self) nogil
cdef class RegressionCriterion(Criterion):
"""Abstract regression criterion."""
cdef double sq_sum_total # Stores sum_i sum_k y_{ik}^2, used for MSE calculation

551
econml/tree/_criterion.pyx Normal file
Просмотреть файл

@ -0,0 +1,551 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_criterion.pyx
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
from libc.stdlib cimport calloc
from libc.stdlib cimport free
from libc.string cimport memcpy
from libc.string cimport memset
from libc.math cimport fabs
import numpy as np
cimport numpy as np
np.import_array()
from ._utils cimport log
from ._utils cimport safe_realloc
from ._utils cimport sizet_ptr_to_ndarray
cdef class Criterion:
"""Interface for impurity criteria.
This object stores methods on how to calculate how good a split is using
different metrics.
"""
def __dealloc__(self):
"""Destructor."""
free(self.sum_total)
free(self.sum_left)
free(self.sum_right)
def __getstate__(self):
return {}
def __setstate__(self, d):
pass
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
double weighted_n_samples,
SIZE_t* samples) nogil except -1:
"""Placeholder for a method which will initialize the criterion.
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
Parameters
----------
y : array-like, dtype=DOUBLE_t
y is a buffer that can store values for variables required for parameter/value estimation
sample_weight : array-like, dtype=DOUBLE_t
The weight of each sample in y
weighted_n_samples : double
The total weight of all the samples whose indices are contained in the samples array
samples : array-like, dtype=SIZE_t
Indices of the samples in X and y, where samples[start:end]
correspond to the samples in this node
"""
pass
cdef int node_reset(self, SIZE_t start, SIZE_t end) nogil except -1:
""" Initialize a node calculation
Parameters
----------
start : SIZE_t
The first sample to be used on this node
end : SIZE_t
The last sample used on this node
"""
pass
cdef int reset(self) nogil except -1:
"""Reset the criterion at pos=start.
This method must be implemented by the subclass.
"""
pass
cdef int reverse_reset(self) nogil except -1:
"""Reset the criterion at pos=end.
This method must be implemented by the subclass.
"""
pass
cdef int update(self, SIZE_t new_pos) nogil except -1:
"""Updated statistics by moving samples[pos:new_pos] to the left child.
This updates the collected statistics by moving samples[pos:new_pos]
from the right child to the left child. It must be implemented by
the subclass.
Parameters
----------
new_pos : SIZE_t
New starting index position of the samples in the right child
"""
pass
cdef double node_impurity(self) nogil:
"""Placeholder for calculating the impurity of the node.
Placeholder for a method which will evaluate the impurity of
the current node, i.e. the impurity of samples[start:end]. This is the
primary function of the criterion class.
"""
pass
cdef double proxy_node_impurity(self) nogil:
""" A proxy for the node impurity to be used for min_impurity_decrease.
By default it is equivalent to node_impurity, unless overwritten by child class.
"""
return self.node_impurity()
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil:
"""Placeholder for calculating the impurity of children.
Placeholder for a method which evaluates the impurity in
children nodes, i.e. the impurity of samples[start:pos] + the impurity
of samples[pos:end].
Parameters
----------
impurity_left : double pointer
The memory address where the impurity of the left child should be
stored.
impurity_right : double pointer
The memory address where the impurity of the right child should be
stored
"""
pass
cdef void node_value(self, double* dest) nogil:
"""Placeholder for storing the node value.
Placeholder for a method which will compute the node value
of samples[start:end] and save the value into dest.
Parameters
----------
dest : double pointer
The memory address where the node value should be stored.
"""
pass
cdef void node_jacobian(self, double* dest) nogil:
"""Placeholder for storing the node jacobian value.
Placeholder for a method which will compute the node jacobian value in a linear
moment J(x) * theta(x) - precond(x) = 0 of samples[start:end] and save the value
into dest. If not implemented by child, raises an AttributeError if called.
Parameters
----------
dest : double pointer
The memory address where the node jacobian should be stored.
"""
with gil:
raise AttributeError("Criterion does not support jacobian calculation")
cdef void node_precond(self, double* dest) nogil:
"""Placeholder for storing the node precond value.
Placeholder for a method which will compute the node precond value in a linear
moment J(x) * theta(x) - precond(x) = 0 of samples[start:end] and save the value
into dest. If not implemented by child, raises an AttributeError if called.
Parameters
----------
dest : double pointer
The memory address where the node precond should be stored.
"""
with gil:
raise AttributeError("Criterion does not support preconditioned value calculation")
cdef double min_eig_left(self) nogil:
"""Placeholder for calculating proxy for minimum eigenvalue of the jacobian
of the left child of the current split. If not implemented by child, raises an AttributeError if called.
"""
with gil:
raise AttributeError("Criterion does not support jacobian and eigenvalue calculation!")
cdef double min_eig_right(self) nogil:
"""Placeholder for calculating proxy for minimum eigenvalue of the jacobian
of the right child of the current split. If not implemented by child, raises an AttributeError if called.
"""
with gil:
raise AttributeError("Criterion does not support jacobian and eigenvalue calculation!")
cdef double proxy_impurity_improvement(self) nogil:
"""Compute a proxy of the impurity reduction
This method is used to speed up the search for the best split.
It is a proxy quantity such that the split that maximizes this value
also maximizes the impurity improvement. It neglects all constant terms
of the impurity decrease for a given split.
The absolute impurity improvement is only computed by the
impurity_improvement method once the best split has been found.
"""
cdef double impurity_left
cdef double impurity_right
self.children_impurity(&impurity_left, &impurity_right)
return (- self.weighted_n_right * impurity_right
- self.weighted_n_left * impurity_left)
cdef double impurity_improvement(self, double impurity) nogil:
"""Compute the improvement in impurity
This method computes the improvement in impurity when a split occurs.
The weighted impurity improvement equation is the following:
N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)
where N is the total number of samples, N_t is the number of samples
at the current node, N_t_L is the number of samples in the left child,
and N_t_R is the number of samples in the right child,
Parameters
----------
impurity : double
The initial impurity of the node before the split
Return
------
double : improvement in impurity after the split occurs
"""
cdef double impurity_left
cdef double impurity_right
self.children_impurity(&impurity_left, &impurity_right)
return ((self.weighted_n_node_samples / self.weighted_n_samples) *
(impurity - (self.weighted_n_right /
self.weighted_n_node_samples * impurity_right)
- (self.weighted_n_left /
self.weighted_n_node_samples * impurity_left)))
# =============================================================================
# Regression Criterion
# =============================================================================
cdef class RegressionCriterion(Criterion):
r"""Abstract regression criterion.
This handles cases where the target is a continuous value, and is
evaluated by computing the variance of the target values left and right
of the split point. The computation takes linear time with `n_samples`
by using ::
var = \sum_i^n (y_i - y_bar) ** 2
= (\sum_i^n y_i ** 2) - n_samples * y_bar ** 2
"""
def __cinit__(self, SIZE_t n_outputs, SIZE_t n_relevant_outputs, SIZE_t n_features, SIZE_t n_y,
SIZE_t n_samples, SIZE_t max_node_samples, UINT32_t random_state):
"""Initialize parameters for this criterion.
Parameters
----------
n_outputs : SIZE_t
The number of parameters/values to be estimated
n_relevant_outputs : SIZE_t
We only care about the first n_relevant_outputs of these parameters/values
n_features : SIZE_t
The number of features
n_y : SIZE_t
The first n_y columns of the 2d matrix y, contain the raw labels y_{ik}, the rest are auxiliary variables
n_samples : SIZE_t
The total number of rows in the 2d matrix y
max_node_samples : SIZE_t
The maximum number of samples that can ever be contained in a node
random_state : UINT32_t
A random seed for any internal randomness
"""
# Default values
self.n_outputs = n_outputs
self.n_relevant_outputs = n_relevant_outputs
self.n_features = n_features
self.n_y = n_y
self.random_state = random_state
self.proxy_children_impurity = False
self.samples = NULL
self.start = 0
self.pos = 0
self.end = 0
self.n_samples = n_samples
self.max_node_samples = max_node_samples
self.n_node_samples = 0
self.weighted_n_node_samples = 0.0
self.weighted_n_left = 0.0
self.weighted_n_right = 0.0
self.sq_sum_total = 0.0
# Allocate accumulators. Make sure they are NULL, not uninitialized,
# before an exception can be raised (which triggers __dealloc__).
self.sum_total = NULL
self.sum_left = NULL
self.sum_right = NULL
# Allocate memory for the accumulators
self.sum_total = <double*> calloc(n_outputs, sizeof(double))
self.sum_left = <double*> calloc(n_outputs, sizeof(double))
self.sum_right = <double*> calloc(n_outputs, sizeof(double))
if (self.sum_total == NULL or
self.sum_left == NULL or
self.sum_right == NULL):
raise MemoryError()
def __reduce__(self):
return (type(self), (self.n_outputs, self.n_relevant_outputs, self.n_features, self.n_y,
self.n_samples, self.max_node_samples, self.random_state), self.__getstate__())
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
double weighted_n_samples,
SIZE_t* samples) nogil except -1:
# Initialize fields
self.y = y
self.sample_weight = sample_weight
self.samples = samples
self.weighted_n_samples = weighted_n_samples
return 0
cdef int node_reset(self, SIZE_t start, SIZE_t end) nogil except -1:
"""Initialize the criterion at node samples[start:end] and
children samples[start:start] and samples[start:end]."""
self.start = start
self.end = end
self.n_node_samples = end - start
self.weighted_n_node_samples = 0.
cdef SIZE_t i
cdef SIZE_t p
cdef SIZE_t k
cdef DOUBLE_t y_ik
cdef DOUBLE_t w_y_ik
cdef DOUBLE_t w = 1.0
self.sq_sum_total = 0.0
memset(self.sum_total, 0, self.n_outputs * sizeof(double))
for p in range(start, end):
i = self.samples[p]
if self.sample_weight != NULL:
w = self.sample_weight[i]
for k in range(self.n_outputs):
y_ik = self.y[i, k]
w_y_ik = w * y_ik
self.sum_total[k] += w_y_ik
self.sq_sum_total += w_y_ik * y_ik
self.weighted_n_node_samples += w
# Reset to pos=start
self.reset()
return 0
cdef int reset(self) nogil except -1:
"""Reset the criterion at pos=start."""
cdef SIZE_t n_bytes = self.n_outputs * sizeof(double)
memset(self.sum_left, 0, n_bytes)
memcpy(self.sum_right, self.sum_total, n_bytes)
self.weighted_n_left = 0
self.weighted_n_right = self.weighted_n_node_samples
self.pos = self.start
return 0
cdef int reverse_reset(self) nogil except -1:
"""Reset the criterion at pos=end."""
cdef SIZE_t n_bytes = self.n_outputs * sizeof(double)
memset(self.sum_right, 0, n_bytes)
memcpy(self.sum_left, self.sum_total, n_bytes)
self.weighted_n_right = 0
self.weighted_n_left = self.weighted_n_node_samples
self.pos = self.end
return 0
cdef int update(self, SIZE_t new_pos) nogil except -1:
"""Updated statistics by moving samples[pos:new_pos] to the left."""
cdef double* sum_left = self.sum_left
cdef double* sum_right = self.sum_right
cdef double* sum_total = self.sum_total
cdef SIZE_t* samples = self.samples
cdef DOUBLE_t* sample_weight = self.sample_weight
cdef SIZE_t pos = self.pos
cdef SIZE_t end = self.end
cdef SIZE_t i
cdef SIZE_t p
cdef SIZE_t k
cdef DOUBLE_t w = 1.0
# Update statistics up to new_pos
#
# Given that
# sum_left[x] + sum_right[x] = sum_total[x]
# and that sum_total is known, we are going to update
# sum_left from the direction that require the least amount
# of computations, i.e. from pos to new_pos or from end to new_pos.
if (new_pos - pos) <= (end - new_pos):
for p in range(pos, new_pos):
i = samples[p]
if sample_weight != NULL:
w = sample_weight[i]
for k in range(self.n_outputs):
sum_left[k] += w * self.y[i, k]
self.weighted_n_left += w
else:
self.reverse_reset()
for p in range(end - 1, new_pos - 1, -1):
i = samples[p]
if sample_weight != NULL:
w = sample_weight[i]
for k in range(self.n_outputs):
sum_left[k] -= w * self.y[i, k]
self.weighted_n_left -= w
self.weighted_n_right = (self.weighted_n_node_samples -
self.weighted_n_left)
for k in range(self.n_outputs):
sum_right[k] = sum_total[k] - sum_left[k]
self.pos = new_pos
return 0
cdef double node_impurity(self) nogil:
pass
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil:
pass
cdef void node_value(self, double* dest) nogil:
"""Compute the node value of samples[start:end] into dest."""
cdef SIZE_t k
for k in range(self.n_outputs):
dest[k] = self.sum_total[k] / self.weighted_n_node_samples
cdef class MSE(RegressionCriterion):
"""Mean squared error impurity criterion.
MSE = var_left + var_right
"""
cdef double node_impurity(self) nogil:
"""Evaluate the impurity of the current node, i.e. the impurity of
samples[start:end]."""
cdef double* sum_total = self.sum_total
cdef double impurity
cdef SIZE_t k
impurity = self.sq_sum_total / self.weighted_n_node_samples
for k in range(self.n_outputs):
impurity -= (sum_total[k] / self.weighted_n_node_samples)**2.0
return impurity / self.n_outputs
cdef double proxy_impurity_improvement(self) nogil:
"""Compute a proxy of the impurity reduction
This method is used to speed up the search for the best split.
It is a proxy quantity such that the split that maximizes this value
also maximizes the impurity improvement. It neglects all constant terms
of the impurity decrease for a given split.
The absolute impurity improvement is only computed by the
impurity_improvement method once the best split has been found.
"""
cdef double* sum_left = self.sum_left
cdef double* sum_right = self.sum_right
cdef SIZE_t k
cdef double proxy_impurity_left = 0.0
cdef double proxy_impurity_right = 0.0
for k in range(self.n_outputs):
proxy_impurity_left += sum_left[k] * sum_left[k]
proxy_impurity_right += sum_right[k] * sum_right[k]
return (proxy_impurity_left / self.weighted_n_left +
proxy_impurity_right / self.weighted_n_right)
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil:
"""Evaluate the impurity in children nodes, i.e. the impurity of the
left child (samples[start:pos]) and the impurity the right child
(samples[pos:end])."""
cdef DOUBLE_t* sample_weight = self.sample_weight
cdef SIZE_t* samples = self.samples
cdef SIZE_t pos = self.pos
cdef SIZE_t start = self.start
cdef double* sum_left = self.sum_left
cdef double* sum_right = self.sum_right
cdef DOUBLE_t y_ik
cdef double sq_sum_left = 0.0
cdef double sq_sum_right
cdef SIZE_t i
cdef SIZE_t p
cdef SIZE_t k
cdef DOUBLE_t w = 1.0
for p in range(start, pos):
i = samples[p]
if sample_weight != NULL:
w = sample_weight[i]
for k in range(self.n_outputs):
y_ik = self.y[i, k]
sq_sum_left += w * y_ik * y_ik
sq_sum_right = self.sq_sum_total - sq_sum_left
impurity_left[0] = sq_sum_left / self.weighted_n_left
impurity_right[0] = sq_sum_right / self.weighted_n_right
for k in range(self.n_outputs):
impurity_left[0] -= (sum_left[k] / self.weighted_n_left) ** 2.0
impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0
impurity_left[0] /= self.n_outputs
impurity_right[0] /= self.n_outputs

129
econml/tree/_splitter.pxd Normal file
Просмотреть файл

@ -0,0 +1,129 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pxd
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
# See _splitter.pyx for details.
import numpy as np
cimport numpy as np
from ._criterion cimport Criterion
from ._tree cimport DTYPE_t # Type of X
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport INT32_t # Signed 32 bit integer
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef struct SplitRecord:
# Data to track sample split
SIZE_t feature # Which feature to split on.
SIZE_t pos # Split samples array at the given position,
# i.e. count of samples below threshold for feature.
# pos is >= end if the node is a leaf.
SIZE_t pos_val # Split samples_val array at the given position,
# i.e. count of samples below threshold for feature.
# pos_val is >= end_val if the node is a leaf.
double threshold # Threshold to split at.
double improvement # Impurity improvement given parent node.
double impurity_left # Impurity of the left split on train set.
double impurity_right # Impurity of the right split on train set.
double impurity_left_val # Impurity of the left split on validation set.
double impurity_right_val # Impurity of the right split on validation set.
cdef class Splitter:
# The splitter searches in the input space for a feature and a threshold
# to split the samples samples[start:end] and samples_val[start_val:end_val].
#
# The impurity and value computations are delegated to a criterion and criterion_val object.
# Internal structures
cdef public Criterion criterion # Impurity criterion for train set calculations
cdef public Criterion criterion_val # Impurity criterion for val set calculations
cdef public SIZE_t max_features # Number of features to test
cdef public SIZE_t min_samples_leaf # Min samples in a leaf (on both train and val set)
cdef public double min_weight_leaf # Minimum weight in a leaf (on both train and val set)
cdef public double min_eig_leaf # Minimum value of proxy for the min eigenvalue of the jacobian (on train)
cdef public bint min_eig_leaf_on_val # Whether minimum eigenvalue constraint should also be enforced on val
cdef public double min_balancedness_tol # Tolerance level of how balanced a split can be (in [0, .5])
cdef public bint honest # Are we doing train/val honest splitting
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t n_samples # X.shape[0]
cdef double weighted_n_samples # Weighted number of samples
cdef SIZE_t* samples_val # Sample indices in Xval
cdef SIZE_t n_samples_val # Xval.shape[0]
cdef double weighted_n_samples_val # Weighted number of samples
cdef SIZE_t* features # Feature indices in X
cdef SIZE_t* constant_features # Constant features indices
cdef SIZE_t n_features # X.shape[1]
cdef DTYPE_t* feature_values # temp. array holding feature values
cdef DTYPE_t* feature_values_val # temp. array holding feature values from validation set
cdef SIZE_t start # Start position for the current node for samples
cdef SIZE_t end # End position for the current node for samples
cdef SIZE_t start_val # Start position for the current node for samples_val
cdef SIZE_t end_val # End position for the current node for samples_val
cdef const DTYPE_t[:, :] X
cdef const DOUBLE_t[:, ::1] y
cdef DOUBLE_t* sample_weight
# The samples vector `samples` is maintained by the Splitter object such
# that the train samples contained in a node are contiguous. With this setting,
# `node_split` reorganizes the node samples `samples[start:end]` in two
# subsets `samples[start:pos]` and `samples[pos:end]`.
# When `honest=True` then we also store a samples vector `samples_val` that
# is also maintained by the Splitter that keeps indices of the val set and
# such that the val samples contained in a node are contiguous in `samples_val`.
# `node_split` also reoganizes the node samples `samples_val[start_val:end_val]` in two
# subsets `samples_val[start_val:pos_val]` and `samples_val[pos_val:end_val]`.
# The 1-d `features` array of size n_features contains the features
# indices and allows fast sampling without replacement of features.
# The 1-d `constant_features` array of size n_features holds in
# `constant_features[:n_constant_features]` the feature ids with
# constant values for all the samples in the train set that reached a specific node.
# The value `n_constant_features` is given by the parent node to its
# child nodes. The content of the range `[n_constant_features:]` is left
# undefined, but preallocated for performance reasons
# This allows optimization with depth-based tree building.
# Methods
cdef int init_sample_inds(self, SIZE_t* samples,
const SIZE_t[::1] np_samples,
DOUBLE_t* sample_weight,
SIZE_t* n_samples, double* weighted_n_samples) nogil except -1
cdef int init(self, const DTYPE_t[:, :] X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
const SIZE_t[::1] np_samples_train,
const SIZE_t[::1] np_samples_val) nogil except -1
cdef int node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples,
SIZE_t start_val, SIZE_t end_val, double* weighted_n_node_samples_val) nogil except -1
cdef int node_split(self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1
cdef void node_value_val(self, double* dest) nogil
cdef void node_jacobian_val(self, double* dest) nogil
cdef void node_precond_val(self, double* dest) nogil
cdef double node_impurity(self) nogil
cdef double node_impurity_val(self) nogil
cdef double proxy_node_impurity(self) nogil
cdef double proxy_node_impurity_val(self) nogil
cdef bint is_children_impurity_proxy(self) nogil

761
econml/tree/_splitter.pyx Normal file
Просмотреть файл

@ -0,0 +1,761 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pyx
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
from ._criterion cimport Criterion
from libc.stdlib cimport free
from libc.string cimport memcpy
from libc.math cimport floor
import copy
import numpy as np
cimport numpy as np
np.import_array()
from ._utils cimport log
from ._utils cimport rand_int
from ._utils cimport rand_uniform
from ._utils cimport RAND_R_MAX
from ._utils cimport safe_realloc
cdef double INFINITY = np.inf
# Mitigate precision differences between 32 bit and 64 bit
cdef DTYPE_t FEATURE_THRESHOLD = 1e-7
cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos, SIZE_t start_pos_val) nogil:
self.impurity_left = INFINITY
self.impurity_right = INFINITY
self.impurity_left_val = INFINITY
self.impurity_right_val = INFINITY
self.pos = start_pos
self.pos_val = start_pos_val
self.feature = 0
self.threshold = 0.
self.improvement = -INFINITY
cdef class Splitter:
"""Abstract splitter class.
Splitters are called by tree builders to find the best splits, one split at a time.
"""
def __cinit__(self, Criterion criterion, Criterion criterion_val,
SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf,
DTYPE_t min_balancedness_tol, bint honest, double min_eig_leaf, bint min_eig_leaf_on_val,
UINT32_t random_state):
"""
Parameters
----------
criterion : Criterion
The criterion to measure the quality of a split on the train set.
criterion_val : Criterion
The criterion to be used to calculate quantities related to split quality on the val set.
max_features : SIZE_t
The maximal number of randomly selected features which can be
considered for a split.
min_samples_leaf : SIZE_t
The minimal number of samples each leaf can have, where splits
which would result in having less samples in a leaf are not considered.
Constraint is enforced on both the train and val set separately.
min_weight_leaf : SIZE_t
The minimal number of total weight of samples each leaf can have, where splits
which would result in having less weight in a leaf are not considered.
Constraint is enforced on both train and val set separately.
min_balancedness_tol : DTYPE_t
Tolerance level of how balanced a split can be (in [0, .5]) with
0 meaning split has to be fully balanced and .5 meaning no balancedness
constraint. Constraint is enforced on both train and val set separately.
honest : bint
Whether we should do honest splitting, i.e. train and val set are different.
min_eig_leaf : double
The minimum value of computationally fast proxies for the minimum eigenvalue
of the jacobian J(x) of a node in the case of linear moment equation trees:
J(x) * theta(x) - precond(x) = 0. The proxy used is defered and must be implemented
by the criterion objects, if min_eig_leaf >= 0.0, via the methods `min_eig_left()`
and `min_eig_right()` for the minimum eigenvalue proxy of the left and right child
correspondingly.
min_eig_leaf_on_val : bool
Whether the minimum eigenvalue constraint should also be enforced on the val set.
Should be used with caution as honesty is partially violated.
random_state : UINT32_t
The user inputed random seed to be used for pseudo-randomness
"""
self.criterion = criterion
if honest:
self.criterion_val = criterion_val
else:
self.criterion_val = criterion
self.features = NULL
self.n_features = 0
self.samples = NULL
self.n_samples = 0
self.samples_val = NULL
self.n_samples_val = 0
self.feature_values = NULL
self.feature_values_val = NULL
self.sample_weight = NULL
self.max_features = max_features
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.min_balancedness_tol = min_balancedness_tol
self.honest = honest
self.min_eig_leaf = min_eig_leaf
self.min_eig_leaf_on_val = min_eig_leaf_on_val
self.rand_r_state = random_state
def __dealloc__(self):
"""Destructor."""
free(self.samples)
free(self.features)
free(self.constant_features)
free(self.feature_values)
if self.honest:
free(self.samples_val)
free(self.feature_values_val)
def __getstate__(self):
return {}
def __setstate__(self, d):
pass
cdef int init_sample_inds(self, SIZE_t* samples,
const SIZE_t[::1] np_samples,
DOUBLE_t* sample_weight,
SIZE_t* n_samples,
double* weighted_n_samples) nogil except -1:
""" Initialize the cython sample index arrays `samples` based on the python
numpy array `np_samples`. Calculate total weight of samples as you go though the pass
and store it in the output variable `weighted_n_samples`. Update the number of samples
passed via the input/output variable `n_samples` to the number of *positively* weighted
samples, so that we only work with that subset.
"""
cdef SIZE_t i, j, ind
weighted_n_samples[0] = 0.0
j = 0
for i in range(np_samples.shape[0]):
ind = np_samples[i]
# Only work with positively weighted samples
if sample_weight == NULL or sample_weight[ind] > 0.0:
samples[j] = ind
j += 1
if sample_weight != NULL:
weighted_n_samples[0] += sample_weight[ind]
else:
weighted_n_samples[0] += 1.0
# Number of samples is number of positively weighted samples
n_samples[0] = j
cdef int init(self, const DTYPE_t[:, :] X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
const SIZE_t[::1] np_samples_train,
const SIZE_t[::1] np_samples_val) nogil except -1:
"""Initialize the splitter.
Take in the input data X, y and the train/val split. Returns -1 in case of failure to
allocate memory (and raise MemoryError) or 0 otherwise.
Parameters
----------
X : object
This contains the inputs. Usually it is a 2d numpy array.
y : ndarray, dtype=DOUBLE_t
This is the vector of targets, or true labels, for the samples
sample_weight : ndarray, dtype=DOUBLE_t
The sample weights
np_samples_train : ndarray, dtype=SIZE_t
The indices of the samples in the train set
np_samples_val : ndarray, dtype=SIZE_t
The indices of the samples in the val set
"""
cdef SIZE_t n_features = X.shape[1]
cdef SIZE_t n_samples = np_samples_train.shape[0]
# Create a new array which will be used to store nonzero weighted
# sample indices of the training set.
cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples)
# Initialize this array based on the numpy array np_samples_train
self.init_sample_inds(self.samples, np_samples_train, sample_weight,
&self.n_samples, &self.weighted_n_samples)
# Create an array that will store a permuted version of the feature indices
# such that throughout the execution the first n_constant_features in this
# array are features that have been deemed as constant valued and are ignored.
# See loop invariant in `node_split()` method below.
cdef SIZE_t* features = safe_realloc(&self.features, n_features)
for i in range(n_features):
features[i] = i
self.n_features = n_features
safe_realloc(&self.feature_values, self.n_samples) # Will store feature_values of the drawn feature
safe_realloc(&self.constant_features, self.n_features) # Used as helper storage
self.X = X
self.y = y
self.sample_weight = sample_weight
# Initialize criterion
self.criterion.init(self.y, self.sample_weight, self.weighted_n_samples, self.samples)
# If `honest=True` do initialize analogous validation set objects.
cdef SIZE_t n_samples_val
cdef SIZE_t* samples_val
if self.honest:
n_samples_val = np_samples_val.shape[0]
samples_val = safe_realloc(&self.samples_val, n_samples_val)
self.init_sample_inds(self.samples_val, np_samples_val, sample_weight,
&self.n_samples_val, &self.weighted_n_samples_val)
safe_realloc(&self.feature_values_val, self.n_samples_val)
self.criterion_val.init(self.y, self.sample_weight, self.weighted_n_samples_val,
self.samples_val)
else:
self.n_samples_val = self.n_samples
self.samples_val = self.samples
self.weighted_n_samples_val = self.weighted_n_samples
self.feature_values_val = self.feature_values
return 0
cdef int node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples,
SIZE_t start_val, SIZE_t end_val, double* weighted_n_node_samples_val) nogil except -1:
"""Reset splitter on node samples[start:end] on train set and [start_val:end_val] on val set.
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
Parameters
----------
start : SIZE_t
The index of the first sample to consider on the train set
end : SIZE_t
The index of the last sample to consider on the train set
weighted_n_node_samples : double*
On output, weighted_n_node_samples[0] stores the total weight of the training samples in the node
start_val : SIZE_t
The index of the first sample to consider on the val set
end_val : SIZE_t
The index of the last sample to consider on the val set
weighted_n_node_samples_val : double*
On output, weighted_n_node_samples_val[0] stores the total weight of the val samples in the node
"""
self.start = start
self.end = end
self.start_val = start_val
self.end_val = end_val
self.criterion.node_reset(start, end)
weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples
if self.honest:
self.criterion_val.node_reset(start_val, end_val)
weighted_n_node_samples_val[0] = self.criterion_val.weighted_n_node_samples
else:
weighted_n_node_samples_val[0] = self.criterion.weighted_n_node_samples
return 0
cdef int node_split(self, double impurity, SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1:
"""Find the best split on node samples[start:end].
This is a placeholder method. The majority of computation will be done
here.
It should return -1 upon errors.
"""
pass
cdef void node_value_val(self, double* dest) nogil:
"""Copy the value of node samples[start:end] into dest."""
self.criterion_val.node_value(dest)
cdef void node_jacobian_val(self, double* dest) nogil:
"""Copy the mean jacobian of node samples[start:end] into dest."""
self.criterion_val.node_jacobian(dest)
cdef void node_precond_val(self, double* dest) nogil:
"""Copy the mean precond of node samples[start:end] into dest."""
self.criterion_val.node_precond(dest)
cdef double node_impurity(self) nogil:
"""Return the impurity of the current node on the train set."""
return self.criterion.node_impurity()
cdef double node_impurity_val(self) nogil:
"""Return the impurity of the current node on the val set."""
return self.criterion_val.node_impurity()
cdef double proxy_node_impurity(self) nogil:
"""Return the impurity of the current node on the train set."""
return self.criterion.proxy_node_impurity()
cdef double proxy_node_impurity_val(self) nogil:
"""Return the impurity of the current node on the val set."""
return self.criterion_val.proxy_node_impurity()
cdef bint is_children_impurity_proxy(self) nogil:
"""Whether the criterion method children_impurity() returns an
accurate node impurity of the children or just some computationally efficient
approximation.
"""
return (self.criterion.proxy_children_impurity or
self.criterion_val.proxy_children_impurity)
cdef class BestSplitter(Splitter):
"""Splitter for finding the best split."""
def __reduce__(self):
return (BestSplitter, (self.criterion,
self.criterion_val,
self.max_features,
self.min_samples_leaf,
self.min_weight_leaf,
self.min_balancedness_tol,
self.honest,
self.min_eig_leaf,
self.min_eig_leaf_on_val,
self.random_state), self.__getstate__())
cdef int node_split(self, double impurity, SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1:
"""Find the best split on node samples[start:end]
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
Parameters
----------
impurity : double
The impurity of the current node to be split. Passed in explicitly as it might
be potentially be calculated by the builder from a prior call to children_impurity
when the parent node was split, so as to avoid double calculation
split : SplitRecord*
On output, it stores the information that describe the best split found
n_constant_features : SIZE_t*
On input, it contains the number of known constant features in the node (because
they were already constant in the parent node). On output, it contains the new
number of constant features, including the new ones that were found constant within this
node.
"""
# Find the best split
cdef SIZE_t* samples = self.samples
cdef SIZE_t start = self.start
cdef SIZE_t end = self.end
cdef SIZE_t* samples_val = self.samples_val
cdef SIZE_t start_val = self.start_val
cdef SIZE_t end_val = self.end_val
cdef SIZE_t* features = self.features
cdef SIZE_t* constant_features = self.constant_features
cdef SIZE_t n_features = self.n_features
cdef DTYPE_t* Xf = self.feature_values
cdef DTYPE_t* Xf_val = self.feature_values_val
cdef SIZE_t max_features = self.max_features
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
cdef double min_weight_leaf = self.min_weight_leaf
cdef double min_eig_leaf = self.min_eig_leaf
cdef UINT32_t* random_state = &self.rand_r_state
cdef SplitRecord best, current
cdef double current_proxy_improvement = -INFINITY
cdef double best_proxy_improvement = -INFINITY
cdef double current_threshold = 0.0
cdef double weighted_n_node_samples, weighted_n_samples, weighted_n_left, weighted_n_right
cdef SIZE_t f_i = n_features
cdef SIZE_t f_j
cdef SIZE_t p
cdef SIZE_t p_val
cdef SIZE_t i
cdef SIZE_t n_visited_features = 0
# Number of features discovered to be constant during the split search
cdef SIZE_t n_found_constants = 0
# Number of features known to be constant and drawn without replacement
cdef SIZE_t n_drawn_constants = 0
cdef SIZE_t n_known_constants = n_constant_features[0]
# n_total_constants = n_known_constants + n_found_constants
cdef SIZE_t n_total_constants = n_known_constants
cdef SIZE_t partition_end
_init_split(&best, end, end_val)
# Sample up to max_features without replacement using a
# Fisher-Yates-based algorithm (using the local variables `f_i` and
# `f_j` to compute a permutation of the `features` array).
#
# Skip the CPU intensive evaluation of the impurity criterion for
# features that were already detected as constant (hence not suitable
# for good splitting) by ancestor nodes and save the information on
# newly discovered constant features to spare computation on descendant
# nodes.
while (f_i > n_total_constants and # Stop early if remaining features
# are constant
(n_visited_features < max_features or
# At least one drawn features must be non constant
n_visited_features <= n_found_constants + n_drawn_constants)):
n_visited_features += 1
# Loop invariant: elements of features in
# - [0, n_drawn_constant) holds drawn and known constant features;
# - [n_drawn_constant, n_known_constant) holds known constant
# features that haven't been drawn yet;
# - [n_known_constant, n_total_constant) holds newly found constant
# features;
# - [n_total_constant, f_i) holds features that haven't been drawn
# yet and aren't constant apriori.
# - [f_i, n_features) holds features that have been drawn
# and aren't constant.
# TODO. Add some randomness that rejects some splits so that each the threshold on
# each feature is sufficiently random and so that the number of features that
# are inspected is also random. This is required by the theory so that every feature
# will have the possibility to be the one that is split upon.
# Draw a feature at random
f_j = rand_int(n_drawn_constants, f_i - n_found_constants,
random_state)
if f_j < n_known_constants:
# f_j in the interval [n_drawn_constants, n_known_constants[
features[n_drawn_constants], features[f_j] = features[f_j], features[n_drawn_constants]
n_drawn_constants += 1
else:
# f_j in the interval [n_known_constants, f_i - n_found_constants[
f_j += n_found_constants
# f_j in the interval [n_total_constants, f_i[
current.feature = features[f_j]
# Sort samples along that feature; by
# copying the values into an array and
# sorting the array in a manner which utilizes the cache more
# effectively.
for i in range(start, end):
Xf[i] = self.X[samples[i], current.feature]
sort(Xf + start, samples + start, end - start)
if self.honest:
for i in range(start_val, end_val):
Xf_val[i] = self.X[samples_val[i], current.feature]
sort(Xf_val + start_val, samples_val + start_val, end_val - start_val)
if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD:
features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]
n_found_constants += 1
n_total_constants += 1
else:
f_i -= 1
features[f_i], features[f_j] = features[f_j], features[f_i]
# Evaluate all splits
self.criterion.reset() # Reset criterion to start evaluating splits in increasing feature manner
if self.honest:
self.criterion_val.reset() # If honest, then reset val criterion too
# We know that by balancedness we must start by at least this position of the node
p = start + <int>floor((.5 - self.min_balancedness_tol) * (end - start)) - 1
p_val = start_val # p_val will track p so no need to add the offset
while p < end and p_val < end_val:
# We find equivalent values up to floating point precision
while (p + 1 < end and
Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD):
p += 1
# (p + 1 >= end) or (X[samples[p + 1], current.feature] >
# X[samples[p], current.feature])
p += 1
# (p >= end) or (X[samples[p], current.feature] >
# X[samples[p - 1], current.feature])
# we set the threshold to be the mid-point between the two feature values
current_threshold = Xf[p] / 2.0 + Xf[p - 1] / 2.0
if ((current_threshold == Xf[p]) or
(current_threshold == INFINITY) or
(current_threshold == -INFINITY)):
current_threshold = Xf[p - 1]
# We need to advance p_val such that if we partition samples_val[start_val:end_val]
# into samples_val[start_val:best.pos_val] and samples_val[best:pos_val:end_val], then
# the first part contains all samples in Xval that are below the threshold. Thus we need
# to advance p_val, until Xf_val[p_val] is the first p such that Xf_val[p] > threshold.
if self.honest:
while (p_val < end_val and
Xf_val[p_val] <= current_threshold):
p_val += 1
else:
p_val = p # If not honest then p_val is same as p
if p < end and p_val < end_val:
current.pos = p
current.pos_val = p_val
# Reject if imbalanced on either train or val set. We know that the first
# direction on the train set is guaranteed due to the offset we added to the
# starting point.
if (end - current.pos) < (.5 - self.min_balancedness_tol) * (end - start):
break
if (current.pos_val - start_val) < (.5 - self.min_balancedness_tol) * (end_val - start_val):
continue
if (end_val - current.pos_val) < (.5 - self.min_balancedness_tol) * (end_val - start_val):
break
# Reject if min_samples_leaf is not guaranteed
if (current.pos - start) < min_samples_leaf:
continue
if (end - current.pos) < min_samples_leaf:
break
# Reject if min_samples_leaf is not guaranteed on val
if (current.pos_val - start_val) < min_samples_leaf:
continue
if (end_val - current.pos_val) < min_samples_leaf:
break
# If nothing is rejected, then update the criterion to hold info on the split
# at position current.pos
self.criterion.update(current.pos)
if self.honest:
self.criterion_val.update(current.pos_val) # similarly for criterion_val if honest
# Reject if min_weight_leaf is not satisfied
if self.criterion.weighted_n_left < min_weight_leaf:
continue
if self.criterion.weighted_n_right < min_weight_leaf:
break
# Reject if minimum eigenvalue proxy requirement is not satisfied on train
# We do not check this constraint on val, since the eigenvalue proxy can depend on
# label information and we will be violating honesty.
if min_eig_leaf >= 0.0:
if self.criterion.min_eig_left() < min_eig_leaf:
continue
if self.criterion.min_eig_right() < min_eig_leaf:
continue
if self.min_eig_leaf_on_val:
if self.criterion_val.min_eig_left() < min_eig_leaf:
continue
if self.criterion_val.min_eig_right() < min_eig_leaf:
continue
# Reject if min_weight_leaf constraint is violated
if self.honest:
if self.criterion_val.weighted_n_left < min_weight_leaf:
continue
if self.criterion_val.weighted_n_right < min_weight_leaf:
break
# Calculate fast version of impurity_improvement of the split to be used for ranking splits
current_proxy_improvement = self.criterion.proxy_impurity_improvement()
if current_proxy_improvement > best_proxy_improvement:
best_proxy_improvement = current_proxy_improvement
# sum of halves is used to avoid infinite value
current.threshold = current_threshold
best = current # copy
# Reorganize into samples[start:best.pos] + samples[best.pos:end]
if best.pos < end and best.pos_val < end_val:
partition_end = end
p = start
while p < partition_end:
if self.X[samples[p], best.feature] <= best.threshold:
p += 1
else:
partition_end -= 1
samples[p], samples[partition_end] = samples[partition_end], samples[p]
if self.honest:
partition_end = end_val
p = start_val
while p < partition_end:
if self.X[samples_val[p], best.feature] <= best.threshold:
p += 1
else:
partition_end -= 1
samples_val[p], samples_val[partition_end] = samples_val[partition_end], samples_val[p]
self.criterion.reset()
self.criterion.update(best.pos)
if self.honest:
self.criterion_val.reset()
self.criterion_val.update(best.pos_val)
# Calculate a more accurate version of impurity improvement using the input baseline impurity
# passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate
# this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity()
# on the parent node, when that node was split.
best.improvement = self.criterion.impurity_improvement(impurity)
# if we need children impurities by the builder, then we populate these entries
# otherwise, we leave them blank to avoid the extra computation.
if not self.is_children_impurity_proxy():
self.criterion.children_impurity(&best.impurity_left, &best.impurity_right)
if self.honest:
self.criterion_val.children_impurity(&best.impurity_left_val,
&best.impurity_right_val)
else:
best.impurity_left_val = best.impurity_left
best.impurity_right_val = best.impurity_right
# Respect invariant for constant features: the original order of
# element in features[:n_known_constants] must be preserved for sibling
# and child nodes
memcpy(features, constant_features, sizeof(SIZE_t) * n_known_constants)
# Copy newly found constant features
memcpy(constant_features + n_known_constants,
features + n_known_constants,
sizeof(SIZE_t) * n_found_constants)
# Return values
split[0] = best
n_constant_features[0] = n_total_constants
return 0
# Sort n-element arrays pointed to by Xf and samples, simultaneously,
# by the values in Xf. Algorithm: Introsort (Musser, SP&E, 1997).
cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil:
if n == 0:
return
cdef int maxd = 2 * <int>log(n)
introsort(Xf, samples, n, maxd)
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples,
SIZE_t i, SIZE_t j) nogil:
# Helper for sort
Xf[i], Xf[j] = Xf[j], Xf[i]
samples[i], samples[j] = samples[j], samples[i]
cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil:
# Median of three pivot selection, after Bentley and McIlroy (1993).
# Engineering a sort function. SP&E. Requires 8/3 comparisons on average.
cdef DTYPE_t a = Xf[0], b = Xf[n / 2], c = Xf[n - 1]
if a < b:
if b < c:
return b
elif a < c:
return c
else:
return a
elif b < c:
if a < c:
return a
else:
return c
else:
return b
# Introsort with median of 3 pivot selection and 3-way partition function
# (robust to repeated elements, e.g. lots of zero features).
cdef void introsort(DTYPE_t* Xf, SIZE_t *samples,
SIZE_t n, int maxd) nogil:
cdef DTYPE_t pivot
cdef SIZE_t i, l, r
while n > 1:
if maxd <= 0: # max depth limit exceeded ("gone quadratic")
heapsort(Xf, samples, n)
return
maxd -= 1
pivot = median3(Xf, n)
# Three-way partition.
i = l = 0
r = n
while i < r:
if Xf[i] < pivot:
swap(Xf, samples, i, l)
i += 1
l += 1
elif Xf[i] > pivot:
r -= 1
swap(Xf, samples, i, r)
else:
i += 1
introsort(Xf, samples, l, maxd)
Xf += r
samples += r
n -= r
cdef inline void sift_down(DTYPE_t* Xf, SIZE_t* samples,
SIZE_t start, SIZE_t end) nogil:
# Restore heap order in Xf[start:end] by moving the max element to start.
cdef SIZE_t child, maxind, root
root = start
while True:
child = root * 2 + 1
# find max of root, left child, right child
maxind = root
if child < end and Xf[maxind] < Xf[child]:
maxind = child
if child + 1 < end and Xf[maxind] < Xf[child + 1]:
maxind = child + 1
if maxind == root:
break
else:
swap(Xf, samples, root, maxind)
root = maxind
cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil:
cdef SIZE_t start, end
# heapify
start = (n - 2) / 2
end = n
while True:
sift_down(Xf, samples, start, end)
if start == 0:
break
start -= 1
# sort by shrinking the heap, putting the max element immediately after it
end = n - 1
while end > 0:
swap(Xf, samples, 0, end)
sift_down(Xf, samples, 0, end)
end = end - 1

120
econml/tree/_tree.pxd Normal file
Просмотреть файл

@ -0,0 +1,120 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pxd
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
# See _tree.pyx for details.
import numpy as np
cimport numpy as np
ctypedef np.npy_float64 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
from ._splitter cimport Splitter
from ._splitter cimport SplitRecord
cdef struct Node:
# Base storage structure for the nodes in a Tree object
SIZE_t left_child # id of the left child of the node
SIZE_t right_child # id of the right child of the node
SIZE_t depth # the depth level of the node
SIZE_t feature # Feature used for splitting the node
DOUBLE_t threshold # Threshold value at the node
DOUBLE_t impurity # Impurity of the node on the val set
SIZE_t n_node_samples # Number of samples at the node on the val set
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node on the val set
DOUBLE_t impurity_train # Impurity of the node on the training set
SIZE_t n_node_samples_train # Number of samples at the node on the training set
DOUBLE_t weighted_n_node_samples_train # Weighted number of samples at the node on the training set
cdef class Tree:
# The Tree object is a binary tree structure constructed by the
# TreeBuilder. The tree structure is used for predictions and
# feature importances.
# Input/Output layout
cdef public SIZE_t n_features # Number of features in X
cdef public SIZE_t n_outputs # Number of parameters estimated at each node
cdef public SIZE_t n_relevant_outputs # Prefix of the parameters that we care about
cdef SIZE_t* n_classes # Legacy from sklearn for compatibility. Number of classes in classification
cdef public SIZE_t max_n_classes # Number of classes for each output coordinate
# Inner structures: values are stored separately from node structure,
# since size is determined at runtime.
cdef public SIZE_t max_depth # Max depth of the tree
cdef public SIZE_t node_count # Counter for node IDs
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
cdef Node* nodes # Array of nodes
cdef double* value # (capacity, n_outputs, max_n_classes) array of values
cdef SIZE_t value_stride # = n_outputs * max_n_classes
cdef bint store_jac # wether to store jacobian and precond information
cdef double* jac # node jacobian in linear moment: J(x) * theta - precond(x) = 0
cdef SIZE_t jac_stride # = n_outputs * n_outputs
cdef double* precond # node preconditioned value
cdef SIZE_t precond_stride # = n_outputs
# Methods
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold,
double impurity_train, SIZE_t n_node_samples_train,
double weighted_n_samples_train,
double impurity_val, SIZE_t n_node_samples_val,
double weighted_n_samples_val) nogil except -1
cdef int _resize(self, SIZE_t capacity) nogil except -1
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
cdef np.ndarray _get_value_ndarray(self)
cdef np.ndarray _get_jac_ndarray(self)
cdef np.ndarray _get_precond_ndarray(self)
cdef np.ndarray _get_node_ndarray(self)
cpdef np.ndarray predict(self, object X)
cpdef np.ndarray predict_jac(self, object X)
cpdef np.ndarray predict_precond(self, object X)
cpdef predict_precond_and_jac(self, object X)
cpdef np.ndarray predict_full(self, object X)
cpdef np.ndarray apply(self, object X)
cdef np.ndarray _apply(self, object X)
cpdef object decision_path(self, object X)
cdef object _decision_path(self, object X)
cpdef compute_feature_importances(self, normalize=*, max_depth=*, depth_decay=*)
cpdef compute_feature_heterogeneity_importances(self, normalize=*, max_depth=*, depth_decay=*)
# =============================================================================
# Tree builder
# =============================================================================
cdef class TreeBuilder:
# The TreeBuilder recursively builds a Tree object from training samples,
# using a Splitter object for splitting internal nodes and assigning
# values to leaves.
#
# This class controls the various stopping criteria and the node splitting
# evaluation order, e.g. depth-first or best-first.
cdef Splitter splitter # Splitting algorithm
cdef SIZE_t min_samples_split # Minimum number of samples in an internal node
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
cdef double min_weight_leaf # Minimum weight in a leaf
cdef SIZE_t max_depth # Maximal tree depth
cdef double min_impurity_decrease # Impurity threshold for early stopping
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray samples_train,
np.ndarray samples_val,
np.ndarray sample_weight=*,
bint store_jac=*)
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)

973
econml/tree/_tree.pyx Normal file
Просмотреть файл

@ -0,0 +1,973 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
from cpython cimport Py_INCREF, PyObject, PyTypeObject
from libc.stdlib cimport free
from libc.string cimport memcpy
from libc.string cimport memset
from libc.stdint cimport SIZE_MAX
from libc.math cimport pow
import numpy as np
cimport numpy as np
np.import_array()
from scipy.sparse import csr_matrix
from ._utils cimport Stack
from ._utils cimport StackRecord
from ._utils cimport safe_realloc
from ._utils cimport sizet_ptr_to_ndarray
cdef extern from "numpy/arrayobject.h":
object PyArray_NewFromDescr(PyTypeObject* subtype, np.dtype descr,
int nd, np.npy_intp* dims,
np.npy_intp* strides,
void* data, int flags, object obj)
# =============================================================================
# Types and constants
# =============================================================================
from numpy import float64 as DTYPE
from numpy import float64 as DOUBLE
cdef double INFINITY = np.inf
cdef double EPSILON = np.finfo('double').eps
TREE_LEAF = -1
TREE_UNDEFINED = -2
cdef SIZE_t _TREE_LEAF = TREE_LEAF
cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED
cdef SIZE_t INITIAL_STACK_SIZE = 10
# The definition of a numpy type to be used for converting the malloc'ed memory space that
# contains an array of Node struct's into a structured numpy parallel array that can be as
# array[key][index].
NODE_DTYPE = np.dtype({
'names': ['left_child', 'right_child', 'depth', 'feature', 'threshold',
'impurity', 'n_node_samples', 'weighted_n_node_samples',
'impurity_train', 'n_node_samples_train', 'weighted_n_node_samples_train'],
'formats': [np.intp, np.intp, np.intp, np.intp, np.float64,
np.float64, np.intp, np.float64,
np.float64, np.intp, np.float64],
'offsets': [
<Py_ssize_t> &(<Node*> NULL).left_child,
<Py_ssize_t> &(<Node*> NULL).right_child,
<Py_ssize_t> &(<Node*> NULL).depth,
<Py_ssize_t> &(<Node*> NULL).feature,
<Py_ssize_t> &(<Node*> NULL).threshold,
<Py_ssize_t> &(<Node*> NULL).impurity,
<Py_ssize_t> &(<Node*> NULL).n_node_samples,
<Py_ssize_t> &(<Node*> NULL).weighted_n_node_samples,
<Py_ssize_t> &(<Node*> NULL).impurity_train,
<Py_ssize_t> &(<Node*> NULL).n_node_samples_train,
<Py_ssize_t> &(<Node*> NULL).weighted_n_node_samples_train,
]
})
cdef class TreeBuilder:
"""Interface for different tree building strategies."""
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray samples_train,
np.ndarray samples_val,
np.ndarray sample_weight=None,
bint store_jac=False):
"""Build tree from the training set (X, y) using samples_train for the
constructing the splits and samples_val for estimating the node values. store_jac
controls whether jacobian information is stored in the tree nodes in the case of
generalized random forests.
"""
pass
cdef inline _check_input(self, object X, np.ndarray y, np.ndarray sample_weight):
"""Check input dtype, layout and format"""
# since we have to copy and perform linear algebra we will make it fortran for efficiency
if X.dtype != DTYPE:
X = np.asfortranarray(X, dtype=DTYPE)
if y.dtype != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
if (sample_weight is not None and
(sample_weight.dtype != DOUBLE or
not sample_weight.flags.contiguous)):
sample_weight = np.asarray(sample_weight, dtype=DOUBLE,
order="C")
return X, y, sample_weight
# Depth first builder ---------------------------------------------------------
cdef class DepthFirstTreeBuilder(TreeBuilder):
"""Build a tree in depth-first fashion."""
def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
SIZE_t min_samples_leaf, double min_weight_leaf,
SIZE_t max_depth, double min_impurity_decrease):
""" Initialize parameters.
Parameters
----------
splitter : cython extension class of type Splitter
The splitter to be used for deciding the best split of each node.
min_samples_split : SIZE_t
The minimum number of samples required for a node to be considered for splitting
min_samples_leaf : SIZE_t
The minimum number of samples that each node must contain
min_weight_leaf : double
The minimum total weight of samples that each node must contain
max_depth : SIZE_t
The maximum depth of the tree
min_impurity_decrease : SIZE_t
The minimum improvement in impurity that a split must provide to be executed
"""
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.min_impurity_decrease = min_impurity_decrease
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray samples_train,
np.ndarray samples_val,
np.ndarray sample_weight=None,
bint store_jac=False):
"""Build an honest tree from data (X, y).
Parameters
----------
X : (n, d) np.array
The features to use for splitting
y : (n, p) np.array
Any information used by the criterion to calculate the node values for a given node defined by
the X's
samples_train : (n,) np.array of type np.intp
The indices of the samples in X to be used for creating the splits of the tree (training set).
samples_val : (n,) np.array of type np.intp
The indices of the samples in X to be used for calculating the node values of the tree (val set).
sample_weight : (n,) np.array of type np.float64
The weight of each sample
store_jac : bool, optional (default=False)
Whether jacobian information should be stored in the tree nodes by calling the node_jacobian_val
and node_precond_val of the splitter. This is related to trees that solve linear moment equations
of the form: J(x) * theta(x) - precond(x) = 0. If store_jac=True, then J(x) and precond(x) are also
stored at the tree nodes at the end of the build for easy access.
"""
# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
# Initial capacity
cdef int init_capacity
if tree.max_depth <= 10:
init_capacity = (2 ** (tree.max_depth + 1)) - 1
else:
init_capacity = 2047
tree._resize(init_capacity)
# Parameters
cdef Splitter splitter = self.splitter
cdef SIZE_t max_depth = self.max_depth
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
cdef double min_weight_leaf = self.min_weight_leaf
cdef SIZE_t min_samples_split = self.min_samples_split
cdef double min_impurity_decrease = self.min_impurity_decrease
# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr, samples_train, samples_val)
# The indices of all samples are stored in two arrays, `samples` and `samples_val` by the splitter,
# such that each node contains the samples from `start` to `end` in the samples array and from
# `start_val` to `end_val` in the `samples_val` array. Thus these four numbers are sufficient to
# represent a "node" of the tree during building.
cdef SIZE_t start
cdef SIZE_t end
cdef SIZE_t start_val
cdef SIZE_t end_val
cdef SIZE_t depth # The depth of the current node considered for splitting
cdef SIZE_t parent # The parent of the current node considered for splitting
cdef bint is_left # Whether the current node considered for splitting is a left or right child
cdef SIZE_t n_node_samples = splitter.n_samples # Total number of training samples
cdef double weighted_n_samples = splitter.weighted_n_samples # Total weight of training samples
cdef double weighted_n_node_samples # Will be storing the total training weight of the current node
cdef SIZE_t n_node_samples_val = splitter.n_samples_val # Total number of val samples
cdef double weighted_n_samples_val = splitter.weighted_n_samples_val # Total weight of val samples
cdef double weighted_n_node_samples_val # Will be storing the total val weight of the current node
cdef SplitRecord split # A split record is a struct produced by the splitter that contains all the split info
cdef SIZE_t node_id # Will be storing the id that the tree assigns to a node when added to the tree
cdef double impurity = INFINITY # Will be storing the impurity of the node considered for splitting
cdef double proxy_impurity = INFINITY # An approximate version of the impurity used for min impurity decrease
cdef SIZE_t n_constant_features # number of features identified as taking a constant value in the node
cdef bint is_leaf # Whether the node we are about to add to the tree is a leaf
cdef bint first = 1 # If this is the root node we are splitting
cdef SIZE_t max_depth_seen = -1 # Max depth we've seen so far
cdef int rc = 0 # To be used as a success flag for memory resizing calls
cdef Stack stack = Stack(INITIAL_STACK_SIZE) # A stack contains the entries of all nodes to be considered
cdef StackRecord stack_record # A stack record contains all the information required to split a node
with nogil:
# push root node onto stack
rc = stack.push(0, n_node_samples, 0, n_node_samples_val,
0, _TREE_UNDEFINED, 0, INFINITY, INFINITY, 0)
if rc == -1:
# got return code -1 - out-of-memory
with gil:
raise MemoryError()
while not stack.is_empty():
stack.pop(&stack_record) # Let's pop a node from the stack to split
# Let's store the stack record in local values for easy access and manipulation
start = stack_record.start
end = stack_record.end
start_val = stack_record.start_val
end_val = stack_record.end_val
depth = stack_record.depth
parent = stack_record.parent
is_left = stack_record.is_left
impurity = stack_record.impurity
impurity_val = stack_record.impurity_val
n_constant_features = stack_record.n_constant_features
# Some easy calculations
n_node_samples = end - start
n_node_samples_val = end_val - start_val
# Let's reset the splitter to the initial state of considering the current node to split
# This will also return the total weight of the node in the training and validation set
# in the two variables passed by reference.
splitter.node_reset(start, end, &weighted_n_node_samples,
start_val, end_val, &weighted_n_node_samples_val)
# Determine if the node is a leaf based on simple constraints on the training and val set
is_leaf = (depth >= max_depth or
n_node_samples < min_samples_split or
n_node_samples < 2 * min_samples_leaf or
weighted_n_node_samples < 2 * min_weight_leaf or
n_node_samples_val < min_samples_split or
n_node_samples_val < 2 * min_samples_leaf or
weighted_n_node_samples_val < 2 * min_weight_leaf)
# If either the splitter only returns approximate children impurities at the end of each split
# or if we are in the root node, then we need to calculate node impurity for both evaluating
# the min_impurity_decrease constraint and for storing the impurity at the tree. This is done
# because sometimes node impurity might be computationally intensive to calculate and can be easily
# done once the splitter has calculated all the quantities at `node_reset`, but would add too much
# computational burden if done twice (once when considering the node for splitting and once when
# calculating the node's impurity when it is the children of a node that has just been split). Thus
# sometimes we just want to calculate an approximate child node impurity, solely for the purpose
# of evaluating whether the min_impurity_decrease constraint is satisfied by the returned split.
if (splitter.is_children_impurity_proxy()) or first:
# This is the baseline of what we should use for impurity improvement
proxy_impurity = splitter.proxy_node_impurity()
# This is the true node impurity we want to store in the tree. The two should coincide if
# `splitter.is_children_impurity_proxy()==False`.
impurity = splitter.node_impurity() # The node impurity on the training set
impurity_val = splitter.node_impurity_val() # The node impurity on the val set
first = 0
else:
# We use the impurity value stored in the stack, which was returned by children_impurity()
# when the parent node of this node was split.
proxy_impurity = impurity
if not is_leaf:
# Find the best split of the node and return it in the `split` variable
# Also use the fact that so far we have deemed `n_constant_features` to be constant in the
# parent node and at the end return the number of features that have been deemed as taking
# constant value, by updating the `n_constant_features` which is passed by reference.
# This is used for speeding up computation as these features are not considered further for
# splitting. This speed up is also enabled by the fact that we are doing depth-first-search
# build.
splitter.node_split(proxy_impurity, &split, &n_constant_features)
# Note from original sklearn comments: If EPSILON=0 in the below comparison, float precision
# issues stop splitting, producing trees that are dissimilar to v0.18
is_leaf = (is_leaf or
split.pos >= end or # no split of the training set was valid
split.pos_val >= end_val or # no split of the validation set was valid
(split.improvement + EPSILON < min_impurity_decrease)) # min impurity is violated
# Add the node that was just split to the tree, with all the auxiliary information and
# get the `node_id` assigned to it.
node_id = tree._add_node(parent, is_left, is_leaf,
split.feature, split.threshold,
impurity, n_node_samples, weighted_n_node_samples,
impurity_val, n_node_samples_val, weighted_n_node_samples_val)
# Memory error
if node_id == SIZE_MAX:
rc = -1
break
# Store value for all nodes, to facilitate tree/model inspection and interpretation
splitter.node_value_val(tree.value + node_id * tree.value_stride)
# If we are in a linear moment case and we want to store the node jacobian and node precond,
# i.e. value = Jacobian^{-1} @ precond
if store_jac:
splitter.node_jacobian_val(tree.jac + node_id * tree.jac_stride)
splitter.node_precond_val(tree.precond + node_id * tree.precond_stride)
if not is_leaf:
# Push right child on stack
rc = stack.push(split.pos, end, split.pos_val, end_val, depth + 1, node_id, 0,
split.impurity_right, split.impurity_right_val, n_constant_features)
if rc == -1:
break
# Push left child on stack
rc = stack.push(start, split.pos, start_val, split.pos_val, depth + 1, node_id, 1,
split.impurity_left, split.impurity_left_val, n_constant_features)
if rc == -1:
break
if depth > max_depth_seen:
max_depth_seen = depth
# Resize the tree to use the minimal required memory
if rc >= 0:
rc = tree._resize_c(tree.node_count)
# Update trees max_depth variable for the maximum seen depth
if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError()
# =============================================================================
# Tree
# =============================================================================
cdef class Tree:
# This is only used for compatibility with sklearn trees. In sklearn this represents the number of classes
# in a classification tree for each target output. Here it is always an array of 1's of size `self.n_outputs`.
property n_classes:
def __get__(self):
return sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)
property children_left:
def __get__(self):
return self._get_node_ndarray()['left_child'][:self.node_count]
property children_right:
def __get__(self):
return self._get_node_ndarray()['right_child'][:self.node_count]
property depth:
def __get__(self):
return self._get_node_ndarray()['depth'][:self.node_count]
property n_leaves:
def __get__(self):
return np.sum(np.logical_and(
self.children_left == -1,
self.children_right == -1))
property feature:
def __get__(self):
return self._get_node_ndarray()['feature'][:self.node_count]
property threshold:
def __get__(self):
return self._get_node_ndarray()['threshold'][:self.node_count]
property impurity:
def __get__(self):
return self._get_node_ndarray()['impurity'][:self.node_count]
property n_node_samples:
def __get__(self):
return self._get_node_ndarray()['n_node_samples'][:self.node_count]
property weighted_n_node_samples:
def __get__(self):
return self._get_node_ndarray()['weighted_n_node_samples'][:self.node_count]
property impurity_train:
def __get__(self):
return self._get_node_ndarray()['impurity_train'][:self.node_count]
property n_node_samples_train:
def __get__(self):
return self._get_node_ndarray()['n_node_samples_train'][:self.node_count]
property weighted_n_node_samples_train:
def __get__(self):
return self._get_node_ndarray()['weighted_n_node_samples_train'][:self.node_count]
# Value returns the relevant parameters estimated at each node (the ones we care about)
property value:
def __get__(self):
return self._get_value_ndarray()[:self.node_count, :self.n_relevant_outputs]
# Value returns all the parameters estimated at each node (even the nuisance ones we don't care about)
property full_value:
def __get__(self):
return self._get_value_ndarray()[:self.node_count]
# The jacobian J(x) of the node, for the case of linear moment trees with moment: J(x) * theta(x) - precond(x) = 0
property jac:
def __get__(self):
if not self.store_jac:
raise AttributeError("Jacobian computation was not enabled. Set store_jac=True")
return self._get_jac_ndarray()[:self.node_count]
# The precond(x) of the node, for the case of linear moment trees with moment: J(x) * theta(x) - precond(x) = 0
property precond:
def __get__(self):
if not self.store_jac:
raise AttributeError("Preconditioned quantity computation was not enabled. Set store_jac=True")
return self._get_precond_ndarray()[:self.node_count]
def __cinit__(self, int n_features, int n_outputs, int n_relevant_outputs=-1, bint store_jac=False):
""" Initialize parameters
Parameters
----------
n_features : int
Number of features X at train time
n_outputs : int
How many parameters/outputs are stored/estimated at each node
n_relevant_outputs : int, optional (default=-1)
Which prefix of the parameters do we care about. The remainder are nuisance parameters.
If `n_relevant_outputs=-1`, then all parameters are relevant.
store_jac : bool, optional (default=False)
Whether we will be storing jacobian and precond of linear moments information at each node.
"""
self.n_features = n_features
self.n_outputs = n_outputs
self.n_relevant_outputs = n_relevant_outputs if n_relevant_outputs > 0 else n_outputs
self.value_stride = n_outputs
self.n_classes = NULL
safe_realloc(&self.n_classes, n_outputs)
self.max_n_classes = 1
cdef SIZE_t k
for k in range(n_outputs):
self.n_classes[k] = 1
# Inner structures
self.max_depth = 0
self.node_count = 0
self.capacity = 0
self.value = NULL
self.nodes = NULL
self.store_jac = store_jac
self.jac = NULL
self.jac_stride = n_outputs * n_outputs
self.precond = NULL
self.precond_stride = n_outputs
def __dealloc__(self):
"""Destructor."""
# Free all inner structures
free(self.value)
free(self.nodes)
if self.store_jac:
free(self.jac)
free(self.precond)
def __reduce__(self):
"""Reduce re-implementation, for pickling."""
return (Tree, (self.n_features, self.n_outputs,
self.n_relevant_outputs, self.store_jac), self.__getstate__())
def __getstate__(self):
"""Getstate re-implementation, for pickling."""
d = {}
# capacity is inferred during the __setstate__ using nodes
d['max_depth'] = self.max_depth
d["node_count"] = self.node_count
d["nodes"] = self._get_node_ndarray()
d["values"] = self._get_value_ndarray()
if self.store_jac:
d['jac'] = self._get_jac_ndarray()
d['precond'] = self._get_precond_ndarray()
return d
def __setstate__(self, d):
"""Setstate re-implementation, for unpickling."""
self.max_depth = d['max_depth']
self.node_count = d['node_count']
if 'nodes' not in d:
raise ValueError('You have loaded Tree version which '
'cannot be imported')
node_ndarray = d['nodes']
value_ndarray = d['values']
value_shape = (node_ndarray.shape[0], self.n_outputs)
if (node_ndarray.ndim != 1 or
node_ndarray.dtype != NODE_DTYPE or
not node_ndarray.flags.c_contiguous or
value_ndarray.shape != value_shape or
not value_ndarray.flags.c_contiguous or
value_ndarray.dtype != np.float64):
raise ValueError('Did not recognise loaded array layout')
self.capacity = node_ndarray.shape[0]
if self._resize_c(self.capacity) != 0:
raise MemoryError("resizing tree to %d" % self.capacity)
nodes = memcpy(self.nodes, (<np.ndarray> node_ndarray).data,
self.capacity * sizeof(Node))
value = memcpy(self.value, (<np.ndarray> value_ndarray).data,
self.capacity * self.value_stride * sizeof(double))
if self.store_jac:
jac_ndarray = d['jac']
jac_shape = (node_ndarray.shape[0], self.n_outputs * self.n_outputs)
if (jac_ndarray.shape != jac_shape or
not jac_ndarray.flags.c_contiguous or
jac_ndarray.dtype != np.float64):
raise ValueError('Did not recognise loaded array layout')
jac = memcpy(self.jac, (<np.ndarray> jac_ndarray).data,
self.capacity * self.jac_stride * sizeof(double))
precond_ndarray = d['precond']
precond_shape = (node_ndarray.shape[0], self.n_outputs)
if (precond_ndarray.shape != precond_shape or
not precond_ndarray.flags.c_contiguous or
precond_ndarray.dtype != np.float64):
raise ValueError('Did not recognise loaded array layout')
precond = memcpy(self.precond, (<np.ndarray> precond_ndarray).data,
self.capacity * self.precond_stride * sizeof(double))
cdef int _resize(self, SIZE_t capacity) nogil except -1:
"""Resize all inner arrays to `capacity`, if `capacity` == -1, then
double the size of the inner arrays.
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
if self._resize_c(capacity) != 0:
# Acquire gil only if we need to raise
with gil:
raise MemoryError()
cdef int _resize_c(self, SIZE_t capacity=SIZE_MAX) nogil except -1:
"""Guts of _resize
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
if capacity == self.capacity and self.nodes != NULL:
return 0
if capacity == SIZE_MAX:
if self.capacity == 0:
capacity = 3 # default initial value
else:
capacity = 2 * self.capacity
safe_realloc(&self.nodes, capacity)
safe_realloc(&self.value, capacity * self.value_stride)
# value memory is initialised to 0 to enable classifier argmax
if capacity > self.capacity:
memset(<void*>(self.value + self.capacity * self.value_stride), 0,
(capacity - self.capacity) * self.value_stride *
sizeof(double))
if self.store_jac:
safe_realloc(&self.jac, capacity * self.jac_stride)
safe_realloc(&self.precond, capacity * self.precond_stride)
if capacity > self.capacity:
memset(<void*>(self.jac + self.capacity * self.jac_stride), 0,
(capacity - self.capacity) * self.jac_stride * sizeof(double))
memset(<void*>(self.precond + self.capacity * self.precond_stride), 0,
(capacity - self.capacity) * self.precond_stride * sizeof(double))
# if capacity smaller than node_count, adjust the counter
if capacity < self.node_count:
self.node_count = capacity
self.capacity = capacity
return 0
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold,
double impurity_train, SIZE_t n_node_samples_train,
double weighted_n_node_samples_train,
double impurity_val, SIZE_t n_node_samples_val,
double weighted_n_node_samples_val) nogil except -1:
"""Add a node to the tree.
The new node registers itself as the child of its parent.
Returns (size_t)(-1) on error.
"""
cdef SIZE_t node_id = self.node_count
if node_id >= self.capacity:
if self._resize_c() != 0:
return SIZE_MAX
cdef Node* node = &self.nodes[node_id]
node.impurity = impurity_val
node.n_node_samples = n_node_samples_val
node.weighted_n_node_samples = weighted_n_node_samples_val
node.impurity_train = impurity_train
node.n_node_samples_train = n_node_samples_train
node.weighted_n_node_samples_train = weighted_n_node_samples_train
if parent != _TREE_UNDEFINED:
if is_left:
self.nodes[parent].left_child = node_id
else:
self.nodes[parent].right_child = node_id
node.depth = self.nodes[parent].depth + 1
else:
node.depth = 0
if is_leaf:
node.left_child = _TREE_LEAF
node.right_child = _TREE_LEAF
node.feature = _TREE_UNDEFINED
node.threshold = _TREE_UNDEFINED
else:
# left_child and right_child will be set later
node.feature = feature
node.threshold = threshold
self.node_count += 1
return node_id
cpdef np.ndarray predict(self, object X):
"""Predict target for X."""
out = self._get_value_ndarray().take(self.apply(X), axis=0,
mode='clip')[:, :self.n_relevant_outputs, 0]
return out
cpdef np.ndarray predict_full(self, object X):
"""Predict target for X."""
out = self._get_value_ndarray().take(self.apply(X), axis=0,
mode='clip')[:, :, 0]
return out
cpdef np.ndarray predict_jac(self, object X):
"""Predict target for X."""
if not self.store_jac:
raise AttributeError("Jacobian computation was not enalbed. Set store_jac=True")
out = self._get_jac_ndarray().take(self.apply(X), axis=0,
mode='clip')
return out
cpdef np.ndarray predict_precond(self, object X):
"""Predict target for X."""
if not self.store_jac:
raise AttributeError("Preconditioned quantity computation was not enalbed. Set store_jac=True")
out = self._get_precond_ndarray().take(self.apply(X), axis=0,
mode='clip')
return out
cpdef predict_precond_and_jac(self, object X):
if not self.store_jac:
raise AttributeError("Preconditioned quantity computation was not enalbed. Set store_jac=True")
leafs = self.apply(X)
precond = self._get_precond_ndarray().take(leafs, axis=0,
mode='clip')
jac = self._get_jac_ndarray().take(leafs, axis=0,
mode='clip')
return precond, jac
cpdef np.ndarray apply(self, object X):
return self._apply(X)
cdef inline np.ndarray _apply(self, object X):
# Check input
if not isinstance(X, np.ndarray):
raise ValueError("X should be in np.ndarray format, got %s"
% type(X))
if X.dtype != DTYPE:
raise ValueError("X.dtype should be np.float64, got %s" % X.dtype)
# Extract input
cdef const DTYPE_t[:, :] X_ndarray = X
cdef SIZE_t n_samples = X.shape[0]
# Initialize output
cdef np.ndarray[SIZE_t] out = np.zeros((n_samples,), dtype=np.intp)
cdef SIZE_t* out_ptr = <SIZE_t*> out.data
# Initialize auxiliary data-structure
cdef Node* node = NULL
cdef SIZE_t i = 0
with nogil:
for i in range(n_samples):
node = self.nodes
# While node not a leaf
while node.left_child != _TREE_LEAF:
# ... and node.right_child != _TREE_LEAF:
if X_ndarray[i, node.feature] <= node.threshold:
node = &self.nodes[node.left_child]
else:
node = &self.nodes[node.right_child]
out_ptr[i] = <SIZE_t>(node - self.nodes) # node offset
return out
cpdef object decision_path(self, object X):
"""Finds the decision path (=node) for each sample in X."""
return self._decision_path(X)
cdef inline object _decision_path(self, object X):
"""Finds the decision path (=node) for each sample in X."""
# Check input
if not isinstance(X, np.ndarray):
raise ValueError("X should be in np.ndarray format, got %s"
% type(X))
if X.dtype != DTYPE:
raise ValueError("X.dtype should be np.float64, got %s" % X.dtype)
# Extract input
cdef const DTYPE_t[:, :] X_ndarray = X
cdef SIZE_t n_samples = X.shape[0]
# Initialize output
cdef np.ndarray[SIZE_t] indptr = np.zeros(n_samples + 1, dtype=np.intp)
cdef SIZE_t* indptr_ptr = <SIZE_t*> indptr.data
cdef np.ndarray[SIZE_t] indices = np.zeros(n_samples *
(1 + self.max_depth),
dtype=np.intp)
cdef SIZE_t* indices_ptr = <SIZE_t*> indices.data
# Initialize auxiliary data-structure
cdef Node* node = NULL
cdef SIZE_t i = 0
with nogil:
for i in range(n_samples):
node = self.nodes
indptr_ptr[i + 1] = indptr_ptr[i]
# Add all external nodes
while node.left_child != _TREE_LEAF:
# ... and node.right_child != _TREE_LEAF:
indices_ptr[indptr_ptr[i + 1]] = <SIZE_t>(node - self.nodes)
indptr_ptr[i + 1] += 1
if X_ndarray[i, node.feature] <= node.threshold:
node = &self.nodes[node.left_child]
else:
node = &self.nodes[node.right_child]
# Add the leave node
indices_ptr[indptr_ptr[i + 1]] = <SIZE_t>(node - self.nodes)
indptr_ptr[i + 1] += 1
indices = indices[:indptr[n_samples]]
cdef np.ndarray[SIZE_t] data = np.ones(shape=len(indices),
dtype=np.intp)
out = csr_matrix((data, indices, indptr),
shape=(n_samples, self.node_count))
return out
cpdef compute_feature_importances(self, normalize=True, max_depth=None, depth_decay=.0):
"""Computes the importance of each feature (aka variable) based on impurity decrease.
Parameters
----------
normalize : bool, optional (default=True)
Whether to normalize importances to sum to 1
max_depth : int or None, optional (default=None)
The max depth of a split to consider when calculating importance
depth_decay : float, optional (default=.0)
The decay of the importance of a split as a function of depth. The split importance is
re-weighted by 1 / (1 + depth)**depth_decay.
"""
cdef Node* left
cdef Node* right
cdef Node* nodes = self.nodes
cdef Node* node = nodes
cdef Node* end_node = node + self.node_count
cdef double c_depth_decay = depth_decay
cdef SIZE_t c_max_depth
cdef double normalizer = 0.
cdef np.ndarray[np.float64_t, ndim=1] importances
importances = np.zeros((self.n_features,))
cdef DOUBLE_t* importance_data = <DOUBLE_t*>importances.data
if max_depth is None:
c_max_depth = self.max_depth
else:
c_max_depth = max_depth
with nogil:
while node != end_node:
if node.left_child != _TREE_LEAF:
# ... and node.right_child != _TREE_LEAF:
if (max_depth is None) or node.depth <= c_max_depth:
left = &nodes[node.left_child]
right = &nodes[node.right_child]
importance_data[node.feature] += pow(1 + node.depth, -c_depth_decay) * (
node.weighted_n_node_samples * node.impurity -
left.weighted_n_node_samples * left.impurity -
right.weighted_n_node_samples * right.impurity)
node += 1
importances /= nodes[0].weighted_n_node_samples
if normalize:
normalizer = np.sum(importances)
if normalizer > 0.0:
# Avoid dividing by zero (e.g., when root is pure)
importances /= normalizer
return importances
cpdef compute_feature_heterogeneity_importances(self, normalize=True, max_depth=None, depth_decay=.0):
"""Computes the importance of each feature (aka variable) based on amount of
parameter heterogeneity it creates. Each split adds:
parent_weight * (left_weight * right_weight) * mean((value_left[k] - value_right[k])**2) / parent_weight**2
Parameters
----------
normalize : bool, optional (default=True)
Whether to normalize importances to sum to 1
max_depth : int or None, optional (default=None)
The max depth of a split to consider when calculating importance
depth_decay : float, optional (default=.0)
The decay of the importance of a split as a function of depth. The split importance is
re-weighted by 1 / (1 + depth)**depth_decay.
"""
cdef Node* left
cdef Node* right
cdef Node* nodes = self.nodes
cdef Node* node = nodes
cdef Node* end_node = node + self.node_count
cdef double c_depth_decay = depth_decay
cdef SIZE_t c_max_depth
cdef SIZE_t i
cdef double normalizer = 0.
cdef np.ndarray[np.float64_t, ndim=1] importances
importances = np.zeros((self.n_features,))
cdef DOUBLE_t* importance_data = <DOUBLE_t*>importances.data
if max_depth is None:
c_max_depth = self.max_depth
else:
c_max_depth = max_depth
with nogil:
while node != end_node:
if node.left_child != _TREE_LEAF:
if (max_depth is None) or node.depth <= c_max_depth:
# ... and node.right_child != _TREE_LEAF:
left = &nodes[node.left_child]
right = &nodes[node.right_child]
# node_value = &self.value[(node - nodes) * self.value_stride]
left_value = &self.value[(left - nodes) * self.value_stride]
right_value = &self.value[(right - nodes) * self.value_stride]
for i in range(self.n_relevant_outputs):
importance_data[node.feature] += pow(1 + node.depth, -c_depth_decay) * (
left.weighted_n_node_samples * right.weighted_n_node_samples *
(left_value[i] - right_value[i])**2 / node.weighted_n_node_samples)
node += 1
importances /= (nodes[0].weighted_n_node_samples * self.n_relevant_outputs)
if normalize:
normalizer = np.sum(importances)
if normalizer > 0.0:
# Avoid dividing by zero (e.g., when root is pure)
importances /= normalizer
return importances
cdef np.ndarray _get_value_ndarray(self):
"""Wraps value as a 3-d NumPy array.
The array keeps a reference to this Tree, which manages the underlying
memory.
"""
# we make it a 3d array even though we only need 2d, for compatibility with sklearn
# plotting of trees.
cdef np.npy_intp shape[3]
shape[0] = <np.npy_intp> self.node_count
shape[1] = <np.npy_intp> self.n_outputs
shape[2] = 1
cdef np.ndarray arr
arr = np.PyArray_SimpleNewFromData(3, shape, np.NPY_DOUBLE, self.value)
Py_INCREF(self)
arr.base = <PyObject*> self
return arr
cdef np.ndarray _get_jac_ndarray(self):
"""Wraps jacobian as a 2-d NumPy array.
The array keeps a reference to this Tree, which manages the underlying
memory.
"""
cdef np.npy_intp shape[2]
shape[0] = <np.npy_intp> self.node_count
shape[1] = <np.npy_intp> (self.n_outputs * self.n_outputs)
cdef np.ndarray arr
arr = np.PyArray_SimpleNewFromData(2, shape, np.NPY_DOUBLE, self.jac)
Py_INCREF(self)
arr.base = <PyObject*> self
return arr
cdef np.ndarray _get_precond_ndarray(self):
"""Wraps precond as a 2-d NumPy array.
The array keeps a reference to this Tree, which manages the underlying
memory.
"""
cdef np.npy_intp shape[2]
shape[0] = <np.npy_intp> self.node_count
shape[1] = <np.npy_intp> self.n_outputs
cdef np.ndarray arr
arr = np.PyArray_SimpleNewFromData(2, shape, np.NPY_DOUBLE, self.precond)
Py_INCREF(self)
arr.base = <PyObject*> self
return arr
cdef np.ndarray _get_node_ndarray(self):
"""Wraps nodes as a NumPy struct array.
The array keeps a reference to this Tree, which manages the underlying
memory. Individual fields are publicly accessible as properties of the
Tree.
"""
cdef np.npy_intp shape[1]
shape[0] = <np.npy_intp> self.node_count
cdef np.npy_intp strides[1]
strides[0] = sizeof(Node)
cdef np.ndarray arr
Py_INCREF(NODE_DTYPE)
arr = PyArray_NewFromDescr(<PyTypeObject *> np.ndarray,
<np.dtype> NODE_DTYPE, 1, shape,
strides, <void*> self.nodes,
np.NPY_DEFAULT, None)
Py_INCREF(self)
arr.base = <PyObject*> self
return arr

85
econml/tree/_utils.pxd Normal file
Просмотреть файл

@ -0,0 +1,85 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_utils.pxd
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
import numpy as np
cimport numpy as np
from ._tree cimport Node
ctypedef np.npy_float64 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
# particularly tiny on Windows/MSVC.
RAND_R_MAX = 0x7FFFFFFF
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
# cdef DTYPE_t *p = NULL
# safe_realloc(&p, n)
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
ctypedef fused realloc_ptr:
# Add pointer types here as needed.
# (DTYPE_t*)
(SIZE_t*)
(unsigned char*)
(DOUBLE_t*)
(DOUBLE_t**)
(Node*)
(StackRecord*)
cdef realloc_ptr safe_realloc(realloc_ptr* p, SIZE_t nelems) nogil except *
cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size)
cdef SIZE_t rand_int(SIZE_t low, SIZE_t high,
UINT32_t* random_state) nogil
cdef double rand_uniform(double low, double high,
UINT32_t* random_state) nogil
cdef double log(double x) nogil
# =============================================================================
# Stack data structure
# =============================================================================
# A record on the stack for depth-first tree growing
cdef struct StackRecord:
SIZE_t start
SIZE_t end
SIZE_t start_val
SIZE_t end_val
SIZE_t depth
SIZE_t parent
bint is_left
double impurity
double impurity_val
SIZE_t n_constant_features
cdef class Stack:
cdef SIZE_t capacity
cdef SIZE_t top
cdef StackRecord* stack_
cdef bint is_empty(self) nogil
cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t start_val, SIZE_t end_val,
SIZE_t depth, SIZE_t parent,
bint is_left, double impurity, double impurity_val,
SIZE_t n_constant_features) nogil except -1
cdef int pop(self, StackRecord* res) nogil

173
econml/tree/_utils.pyx Normal file
Просмотреть файл

@ -0,0 +1,173 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# This code is a fork from: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_utils.pyx
# published under the following license and copyright:
# BSD 3-Clause License
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.
from libc.stdlib cimport free
from libc.stdlib cimport malloc
from libc.stdlib cimport realloc
from libc.math cimport log as ln
import numpy as np
cimport numpy as np
np.import_array()
cdef inline UINT32_t DEFAULT_SEED = 1
cdef inline double LN_TWO = ln(2.0)
# =============================================================================
# Helper functions
# =============================================================================
# rand_r replacement using a 32bit XorShift generator
# See http://www.jstatsoft.org/v08/i14/paper for details
cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil:
"""Generate a pseudo-random np.uint32 from a np.uint32 seed"""
# seed shouldn't ever be 0.
if (seed[0] == 0): seed[0] = DEFAULT_SEED
seed[0] ^= <UINT32_t>(seed[0] << 13)
seed[0] ^= <UINT32_t>(seed[0] >> 17)
seed[0] ^= <UINT32_t>(seed[0] << 5)
# Note: we must be careful with the final line cast to np.uint32 so that
# the function behaves consistently across platforms.
#
# The following cast might yield different results on different platforms:
# wrong_cast = <UINT32_t> RAND_R_MAX + 1
#
# We can use:
# good_cast = <UINT32_t>(RAND_R_MAX + 1)
# or:
# cdef np.uint32_t another_good_cast = <UINT32_t>RAND_R_MAX + 1
return seed[0] % <UINT32_t>(RAND_R_MAX + 1)
cdef realloc_ptr safe_realloc(realloc_ptr* p, SIZE_t nelems) nogil except *:
# sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython
# 0.20.1 to crash.
cdef SIZE_t nbytes = nelems * sizeof(p[0][0])
if nbytes / sizeof(p[0][0]) != nelems:
# Overflow in the multiplication
with gil:
raise MemoryError("could not allocate (%d * %d) bytes"
% (nelems, sizeof(p[0][0])))
cdef realloc_ptr tmp = <realloc_ptr>realloc(p[0], nbytes)
if tmp == NULL:
with gil:
raise MemoryError("could not allocate %d bytes" % nbytes)
p[0] = tmp
return tmp # for convenience
cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size):
"""Return copied data as 1D numpy array of intp's."""
cdef np.npy_intp shape[1]
shape[0] = <np.npy_intp> size
return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy()
cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high,
UINT32_t* random_state) nogil:
"""Generate a random integer in [low; end)."""
return low + our_rand_r(random_state) % (high - low)
cdef inline double rand_uniform(double low, double high,
UINT32_t* random_state) nogil:
"""Generate a random double in [low; high)."""
return ((high - low) * <double> our_rand_r(random_state) /
<double> RAND_R_MAX) + low
cdef inline double log(double x) nogil:
return ln(x) / LN_TWO
# =============================================================================
# Stack data structure
# =============================================================================
cdef class Stack:
"""A LIFO data structure.
Attributes
----------
capacity : SIZE_t
The elements the stack can hold; if more added then ``self.stack_``
needs to be resized.
top : SIZE_t
The number of elements currently on the stack.
stack : StackRecord pointer
The stack of records (upward in the stack corresponds to the right).
"""
def __cinit__(self, SIZE_t capacity):
self.capacity = capacity
self.top = 0
self.stack_ = <StackRecord*> malloc(capacity * sizeof(StackRecord))
def __dealloc__(self):
free(self.stack_)
cdef bint is_empty(self) nogil:
return self.top <= 0
cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t start_val, SIZE_t end_val,
SIZE_t depth, SIZE_t parent,
bint is_left, double impurity, double impurity_val,
SIZE_t n_constant_features) nogil except -1:
"""Push a new element onto the stack.
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef SIZE_t top = self.top
cdef StackRecord* stack = NULL
# Resize if capacity not sufficient
if top >= self.capacity:
self.capacity *= 2
# Since safe_realloc can raise MemoryError, use `except -1`
safe_realloc(&self.stack_, self.capacity)
stack = self.stack_
stack[top].start = start
stack[top].end = end
stack[top].start_val = start_val
stack[top].end_val = end_val
stack[top].depth = depth
stack[top].parent = parent
stack[top].is_left = is_left
stack[top].impurity = impurity
stack[top].impurity_val = impurity_val
stack[top].n_constant_features = n_constant_features
# Increment stack pointer
self.top = top + 1
return 0
cdef int pop(self, StackRecord* res) nogil:
"""Remove the top element from the stack and copy to ``res``.
Returns 0 if pop was successful (and ``res`` is set); -1
otherwise.
"""
cdef SIZE_t top = self.top
cdef StackRecord* stack = self.stack_
if top <= 0:
return -1
res[0] = stack[top - 1]
self.top = top - 1
return 0

Просмотреть файл

@ -9,7 +9,7 @@ from sklearn import clone
from sklearn.linear_model import LinearRegression
from .utilities import shape, transpose, reshape, cross_product, ndim, size,\
_deprecate_positional, check_input_arrays
from .cate_estimator import BaseCateEstimator, LinearCateEstimator
from ._cate_estimator import BaseCateEstimator, LinearCateEstimator
from numpy.polynomial.hermite_e import hermeval
from sklearn.base import TransformerMixin
from sklearn.preprocessing import PolynomialFeatures

Просмотреть файл

@ -1,85 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from itertools import product
from sklearn.linear_model import Lasso, LassoCV, LinearRegression
import econml.dml
import econml.dgp
import unittest
########################################
# Core DML Tests
########################################
class TestDMLMethods(unittest.TestCase):
def test_dml_effect(self):
"""Testing dml.effect."""
np.random.seed(123)
# How many samples
n_samples = 200
# How many control features
n_cov = 5
# How many treatment variables
n_treatments = 10
for exp in range(100):
# Coefficients of how controls affect treatments
Alpha = 20 * np.random.rand(n_cov, n_treatments) - 10
# Coefficients of how controls affect outcome
beta = 20 * np.random.rand(n_cov) - 10
# Treatment effects that we want to estimate
effect = 20 * np.random.rand(n_treatments) - 10
y, T, X, epsilon = dgp.dgp_perfect_data_multiple_treatments(
n_samples, n_cov, n_treatments, Alpha, beta, effect)
# Run dml estimation
reg = dml.LinearDML(np.arange(X.shape[1]), [], np.arange(X.shape[1], X.shape[1] + T.shape[1]))
reg.fit(np.concatenate((X, T), axis=1), y)
T0 = np.zeros((1, T.shape[1]))
T1 = np.zeros((1, T.shape[1]))
dml_coef = np.zeros(T.shape[1])
for t in range(T.shape[1]):
T1[:, t] = 1
dml_coef[t] = reg.effect([], T0, T1)
T1[:, t] = 0
self.assertTrue(np.max(np.abs(dml_coef - effect)) < 0.0000000001, "core.double_ml() wrong")
def test_dml_predict(self):
"""Testing dml.predict."""
np.random.seed(123)
# How many samples
n_samples = 200
# How many control features
n_cov = 5
# How many treatment variables
n_treatments = 10
for exp in range(100):
# Coefficients of how controls affect treatments
Alpha = 20 * np.random.rand(n_cov, n_treatments) - 10
# Coefficients of how controls affect outcome
beta = 20 * np.random.rand(n_cov) - 10
# Treatment effects that we want to estimate
effect = 20 * np.random.rand(n_treatments) - 10
y, T, X, epsilon = dgp.dgp_perfect_data_multiple_treatments(
n_samples, n_cov, n_treatments, Alpha, beta, effect)
# Run dml estimation
reg = dml.LinearDML(np.arange(X.shape[1]), [], np.arange(X.shape[1], X.shape[1] + T.shape[1]))
reg.fit(np.concatenate((X, T), axis=1), y)
y, T, X = dgp.dgp_perfect_counterfactual_data_multiple_treatments(
n_samples, n_cov, beta, effect, np.ones(n_treatments))
r2score = reg.score(np.concatenate((X, T), axis=1), y)
self.assertTrue(r2score > 0.99, "core.double_ml() wrong")
if __name__ == '__main__':
unittest.main(verbosity=2)

Просмотреть файл

@ -508,7 +508,7 @@ def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True):
X, T = check_X_y(X, T, multi_output=multi_output_T, y_numeric=True)
_, Y = check_X_y(X, Y, multi_output=multi_output_Y, y_numeric=True)
if W is not None:
W, _ = check_X_y(W, Y)
W, _ = check_X_y(W, Y, multi_output=multi_output_Y, y_numeric=True)
return Y, T, X, W

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -948,7 +948,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.1"
}
},
"nbformat": 4,

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -768,7 +768,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.1"
}
},
"nbformat": 4,

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

17
pyproject.toml Normal file
Просмотреть файл

@ -0,0 +1,17 @@
[build-system]
requires = [
"setuptools",
"wheel",
"Cython",
"numpy == 1.19.3",
"scipy"
]
build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
addopts = "--junitxml=junit/test-results.xml -n auto --strict-markers --cov-config=setup.cfg --cov=econml --cov-report=xml"
markers = [
"slow",
"notebook",
"automl"
]

Просмотреть файл

@ -36,10 +36,11 @@ setup_requires =
pytest-runner
sphinx < 3.2
sphinx_rtd_theme
Cython
install_requires =
numpy
scipy != 1.4.0
scikit-learn > 0.21.0, <0.24.0
scikit-learn > 0.21.0
keras < 2.4
sparse
tensorflow > 1.10, < 2.3
@ -57,6 +58,7 @@ tests_require =
pytest-cov
jupyter
nbconvert < 6
nbformat
seaborn
lightgbm
dowhy
@ -78,13 +80,6 @@ include =
; include all CSV files as data
* = *.csv
[tool:pytest]
addopts = --junitxml=junit/test-results.xml -n auto --strict --cov-config=setup.cfg --cov=econml --cov-report=xml
markers =
slow
notebook
automl
; coverage configuration
[coverage:run]
omit = econml/tests/*

Просмотреть файл

@ -1,4 +1,11 @@
from setuptools import setup, find_packages
from setuptools import setup
from setuptools.extension import Extension
from Cython.Build import cythonize
import numpy as np
# configuration is all pulled from setup.cfg
setup()
setup(ext_modules=cythonize([Extension("*", ["**/*.pyx"],
include_dirs=[np.get_include()])],
language_level="3"),
zip_safe=False)