added support for metalearners from econml
This commit is contained in:
Родитель
8d836aa5a1
Коммит
5a41191a0e
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -20,9 +20,6 @@ class EconmlCateEstimator(CausalEstimator):
|
|||
self._observed_common_causes = pd.get_dummies(self._observed_common_causes, drop_first=True)
|
||||
else:
|
||||
self._observed_common_causes= None
|
||||
error_msg ="No common causes/confounders present."
|
||||
self.logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# Instrumental variables names, if present
|
||||
self._instrumental_variable_names = self._target_estimand.instrumental_variables
|
||||
|
@ -64,8 +61,13 @@ class EconmlCateEstimator(CausalEstimator):
|
|||
Z = np.reshape(np.array(self._instrumental_variables), (n_samples, len(self._instrumental_variable_names)))
|
||||
|
||||
# Calling the econml estimator's fit method
|
||||
(module_name, _, class_name) = self._econml_methodname.rpartition(".")
|
||||
if self.identifier_method == "backdoor":
|
||||
self.estimator.fit(Y, T, X, W, **self.method_params["fit_params"])
|
||||
if module_name == "econml.metalearners":
|
||||
# Meta learners only need X (e.g., data is from a randomized experiment)
|
||||
self.estimator.fit(Y, T, X, **self.method_params["fit_params"])
|
||||
else:
|
||||
self.estimator.fit(Y, T, X, W, **self.method_params["fit_params"])
|
||||
else:
|
||||
self.estimator.fit(Y, T, X, Z, **self.method_params["fit_params"])
|
||||
|
||||
|
@ -87,6 +89,7 @@ class EconmlCateEstimator(CausalEstimator):
|
|||
self._treatment_value = [self._treatment_value]
|
||||
T0_test = np.repeat([self._control_value], n_target_units, axis=0)
|
||||
T1_test = np.repeat([self._treatment_value], n_target_units, axis=0)
|
||||
|
||||
est = self.estimator.effect(X_test, T0 = T0_test, T1 = T1_test)
|
||||
ate = np.mean(est)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче