Updated EconML names and argument parsing. (#229)

* Updated EconML names and argument parsing.

* Changed EconML namespaces to match v0.8.1. Further name
  changes will be implemented in v0.9
* Updated argument parsing. As of v0.9, some `fit` arguments will be
  passed in by keyword only
* Updated tests and example notebooks to be compatible with the latest
  EconML changes
* Fix Windows compatibility of example notebook

* Point econml dependency to specific branch to debug build failures.

* Account for EconML internal features

* EconML concatenates the common causes W and effect modifiers X
  internally so no need to do it explicitly in dowhy.
* EconML has better support for Pandas dataframes so I removed the
  casting to numpy array.
This commit is contained in:
Miruna Oprescu 2021-02-17 08:47:43 -05:00 коммит произвёл GitHub
Родитель 2681d5eb6c
Коммит 986134c9ad
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 3177 добавлений и 7637 удалений

3
.github/workflows/python-package.yml поставляемый
Просмотреть файл

@ -28,7 +28,8 @@ jobs:
python -m pip install --upgrade pip
pip install flake8 pytest twine
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install econml causalml nbformat jupyter
pip install causalml nbformat jupyter
pip install econml
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names

Просмотреть файл

@ -398,7 +398,7 @@ learning estimator.
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DMLCateEstimator",
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DML",
control_value = 0,
treatment_value = 1,
target_units = lambda df: df["X0"]>1,

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -14,17 +14,6 @@ class Econml(CausalEstimator):
self.logger.info("INFO: Using EconML Estimator")
self.identifier_method = self._target_estimand.identifier_method
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables().copy()
# Checking if effect modifiers are a subset of common causes
x_subsetof_w = True
unique_effect_modifier_names = []
for em_name in self._effect_modifier_names:
if em_name not in self._observed_common_causes_names:
x_subsetof_w = False
unique_effect_modifier_names.append(em_name)
if not x_subsetof_w:
self.logger.warn("Effect modifiers are not a subset of common causes. For efficiency in estimation, EconML will consider all effect modifiers as common causes too.")
self._observed_common_causes_names.extend(unique_effect_modifier_names)
# For metalearners only--issue a warning if w contains variables not in x
(module_name, _, class_name) = self._econml_methodname.rpartition(".")
if module_name.endswith("metalearners"):
@ -80,20 +69,21 @@ class Econml(CausalEstimator):
X = None # Effect modifiers
W = None # common causes/ confounders
Z = None # Instruments
Y = np.array(self._outcome)
T = np.array(self._treatment)
Y = self._outcome
T = self._treatment
if self._effect_modifiers is not None:
X = np.reshape(np.array(self._effect_modifiers), (n_samples, self._effect_modifiers.shape[1]))
X = self._effect_modifiers
if self._observed_common_causes_names:
W = np.reshape(np.array(self._observed_common_causes), (n_samples, self._observed_common_causes.shape[1]))
W = self._observed_common_causes
if self._instrumental_variable_names:
Z = np.array(self._instrumental_variables)
Z = self._instrumental_variables
named_data_args = {'Y': Y, 'T': T, 'X': X, 'W': W, 'Z': Z}
# Calling the econml estimator's fit method
estimator_named_args = inspect.getfullargspec(
inspect.unwrap(self.estimator.fit)
)[0]
estimator_argspec = inspect.getfullargspec(
inspect.unwrap(self.estimator.fit))
# As of v0.9, econml has some kewyord only arguments
estimator_named_args = estimator_argspec.args + estimator_argspec.kwonlyargs
estimator_data_args = {
arg: named_data_args[arg] for arg in named_data_args.keys() if arg in estimator_named_args
}

Просмотреть файл

@ -158,7 +158,7 @@ class CausalModel:
* Instrumental Variables: "iv.instrumental_variable"
* Regression Discontinuity: "iv.regression_discontinuity"
In addition, you can directly call any of the EconML estimation methods. The convention is "backdoor.econml.path-to-estimator-class". For example, for the double machine learning estimator ("DMLCateEstimator" class) that is located inside "dml" module of EconML, you can use the method name, "backdoor.econml.dml.DMLCateEstimator". CausalML estimators can also be called. See `this demo notebook <https://microsoft.github.io/dowhy/example_notebooks/dowhy-conditional-treatment-effects.html>`_.
In addition, you can directly call any of the EconML estimation methods. The convention is "backdoor.econml.path-to-estimator-class". For example, for the double machine learning estimator ("DML" class) that is located inside "dml" module of EconML, you can use the method name, "backdoor.econml.dml.DML". CausalML estimators can also be called. See `this demo notebook <https://microsoft.github.io/dowhy/example_notebooks/dowhy-conditional-treatment-effects.html>`_.
:param identified_estimand: a probability expression

Просмотреть файл

@ -4,19 +4,20 @@ import itertools
from dowhy import CausalModel
from dowhy import datasets
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.preprocessing import PolynomialFeatures
econml = pytest.importorskip("econml")
class TestEconMLEstimator:
"""Smoke tests for the integration with EconML estimators
These tests only check that the ate estimation routine can be executed without errors.
We don't check the accuracy of the ate estimates as we don't want to take dependencies on
EconML estimators.
"""
"""
def test_backdoor_estimators(self):
# Setup data
@ -32,12 +33,13 @@ class TestEconMLEstimator:
outcome=data["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data["gml_graph"]
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
# Test LinearDMLCateEstimator
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True)
# Test LinearDML
dml_estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.econml.dml.LinearDMLCateEstimator",
method_name="backdoor.econml.dml.LinearDML",
control_value=0,
treatment_value=1,
target_units=lambda df: df["X0"] > 1, # condition used for CATE
@ -46,8 +48,8 @@ class TestEconMLEstimator:
'model_t': GradientBoostingRegressor(),
'featurizer': PolynomialFeatures(degree=1, include_bias=True)},
"fit_params": {}
}
)
}
)
# Test ContinuousTreatmentOrthoForest
orthoforest_estimate = model.estimate_effect(
identified_estimand,
@ -56,8 +58,8 @@ class TestEconMLEstimator:
method_params={
"init_params": {'n_trees': 10},
"fit_params": {}
}
)
}
)
# Test LinearDRLearner
data_binary = datasets.linear_dataset(
10, num_common_causes=4, num_samples=10000,
@ -65,11 +67,12 @@ class TestEconMLEstimator:
treatment_is_binary=True, outcome_is_binary=True)
model_binary = CausalModel(
data=data_binary["df"],
treatment=data_binary["treatment_name"],
treatment=data_binary["treatment_name"],
outcome=data_binary["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data_binary["gml_graph"])
identified_estimand_binary = model_binary.identify_effect(proceed_when_unidentifiable=True)
identified_estimand_binary = model_binary.identify_effect(
proceed_when_unidentifiable=True)
drlearner_estimate = model_binary.estimate_effect(
identified_estimand_binary,
method_name="backdoor.econml.drlearner.LinearDRLearner",
@ -78,7 +81,7 @@ class TestEconMLEstimator:
method_params={
"init_params": {'model_propensity': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')},
"fit_params": {}
})
})
def test_iv_estimators(self):
keras = pytest.importorskip("keras")
@ -95,19 +98,23 @@ class TestEconMLEstimator:
outcome=data["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data["gml_graph"]
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True)
# Test DeepIV
dims_zx = len(model._instruments)+len(model._effect_modifiers)
dims_tx = len(model._treatment)+len(model._effect_modifiers)
treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_zx,)), # sum of dims of Z and X
treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_zx,)), # sum of dims of Z and X
keras.layers.Dropout(0.17),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(
64, activation='relu'),
keras.layers.Dropout(0.17),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(
32, activation='relu'),
keras.layers.Dropout(0.17)])
response_model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(dims_tx,)), # sum of dims of T and X
keras.layers.Dense(128, activation='relu', input_shape=(
dims_tx,)), # sum of dims of T and X
keras.layers.Dropout(0.17),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dropout(0.17),
@ -119,14 +126,44 @@ class TestEconMLEstimator:
target_units=lambda df: df["X0"] > -1,
confidence_intervals=False,
method_params={
"init_params":{'n_components': 10, # Number of gaussians in the mixture density networks
'm': lambda z, x: treatment_model(keras.layers.concatenate([z, x])), # Treatment model,
"h": lambda t, x: response_model(keras.layers.concatenate([t, x])), # Response model
'n_samples': 1, # Number of samples used to estimate the response
'first_stage_options': {'epochs': 25},
'second_stage_options': {'epochs': 25}
},
"init_params": {'n_components': 10, # Number of gaussians in the mixture density networks
# Treatment model,
'm': lambda z, x: treatment_model(keras.layers.concatenate([z, x])),
# Response model
"h": lambda t, x: response_model(keras.layers.concatenate([t, x])),
'n_samples': 1, # Number of samples used to estimate the response
'first_stage_options': {'epochs': 25},
'second_stage_options': {'epochs': 25}
},
"fit_params": {}
}
}
)
# TODO: Test IntentToTreatDRIV when EconML v0.7 comes out
# Test IntentToTreatDRIV
data = datasets.linear_dataset(
10, num_common_causes=4, num_samples=10000,
num_instruments=1, num_effect_modifiers=2,
num_treatments=1,
treatment_is_binary=True,
num_discrete_instruments=1)
df = data['df']
model = CausalModel(
data=data["df"],
treatment=data["treatment_name"],
outcome=data["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data["gml_graph"]
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True)
driv_estimate = model.estimate_effect(
identified_estimand,
method_name="iv.econml.ortho_iv.LinearIntentToTreatDRIV",
target_units=lambda df: df["X0"] > 1,
confidence_intervals=False,
method_params={
"init_params": {'model_T_XZ': GradientBoostingClassifier(),
'model_Y_X': GradientBoostingRegressor(),
'flexible_model_effect': GradientBoostingRegressor(),
'featurizer': PolynomialFeatures(degree=1, include_bias=False)
},
"fit_params": {}})