added support for metalearners from econml

This commit is contained in:
Amit Sharma 2020-02-08 16:59:15 +05:30
Родитель 8d836aa5a1
Коммит 5a41191a0e
2 изменённых файлов: 1136 добавлений и 296 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -20,9 +20,6 @@ class EconmlCateEstimator(CausalEstimator):
self._observed_common_causes = pd.get_dummies(self._observed_common_causes, drop_first=True)
else:
self._observed_common_causes= None
error_msg ="No common causes/confounders present."
self.logger.error(error_msg)
raise Exception(error_msg)
# Instrumental variables names, if present
self._instrumental_variable_names = self._target_estimand.instrumental_variables
@ -64,8 +61,13 @@ class EconmlCateEstimator(CausalEstimator):
Z = np.reshape(np.array(self._instrumental_variables), (n_samples, len(self._instrumental_variable_names)))
# Calling the econml estimator's fit method
(module_name, _, class_name) = self._econml_methodname.rpartition(".")
if self.identifier_method == "backdoor":
self.estimator.fit(Y, T, X, W, **self.method_params["fit_params"])
if module_name == "econml.metalearners":
# Meta learners only need X (e.g., data is from a randomized experiment)
self.estimator.fit(Y, T, X, **self.method_params["fit_params"])
else:
self.estimator.fit(Y, T, X, W, **self.method_params["fit_params"])
else:
self.estimator.fit(Y, T, X, Z, **self.method_params["fit_params"])
@ -87,6 +89,7 @@ class EconmlCateEstimator(CausalEstimator):
self._treatment_value = [self._treatment_value]
T0_test = np.repeat([self._control_value], n_target_units, axis=0)
T1_test = np.repeat([self._treatment_value], n_target_units, axis=0)
est = self.estimator.effect(X_test, T0 = T0_test, T1 = T1_test)
ate = np.mean(est)