зеркало из https://github.com/py-why/EconML.git
1754 строки
313 KiB
Plaintext
1754 строки
313 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<table border=\"0\">\n",
|
||
" <tr>\n",
|
||
" <td>\n",
|
||
" <img src=\"https://ictd2016.files.wordpress.com/2016/04/microsoft-research-logo-copy.jpg\" style=\"width 30px;\" />\n",
|
||
" </td>\n",
|
||
" <td>\n",
|
||
" <img src=\"https://www.microsoft.com/en-us/research/wp-content/uploads/2016/12/MSR-ALICE-HeaderGraphic-1920x720_1-800x550.jpg\" style=\"width 100px;\"/></td>\n",
|
||
" </tr>\n",
|
||
"</table>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Orthogonal Random Forest and Causal Forest: Use Cases and Examples\n",
|
||
"\n",
|
||
"Causal Forests and Generalized Random Forests are a flexible method for estimating treatment effect heterogeneity with Random Forests. Orthogonal Random Forest (ORF) combines orthogonalization, a technique that effectively removes the confounding effect in two-stage estimation, with generalized random forests. Due to the orthogonalization aspect of this method, the ORF performs especially well in the presence of high-dimensional confounders. For more details, see [this paper](https://arxiv.org/abs/1806.03467) or the [EconML docummentation](https://econml.azurewebsites.net/).\n",
|
||
"\n",
|
||
"The EconML SDK implements the following OrthoForest variants:\n",
|
||
"\n",
|
||
"* DMLOrthoForest: suitable for continuous or discrete treatments\n",
|
||
"\n",
|
||
"* DROrthoForest: suitable for discrete treatments\n",
|
||
"\n",
|
||
"* CausalForest: suitable for both discrete and continuous treatments\n",
|
||
"\n",
|
||
"In this notebook, we show the performance of the ORF on synthetic and observational data. \n",
|
||
"\n",
|
||
"## Notebook Contents\n",
|
||
"\n",
|
||
"1. [Example Usage with Continuous Treatment Synthetic Data](#1.-Example-Usage-with-Continuous-Treatment-Synthetic-Data)\n",
|
||
"2. [Example Usage with Binary Treatment Synthetic Data](#2.-Example-Usage-with-Binary-Treatment-Synthetic-Data)\n",
|
||
"3. [Example Usage with Multiple Treatment Synthetic Data](#3.-Example-Usage-with-Multiple-Treatment-Synthetic-Data)\n",
|
||
"4. [Example Usage with Real Continuous Treatment Observational Data](#4.-Example-Usage-with-Real-Continuous-Treatment-Observational-Data)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import econml"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Main imports\n",
|
||
"from econml.orf import DMLOrthoForest, DROrthoForest\n",
|
||
"from econml.dml import CausalForestDML\n",
|
||
"from econml.sklearn_extensions.linear_model import WeightedLassoCVWrapper, WeightedLasso, WeightedLassoCV\n",
|
||
"\n",
|
||
"# Helper imports\n",
|
||
"import numpy as np\n",
|
||
"from itertools import product\n",
|
||
"from sklearn.linear_model import Lasso, LassoCV, LogisticRegression, LogisticRegressionCV\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"%matplotlib inline"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 1. Example Usage with Continuous Treatment Synthetic Data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1.1 DGP \n",
|
||
"We use the data generating process (DGP) from [here](https://arxiv.org/abs/1806.03467). The DGP is described by the following equations:\n",
|
||
"\n",
|
||
"\\begin{align}\n",
|
||
"T =& \\langle W, \\beta\\rangle + \\eta, & \\;\\eta \\sim \\text{Uniform}(-1, 1)\\\\\n",
|
||
"Y =& T\\cdot \\theta(X) + \\langle W, \\gamma\\rangle + \\epsilon, &\\; \\epsilon \\sim \\text{Uniform}(-1, 1)\\\\\n",
|
||
"W \\sim& \\text{Normal}(0,\\, I_{n_w})\\\\\n",
|
||
"X \\sim& \\text{Uniform}(0,1)^{n_x}\n",
|
||
"\\end{align}\n",
|
||
"\n",
|
||
"where $W$ is a matrix of high-dimensional confounders and $\\beta, \\gamma$ have high sparsity.\n",
|
||
"\n",
|
||
"For this DGP, \n",
|
||
"\\begin{align}\n",
|
||
"\\theta(x) = \\exp(2\\cdot x_1).\n",
|
||
"\\end{align}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Treatment effect function\n",
|
||
"def exp_te(x):\n",
|
||
" return np.exp(2*x[0])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# DGP constants\n",
|
||
"np.random.seed(123)\n",
|
||
"n = 1000\n",
|
||
"n_w = 30\n",
|
||
"support_size = 5\n",
|
||
"n_x = 1\n",
|
||
"# Outcome support\n",
|
||
"support_Y = np.random.choice(range(n_w), size=support_size, replace=False)\n",
|
||
"coefs_Y = np.random.uniform(0, 1, size=support_size)\n",
|
||
"epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n)\n",
|
||
"# Treatment support \n",
|
||
"support_T = support_Y\n",
|
||
"coefs_T = np.random.uniform(0, 1, size=support_size)\n",
|
||
"eta_sample = lambda n: np.random.uniform(-1, 1, size=n) \n",
|
||
"\n",
|
||
"# Generate controls, covariates, treatments and outcomes\n",
|
||
"W = np.random.normal(0, 1, size=(n, n_w))\n",
|
||
"X = np.random.uniform(0, 1, size=(n, n_x))\n",
|
||
"# Heterogeneous treatment effects\n",
|
||
"TE = np.array([exp_te(x_i) for x_i in X])\n",
|
||
"T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)\n",
|
||
"Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)\n",
|
||
"\n",
|
||
"# ORF parameters and test data\n",
|
||
"subsample_ratio = 0.3\n",
|
||
"lambda_reg = np.sqrt(np.log(n_w) / (10 * subsample_ratio * n))\n",
|
||
"X_test = np.array(list(product(np.arange(0, 1, 0.01), repeat=n_x)))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1.2. Train Estimator\n",
|
||
"\n",
|
||
"**Note:** The models in the final stage of the estimation (``model_T_final``, ``model_Y_final``) need to support sample weighting. \n",
|
||
"\n",
|
||
"If the models of choice do not support sample weights (e.g. ``sklearn.linear_model.LassoCV``), the ``econml`` packages provides a convenient wrapper for these models ``WeightedModelWrapper`` in order to allow sample weights."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est = DMLOrthoForest(\n",
|
||
" n_trees=1000, min_leaf_size=5,\n",
|
||
" max_depth=50, subsample_ratio=subsample_ratio,\n",
|
||
" model_T=Lasso(alpha=lambda_reg),\n",
|
||
" model_Y=Lasso(alpha=lambda_reg),\n",
|
||
" model_T_final=WeightedLasso(alpha=lambda_reg),\n",
|
||
" model_Y_final=WeightedLasso(alpha=lambda_reg),\n",
|
||
" global_residualization=False,\n",
|
||
" random_state=123)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To use the built-in confidence intervals constructed via Bootstrap of Little Bags, we can specify `inference=\"blb\"` at `fit` time or leave the default `inference='auto'` which will automatically use the Bootstrap of Little Bags."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 21.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 176 tasks | elapsed: 22.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 816 tasks | elapsed: 25.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 1000 out of 1000 | elapsed: 26.5s finished\n",
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 0.0s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 368 tasks | elapsed: 1.7s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 984 tasks | elapsed: 4.7s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 1000 out of 1000 | elapsed: 4.7s finished\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<econml.orf._ortho_forest.DMLOrthoForest at 0x1a7d2b58a58>"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"est.fit(Y, T, X=X, W=W, inference=\"blb\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 21.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 23.9s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate treatment effects\n",
|
||
"treatment_effects = est.effect(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 18 tasks | elapsed: 3.3s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 7.6s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate default (95%) confidence intervals for the test data\n",
|
||
"te_lower, te_upper = est.effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 3.4s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 7.4s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"res = est.effect_inference(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>point_estimate</th>\n",
|
||
" <th>stderr</th>\n",
|
||
" <th>zstat</th>\n",
|
||
" <th>pvalue</th>\n",
|
||
" <th>ci_lower</th>\n",
|
||
" <th>ci_upper</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>X</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>1.161</td>\n",
|
||
" <td>0.183</td>\n",
|
||
" <td>6.339</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.802</td>\n",
|
||
" <td>1.520</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>1.171</td>\n",
|
||
" <td>0.177</td>\n",
|
||
" <td>6.628</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.825</td>\n",
|
||
" <td>1.518</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>1.182</td>\n",
|
||
" <td>0.171</td>\n",
|
||
" <td>6.925</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.847</td>\n",
|
||
" <td>1.516</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>1.192</td>\n",
|
||
" <td>0.165</td>\n",
|
||
" <td>7.228</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.869</td>\n",
|
||
" <td>1.515</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>1.202</td>\n",
|
||
" <td>0.160</td>\n",
|
||
" <td>7.533</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.890</td>\n",
|
||
" <td>1.515</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" point_estimate stderr zstat pvalue ci_lower ci_upper\n",
|
||
"X \n",
|
||
"0 1.161 0.183 6.339 0.0 0.802 1.520\n",
|
||
"1 1.171 0.177 6.628 0.0 0.825 1.518\n",
|
||
"2 1.182 0.171 6.925 0.0 0.847 1.516\n",
|
||
"3 1.192 0.165 7.228 0.0 0.869 1.515\n",
|
||
"4 1.202 0.160 7.533 0.0 0.890 1.515"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"res.summary_frame().head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Uncertainty of Mean Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>mean_point</th> <th>stderr_mean</th> <th>zstat</th> <th>pvalue</th> <th>ci_mean_lower</th> <th>ci_mean_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>3.179</td> <td>0.287</td> <td>11.06</td> <td>0.0</td> <td>2.616</td> <td>3.742</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Distribution of Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>std_point</th> <th>pct_point_lower</th> <th>pct_point_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>1.715</td> <td>1.187</td> <td>6.276</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Total Variance of Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>stderr_point</th> <th>ci_point_lower</th> <th>ci_point_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>1.739</td> <td>1.079</td> <td>6.525</td> \n",
|
||
"</tr>\n",
|
||
"</table><br/><br/>Note: The stderr_mean is a conservative upper bound."
|
||
],
|
||
"text/plain": [
|
||
"<econml.inference._inference.PopulationSummaryResults at 0x1a7b5af2f98>"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"res.population_summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Similarly we can estimate effects and get confidence intervals and inference results using a `CausalForest`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est2 = CausalForestDML(model_t=Lasso(alpha=lambda_reg),\n",
|
||
" model_y=Lasso(alpha=lambda_reg),\n",
|
||
" n_estimators=4000, min_samples_leaf=5,\n",
|
||
" max_depth=50,\n",
|
||
" verbose=0, random_state=123)\n",
|
||
"est2.tune(Y, T, X=X, W=W)\n",
|
||
"est2.fit(Y, T, X=X, W=W)\n",
|
||
"treatment_effects2 = est2.effect(X_test)\n",
|
||
"te_lower2, te_upper2 = est2.effect_interval(X_test, alpha=0.01)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1.3. Performance Visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1080x360 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.figure(figsize=(15, 5))\n",
|
||
"plt.subplot(1, 2, 1)\n",
|
||
"plt.title(\"ContinuousOrthoForest\")\n",
|
||
"plt.plot(X_test, treatment_effects, label='ORF estimate')\n",
|
||
"expected_te = np.array([exp_te(x_i) for x_i in X_test])\n",
|
||
"plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')\n",
|
||
"plt.fill_between(X_test[:, 0], te_lower, te_upper, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.subplot(1, 2, 2)\n",
|
||
"plt.title(\"CausalForest\")\n",
|
||
"plt.plot(X_test, treatment_effects2, label='ORF estimate')\n",
|
||
"expected_te = np.array([exp_te(x_i) for x_i in X_test])\n",
|
||
"plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')\n",
|
||
"plt.fill_between(X_test[:, 0], te_lower2, te_upper2, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 2. Example Usage with Binary Treatment Synthetic Data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.1. DGP \n",
|
||
"We use the following DGP:\n",
|
||
"\n",
|
||
"\\begin{align}\n",
|
||
"T \\sim & \\text{Bernoulli}\\left(f(W)\\right), &\\; f(W)=\\sigma(\\langle W, \\beta\\rangle + \\eta), \\;\\eta \\sim \\text{Uniform}(-1, 1)\\\\\n",
|
||
"Y = & T\\cdot \\theta(X) + \\langle W, \\gamma\\rangle + \\epsilon, & \\; \\epsilon \\sim \\text{Uniform}(-1, 1)\\\\\n",
|
||
"W \\sim & \\text{Normal}(0,\\, I_{n_w}) & \\\\\n",
|
||
"X \\sim & \\text{Uniform}(0,\\, 1)^{n_x}\n",
|
||
"\\end{align}\n",
|
||
"\n",
|
||
"where $W$ is a matrix of high-dimensional confounders, $\\beta, \\gamma$ have high sparsity and $\\sigma$ is the sigmoid function.\n",
|
||
"\n",
|
||
"For this DGP, \n",
|
||
"\\begin{align}\n",
|
||
"\\theta(x) = \\exp( 2\\cdot x_1 ).\n",
|
||
"\\end{align}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# DGP constants\n",
|
||
"np.random.seed(1234)\n",
|
||
"n = 1000\n",
|
||
"n_w = 30\n",
|
||
"support_size = 5\n",
|
||
"n_x = 1\n",
|
||
"# Outcome support\n",
|
||
"support_Y = np.random.choice(range(n_w), size=support_size, replace=False)\n",
|
||
"coefs_Y = np.random.uniform(0, 1, size=support_size)\n",
|
||
"epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n)\n",
|
||
"# Treatment support\n",
|
||
"support_T = support_Y\n",
|
||
"coefs_T = np.random.uniform(0, 1, size=support_size)\n",
|
||
"eta_sample = lambda n: np.random.uniform(-1, 1, size=n) \n",
|
||
"\n",
|
||
"# Generate controls, covariates, treatments and outcomes\n",
|
||
"W = np.random.normal(0, 1, size=(n, n_w))\n",
|
||
"X = np.random.uniform(0, 1, size=(n, n_x))\n",
|
||
"# Heterogeneous treatment effects\n",
|
||
"TE = np.array([exp_te(x_i) for x_i in X])\n",
|
||
"# Define treatment\n",
|
||
"log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)\n",
|
||
"T_sigmoid = 1/(1 + np.exp(-log_odds))\n",
|
||
"T = np.array([np.random.binomial(1, p) for p in T_sigmoid])\n",
|
||
"# Define the outcome\n",
|
||
"Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)\n",
|
||
"\n",
|
||
"# ORF parameters and test data\n",
|
||
"subsample_ratio = 0.4\n",
|
||
"X_test = np.array(list(product(np.arange(0, 1, 0.01), repeat=n_x)))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.2. Train Estimator "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est = DROrthoForest(\n",
|
||
" n_trees=200, min_leaf_size=10,\n",
|
||
" max_depth=30, subsample_ratio=subsample_ratio,\n",
|
||
" propensity_model = LogisticRegression(C=1/(X.shape[0]*lambda_reg), penalty='l1', solver='saga'),\n",
|
||
" model_Y = Lasso(alpha=lambda_reg),\n",
|
||
" propensity_model_final=LogisticRegression(C=1/(X.shape[0]*lambda_reg), penalty='l1', solver='saga'), \n",
|
||
" model_Y_final=WeightedLasso(alpha=lambda_reg)\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 26.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 176 tasks | elapsed: 27.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 200 out of 200 | elapsed: 27.8s finished\n",
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 0.2s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 185 out of 200 | elapsed: 1.0s remaining: 0.0s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 200 out of 200 | elapsed: 1.0s finished\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<econml.orf._ortho_forest.DROrthoForest at 0x1a7b974ee48>"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"est.fit(Y, T, X=X, W=W)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 37.4s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 41.0s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate treatment effects for the default treatment points T0=0 and T1=1\n",
|
||
"treatment_effects = est.effect(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 1.8s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 3.5s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate default (95%) confidence intervals for the default treatment points T0=0 and T1=1\n",
|
||
"te_lower, te_upper = est.effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est2 = CausalForestDML(model_y=Lasso(alpha=lambda_reg),\n",
|
||
" model_t=LogisticRegression(C=1/(X.shape[0]*lambda_reg)),\n",
|
||
" n_estimators=200, min_samples_leaf=5,\n",
|
||
" max_depth=50, max_samples=subsample_ratio/2,\n",
|
||
" discrete_treatment=True,\n",
|
||
" random_state=123)\n",
|
||
"est2.fit(Y, T, X=X, W=W, cache_values=True)\n",
|
||
"treatment_effects2 = est2.effect(X_test)\n",
|
||
"te_lower2, te_upper2 = est2.effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Population summary of CATE predictions on Training Data\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Uncertainty of Mean Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>mean_point</th> <th>stderr_mean</th> <th>zstat</th> <th>pvalue</th> <th>ci_mean_lower</th> <th>ci_mean_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>3.088</td> <td>0.157</td> <td>19.677</td> <td>0.0</td> <td>2.78</td> <td>3.396</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Distribution of Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>std_point</th> <th>pct_point_lower</th> <th>pct_point_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>1.757</td> <td>0.846</td> <td>6.962</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Total Variance of Point Estimate</caption>\n",
|
||
"<tr>\n",
|
||
" <th>stderr_point</th> <th>ci_point_lower</th> <th>ci_point_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <td>1.764</td> <td>0.774</td> <td>6.951</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Doubly Robust ATE on Training Data Results</caption>\n",
|
||
"<tr>\n",
|
||
" <td></td> <th>point_estimate</th> <th>stderr</th> <th>zstat</th> <th>pvalue</th> <th>ci_lower</th> <th>ci_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <th>ATE</th> <td>3.158</td> <td>0.082</td> <td>38.551</td> <td>0.0</td> <td>2.997</td> <td>3.318</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Doubly Robust ATT(T=0) on Training Data Results</caption>\n",
|
||
"<tr>\n",
|
||
" <td></td> <th>point_estimate</th> <th>stderr</th> <th>zstat</th> <th>pvalue</th> <th>ci_lower</th> <th>ci_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <th>ATT</th> <td>3.1</td> <td>0.096</td> <td>32.322</td> <td>0.0</td> <td>2.912</td> <td>3.288</td> \n",
|
||
"</tr>\n",
|
||
"</table>\n",
|
||
"<table class=\"simpletable\">\n",
|
||
"<caption>Doubly Robust ATT(T=1) on Training Data Results</caption>\n",
|
||
"<tr>\n",
|
||
" <td></td> <th>point_estimate</th> <th>stderr</th> <th>zstat</th> <th>pvalue</th> <th>ci_lower</th> <th>ci_upper</th>\n",
|
||
"</tr>\n",
|
||
"<tr>\n",
|
||
" <th>ATT</th> <td>3.218</td> <td>0.134</td> <td>23.965</td> <td>0.0</td> <td>2.955</td> <td>3.481</td> \n",
|
||
"</tr>\n",
|
||
"</table><br/><br/>Note: The stderr_mean is a conservative upper bound."
|
||
],
|
||
"text/plain": [
|
||
"<class 'econml.utilities.Summary'>\n",
|
||
"\"\"\"\n",
|
||
" Uncertainty of Mean Point Estimate \n",
|
||
"================================================================\n",
|
||
"mean_point stderr_mean zstat pvalue ci_mean_lower ci_mean_upper\n",
|
||
"----------------------------------------------------------------\n",
|
||
" 3.088 0.157 19.677 0.0 2.78 3.396\n",
|
||
" Distribution of Point Estimate \n",
|
||
"=========================================\n",
|
||
"std_point pct_point_lower pct_point_upper\n",
|
||
"-----------------------------------------\n",
|
||
" 1.757 0.846 6.962\n",
|
||
" Total Variance of Point Estimate \n",
|
||
"==========================================\n",
|
||
"stderr_point ci_point_lower ci_point_upper\n",
|
||
"------------------------------------------\n",
|
||
" 1.764 0.774 6.951\n",
|
||
" Doubly Robust ATE on Training Data Results \n",
|
||
"=========================================================\n",
|
||
" point_estimate stderr zstat pvalue ci_lower ci_upper\n",
|
||
"---------------------------------------------------------\n",
|
||
"ATE 3.158 0.082 38.551 0.0 2.997 3.318\n",
|
||
" Doubly Robust ATT(T=0) on Training Data Results \n",
|
||
"=========================================================\n",
|
||
" point_estimate stderr zstat pvalue ci_lower ci_upper\n",
|
||
"---------------------------------------------------------\n",
|
||
"ATT 3.1 0.096 32.322 0.0 2.912 3.288\n",
|
||
" Doubly Robust ATT(T=1) on Training Data Results \n",
|
||
"=========================================================\n",
|
||
" point_estimate stderr zstat pvalue ci_lower ci_upper\n",
|
||
"---------------------------------------------------------\n",
|
||
"ATT 3.218 0.134 23.965 0.0 2.955 3.481\n",
|
||
"---------------------------------------------------------\n",
|
||
"\n",
|
||
"Note: The stderr_mean is a conservative upper bound.\n",
|
||
"\"\"\""
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"est2.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.3. Performance Visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1080x360 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.figure(figsize=(15, 5))\n",
|
||
"plt.subplot(1, 2, 1)\n",
|
||
"plt.title(\"DROrthoForest\")\n",
|
||
"plt.plot(X_test, treatment_effects, label='ORF estimate')\n",
|
||
"expected_te = np.array([exp_te(x_i) for x_i in X_test])\n",
|
||
"plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')\n",
|
||
"plt.fill_between(X_test[:, 0], te_lower, te_upper, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.subplot(1, 2, 2)\n",
|
||
"plt.title(\"CausalForest\")\n",
|
||
"plt.plot(X_test, treatment_effects2, label='ORF estimate')\n",
|
||
"expected_te = np.array([exp_te(x_i) for x_i in X_test])\n",
|
||
"plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')\n",
|
||
"plt.fill_between(X_test[:, 0], te_lower2, te_upper2, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 3. Example Usage with Multiple Treatment Synthetic Data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3.1. DGP \n",
|
||
"We use the following DGP:\n",
|
||
"\n",
|
||
"\\begin{align}\n",
|
||
"Y = & \\sum_{t=1}^{n_{\\text{treatments}}} 1\\{T=t\\}\\cdot \\theta_{T}(X) + \\langle W, \\gamma\\rangle + \\epsilon, \\; \\epsilon \\sim \\text{Unif}(-1, 1), \\\\\n",
|
||
"\\text{Pr}[T=t \\mid W] \\propto & \\exp\\{\\langle W, \\beta_t \\rangle\\}, \\;\\;\\;\\; \\forall t\\in \\{0, 1, \\ldots, n_{\\text{treatments}}\\} \n",
|
||
"\\end{align}\n",
|
||
"\n",
|
||
"where $W$ is a matrix of high-dimensional confounders, $\\beta_t, \\gamma$ are sparse.\n",
|
||
"\n",
|
||
"For this particular example DGP we used $n_{\\text{treatments}}=3$ and \n",
|
||
"\\begin{align}\n",
|
||
"\\theta_1(x) = & \\exp( 2 x_1 ),\\\\\n",
|
||
"\\theta_2(x) = & 3 \\cdot \\sigma(100\\cdot (x_1 - .5)),\\\\\n",
|
||
"\\theta_3(x) = & -2 \\cdot \\sigma(100\\cdot (x_1 - .25)),\n",
|
||
"\\end{align}\n",
|
||
"where $\\sigma$ is the sigmoid function."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_test_train_data(n, n_w, support_size, n_x, te_func, n_treatments):\n",
|
||
" # Outcome support\n",
|
||
" support_Y = np.random.choice(range(n_w), size=support_size, replace=False)\n",
|
||
" coefs_Y = np.random.uniform(0, 1, size=support_size)\n",
|
||
" epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n)\n",
|
||
" # Treatment support \n",
|
||
" support_T = support_Y\n",
|
||
" coefs_T = np.random.uniform(0, 1, size=(support_size, n_treatments))\n",
|
||
" eta_sample = lambda n: np.random.uniform(-1, 1, size=n) \n",
|
||
" # Generate controls, covariates, treatments and outcomes\n",
|
||
" W = np.random.normal(0, 1, size=(n, n_w))\n",
|
||
" X = np.random.uniform(0, 1, size=(n, n_x))\n",
|
||
" # Heterogeneous treatment effects\n",
|
||
" TE = np.array([te_func(x_i, n_treatments) for x_i in X])\n",
|
||
" log_odds = np.dot(W[:, support_T], coefs_T)\n",
|
||
" T_sigmoid = np.exp(log_odds)\n",
|
||
" T_sigmoid = T_sigmoid/np.sum(T_sigmoid, axis=1, keepdims=True)\n",
|
||
" T = np.array([np.random.choice(n_treatments, p=p) for p in T_sigmoid])\n",
|
||
" TE = np.concatenate((np.zeros((n,1)), TE), axis=1)\n",
|
||
" Y = TE[np.arange(n), T] + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)\n",
|
||
" X_test = np.array(list(product(np.arange(0, 1, 0.01), repeat=n_x)))\n",
|
||
"\n",
|
||
" return (Y, T, X, W), (X_test, np.array([te_func(x, n_treatments) for x in X_test]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import scipy.special\n",
|
||
"def te_func(x, n_treatments):\n",
|
||
" return [np.exp(2*x[0]), 3*scipy.special.expit(100*(x[0] - .5)) - 1, -2*scipy.special.expit(100*(x[0] - .25))]\n",
|
||
"\n",
|
||
"np.random.seed(123)\n",
|
||
"(Y, T, X, W), (X_test, te_test) = get_test_train_data(2000, 3, 3, 1, te_func, 4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3.2. Train Estimator"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est = DROrthoForest(n_trees=500, model_Y = WeightedLasso(alpha=lambda_reg))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 31.2s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 112 tasks | elapsed: 32.9s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 272 tasks | elapsed: 35.5s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 500 out of 500 | elapsed: 39.5s finished\n",
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 0.3s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 208 tasks | elapsed: 3.4s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 500 out of 500 | elapsed: 7.9s finished\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<econml.orf._ortho_forest.DROrthoForest at 0x1a7bac95828>"
|
||
]
|
||
},
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"est.fit(Y, T, X=X, W=W)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 20.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 22.8s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate marginal treatment effects\n",
|
||
"treatment_effects = est.const_marginal_effect(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 2.7s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 5.1s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate default (95%) marginal confidence intervals for the test data\n",
|
||
"te_lower, te_upper = est.const_marginal_effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 2.5s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 4.9s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"res = est.const_marginal_effect_inference(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th>point_estimate</th>\n",
|
||
" <th>stderr</th>\n",
|
||
" <th>zstat</th>\n",
|
||
" <th>pvalue</th>\n",
|
||
" <th>ci_lower</th>\n",
|
||
" <th>ci_upper</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>X</th>\n",
|
||
" <th>T</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"3\" valign=\"top\">0</th>\n",
|
||
" <th>T0_1</th>\n",
|
||
" <td>1.013</td>\n",
|
||
" <td>0.159</td>\n",
|
||
" <td>6.360</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>0.701</td>\n",
|
||
" <td>1.325</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_2</th>\n",
|
||
" <td>-0.989</td>\n",
|
||
" <td>0.149</td>\n",
|
||
" <td>-6.636</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>-1.281</td>\n",
|
||
" <td>-0.697</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_3</th>\n",
|
||
" <td>0.034</td>\n",
|
||
" <td>0.226</td>\n",
|
||
" <td>0.152</td>\n",
|
||
" <td>0.879</td>\n",
|
||
" <td>-0.408</td>\n",
|
||
" <td>0.477</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"2\" valign=\"top\">1</th>\n",
|
||
" <th>T0_1</th>\n",
|
||
" <td>1.018</td>\n",
|
||
" <td>0.160</td>\n",
|
||
" <td>6.379</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>0.705</td>\n",
|
||
" <td>1.331</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_2</th>\n",
|
||
" <td>-0.987</td>\n",
|
||
" <td>0.147</td>\n",
|
||
" <td>-6.717</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>-1.276</td>\n",
|
||
" <td>-0.699</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"2\" valign=\"top\">98</th>\n",
|
||
" <th>T0_2</th>\n",
|
||
" <td>1.967</td>\n",
|
||
" <td>0.210</td>\n",
|
||
" <td>9.345</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>1.554</td>\n",
|
||
" <td>2.379</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_3</th>\n",
|
||
" <td>-2.021</td>\n",
|
||
" <td>0.163</td>\n",
|
||
" <td>-12.414</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>-2.340</td>\n",
|
||
" <td>-1.702</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"3\" valign=\"top\">99</th>\n",
|
||
" <th>T0_1</th>\n",
|
||
" <td>6.867</td>\n",
|
||
" <td>0.244</td>\n",
|
||
" <td>28.194</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>6.390</td>\n",
|
||
" <td>7.344</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_2</th>\n",
|
||
" <td>1.966</td>\n",
|
||
" <td>0.212</td>\n",
|
||
" <td>9.276</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>1.551</td>\n",
|
||
" <td>2.382</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>T0_3</th>\n",
|
||
" <td>-2.017</td>\n",
|
||
" <td>0.163</td>\n",
|
||
" <td>-12.352</td>\n",
|
||
" <td>0.000</td>\n",
|
||
" <td>-2.337</td>\n",
|
||
" <td>-1.697</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>300 rows × 6 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" point_estimate stderr zstat pvalue ci_lower ci_upper\n",
|
||
"X T \n",
|
||
"0 T0_1 1.013 0.159 6.360 0.000 0.701 1.325\n",
|
||
" T0_2 -0.989 0.149 -6.636 0.000 -1.281 -0.697\n",
|
||
" T0_3 0.034 0.226 0.152 0.879 -0.408 0.477\n",
|
||
"1 T0_1 1.018 0.160 6.379 0.000 0.705 1.331\n",
|
||
" T0_2 -0.987 0.147 -6.717 0.000 -1.276 -0.699\n",
|
||
"... ... ... ... ... ... ...\n",
|
||
"98 T0_2 1.967 0.210 9.345 0.000 1.554 2.379\n",
|
||
" T0_3 -2.021 0.163 -12.414 0.000 -2.340 -1.702\n",
|
||
"99 T0_1 6.867 0.244 28.194 0.000 6.390 7.344\n",
|
||
" T0_2 1.966 0.212 9.276 0.000 1.551 2.382\n",
|
||
" T0_3 -2.017 0.163 -12.352 0.000 -2.337 -1.697\n",
|
||
"\n",
|
||
"[300 rows x 6 columns]"
|
||
]
|
||
},
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"res.summary_frame()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est2 = CausalForestDML(model_y=Lasso(alpha=lambda_reg),\n",
|
||
" model_t=LogisticRegression(C=1/(X.shape[0]*lambda_reg)),\n",
|
||
" n_estimators=4000, min_samples_leaf=5,\n",
|
||
" max_depth=50, max_samples=subsample_ratio/2,\n",
|
||
" discrete_treatment=True,\n",
|
||
" random_state=123)\n",
|
||
"est2.fit(Y, T, X=X, W=W)\n",
|
||
"treatment_effects2 = est2.const_marginal_effect(X_test)\n",
|
||
"te_lower2, te_upper2 = est2.const_marginal_effect_interval(X_test, alpha=.01)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3.3. Performance Visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1080x360 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.figure(figsize=(15, 5))\n",
|
||
"plt.subplot(1, 2, 1)\n",
|
||
"plt.title(\"DROrthoForest\")\n",
|
||
"y = treatment_effects\n",
|
||
"colors = ['b', 'r', 'g']\n",
|
||
"for it in range(y.shape[1]):\n",
|
||
" plt.plot(X_test[:, 0], te_test[:, it], '--', label='True effect T={}'.format(it), color=colors[it])\n",
|
||
" plt.fill_between(X_test[:, 0], te_lower[:, it], te_upper[:, it], alpha=0.3, color='C{}'.format(it))\n",
|
||
" plt.plot(X_test, y[:, it], label='ORF estimate T={}'.format(it), color='C{}'.format(it))\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.subplot(1, 2, 2)\n",
|
||
"plt.title(\"CausalForest\")\n",
|
||
"y = treatment_effects2\n",
|
||
"colors = ['b', 'r', 'g']\n",
|
||
"for it in range(y.shape[1]):\n",
|
||
" plt.plot(X_test[:, 0], te_test[:, it], '--', label='True effect T={}'.format(it), color=colors[it])\n",
|
||
" plt.fill_between(X_test[:, 0], te_lower2[:, it], te_upper2[:, it], alpha=0.3, color='C{}'.format(it))\n",
|
||
" plt.plot(X_test, y[:, it], label='ORF estimate T={}'.format(it), color='C{}'.format(it))\n",
|
||
"plt.ylabel(\"Treatment Effect\")\n",
|
||
"plt.xlabel(\"x\")\n",
|
||
"plt.legend()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 4. Example Usage with Real Continuous Treatment Observational Data\n",
|
||
"\n",
|
||
"We applied our technique to Dominick’s dataset, a popular historical dataset of store-level orange juice prices and sales provided by University of Chicago Booth School of Business. \n",
|
||
"\n",
|
||
"The dataset is comprised of a large number of covariates $W$, but researchers might only be interested in learning the elasticity of demand as a function of a few variables $x$ such\n",
|
||
"as income or education. \n",
|
||
"\n",
|
||
"We applied the `DMLOrthoForest` to estimate orange juice price elasticity\n",
|
||
"as a function of income, and our results, unveil the natural phenomenon that lower income consumers are more price-sensitive."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4.1. Data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# A few more imports\n",
|
||
"import os\n",
|
||
"import pandas as pd\n",
|
||
"import urllib.request\n",
|
||
"from sklearn.preprocessing import StandardScaler"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>store</th>\n",
|
||
" <th>brand</th>\n",
|
||
" <th>week</th>\n",
|
||
" <th>logmove</th>\n",
|
||
" <th>feat</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>AGE60</th>\n",
|
||
" <th>EDUC</th>\n",
|
||
" <th>ETHNIC</th>\n",
|
||
" <th>INCOME</th>\n",
|
||
" <th>HHLARGE</th>\n",
|
||
" <th>WORKWOM</th>\n",
|
||
" <th>HVAL150</th>\n",
|
||
" <th>SSTRDIST</th>\n",
|
||
" <th>SSTRVOL</th>\n",
|
||
" <th>CPDIST5</th>\n",
|
||
" <th>CPWVOL5</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>tropicana</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>9.018695</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.87</td>\n",
|
||
" <td>0.232865</td>\n",
|
||
" <td>0.248935</td>\n",
|
||
" <td>0.11428</td>\n",
|
||
" <td>10.553205</td>\n",
|
||
" <td>0.103953</td>\n",
|
||
" <td>0.303585</td>\n",
|
||
" <td>0.463887</td>\n",
|
||
" <td>2.110122</td>\n",
|
||
" <td>1.142857</td>\n",
|
||
" <td>1.92728</td>\n",
|
||
" <td>0.376927</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>tropicana</td>\n",
|
||
" <td>46</td>\n",
|
||
" <td>8.723231</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.87</td>\n",
|
||
" <td>0.232865</td>\n",
|
||
" <td>0.248935</td>\n",
|
||
" <td>0.11428</td>\n",
|
||
" <td>10.553205</td>\n",
|
||
" <td>0.103953</td>\n",
|
||
" <td>0.303585</td>\n",
|
||
" <td>0.463887</td>\n",
|
||
" <td>2.110122</td>\n",
|
||
" <td>1.142857</td>\n",
|
||
" <td>1.92728</td>\n",
|
||
" <td>0.376927</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>tropicana</td>\n",
|
||
" <td>47</td>\n",
|
||
" <td>8.253228</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.87</td>\n",
|
||
" <td>0.232865</td>\n",
|
||
" <td>0.248935</td>\n",
|
||
" <td>0.11428</td>\n",
|
||
" <td>10.553205</td>\n",
|
||
" <td>0.103953</td>\n",
|
||
" <td>0.303585</td>\n",
|
||
" <td>0.463887</td>\n",
|
||
" <td>2.110122</td>\n",
|
||
" <td>1.142857</td>\n",
|
||
" <td>1.92728</td>\n",
|
||
" <td>0.376927</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>tropicana</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>8.987197</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.87</td>\n",
|
||
" <td>0.232865</td>\n",
|
||
" <td>0.248935</td>\n",
|
||
" <td>0.11428</td>\n",
|
||
" <td>10.553205</td>\n",
|
||
" <td>0.103953</td>\n",
|
||
" <td>0.303585</td>\n",
|
||
" <td>0.463887</td>\n",
|
||
" <td>2.110122</td>\n",
|
||
" <td>1.142857</td>\n",
|
||
" <td>1.92728</td>\n",
|
||
" <td>0.376927</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>tropicana</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>9.093357</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.87</td>\n",
|
||
" <td>0.232865</td>\n",
|
||
" <td>0.248935</td>\n",
|
||
" <td>0.11428</td>\n",
|
||
" <td>10.553205</td>\n",
|
||
" <td>0.103953</td>\n",
|
||
" <td>0.303585</td>\n",
|
||
" <td>0.463887</td>\n",
|
||
" <td>2.110122</td>\n",
|
||
" <td>1.142857</td>\n",
|
||
" <td>1.92728</td>\n",
|
||
" <td>0.376927</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" store brand week logmove feat price AGE60 EDUC ETHNIC \\\n",
|
||
"0 2 tropicana 40 9.018695 0 3.87 0.232865 0.248935 0.11428 \n",
|
||
"1 2 tropicana 46 8.723231 0 3.87 0.232865 0.248935 0.11428 \n",
|
||
"2 2 tropicana 47 8.253228 0 3.87 0.232865 0.248935 0.11428 \n",
|
||
"3 2 tropicana 48 8.987197 0 3.87 0.232865 0.248935 0.11428 \n",
|
||
"4 2 tropicana 50 9.093357 0 3.87 0.232865 0.248935 0.11428 \n",
|
||
"\n",
|
||
" INCOME HHLARGE WORKWOM HVAL150 SSTRDIST SSTRVOL CPDIST5 \\\n",
|
||
"0 10.553205 0.103953 0.303585 0.463887 2.110122 1.142857 1.92728 \n",
|
||
"1 10.553205 0.103953 0.303585 0.463887 2.110122 1.142857 1.92728 \n",
|
||
"2 10.553205 0.103953 0.303585 0.463887 2.110122 1.142857 1.92728 \n",
|
||
"3 10.553205 0.103953 0.303585 0.463887 2.110122 1.142857 1.92728 \n",
|
||
"4 10.553205 0.103953 0.303585 0.463887 2.110122 1.142857 1.92728 \n",
|
||
"\n",
|
||
" CPWVOL5 \n",
|
||
"0 0.376927 \n",
|
||
"1 0.376927 \n",
|
||
"2 0.376927 \n",
|
||
"3 0.376927 \n",
|
||
"4 0.376927 "
|
||
]
|
||
},
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Import the data\n",
|
||
"file_name = \"oj_large.csv\"\n",
|
||
"\n",
|
||
"if not os.path.isfile(file_name):\n",
|
||
" print(\"Downloading file (this might take a few seconds)...\")\n",
|
||
" urllib.request.urlretrieve(\"https://msalicedatapublic.z5.web.core.windows.net/datasets/OrangeJuice/oj_large.csv\", file_name)\n",
|
||
"oj_data = pd.read_csv(file_name)\n",
|
||
"oj_data.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Prepare data\n",
|
||
"Y = oj_data['logmove'].values\n",
|
||
"T = np.log(oj_data[\"price\"]).values\n",
|
||
"scaler = StandardScaler()\n",
|
||
"W1 = scaler.fit_transform(oj_data[[c for c in oj_data.columns if c not in ['price', 'logmove', 'brand', 'week', 'store']]].values)\n",
|
||
"W2 = pd.get_dummies(oj_data[['brand']]).values\n",
|
||
"W = np.concatenate([W1, W2], axis=1)\n",
|
||
"X = oj_data[['INCOME']].values"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4.2. Train Estimator"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Define some parameters\n",
|
||
"n_trees = 1000\n",
|
||
"min_leaf_size = 50\n",
|
||
"max_depth = 20\n",
|
||
"subsample_ratio = 0.04"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est = DMLOrthoForest(\n",
|
||
" n_trees=n_trees, min_leaf_size=min_leaf_size, max_depth=max_depth, \n",
|
||
" subsample_ratio=subsample_ratio,\n",
|
||
" model_T=Lasso(alpha=0.1),\n",
|
||
" model_Y=Lasso(alpha=0.1),\n",
|
||
" model_T_final=WeightedLassoCVWrapper(cv=3), \n",
|
||
" model_Y_final=WeightedLassoCVWrapper(cv=3)\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {
|
||
"scrolled": true
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 20.3s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 152 tasks | elapsed: 21.0s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 1000 out of 1000 | elapsed: 22.5s finished\n",
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 0.0s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 888 tasks | elapsed: 1.6s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 1000 out of 1000 | elapsed: 2.1s finished\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<econml.orf._ortho_forest.DMLOrthoForest at 0x1a7cdd37588>"
|
||
]
|
||
},
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"est.fit(Y, T, X=X, W=W)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"min_income = 10.0 \n",
|
||
"max_income = 11.1\n",
|
||
"delta = (max_income - min_income) / 100\n",
|
||
"X_test = np.arange(min_income, max_income + delta - 0.001, delta).reshape(-1, 1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 23.0s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 101 out of 101 | elapsed: 35.2s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate marginal treatment effects\n",
|
||
"treatment_effects = est.const_marginal_effect(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
||
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 6.1s\n",
|
||
"[Parallel(n_jobs=-1)]: Done 101 out of 101 | elapsed: 21.3s finished\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Calculate default (95%) marginal confidence intervals for the test data\n",
|
||
"te_upper, te_lower = est.const_marginal_effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"est2 = CausalForestDML(model_y=WeightedLassoCVWrapper(cv=3),\n",
|
||
" model_t=WeightedLassoCVWrapper(cv=3),\n",
|
||
" n_estimators=n_trees, min_samples_leaf=min_leaf_size, max_depth=max_depth,\n",
|
||
" max_samples=subsample_ratio/2,\n",
|
||
" random_state=123)\n",
|
||
"est2.fit(Y, T, X=X, W=W)\n",
|
||
"treatment_effects2 = est2.effect(X_test)\n",
|
||
"te_lower2, te_upper2 = est2.effect_interval(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4.3. Performance Visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1080x360 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Plot Orange Juice elasticity as a function of income\n",
|
||
"plt.figure(figsize=(15, 5))\n",
|
||
"plt.subplot(1, 2, 1)\n",
|
||
"plt.plot(X_test.flatten(), treatment_effects, label=\"OJ Elasticity\")\n",
|
||
"plt.fill_between(X_test.flatten(), te_lower, te_upper, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.xlabel(r'$\\log$(Income)')\n",
|
||
"plt.ylabel('Orange Juice Elasticity')\n",
|
||
"plt.legend()\n",
|
||
"plt.title(\"Orange Juice Elasticity vs Income: DMLOrthoForest\")\n",
|
||
"plt.subplot(1, 2, 2)\n",
|
||
"plt.plot(X_test.flatten(), treatment_effects2, label=\"OJ Elasticity\")\n",
|
||
"plt.fill_between(X_test.flatten(), te_lower2, te_upper2, label=\"95% BLB CI\", alpha=0.3)\n",
|
||
"plt.xlabel(r'$\\log$(Income)')\n",
|
||
"plt.ylabel('Orange Juice Elasticity')\n",
|
||
"plt.legend()\n",
|
||
"plt.title(\"Orange Juice Elasticity vs Income: CausalForest\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3.9.13 64-bit (microsoft store)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.13"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "d2603944574c6ce6e242666bf20bfb1bc23ccfec8e562036b69397ff157ca866"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|