added lalonde ipw estimate
This commit is contained in:
Родитель
b1c8362c1a
Коммит
8a5722ea34
|
@ -0,0 +1,235 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# DoWhy example on the Lalonde dataset\n",
|
||||
"\n",
|
||||
"Thanks to [@mizuy](https://github.com/mizuy) for providing this example. Here we use the Lalonde dataset and apply IPW estimator to it. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: Installing package into ‘/home/amit/R/x86_64-pc-linux-gnu-library/3.4’\n",
|
||||
"(as ‘lib’ is unspecified)\n",
|
||||
"\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The rpy2.ipython extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext rpy2.ipython\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: trying URL 'https://cloud.r-project.org/src/contrib/Matching_4.9-3.tar.gz'\n",
|
||||
"\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: Content type 'application/x-gzip'\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: length 302135 bytes (295 KB)\n",
|
||||
"\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: =\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: \n",
|
||||
"\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: downloaded 295 KB\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: \n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n",
|
||||
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: The downloaded source packages are in\n",
|
||||
"\t‘/tmp/RtmpVW644g/downloaded_packages’\n",
|
||||
" warnings.warn(x, RRuntimeWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array(['Matching', 'MASS', 'tools', 'stats', 'graphics', 'grDevices',\n",
|
||||
" 'utils', 'datasets', 'methods', 'base'], \n",
|
||||
" dtype='<U9')"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os, sys\n",
|
||||
"sys.path.append(os.path.abspath(\"../../\"))\n",
|
||||
"\n",
|
||||
"import dowhy\n",
|
||||
"from dowhy.do_why import CausalModel\n",
|
||||
"from rpy2.robjects import r as R\n",
|
||||
"%load_ext rpy2.ipython\n",
|
||||
"\n",
|
||||
"#%R install.packages(\"Matching\")\n",
|
||||
"%R library(Matching)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Load the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%R data(lalonde)\n",
|
||||
"%R -o lalonde\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run DoWhy analysis: model, identify, estimate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n",
|
||||
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'educ', 'nodegr', 'hisp', 'married', 'black', 'age', 'U'}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model to find the causal effect of treatment treat on outcome re78\n",
|
||||
"There are unobserved common causes. Causal effect cannot be identified.\n",
|
||||
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"PropensityScoreWeightingEstimator\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Propensity Score Weighting Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: re78~treat+educ+nodegr+hisp+married+black+age\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Causal Estimate is 1634.98683597\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"model=CausalModel(\n",
|
||||
" data = lalonde,\n",
|
||||
" treatment='treat',\n",
|
||||
" outcome='re78',\n",
|
||||
" common_causes='nodegr+black+hisp+age+educ+married'.split('+'))\n",
|
||||
"identified_estimand = model.identify_effect()\n",
|
||||
"estimate = model.estimate_effect(identified_estimand,\n",
|
||||
" method_name=\"backdoor.propensity_score_weighting\")\n",
|
||||
"#print(estimate)\n",
|
||||
"print(\"Causal Estimate is \" + str(estimate.value))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sanity check: compare to manual IPW estimate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Causal Estimate is 1634.98683597\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = model._data\n",
|
||||
"ps = df['ps']\n",
|
||||
"y = df['re78']\n",
|
||||
"z = df['treat']\n",
|
||||
"\n",
|
||||
"ey1 = z*y/ps / sum(z/ps)\n",
|
||||
"ey0 = (1-z)*y/(1-ps) / sum((1-z)/(1-ps))\n",
|
||||
"ate = ey1.sum()-ey0.sum()\n",
|
||||
"print(\"Causal Estimate is \" + str(ate))\n",
|
||||
"\n",
|
||||
"# correct -> Causal Estimate is 1634.9868359746906"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Загрузка…
Ссылка в новой задаче