This commit is contained in:
Amit Sharma 2018-08-25 02:42:48 +05:30
Родитель b1c8362c1a
Коммит 8a5722ea34
1 изменённых файлов: 235 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,235 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DoWhy example on the Lalonde dataset\n",
"\n",
"Thanks to [@mizuy](https://github.com/mizuy) for providing this example. Here we use the Lalonde dataset and apply IPW estimator to it. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: Installing package into /home/amit/R/x86_64-pc-linux-gnu-library/3.4\n",
"(as lib is unspecified)\n",
"\n",
" warnings.warn(x, RRuntimeWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The rpy2.ipython extension is already loaded. To reload it, use:\n",
" %reload_ext rpy2.ipython\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: trying URL 'https://cloud.r-project.org/src/contrib/Matching_4.9-3.tar.gz'\n",
"\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: Content type 'application/x-gzip'\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: length 302135 bytes (295 KB)\n",
"\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: =\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: \n",
"\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: downloaded 295 KB\n",
"\n",
"\n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: \n",
" warnings.warn(x, RRuntimeWarning)\n",
"/usr/local/lib/python3.5/dist-packages/rpy2/rinterface/__init__.py:146: RRuntimeWarning: The downloaded source packages are in\n",
"\t/tmp/RtmpVW644g/downloaded_packages\n",
" warnings.warn(x, RRuntimeWarning)\n"
]
},
{
"data": {
"text/plain": [
"array(['Matching', 'MASS', 'tools', 'stats', 'graphics', 'grDevices',\n",
" 'utils', 'datasets', 'methods', 'base'], \n",
" dtype='<U9')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os, sys\n",
"sys.path.append(os.path.abspath(\"../../\"))\n",
"\n",
"import dowhy\n",
"from dowhy.do_why import CausalModel\n",
"from rpy2.robjects import r as R\n",
"%load_ext rpy2.ipython\n",
"\n",
"#%R install.packages(\"Matching\")\n",
"%R library(Matching)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load the data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"%R data(lalonde)\n",
"%R -o lalonde\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run DoWhy analysis: model, identify, estimate"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n",
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'educ', 'nodegr', 'hisp', 'married', 'black', 'age', 'U'}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model to find the causal effect of treatment treat on outcome re78\n",
"There are unobserved common causes. Causal effect cannot be identified.\n",
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"PropensityScoreWeightingEstimator\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Propensity Score Weighting Estimator\n",
"INFO:dowhy.causal_estimator:b: re78~treat+educ+nodegr+hisp+married+black+age\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Causal Estimate is 1634.98683597\n"
]
}
],
"source": [
"\n",
"model=CausalModel(\n",
" data = lalonde,\n",
" treatment='treat',\n",
" outcome='re78',\n",
" common_causes='nodegr+black+hisp+age+educ+married'.split('+'))\n",
"identified_estimand = model.identify_effect()\n",
"estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.propensity_score_weighting\")\n",
"#print(estimate)\n",
"print(\"Causal Estimate is \" + str(estimate.value))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sanity check: compare to manual IPW estimate"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Causal Estimate is 1634.98683597\n"
]
}
],
"source": [
"df = model._data\n",
"ps = df['ps']\n",
"y = df['re78']\n",
"z = df['treat']\n",
"\n",
"ey1 = z*y/ps / sum(z/ps)\n",
"ey0 = (1-z)*y/(1-ps) / sum((1-z)/(1-ps))\n",
"ate = ey1.sum()-ey0.sum()\n",
"print(\"Causal Estimate is \" + str(ate))\n",
"\n",
"# correct -> Causal Estimate is 1634.9868359746906"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}