Signed-off-by: Andres Morales <andresmor@microsoft.com>
This commit is contained in:
Andres Morales 2022-10-19 12:01:42 -06:00
Родитель 196c64ac0c
Коммит 6995e467db
3 изменённых файлов: 102 добавлений и 74 удалений

Просмотреть файл

@ -62,6 +62,10 @@
" refute_estimate,\n",
") # import refuters\n",
"\n",
"from dowhy.causal_estimators.propensity_score_matching_estimator import PropensityScoreMatchingEstimator\n",
"\n",
"from dowhy.utils.api import parse_state\n",
"\n",
"from dowhy.causal_estimator import estimate_effect # Estimate effect function\n",
"\n",
"from dowhy.causal_graph import CausalGraph\n",
@ -124,7 +128,9 @@
")\n",
"\n",
"treatment_name = data[\"treatment_name\"]\n",
"print(treatment_name)\n",
"outcome_name = data[\"outcome_name\"]\n",
"print(outcome_name)\n",
"\n",
"graph = CausalGraph(\n",
" treatment_name=treatment_name,\n",
@ -184,13 +190,27 @@
"source": [
"# Basic Estimate Effect function\n",
"\n",
"\n",
"propensity_score_estimator = PropensityScoreMatchingEstimator(\n",
" data=data[\"df\"],\n",
" identified_estimand=identified_estimand,\n",
" treatment=treatment_name,\n",
" outcome=outcome_name,\n",
" control_value=0,\n",
" treatment_value=1,\n",
" test_significance=None,\n",
" evaluate_effect_strength=False,\n",
" confidence_intervals=False,\n",
" target_units=\"ate\",\n",
" effect_modifiers=graph.get_effect_modifiers(treatment_name, outcome_name),\n",
")\n",
"\n",
"estimate = estimate_effect(\n",
" data[\"df\"],\n",
" treatment_name,\n",
" outcome_name,\n",
" identified_estimand,\n",
" graph,\n",
" method_name=\"backdoor.propensity_score_matching\",\n",
" treatment=treatment_name,\n",
" outcome=outcome_name,\n",
" identified_estimand=identified_estimand,\n",
" identifier_name=\"backdoor\",\n",
" method=propensity_score_estimator,\n",
")\n",
"\n",
"print(estimate)"
@ -313,6 +333,13 @@
")\n",
"print(data_subset_refutation_causal_model_api)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

Просмотреть файл

@ -183,7 +183,7 @@ class CausalEstimator:
confidence_intervals=estimate.params["confidence_intervals"],
target_units=estimate.params["target_units"],
effect_modifiers=estimate.params["effect_modifiers"],
**estimate.params["method_params"],
**estimate.params["method_params"] if estimate.params["method_params"] is not None else {},
)
return new_estimator
@ -686,19 +686,18 @@ class CausalEstimator:
def estimate_effect(
data: pd.DataFrame,
treatment: Union[str, List[str]],
outcome: Union[str, List[str]],
identified_estimand: IdentifiedEstimand,
graph: CausalGraph,
method_name: Optional[str] = None,
identifier_name: str,
method: CausalEstimator,
control_value: int = 0,
treatment_value: int = 1,
test_significance: Optional[bool] = None,
evaluate_effect_strength: bool = False,
confidence_intervals: bool = False,
target_units: str = "ate",
effect_modifiers: List[str] = None,
effect_modifiers: List[str] = [],
fit_estimator: bool = True,
method_params: Optional[Dict] = None,
):
@ -736,42 +735,11 @@ def estimate_effect(
"""
treatment = parse_state(treatment)
outcome = parse_state(outcome)
causal_estimator_class = method.__class__
if effect_modifiers is None or len(effect_modifiers) == 0:
effect_modifiers = graph.get_effect_modifiers(treatment, outcome)
identified_estimand.set_identifier_method(identifier_name)
if fit_estimator:
if method_name is None:
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
pass
else:
# TODO add dowhy as a prefix to all dowhy estimators
num_components = len(method_name.split("."))
str_arr = method_name.split(".", maxsplit=1)
identifier_name = str_arr[0]
estimator_name = str_arr[1]
identified_estimand.set_identifier_method(identifier_name)
# This is done as all dowhy estimators have two parts and external ones have two or more parts
if num_components > 2:
estimator_package = estimator_name.split(".")[0]
if estimator_package == "dowhy": # For updated dowhy methods
estimator_method = estimator_name.split(".", maxsplit=1)[
1
] # discard dowhy from the full package name
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
else:
third_party_estimator_package = estimator_package
causal_estimator_class = causal_estimators.get_class_object(
third_party_estimator_package, estimator_name
)
if method_params is None:
method_params = {}
# Define the third-party estimation method to be used
method_params[third_party_estimator_package + "_methodname"] = estimator_name
else: # For older dowhy methods
logger.info(estimator_name)
# Process the dowhy estimators
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
if identified_estimand.no_directed_path:
logger.warning("No directed path from {0} to {1}.".format(treatment, outcome))
return CausalEstimate(
@ -781,36 +749,11 @@ def estimate_effect(
elif identified_estimand.estimands[identifier_name] is None:
logger.error("No valid identified estimand available.")
return CausalEstimate(None, None, None, control_value=control_value, treatment_value=treatment_value)
else:
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
extra_args = method_params.get("init_params", {})
else:
extra_args = {}
if method_params is None:
method_params = {}
causal_estimator = causal_estimator_class(
data,
identified_estimand,
treatment,
outcome, # names of treatment and outcome
control_value=control_value,
treatment_value=treatment_value,
test_significance=test_significance,
evaluate_effect_strength=evaluate_effect_strength,
confidence_intervals=confidence_intervals,
target_units=target_units,
effect_modifiers=effect_modifiers,
**method_params,
**extra_args,
)
else:
# Estimator had been computed in a previous call
assert causal_estimator is not None
causal_estimator_class = causal_estimator.__class__
causal_estimator.update_input(treatment_value, control_value, target_units)
method.update_input(treatment_value, control_value, target_units)
estimate = causal_estimator.estimate_effect()
estimate = method.estimate_effect()
# Store parameters inside estimate object for refutation methods
# TODO: This add_params needs to move to the estimator class
# inside estimate_effect and estimate_conditional_effect

Просмотреть файл

@ -259,13 +259,71 @@ class CausalModel:
and other method-dependent information
"""
if effect_modifiers is None or len(effect_modifiers) == 0:
effect_modifiers = self._graph.get_effect_modifiers(self._treatment, self._outcome)
if method_name is None:
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
pass
else:
# TODO add dowhy as a prefix to all dowhy estimators
num_components = len(method_name.split("."))
str_arr = method_name.split(".", maxsplit=1)
identifier_name = str_arr[0]
estimator_name = str_arr[1]
# This is done as all dowhy estimators have two parts and external ones have two or more parts
if num_components > 2:
estimator_package = estimator_name.split(".")[0]
if estimator_package == "dowhy": # For updated dowhy methods
estimator_method = estimator_name.split(".", maxsplit=1)[
1
] # discard dowhy from the full package name
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
else:
third_party_estimator_package = estimator_package
causal_estimator_class = causal_estimators.get_class_object(
third_party_estimator_package, estimator_name
)
if method_params is None:
method_params = {}
# Define the third-party estimation method to be used
method_params[third_party_estimator_package + "_methodname"] = estimator_name
else: # For older dowhy methods
self.logger.info(estimator_name)
# Process the dowhy estimators
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
extra_args = method_params.get("init_params", {})
else:
extra_args = {}
if method_params is None:
method_params = {}
identified_estimand.set_identifier_method(identifier_name)
causal_estimator = causal_estimator_class(
self._data,
identified_estimand,
self._treatment,
self._outcome, # names of treatment and outcome
control_value=control_value,
treatment_value=treatment_value,
test_significance=test_significance,
evaluate_effect_strength=evaluate_effect_strength,
confidence_intervals=confidence_intervals,
target_units=target_units,
effect_modifiers=effect_modifiers,
**method_params,
**extra_args,
)
return estimate_effect(
self._data,
self._treatment,
self._outcome,
identified_estimand,
self._graph,
method_name,
identifier_name,
causal_estimator,
control_value,
treatment_value,
test_significance,