Merge pull request #717 from py-why/fix/avoid_retrain_estimator

Fix avoid retrain estimator on causal_model API
This commit is contained in:
Andres Morales 2022-10-27 15:13:16 -06:00 коммит произвёл GitHub
Родитель 6dc354c075 80e32cd4d1
Коммит a18c97eb2e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 21 добавлений и 15 удалений

Просмотреть файл

@ -73,6 +73,7 @@ class CausalModel:
self._proceed_when_unidentifiable = proceed_when_unidentifiable
self._missing_nodes_as_confounders = missing_nodes_as_confounders
self.logger = logging.getLogger(__name__)
self._estimator_cache = {}
if graph is None:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
@ -302,21 +303,26 @@ class CausalModel:
method_params = {}
identified_estimand.set_identifier_method(identifier_name)
causal_estimator = causal_estimator_class(
self._data,
identified_estimand,
self._treatment,
self._outcome, # names of treatment and outcome
control_value=control_value,
treatment_value=treatment_value,
test_significance=test_significance,
evaluate_effect_strength=evaluate_effect_strength,
confidence_intervals=confidence_intervals,
target_units=target_units,
effect_modifiers=effect_modifiers,
**method_params,
**extra_args,
)
if not fit_estimator and method_name in self._estimator_cache:
causal_estimator = self._estimator_cache[method_name]
else:
causal_estimator = causal_estimator_class(
self._data,
identified_estimand,
self._treatment,
self._outcome, # names of treatment and outcome
control_value=control_value,
treatment_value=treatment_value,
test_significance=test_significance,
evaluate_effect_strength=evaluate_effect_strength,
confidence_intervals=confidence_intervals,
target_units=target_units,
effect_modifiers=effect_modifiers,
**method_params,
**extra_args,
)
self._estimator_cache[method_name] = causal_estimator
return estimate_effect(
self._treatment,