added support for both dot and gml formats
This commit is contained in:
Родитель
e8e34c8b52
Коммит
ee56b2a8a8
|
@ -47,25 +47,30 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Z0 Z1 X0 X1 X2 X3 X4 v \\\n",
|
||||
"0 0.0 0.682129 0.619275 -1.338917 -1.344954 -1.205516 -0.380428 0.0 \n",
|
||||
"1 0.0 0.780360 1.489020 0.224417 -0.925444 -1.178246 -1.462796 1.0 \n",
|
||||
"2 0.0 0.749278 0.338573 0.200353 0.144210 0.006207 -0.311147 1.0 \n",
|
||||
"3 0.0 0.604209 0.104159 -0.659282 0.186016 -0.054680 -0.065804 1.0 \n",
|
||||
"4 0.0 0.966614 0.441749 -1.867424 -0.851817 -0.843792 -1.123326 0.0 \n",
|
||||
"0 0.0 0.236335 0.266520 -1.282418 0.602125 0.631051 -0.117089 1.0 \n",
|
||||
"1 0.0 0.644808 -0.098507 1.124507 1.039807 -1.105611 -0.293460 1.0 \n",
|
||||
"2 0.0 0.283139 -1.316949 -1.048244 1.198808 -0.362258 2.357232 1.0 \n",
|
||||
"3 0.0 0.818932 -2.352635 -0.024943 -0.126628 -0.903706 -0.055516 0.0 \n",
|
||||
"4 0.0 0.485443 0.117577 -0.683765 0.250597 0.002824 0.366975 0.0 \n",
|
||||
"\n",
|
||||
" y \n",
|
||||
"0 -9.677161 \n",
|
||||
"1 4.996252 \n",
|
||||
"2 11.352700 \n",
|
||||
"3 9.375556 \n",
|
||||
"4 -10.086670 \n",
|
||||
"0 9.427494 \n",
|
||||
"1 13.830990 \n",
|
||||
"2 8.941143 \n",
|
||||
"3 -11.145470 \n",
|
||||
"4 -0.987937 \n",
|
||||
"digraph { v ->y; U[label=\"Unobserved Confounders\"]; U->v; U->y;X0-> v; X1-> v; X2-> v; X3-> v; X4-> v;X0-> y; X1-> y; X2-> y; X3-> y; X4-> y;Z0-> v; Z1-> v;}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"graph[node[ id \"v\" label \"v\"]node[ id \"y\" label \"y\"]node[ id \"Unobserved Confounders\" label \"Unobserved Confounders\"]edge[source \"v\" target \"y\"]edge[source \"Unobserved Confounders\" target \"v\"]edge[source \"Unobserved Confounders\" target \"y\"]node[ id \"X0\" label \"X0\"] edge[ source \"X0\" target \"v\"] node[ id \"X1\" label \"X1\"] edge[ source \"X1\" target \"v\"] node[ id \"X2\" label \"X2\"] edge[ source \"X2\" target \"v\"] node[ id \"X3\" label \"X3\"] edge[ source \"X3\" target \"v\"] node[ id \"X4\" label \"X4\"] edge[ source \"X4\" target \"v\"]edge[ source \"X0\" target \"y\"] edge[ source \"X1\" target \"y\"] edge[ source \"X2\" target \"y\"] edge[ source \"X3\" target \"y\"] edge[ source \"X4\" target \"y\"]node[ id \"Z0\" label \"Z0\"] edge[ source \"Z0\" target \"v\"] node[ id \"Z1\" label \"Z1\"] edge[ source \"Z1\" target \"v\"]]\n"
|
||||
]
|
||||
}
|
||||
|
@ -78,7 +83,9 @@
|
|||
" treatment_is_binary=True)\n",
|
||||
"df = data[\"df\"]\n",
|
||||
"print(df.head())\n",
|
||||
"print(data[\"dot_graph\"])"
|
||||
"print(data[\"dot_graph\"])\n",
|
||||
"print(\"\\n\")\n",
|
||||
"print(data[\"gml_graph\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -104,14 +111,34 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model to find the causal effect of treatment v on outcome y\n"
|
||||
"Error: Pygraphviz cannot be loaded. No module named 'pygraphviz'\n",
|
||||
"Trying pydot ...\n",
|
||||
"Error: Pydot cannot be loaded. 'list' object has no attribute 'get_strict'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'list' object has no attribute 'get_strict'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpygraphviz\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpgv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpgv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdirected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mImportError\u001b[0m: No module named 'pygraphviz'",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-6-6c8a1e7a754d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtreatment\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"treatment_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0moutcome\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"outcome_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dot_graph\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m )\n",
|
||||
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/do_why.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, treatment, outcome, graph, common_causes, instruments, estimand_type, **kwargs)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outcome\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m \u001b[0mobserved_node_names\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 99\u001b[0m )\n\u001b[1;32m 100\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_common_causes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_common_causes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_treatment\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outcome\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Error: Pydot cannot be loaded. \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mre\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".*graph\\s*\\[.*\\]\\s*\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDiGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_gml\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpydot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mP\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpydot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph_from_dot_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_strict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrawing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnx_pydot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pydot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'get_strict'"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -121,15 +148,24 @@
|
|||
" data = df,\n",
|
||||
" treatment=data[\"treatment_name\"],\n",
|
||||
" outcome=data[\"outcome_name\"],\n",
|
||||
" graph=data[\"dot_graph\"],\n",
|
||||
" graph=data[\"dot_graph\"]\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.\n",
|
||||
"Using Matplotlib for plotting\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.view_model()"
|
||||
]
|
||||
|
@ -150,55 +186,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X3', 'Z0', 'X4', 'Z1', 'Unobserved Confounders', 'X2', 'X0', 'X1'}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'no'}\n",
|
||||
"There are unobserved common causes. Causal effect cannot be identified.\n",
|
||||
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Estimand type: ate\n",
|
||||
"### Estimand : 1\n",
|
||||
"Estimand name: iv\n",
|
||||
"No such variable found!\n",
|
||||
"### Estimand : 2\n",
|
||||
"Estimand name: backdoor\n",
|
||||
"Estimand expression:\n",
|
||||
"d \n",
|
||||
"──(Expectation(y|X3,Z0,X4,Z1,X2,X0,X1))\n",
|
||||
"dv \n",
|
||||
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,Z0,X4,Z1,X2,X0,X1,U) = P(y|v,X3,Z0,X4,Z1,X2,X0,X1)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"identified_estimand = model.identify_effect()\n",
|
||||
"print(identified_estimand)"
|
||||
|
@ -206,52 +196,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LinearRegressionEstimator\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: y~v+X3+Z0+X4+Z1+X2+X0+X1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"*** Causal Estimate ***\n",
|
||||
"\n",
|
||||
"## Target estimand\n",
|
||||
"Estimand type: ate\n",
|
||||
"### Estimand : 1\n",
|
||||
"Estimand name: iv\n",
|
||||
"No such variable found!\n",
|
||||
"### Estimand : 2\n",
|
||||
"Estimand name: backdoor\n",
|
||||
"Estimand expression:\n",
|
||||
"d \n",
|
||||
"──(Expectation(y|X3,Z0,X4,Z1,X2,X0,X1))\n",
|
||||
"dv \n",
|
||||
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,Z0,X4,Z1,X2,X0,X1,U) = P(y|v,X3,Z0,X4,Z1,X2,X0,X1)\n",
|
||||
"\n",
|
||||
"## Realized estimand\n",
|
||||
"b: y~v+X3+Z0+X4+Z1+X2+X0+X1\n",
|
||||
"## Estimate\n",
|
||||
"Value: 9.999999999999842\n",
|
||||
"\n",
|
||||
"Causal Estimate is 10.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"causal_estimate = model.estimate_effect(identified_estimand,\n",
|
||||
" method_name=\"backdoor.linear_regression\")\n",
|
||||
|
@ -268,26 +215,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model to find the causal effect of treatment v on outcome y\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Without graph \n",
|
||||
"model= CausalModel( \n",
|
||||
|
@ -299,7 +231,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -335,38 +267,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X3', 'X4', 'X2', 'X0', 'X1', 'U'}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'observed': 'yes'}\n",
|
||||
"{'label': 'Unobserved Confounders', 'observed': 'no'}\n",
|
||||
"There are unobserved common causes. Causal effect cannot be identified.\n",
|
||||
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"identified_estimand = model.identify_effect() "
|
||||
]
|
||||
|
@ -380,49 +283,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LinearRegressionEstimator\n",
|
||||
"*** Causal Estimate ***\n",
|
||||
"\n",
|
||||
"## Target estimand\n",
|
||||
"Estimand type: ate\n",
|
||||
"### Estimand : 1\n",
|
||||
"Estimand name: iv\n",
|
||||
"No such variable found!\n",
|
||||
"### Estimand : 2\n",
|
||||
"Estimand name: backdoor\n",
|
||||
"Estimand expression:\n",
|
||||
"d \n",
|
||||
"──(Expectation(y|X3,X4,X2,X0,X1))\n",
|
||||
"dv \n",
|
||||
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,X4,X2,X0,X1,U) = P(y|v,X3,X4,X2,X0,X1)\n",
|
||||
"\n",
|
||||
"## Realized estimand\n",
|
||||
"b: y~v+X3+X4+X2+X0+X1\n",
|
||||
"## Estimate\n",
|
||||
"Value: 9.999999999999849\n",
|
||||
"\n",
|
||||
"## Statistical Significance\n",
|
||||
"p-value: 0.0\n",
|
||||
"\n",
|
||||
"Causal Estimate is 10.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"estimate = model.estimate_effect(identified_estimand,\n",
|
||||
" method_name=\"backdoor.linear_regression\", \n",
|
||||
|
@ -450,28 +313,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1+w_random\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Refute: Add a Random Common Cause\n",
|
||||
"Estimated effect:(9.999999999999849,)\n",
|
||||
"New effect:(9.9999999999998561,)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res_random=model.refute_estimate(identified_estimand, estimate, method_name=\"random_common_cause\")\n",
|
||||
"print(res_random)"
|
||||
|
@ -486,28 +330,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: y~placebo+X3+X4+X2+X0+X1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Refute: Use a Placebo Treatment\n",
|
||||
"Estimated effect:(9.999999999999849,)\n",
|
||||
"New effect:(-0.039298742866320742,)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res_placebo=model.refute_estimate(identified_estimand, estimate,\n",
|
||||
" method_name=\"placebo_treatment_refuter\", placebo_type=\"permute\")\n",
|
||||
|
@ -523,28 +348,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
|
||||
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Refute: Use a subset of data\n",
|
||||
"Estimated effect:(9.999999999999849,)\n",
|
||||
"New effect:(9.9999999999998206,)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res_subset=model.refute_estimate(identified_estimand, estimate,\n",
|
||||
" method_name=\"data_subset_refuter\", subset_fraction=0.9)\n",
|
||||
|
@ -561,9 +367,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "dowhy",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "dowhy"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
|
||||
import re
|
||||
import networkx as nx
|
||||
|
||||
|
||||
|
@ -22,8 +22,41 @@ class CausalGraph:
|
|||
self._graph = nx.DiGraph()
|
||||
self._graph = self.build_graph(common_cause_names,
|
||||
instrument_names)
|
||||
else:
|
||||
elif re.match(r".*\.dot", graph):
|
||||
# load dot file
|
||||
try:
|
||||
import pygraphviz as pgv
|
||||
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
|
||||
except Exception as e:
|
||||
print("Pygraphviz cannot be loaded. "+ str(e) + "\nTrying pydot...")
|
||||
try:
|
||||
import pydot
|
||||
self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
|
||||
except Exception as e:
|
||||
print("Error: Pydot cannot be loaded. " + str(e))
|
||||
raise e
|
||||
elif re.match(r".*\.gml", graph):
|
||||
self._graph = nx.DiGraph(nx.read_gml(graph))
|
||||
elif re.match(r".*graph\s*\{.*\}\s*", graph):
|
||||
try:
|
||||
import pygraphviz as pgv
|
||||
self._graph = pgv.AGraph(graph, strict=True, directed=True)
|
||||
self._graph = nx.drawing.nx_agraph.from_agraph(self._graph)
|
||||
except Exception as e:
|
||||
print("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...")
|
||||
try:
|
||||
import pydot
|
||||
P_list = pydot.graph_from_dot_data(graph)
|
||||
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
|
||||
except Exception as e:
|
||||
print("Error: Pydot cannot be loaded. " + str(e))
|
||||
raise e
|
||||
elif re.match(".*graph\s*\[.*\]\s*", graph):
|
||||
self._graph = nx.DiGraph(nx.parse_gml(graph))
|
||||
else:
|
||||
print("Error: Please provide graph (as string or text file) in dot or gml format.")
|
||||
print("Error: Incorrect graph format")
|
||||
raise ValueError
|
||||
|
||||
self._graph = self.add_node_attributes(observed_node_names)
|
||||
self._graph = self.add_unobserved_common_cause(observed_node_names)
|
||||
|
|
|
@ -87,7 +87,8 @@ def linear_dataset(beta, num_common_causes, num_samples, num_instruments=0,
|
|||
"outcome_name": outcome,
|
||||
"common_causes_names": common_causes,
|
||||
"instrument_names": instruments,
|
||||
"dot_graph": gml_graph,
|
||||
"dot_graph": dot_graph,
|
||||
"gml_graph": gml_graph,
|
||||
"ate": ate
|
||||
}
|
||||
return ret_dict
|
||||
|
@ -130,6 +131,7 @@ def xy_dataset(num_samples, effect=True, sd_error=1):
|
|||
"time_val": time_var,
|
||||
"instrument_names": None,
|
||||
"dot_graph": None,
|
||||
"gml_graph": None,
|
||||
"ate": None,
|
||||
}
|
||||
return ret_dict
|
||||
|
|
Загрузка…
Ссылка в новой задаче