Merge pull request #693 from py-why/functional_api/estimate_effect_function
Functional api/estimate effect function
This commit is contained in:
Коммит
05bfa49dac
|
@ -62,6 +62,12 @@
|
|||
" refute_estimate,\n",
|
||||
") # import refuters\n",
|
||||
"\n",
|
||||
"from dowhy.causal_estimators.propensity_score_matching_estimator import PropensityScoreMatchingEstimator\n",
|
||||
"\n",
|
||||
"from dowhy.utils.api import parse_state\n",
|
||||
"\n",
|
||||
"from dowhy.causal_estimator import estimate_effect # Estimate effect function\n",
|
||||
"\n",
|
||||
"from dowhy.causal_graph import CausalGraph\n",
|
||||
"\n",
|
||||
"# Other imports required\n",
|
||||
|
@ -122,7 +128,9 @@
|
|||
")\n",
|
||||
"\n",
|
||||
"treatment_name = data[\"treatment_name\"]\n",
|
||||
"print(treatment_name)\n",
|
||||
"outcome_name = data[\"outcome_name\"]\n",
|
||||
"print(outcome_name)\n",
|
||||
"\n",
|
||||
"graph = CausalGraph(\n",
|
||||
" treatment_name=treatment_name,\n",
|
||||
|
@ -171,9 +179,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Estimate Effect\n",
|
||||
"\n",
|
||||
"Estimate Effect is performed by using the causal_model api as there is not functional equivalent yet"
|
||||
"## Estimate Effect - Functional API (Preview)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -182,10 +188,30 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We will still need CausalModel as the Functional Effect Estimation is still Work-In-Progress\n",
|
||||
"causal_model = CausalModel(data=data[\"df\"], treatment=treatment_name, outcome=outcome_name, graph=data[\"gml_graph\"])\n",
|
||||
"# Basic Estimate Effect function\n",
|
||||
"\n",
|
||||
"estimate = causal_model.estimate_effect(identified_estimand, method_name=\"backdoor.propensity_score_matching\")\n",
|
||||
"\n",
|
||||
"propensity_score_estimator = PropensityScoreMatchingEstimator(\n",
|
||||
" data=data[\"df\"],\n",
|
||||
" identified_estimand=identified_estimand,\n",
|
||||
" treatment=treatment_name,\n",
|
||||
" outcome=outcome_name,\n",
|
||||
" control_value=0,\n",
|
||||
" treatment_value=1,\n",
|
||||
" test_significance=None,\n",
|
||||
" evaluate_effect_strength=False,\n",
|
||||
" confidence_intervals=False,\n",
|
||||
" target_units=\"ate\",\n",
|
||||
" effect_modifiers=graph.get_effect_modifiers(treatment_name, outcome_name),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"estimate = estimate_effect(\n",
|
||||
" treatment=treatment_name,\n",
|
||||
" outcome=outcome_name,\n",
|
||||
" identified_estimand=identified_estimand,\n",
|
||||
" identifier_name=\"backdoor\",\n",
|
||||
" method=propensity_score_estimator,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(estimate)"
|
||||
]
|
||||
|
@ -236,6 +262,16 @@
|
|||
"This section shows replicating the same results using only the CausalModel API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create Causal Model\n",
|
||||
"causal_model = CausalModel(data=data[\"df\"], treatment=treatment_name, outcome=outcome_name, graph=data[\"gml_graph\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -297,6 +333,13 @@
|
|||
")\n",
|
||||
"print(data_subset_refutation_causal_model_api)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -7,8 +8,13 @@ import sympy as sp
|
|||
from sklearn.utils import resample
|
||||
|
||||
import dowhy.interpreters as interpreters
|
||||
from dowhy import causal_estimators
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand
|
||||
from dowhy.utils.api import parse_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CausalEstimator:
|
||||
"""Base class for an estimator of causal effect.
|
||||
|
@ -124,7 +130,6 @@ class CausalEstimator:
|
|||
self._treatment = self._data[self._treatment_name]
|
||||
self._outcome = self._data[self._outcome_name]
|
||||
|
||||
# Now saving the effect modifiers
|
||||
if self._effect_modifier_names:
|
||||
# only add the observed nodes
|
||||
self._effect_modifier_names = [
|
||||
|
@ -178,7 +183,7 @@ class CausalEstimator:
|
|||
confidence_intervals=estimate.params["confidence_intervals"],
|
||||
target_units=estimate.params["target_units"],
|
||||
effect_modifiers=estimate.params["effect_modifiers"],
|
||||
**estimate.params["method_params"],
|
||||
**estimate.params["method_params"] if estimate.params["method_params"] is not None else {},
|
||||
)
|
||||
|
||||
return new_estimator
|
||||
|
@ -197,6 +202,7 @@ class CausalEstimator:
|
|||
:param self: object instance of class Estimator
|
||||
:returns: A CausalEstimate instance that contains point estimates of average and conditional effects. Based on the parameters provided, it optionally includes confidence intervals, standard errors,statistical significance and other statistical parameters.
|
||||
"""
|
||||
|
||||
est = self._estimate_effect()
|
||||
est.add_estimator(self)
|
||||
|
||||
|
@ -679,6 +685,90 @@ class CausalEstimator:
|
|||
return s
|
||||
|
||||
|
||||
def estimate_effect(
|
||||
treatment: Union[str, List[str]],
|
||||
outcome: Union[str, List[str]],
|
||||
identified_estimand: IdentifiedEstimand,
|
||||
identifier_name: str,
|
||||
method: CausalEstimator,
|
||||
control_value: int = 0,
|
||||
treatment_value: int = 1,
|
||||
test_significance: Optional[bool] = None,
|
||||
evaluate_effect_strength: bool = False,
|
||||
confidence_intervals: bool = False,
|
||||
target_units: str = "ate",
|
||||
effect_modifiers: List[str] = [],
|
||||
fit_estimator: bool = True,
|
||||
method_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Estimate the identified causal effect.
|
||||
|
||||
Currently requires an explicit method name to be specified. Method names follow the convention of identification method followed by the specific estimation method: "[backdoor/iv].estimation_method_name". Following methods are supported.
|
||||
* Propensity Score Matching: "backdoor.propensity_score_matching"
|
||||
* Propensity Score Stratification: "backdoor.propensity_score_stratification"
|
||||
* Propensity Score-based Inverse Weighting: "backdoor.propensity_score_weighting"
|
||||
* Linear Regression: "backdoor.linear_regression"
|
||||
* Generalized Linear Models (e.g., logistic regression): "backdoor.generalized_linear_model"
|
||||
* Instrumental Variables: "iv.instrumental_variable"
|
||||
* Regression Discontinuity: "iv.regression_discontinuity"
|
||||
|
||||
In addition, you can directly call any of the EconML estimation methods. The convention is "backdoor.econml.path-to-estimator-class". For example, for the double machine learning estimator ("DML" class) that is located inside "dml" module of EconML, you can use the method name, "backdoor.econml.dml.DML". CausalML estimators can also be called. See `this demo notebook <https://py-why.github.io/dowhy/example_notebooks/dowhy-conditional-treatment-effects.html>`_.
|
||||
|
||||
:param treatment: Name of the treatment
|
||||
:param outcome: Name of the outcome
|
||||
:param identified_estimand: a probability expression
|
||||
that represents the effect to be estimated. Output of
|
||||
CausalModel.identify_effect method
|
||||
:param method_name: name of the estimation method to be used.
|
||||
:param control_value: Value of the treatment in the control group, for effect estimation. If treatment is multi-variate, this can be a list.
|
||||
:param treatment_value: Value of the treatment in the treated group, for effect estimation. If treatment is multi-variate, this can be a list.
|
||||
:param test_significance: Binary flag on whether to additionally do a statistical signficance test for the estimate.
|
||||
:param evaluate_effect_strength: (Experimental) Binary flag on whether to estimate the relative strength of the treatment's effect. This measure can be used to compare different treatments for the same outcome (by running this method with different treatments sequentially).
|
||||
:param confidence_intervals: (Experimental) Binary flag indicating whether confidence intervals should be computed.
|
||||
:param target_units: (Experimental) The units for which the treatment effect should be estimated. This can be of three types. (1) a string for common specifications of target units (namely, "ate", "att" and "atc"), (2) a lambda function that can be used as an index for the data (pandas DataFrame), or (3) a new DataFrame that contains values of the effect_modifiers and effect will be estimated only for this new data.
|
||||
:param effect_modifiers: Names of effect modifier variables can be (optionally) specified here too, since they do not affect identification. If None, the effect_modifiers from the CausalModel are used.
|
||||
:param fit_estimator: Boolean flag on whether to fit the estimator.
|
||||
Setting it to False is useful to estimate the effect on new data using a previously fitted estimator.
|
||||
:param method_params: Dictionary containing any method-specific parameters. These are passed directly to the estimating method. See the docs for each estimation method for allowed method-specific params.
|
||||
:returns: An instance of the CausalEstimate class, containing the causal effect estimate
|
||||
and other method-dependent information
|
||||
|
||||
"""
|
||||
treatment = parse_state(treatment)
|
||||
outcome = parse_state(outcome)
|
||||
causal_estimator_class = method.__class__
|
||||
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
|
||||
if identified_estimand.no_directed_path:
|
||||
logger.warning("No directed path from {0} to {1}.".format(treatment, outcome))
|
||||
return CausalEstimate(
|
||||
0, identified_estimand, None, control_value=control_value, treatment_value=treatment_value
|
||||
)
|
||||
# Check if estimator's target estimand is identified
|
||||
elif identified_estimand.estimands[identifier_name] is None:
|
||||
logger.error("No valid identified estimand available.")
|
||||
return CausalEstimate(None, None, None, control_value=control_value, treatment_value=treatment_value)
|
||||
|
||||
method.update_input(treatment_value, control_value, target_units)
|
||||
|
||||
estimate = method.estimate_effect()
|
||||
# Store parameters inside estimate object for refutation methods
|
||||
# TODO: This add_params needs to move to the estimator class
|
||||
# inside estimate_effect and estimate_conditional_effect
|
||||
estimate.add_params(
|
||||
estimand_type=identified_estimand.estimand_type,
|
||||
estimator_class=causal_estimator_class,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
method_params=method_params,
|
||||
)
|
||||
return estimate
|
||||
|
||||
|
||||
class CausalEstimate:
|
||||
"""Class for the estimate object that every causal estimator returns"""
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import dowhy.causal_estimators as causal_estimators
|
|||
import dowhy.causal_refuters as causal_refuters
|
||||
import dowhy.graph_learners as graph_learners
|
||||
import dowhy.utils.cli_helpers as cli
|
||||
from dowhy.causal_estimator import CausalEstimate
|
||||
from dowhy.causal_estimator import CausalEstimate, estimate_effect
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, IDIdentifier
|
||||
from dowhy.causal_identifier.identify_effect import EstimandType
|
||||
|
@ -259,96 +259,81 @@ class CausalModel:
|
|||
and other method-dependent information
|
||||
|
||||
"""
|
||||
if effect_modifiers is None:
|
||||
if self._effect_modifiers is None or len(self._effect_modifiers) == 0:
|
||||
effect_modifiers = self.get_effect_modifiers()
|
||||
else:
|
||||
effect_modifiers = self._effect_modifiers
|
||||
if fit_estimator:
|
||||
if method_name is None:
|
||||
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
|
||||
pass
|
||||
else:
|
||||
# TODO add dowhy as a prefix to all dowhy estimators
|
||||
num_components = len(method_name.split("."))
|
||||
str_arr = method_name.split(".", maxsplit=1)
|
||||
identifier_name = str_arr[0]
|
||||
estimator_name = str_arr[1]
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
# This is done as all dowhy estimators have two parts and external ones have two or more parts
|
||||
if num_components > 2:
|
||||
estimator_package = estimator_name.split(".")[0]
|
||||
if estimator_package == "dowhy": # For updated dowhy methods
|
||||
estimator_method = estimator_name.split(".", maxsplit=1)[
|
||||
1
|
||||
] # discard dowhy from the full package name
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
|
||||
else:
|
||||
third_party_estimator_package = estimator_package
|
||||
causal_estimator_class = causal_estimators.get_class_object(
|
||||
third_party_estimator_package, estimator_name
|
||||
)
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
# Define the third-party estimation method to be used
|
||||
method_params[third_party_estimator_package + "_methodname"] = estimator_name
|
||||
else: # For older dowhy methods
|
||||
self.logger.info(estimator_name)
|
||||
# Process the dowhy estimators
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
|
||||
if identified_estimand.no_directed_path:
|
||||
self.logger.warning("No directed path from {0} to {1}.".format(self._treatment, self._outcome))
|
||||
return CausalEstimate(
|
||||
0, identified_estimand, None, control_value=control_value, treatment_value=treatment_value
|
||||
)
|
||||
# Check if estimator's target estimand is identified
|
||||
elif identified_estimand.estimands[identifier_name] is None:
|
||||
self.logger.error("No valid identified estimand available.")
|
||||
return CausalEstimate(None, None, None, control_value=control_value, treatment_value=treatment_value)
|
||||
else:
|
||||
|
||||
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
|
||||
extra_args = method_params.get("init_params", {})
|
||||
else:
|
||||
extra_args = {}
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
self.causal_estimator = causal_estimator_class(
|
||||
self._data,
|
||||
identified_estimand,
|
||||
self._treatment,
|
||||
self._outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
if effect_modifiers is None or len(effect_modifiers) == 0:
|
||||
effect_modifiers = self._graph.get_effect_modifiers(self._treatment, self._outcome)
|
||||
|
||||
if method_name is None:
|
||||
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
|
||||
pass
|
||||
else:
|
||||
# Estimator had been computed in a previous call
|
||||
assert self.causal_estimator is not None
|
||||
causal_estimator_class = self.causal_estimator.__class__
|
||||
self.causal_estimator.update_input(treatment_value, control_value, target_units)
|
||||
# TODO add dowhy as a prefix to all dowhy estimators
|
||||
num_components = len(method_name.split("."))
|
||||
str_arr = method_name.split(".", maxsplit=1)
|
||||
identifier_name = str_arr[0]
|
||||
estimator_name = str_arr[1]
|
||||
# This is done as all dowhy estimators have two parts and external ones have two or more parts
|
||||
if num_components > 2:
|
||||
estimator_package = estimator_name.split(".")[0]
|
||||
if estimator_package == "dowhy": # For updated dowhy methods
|
||||
estimator_method = estimator_name.split(".", maxsplit=1)[
|
||||
1
|
||||
] # discard dowhy from the full package name
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
|
||||
else:
|
||||
third_party_estimator_package = estimator_package
|
||||
causal_estimator_class = causal_estimators.get_class_object(
|
||||
third_party_estimator_package, estimator_name
|
||||
)
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
# Define the third-party estimation method to be used
|
||||
method_params[third_party_estimator_package + "_methodname"] = estimator_name
|
||||
else: # For older dowhy methods
|
||||
self.logger.info(estimator_name)
|
||||
# Process the dowhy estimators
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
|
||||
|
||||
estimate = self.causal_estimator.estimate_effect()
|
||||
# Store parameters inside estimate object for refutation methods
|
||||
# TODO: This add_params needs to move to the estimator class
|
||||
# inside estimate_effect and estimate_conditional_effect
|
||||
estimate.add_params(
|
||||
estimand_type=identified_estimand.estimand_type,
|
||||
estimator_class=causal_estimator_class,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
method_params=method_params,
|
||||
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
|
||||
extra_args = method_params.get("init_params", {})
|
||||
else:
|
||||
extra_args = {}
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
causal_estimator = causal_estimator_class(
|
||||
self._data,
|
||||
identified_estimand,
|
||||
self._treatment,
|
||||
self._outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return estimate_effect(
|
||||
self._treatment,
|
||||
self._outcome,
|
||||
identified_estimand,
|
||||
identifier_name,
|
||||
causal_estimator,
|
||||
control_value,
|
||||
treatment_value,
|
||||
test_significance,
|
||||
evaluate_effect_strength,
|
||||
confidence_intervals,
|
||||
target_units,
|
||||
effect_modifiers,
|
||||
fit_estimator,
|
||||
method_params,
|
||||
)
|
||||
return estimate
|
||||
|
||||
def do(self, x, identified_estimand, method_name=None, fit_estimator=True, method_params=None):
|
||||
"""Do operator for estimating values of the outcome after intervening on treatment.
|
||||
|
|
Загрузка…
Ссылка в новой задаче