method_name to class instance
Signed-off-by: Andres Morales <andresmor@microsoft.com>
This commit is contained in:
Родитель
196c64ac0c
Коммит
6995e467db
|
@ -62,6 +62,10 @@
|
|||
" refute_estimate,\n",
|
||||
") # import refuters\n",
|
||||
"\n",
|
||||
"from dowhy.causal_estimators.propensity_score_matching_estimator import PropensityScoreMatchingEstimator\n",
|
||||
"\n",
|
||||
"from dowhy.utils.api import parse_state\n",
|
||||
"\n",
|
||||
"from dowhy.causal_estimator import estimate_effect # Estimate effect function\n",
|
||||
"\n",
|
||||
"from dowhy.causal_graph import CausalGraph\n",
|
||||
|
@ -124,7 +128,9 @@
|
|||
")\n",
|
||||
"\n",
|
||||
"treatment_name = data[\"treatment_name\"]\n",
|
||||
"print(treatment_name)\n",
|
||||
"outcome_name = data[\"outcome_name\"]\n",
|
||||
"print(outcome_name)\n",
|
||||
"\n",
|
||||
"graph = CausalGraph(\n",
|
||||
" treatment_name=treatment_name,\n",
|
||||
|
@ -184,13 +190,27 @@
|
|||
"source": [
|
||||
"# Basic Estimate Effect function\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"propensity_score_estimator = PropensityScoreMatchingEstimator(\n",
|
||||
" data=data[\"df\"],\n",
|
||||
" identified_estimand=identified_estimand,\n",
|
||||
" treatment=treatment_name,\n",
|
||||
" outcome=outcome_name,\n",
|
||||
" control_value=0,\n",
|
||||
" treatment_value=1,\n",
|
||||
" test_significance=None,\n",
|
||||
" evaluate_effect_strength=False,\n",
|
||||
" confidence_intervals=False,\n",
|
||||
" target_units=\"ate\",\n",
|
||||
" effect_modifiers=graph.get_effect_modifiers(treatment_name, outcome_name),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"estimate = estimate_effect(\n",
|
||||
" data[\"df\"],\n",
|
||||
" treatment_name,\n",
|
||||
" outcome_name,\n",
|
||||
" identified_estimand,\n",
|
||||
" graph,\n",
|
||||
" method_name=\"backdoor.propensity_score_matching\",\n",
|
||||
" treatment=treatment_name,\n",
|
||||
" outcome=outcome_name,\n",
|
||||
" identified_estimand=identified_estimand,\n",
|
||||
" identifier_name=\"backdoor\",\n",
|
||||
" method=propensity_score_estimator,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(estimate)"
|
||||
|
@ -313,6 +333,13 @@
|
|||
")\n",
|
||||
"print(data_subset_refutation_causal_model_api)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -183,7 +183,7 @@ class CausalEstimator:
|
|||
confidence_intervals=estimate.params["confidence_intervals"],
|
||||
target_units=estimate.params["target_units"],
|
||||
effect_modifiers=estimate.params["effect_modifiers"],
|
||||
**estimate.params["method_params"],
|
||||
**estimate.params["method_params"] if estimate.params["method_params"] is not None else {},
|
||||
)
|
||||
|
||||
return new_estimator
|
||||
|
@ -686,19 +686,18 @@ class CausalEstimator:
|
|||
|
||||
|
||||
def estimate_effect(
|
||||
data: pd.DataFrame,
|
||||
treatment: Union[str, List[str]],
|
||||
outcome: Union[str, List[str]],
|
||||
identified_estimand: IdentifiedEstimand,
|
||||
graph: CausalGraph,
|
||||
method_name: Optional[str] = None,
|
||||
identifier_name: str,
|
||||
method: CausalEstimator,
|
||||
control_value: int = 0,
|
||||
treatment_value: int = 1,
|
||||
test_significance: Optional[bool] = None,
|
||||
evaluate_effect_strength: bool = False,
|
||||
confidence_intervals: bool = False,
|
||||
target_units: str = "ate",
|
||||
effect_modifiers: List[str] = None,
|
||||
effect_modifiers: List[str] = [],
|
||||
fit_estimator: bool = True,
|
||||
method_params: Optional[Dict] = None,
|
||||
):
|
||||
|
@ -736,42 +735,11 @@ def estimate_effect(
|
|||
"""
|
||||
treatment = parse_state(treatment)
|
||||
outcome = parse_state(outcome)
|
||||
causal_estimator_class = method.__class__
|
||||
|
||||
if effect_modifiers is None or len(effect_modifiers) == 0:
|
||||
effect_modifiers = graph.get_effect_modifiers(treatment, outcome)
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
|
||||
if fit_estimator:
|
||||
if method_name is None:
|
||||
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
|
||||
pass
|
||||
else:
|
||||
# TODO add dowhy as a prefix to all dowhy estimators
|
||||
num_components = len(method_name.split("."))
|
||||
str_arr = method_name.split(".", maxsplit=1)
|
||||
identifier_name = str_arr[0]
|
||||
estimator_name = str_arr[1]
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
# This is done as all dowhy estimators have two parts and external ones have two or more parts
|
||||
if num_components > 2:
|
||||
estimator_package = estimator_name.split(".")[0]
|
||||
if estimator_package == "dowhy": # For updated dowhy methods
|
||||
estimator_method = estimator_name.split(".", maxsplit=1)[
|
||||
1
|
||||
] # discard dowhy from the full package name
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
|
||||
else:
|
||||
third_party_estimator_package = estimator_package
|
||||
causal_estimator_class = causal_estimators.get_class_object(
|
||||
third_party_estimator_package, estimator_name
|
||||
)
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
# Define the third-party estimation method to be used
|
||||
method_params[third_party_estimator_package + "_methodname"] = estimator_name
|
||||
else: # For older dowhy methods
|
||||
logger.info(estimator_name)
|
||||
# Process the dowhy estimators
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
|
||||
if identified_estimand.no_directed_path:
|
||||
logger.warning("No directed path from {0} to {1}.".format(treatment, outcome))
|
||||
return CausalEstimate(
|
||||
|
@ -781,36 +749,11 @@ def estimate_effect(
|
|||
elif identified_estimand.estimands[identifier_name] is None:
|
||||
logger.error("No valid identified estimand available.")
|
||||
return CausalEstimate(None, None, None, control_value=control_value, treatment_value=treatment_value)
|
||||
else:
|
||||
|
||||
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
|
||||
extra_args = method_params.get("init_params", {})
|
||||
else:
|
||||
extra_args = {}
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
causal_estimator = causal_estimator_class(
|
||||
data,
|
||||
identified_estimand,
|
||||
treatment,
|
||||
outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
# Estimator had been computed in a previous call
|
||||
assert causal_estimator is not None
|
||||
causal_estimator_class = causal_estimator.__class__
|
||||
causal_estimator.update_input(treatment_value, control_value, target_units)
|
||||
method.update_input(treatment_value, control_value, target_units)
|
||||
|
||||
estimate = causal_estimator.estimate_effect()
|
||||
estimate = method.estimate_effect()
|
||||
# Store parameters inside estimate object for refutation methods
|
||||
# TODO: This add_params needs to move to the estimator class
|
||||
# inside estimate_effect and estimate_conditional_effect
|
||||
|
|
|
@ -259,13 +259,71 @@ class CausalModel:
|
|||
and other method-dependent information
|
||||
|
||||
"""
|
||||
|
||||
if effect_modifiers is None or len(effect_modifiers) == 0:
|
||||
effect_modifiers = self._graph.get_effect_modifiers(self._treatment, self._outcome)
|
||||
|
||||
if method_name is None:
|
||||
# TODO add propensity score as default backdoor method, iv as default iv method, add an informational message to show which method has been selected.
|
||||
pass
|
||||
else:
|
||||
# TODO add dowhy as a prefix to all dowhy estimators
|
||||
num_components = len(method_name.split("."))
|
||||
str_arr = method_name.split(".", maxsplit=1)
|
||||
identifier_name = str_arr[0]
|
||||
estimator_name = str_arr[1]
|
||||
# This is done as all dowhy estimators have two parts and external ones have two or more parts
|
||||
if num_components > 2:
|
||||
estimator_package = estimator_name.split(".")[0]
|
||||
if estimator_package == "dowhy": # For updated dowhy methods
|
||||
estimator_method = estimator_name.split(".", maxsplit=1)[
|
||||
1
|
||||
] # discard dowhy from the full package name
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_method + "_estimator")
|
||||
else:
|
||||
third_party_estimator_package = estimator_package
|
||||
causal_estimator_class = causal_estimators.get_class_object(
|
||||
third_party_estimator_package, estimator_name
|
||||
)
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
# Define the third-party estimation method to be used
|
||||
method_params[third_party_estimator_package + "_methodname"] = estimator_name
|
||||
else: # For older dowhy methods
|
||||
self.logger.info(estimator_name)
|
||||
# Process the dowhy estimators
|
||||
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
|
||||
|
||||
if method_params is not None and (num_components <= 2 or estimator_package == "dowhy"):
|
||||
extra_args = method_params.get("init_params", {})
|
||||
else:
|
||||
extra_args = {}
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
causal_estimator = causal_estimator_class(
|
||||
self._data,
|
||||
identified_estimand,
|
||||
self._treatment,
|
||||
self._outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return estimate_effect(
|
||||
self._data,
|
||||
self._treatment,
|
||||
self._outcome,
|
||||
identified_estimand,
|
||||
self._graph,
|
||||
method_name,
|
||||
identifier_name,
|
||||
causal_estimator,
|
||||
control_value,
|
||||
treatment_value,
|
||||
test_significance,
|
||||
|
|
Загрузка…
Ссылка в новой задаче