added support for both dot and gml formats

This commit is contained in:
Amit Sharma 2018-09-10 00:29:21 +05:30
Родитель e8e34c8b52
Коммит ee56b2a8a8
3 изменённых файлов: 110 добавлений и 269 удалений

Просмотреть файл

@ -47,25 +47,30 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Z0 Z1 X0 X1 X2 X3 X4 v \\\n",
"0 0.0 0.682129 0.619275 -1.338917 -1.344954 -1.205516 -0.380428 0.0 \n",
"1 0.0 0.780360 1.489020 0.224417 -0.925444 -1.178246 -1.462796 1.0 \n",
"2 0.0 0.749278 0.338573 0.200353 0.144210 0.006207 -0.311147 1.0 \n",
"3 0.0 0.604209 0.104159 -0.659282 0.186016 -0.054680 -0.065804 1.0 \n",
"4 0.0 0.966614 0.441749 -1.867424 -0.851817 -0.843792 -1.123326 0.0 \n",
"0 0.0 0.236335 0.266520 -1.282418 0.602125 0.631051 -0.117089 1.0 \n",
"1 0.0 0.644808 -0.098507 1.124507 1.039807 -1.105611 -0.293460 1.0 \n",
"2 0.0 0.283139 -1.316949 -1.048244 1.198808 -0.362258 2.357232 1.0 \n",
"3 0.0 0.818932 -2.352635 -0.024943 -0.126628 -0.903706 -0.055516 0.0 \n",
"4 0.0 0.485443 0.117577 -0.683765 0.250597 0.002824 0.366975 0.0 \n",
"\n",
" y \n",
"0 -9.677161 \n",
"1 4.996252 \n",
"2 11.352700 \n",
"3 9.375556 \n",
"4 -10.086670 \n",
"0 9.427494 \n",
"1 13.830990 \n",
"2 8.941143 \n",
"3 -11.145470 \n",
"4 -0.987937 \n",
"digraph { v ->y; U[label=\"Unobserved Confounders\"]; U->v; U->y;X0-> v; X1-> v; X2-> v; X3-> v; X4-> v;X0-> y; X1-> y; X2-> y; X3-> y; X4-> y;Z0-> v; Z1-> v;}\n",
"\n",
"\n",
"graph[node[ id \"v\" label \"v\"]node[ id \"y\" label \"y\"]node[ id \"Unobserved Confounders\" label \"Unobserved Confounders\"]edge[source \"v\" target \"y\"]edge[source \"Unobserved Confounders\" target \"v\"]edge[source \"Unobserved Confounders\" target \"y\"]node[ id \"X0\" label \"X0\"] edge[ source \"X0\" target \"v\"] node[ id \"X1\" label \"X1\"] edge[ source \"X1\" target \"v\"] node[ id \"X2\" label \"X2\"] edge[ source \"X2\" target \"v\"] node[ id \"X3\" label \"X3\"] edge[ source \"X3\" target \"v\"] node[ id \"X4\" label \"X4\"] edge[ source \"X4\" target \"v\"]edge[ source \"X0\" target \"y\"] edge[ source \"X1\" target \"y\"] edge[ source \"X2\" target \"y\"] edge[ source \"X3\" target \"y\"] edge[ source \"X4\" target \"y\"]node[ id \"Z0\" label \"Z0\"] edge[ source \"Z0\" target \"v\"] node[ id \"Z1\" label \"Z1\"] edge[ source \"Z1\" target \"v\"]]\n"
]
}
@ -78,7 +83,9 @@
" treatment_is_binary=True)\n",
"df = data[\"df\"]\n",
"print(df.head())\n",
"print(data[\"dot_graph\"])"
"print(data[\"dot_graph\"])\n",
"print(\"\\n\")\n",
"print(data[\"gml_graph\"])"
]
},
{
@ -104,14 +111,34 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model to find the causal effect of treatment v on outcome y\n"
"Error: Pygraphviz cannot be loaded. No module named 'pygraphviz'\n",
"Trying pydot ...\n",
"Error: Pydot cannot be loaded. 'list' object has no attribute 'get_strict'\n"
]
},
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute 'get_strict'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpygraphviz\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpgv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpgv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdirected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: No module named 'pygraphviz'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-6c8a1e7a754d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtreatment\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"treatment_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0moutcome\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"outcome_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dot_graph\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m )\n",
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/do_why.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, treatment, outcome, graph, common_causes, instruments, estimand_type, **kwargs)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outcome\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m \u001b[0mobserved_node_names\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 99\u001b[0m )\n\u001b[1;32m 100\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_common_causes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_common_causes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_treatment\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outcome\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Error: Pydot cannot be loaded. \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mre\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".*graph\\s*\\[.*\\]\\s*\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDiGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_gml\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/mnt/c/Users/amit_/code/dowhy/dowhy/causal_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, treatment_name, outcome_name, graph, common_cause_names, instrument_names, observed_node_names)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpydot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mP\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpydot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph_from_dot_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_strict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrawing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnx_pydot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pydot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'get_strict'"
]
}
],
@ -121,15 +148,24 @@
" data = df,\n",
" treatment=data[\"treatment_name\"],\n",
" outcome=data[\"outcome_name\"],\n",
" graph=data[\"dot_graph\"],\n",
" graph=data[\"dot_graph\"]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.\n",
"Using Matplotlib for plotting\n"
]
}
],
"source": [
"model.view_model()"
]
@ -150,55 +186,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X3', 'Z0', 'X4', 'Z1', 'Unobserved Confounders', 'X2', 'X0', 'X1'}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'no'}\n",
"There are unobserved common causes. Causal effect cannot be identified.\n",
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Estimand type: ate\n",
"### Estimand : 1\n",
"Estimand name: iv\n",
"No such variable found!\n",
"### Estimand : 2\n",
"Estimand name: backdoor\n",
"Estimand expression:\n",
"d \n",
"──(Expectation(y|X3,Z0,X4,Z1,X2,X0,X1))\n",
"dv \n",
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,Z0,X4,Z1,X2,X0,X1,U) = P(y|v,X3,Z0,X4,Z1,X2,X0,X1)\n",
"\n"
]
}
],
"outputs": [],
"source": [
"identified_estimand = model.identify_effect()\n",
"print(identified_estimand)"
@ -206,52 +196,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LinearRegressionEstimator\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
"INFO:dowhy.causal_estimator:b: y~v+X3+Z0+X4+Z1+X2+X0+X1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"*** Causal Estimate ***\n",
"\n",
"## Target estimand\n",
"Estimand type: ate\n",
"### Estimand : 1\n",
"Estimand name: iv\n",
"No such variable found!\n",
"### Estimand : 2\n",
"Estimand name: backdoor\n",
"Estimand expression:\n",
"d \n",
"──(Expectation(y|X3,Z0,X4,Z1,X2,X0,X1))\n",
"dv \n",
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,Z0,X4,Z1,X2,X0,X1,U) = P(y|v,X3,Z0,X4,Z1,X2,X0,X1)\n",
"\n",
"## Realized estimand\n",
"b: y~v+X3+Z0+X4+Z1+X2+X0+X1\n",
"## Estimate\n",
"Value: 9.999999999999842\n",
"\n",
"Causal Estimate is 10.0\n"
]
}
],
"outputs": [],
"source": [
"causal_estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.linear_regression\")\n",
@ -268,26 +215,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model to find the causal effect of treatment v on outcome y\n"
]
}
],
"outputs": [],
"source": [
"# Without graph \n",
"model= CausalModel( \n",
@ -299,7 +231,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -335,38 +267,9 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X3', 'X4', 'X2', 'X0', 'X1', 'U'}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'observed': 'yes'}\n",
"{'label': 'Unobserved Confounders', 'observed': 'no'}\n",
"There are unobserved common causes. Causal effect cannot be identified.\n",
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] y\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
]
}
],
"outputs": [],
"source": [
"identified_estimand = model.identify_effect() "
]
@ -380,49 +283,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LinearRegressionEstimator\n",
"*** Causal Estimate ***\n",
"\n",
"## Target estimand\n",
"Estimand type: ate\n",
"### Estimand : 1\n",
"Estimand name: iv\n",
"No such variable found!\n",
"### Estimand : 2\n",
"Estimand name: backdoor\n",
"Estimand expression:\n",
"d \n",
"──(Expectation(y|X3,X4,X2,X0,X1))\n",
"dv \n",
"Estimand assumption 1, Unconfoundedness: If U→v and U→y then P(y|v,X3,X4,X2,X0,X1,U) = P(y|v,X3,X4,X2,X0,X1)\n",
"\n",
"## Realized estimand\n",
"b: y~v+X3+X4+X2+X0+X1\n",
"## Estimate\n",
"Value: 9.999999999999849\n",
"\n",
"## Statistical Significance\n",
"p-value: 0.0\n",
"\n",
"Causal Estimate is 10.0\n"
]
}
],
"outputs": [],
"source": [
"estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.linear_regression\", \n",
@ -450,28 +313,9 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1+w_random\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Refute: Add a Random Common Cause\n",
"Estimated effect:(9.999999999999849,)\n",
"New effect:(9.9999999999998561,)\n",
"\n"
]
}
],
"outputs": [],
"source": [
"res_random=model.refute_estimate(identified_estimand, estimate, method_name=\"random_common_cause\")\n",
"print(res_random)"
@ -486,28 +330,9 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
"INFO:dowhy.causal_estimator:b: y~placebo+X3+X4+X2+X0+X1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Refute: Use a Placebo Treatment\n",
"Estimated effect:(9.999999999999849,)\n",
"New effect:(-0.039298742866320742,)\n",
"\n"
]
}
],
"outputs": [],
"source": [
"res_placebo=model.refute_estimate(identified_estimand, estimate,\n",
" method_name=\"placebo_treatment_refuter\", placebo_type=\"permute\")\n",
@ -523,28 +348,9 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator\n",
"INFO:dowhy.causal_estimator:b: y~v+X3+X4+X2+X0+X1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Refute: Use a subset of data\n",
"Estimated effect:(9.999999999999849,)\n",
"New effect:(9.9999999999998206,)\n",
"\n"
]
}
],
"outputs": [],
"source": [
"res_subset=model.refute_estimate(identified_estimand, estimate,\n",
" method_name=\"data_subset_refuter\", subset_fraction=0.9)\n",
@ -561,9 +367,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "dowhy",
"language": "python",
"name": "python3"
"name": "dowhy"
},
"language_info": {
"codemirror_mode": {

Просмотреть файл

@ -1,5 +1,5 @@
import logging
import re
import networkx as nx
@ -22,8 +22,41 @@ class CausalGraph:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names)
else:
elif re.match(r".*\.dot", graph):
# load dot file
try:
import pygraphviz as pgv
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
except Exception as e:
print("Pygraphviz cannot be loaded. "+ str(e) + "\nTrying pydot...")
try:
import pydot
self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
except Exception as e:
print("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(r".*\.gml", graph):
self._graph = nx.DiGraph(nx.read_gml(graph))
elif re.match(r".*graph\s*\{.*\}\s*", graph):
try:
import pygraphviz as pgv
self._graph = pgv.AGraph(graph, strict=True, directed=True)
self._graph = nx.drawing.nx_agraph.from_agraph(self._graph)
except Exception as e:
print("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...")
try:
import pydot
P_list = pydot.graph_from_dot_data(graph)
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
except Exception as e:
print("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(".*graph\s*\[.*\]\s*", graph):
self._graph = nx.DiGraph(nx.parse_gml(graph))
else:
print("Error: Please provide graph (as string or text file) in dot or gml format.")
print("Error: Incorrect graph format")
raise ValueError
self._graph = self.add_node_attributes(observed_node_names)
self._graph = self.add_unobserved_common_cause(observed_node_names)

Просмотреть файл

@ -87,7 +87,8 @@ def linear_dataset(beta, num_common_causes, num_samples, num_instruments=0,
"outcome_name": outcome,
"common_causes_names": common_causes,
"instrument_names": instruments,
"dot_graph": gml_graph,
"dot_graph": dot_graph,
"gml_graph": gml_graph,
"ate": ate
}
return ret_dict
@ -130,6 +131,7 @@ def xy_dataset(num_samples, effect=True, sd_error=1):
"time_val": time_var,
"instrument_names": None,
"dot_graph": None,
"gml_graph": None,
"ate": None,
}
return ret_dict