Merge pull request #717 from py-why/fix/avoid_retrain_estimator
Fix avoid retrain estimator on causal_model API
This commit is contained in:
Коммит
a18c97eb2e
|
@ -73,6 +73,7 @@ class CausalModel:
|
|||
self._proceed_when_unidentifiable = proceed_when_unidentifiable
|
||||
self._missing_nodes_as_confounders = missing_nodes_as_confounders
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._estimator_cache = {}
|
||||
|
||||
if graph is None:
|
||||
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
|
||||
|
@ -302,21 +303,26 @@ class CausalModel:
|
|||
method_params = {}
|
||||
|
||||
identified_estimand.set_identifier_method(identifier_name)
|
||||
causal_estimator = causal_estimator_class(
|
||||
self._data,
|
||||
identified_estimand,
|
||||
self._treatment,
|
||||
self._outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
if not fit_estimator and method_name in self._estimator_cache:
|
||||
causal_estimator = self._estimator_cache[method_name]
|
||||
else:
|
||||
causal_estimator = causal_estimator_class(
|
||||
self._data,
|
||||
identified_estimand,
|
||||
self._treatment,
|
||||
self._outcome, # names of treatment and outcome
|
||||
control_value=control_value,
|
||||
treatment_value=treatment_value,
|
||||
test_significance=test_significance,
|
||||
evaluate_effect_strength=evaluate_effect_strength,
|
||||
confidence_intervals=confidence_intervals,
|
||||
target_units=target_units,
|
||||
effect_modifiers=effect_modifiers,
|
||||
**method_params,
|
||||
**extra_args,
|
||||
)
|
||||
self._estimator_cache[method_name] = causal_estimator
|
||||
|
||||
return estimate_effect(
|
||||
self._treatment,
|
||||
|
|
Загрузка…
Ссылка в новой задаче