Changed bootstrap method to empirical and added treatment and control values to bootstrap generation (#278)
This commit is contained in:
Родитель
7238193052
Коммит
21fccf133a
|
@ -9,6 +9,7 @@ from sklearn.utils import resample
|
|||
import dowhy.interpreters as interpreters
|
||||
from dowhy.utils.api import parse_state
|
||||
|
||||
|
||||
class CausalEstimator:
|
||||
"""Base class for an estimator of causal effect.
|
||||
|
||||
|
@ -140,7 +141,8 @@ class CausalEstimator:
|
|||
new_estimator = estimator_class(
|
||||
new_data,
|
||||
identified_estimand,
|
||||
identified_estimand.treatment_variable, identified_estimand.outcome_variable, #names of treatment and outcome
|
||||
identified_estimand.treatment_variable, identified_estimand.outcome_variable,
|
||||
# names of treatment and outcome
|
||||
control_value=estimate.control_value,
|
||||
treatment_value=estimate.treatment_value,
|
||||
test_significance=False,
|
||||
|
@ -153,11 +155,11 @@ class CausalEstimator:
|
|||
|
||||
return new_estimator
|
||||
|
||||
|
||||
def _estimate_effect(self):
|
||||
'''This method is to be overriden by the child classes, so that they can run the estimation technique of their choice
|
||||
'''
|
||||
raise NotImplementedError(("Main estimation method is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(self.__class__))
|
||||
raise NotImplementedError(
|
||||
("Main estimation method is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(self.__class__))
|
||||
|
||||
def estimate_effect(self):
|
||||
"""Base estimation method that calls the estimate_effect method of its calling subclass.
|
||||
|
@ -173,8 +175,8 @@ class CausalEstimator:
|
|||
if self._significance_test:
|
||||
self.test_significance(est.value, method=self._significance_test)
|
||||
if self._confidence_intervals:
|
||||
self.estimate_confidence_intervals(method=self._confidence_intervals,
|
||||
confidence_level=self.confidence_level)
|
||||
self.estimate_confidence_intervals(est.value, confidence_level=self.confidence_level,
|
||||
method=self._confidence_intervals)
|
||||
if self._effect_strength_eval:
|
||||
effect_strength_dict = self.evaluate_effect_strength(est)
|
||||
est.add_effect_strength(effect_strength_dict)
|
||||
|
@ -193,7 +195,9 @@ class CausalEstimator:
|
|||
|
||||
The overridden function should take in a dataframe as input and return the estimate for that data.
|
||||
"""
|
||||
raise NotImplementedError(("Conditional treatment effects are " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(self.__class__))
|
||||
raise NotImplementedError(
|
||||
("Conditional treatment effects are " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(
|
||||
self.__class__))
|
||||
|
||||
def _estimate_conditional_effects(self, estimate_effect_fn,
|
||||
effect_modifier_names=None,
|
||||
|
@ -219,7 +223,8 @@ class CausalEstimator:
|
|||
# Making sure that effect_modifier_names is a list
|
||||
effect_modifier_names = parse_state(effect_modifier_names)
|
||||
if not all(em in self._effect_modifier_names for em in effect_modifier_names):
|
||||
self.logger.warn("At least one of the provided effect modifiers was not included while fitting the estimator. You may get incorrect results. To resolve, fit the estimator again by providing the updated effect modifiers in estimate_effect().")
|
||||
self.logger.warn(
|
||||
"At least one of the provided effect modifiers was not included while fitting the estimator. You may get incorrect results. To resolve, fit the estimator again by providing the updated effect modifiers in estimate_effect().")
|
||||
# Making a copy since we are going to be changing effect modifier names
|
||||
effect_modifier_names = effect_modifier_names.copy()
|
||||
prefix = CausalEstimator.TEMP_CAT_COLUMN_PREFIX
|
||||
|
@ -240,9 +245,9 @@ class CausalEstimator:
|
|||
self._data.pop(em)
|
||||
return conditional_estimates
|
||||
|
||||
|
||||
def _do(self, x, data_df=None):
|
||||
raise NotImplementedError(("Do-operator is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(self.__class__))
|
||||
raise NotImplementedError(
|
||||
("Do-operator is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG).format(self.__class__))
|
||||
|
||||
def do(self, x, data_df=None):
|
||||
"""Method that implements the do-operator.
|
||||
|
@ -288,7 +293,10 @@ class CausalEstimator:
|
|||
new_estimator = type(self)(
|
||||
new_data,
|
||||
self._target_estimand,
|
||||
self._target_estimand.treatment_variable, self._target_estimand.outcome_variable, #names of treatment and outcome
|
||||
self._target_estimand.treatment_variable, self._target_estimand.outcome_variable,
|
||||
# names of treatment and outcome
|
||||
treatment_value=self._treatment_value,
|
||||
control_value=self._control_value,
|
||||
test_significance=False,
|
||||
evaluate_effect_strength=False,
|
||||
confidence_intervals=False,
|
||||
|
@ -304,13 +312,13 @@ class CausalEstimator:
|
|||
'sample_size_fraction': sample_size_fraction})
|
||||
return estimates
|
||||
|
||||
|
||||
def _estimate_confidence_intervals_with_bootstrap(self,
|
||||
def _estimate_confidence_intervals_with_bootstrap(self, estimate_value,
|
||||
confidence_level=None,
|
||||
num_simulations=None, sample_size_fraction=None):
|
||||
'''
|
||||
Method to compute confidence interval using bootstrapped sampling.
|
||||
|
||||
:param estimate_value: obtained estimate's value
|
||||
:param confidence_level: The level for which to compute CI (e.g., 95% confidence level translates to confidence_level=0.95)
|
||||
:param num_simulations: The number of simulations to be performed to get the bootstrap confidence intervals.
|
||||
:param sample_size_fraction: The fraction of the dataset to be resampled.
|
||||
|
@ -335,28 +343,32 @@ class CausalEstimator:
|
|||
self._bootstrap_estimates = self._generate_bootstrap_estimates(
|
||||
num_simulations, sample_size_fraction)
|
||||
# Now use the data obtained from the simulations to get the value of the confidence estimates
|
||||
# Sort the simulations
|
||||
bootstrap_estimates = np.sort(self._bootstrap_estimates.estimates)
|
||||
# Now we take the (1- p)th and the (p)th values, where p is the chosen confidence level
|
||||
lower_bound_index = int( ( 1 - confidence_level ) * len(bootstrap_estimates) )
|
||||
upper_bound_index = int( confidence_level * len(bootstrap_estimates) )
|
||||
bootstrap_estimates = self._bootstrap_estimates.estimates
|
||||
# Get the variations of each bootstrap estimate and sort
|
||||
bootstrap_variations = [bootstrap_estimate - estimate_value for bootstrap_estimate in bootstrap_estimates]
|
||||
sorted_bootstrap_variations = np.sort(bootstrap_variations)
|
||||
|
||||
# get the values
|
||||
lower_bound = bootstrap_estimates[lower_bound_index]
|
||||
upper_bound = bootstrap_estimates[upper_bound_index]
|
||||
# Now we take the (1- p)th and the (p)th variations, where p is the chosen confidence level
|
||||
upper_bound_index = int((1 - confidence_level) * len(sorted_bootstrap_variations))
|
||||
lower_bound_index = int(confidence_level * len(sorted_bootstrap_variations))
|
||||
|
||||
return (lower_bound, upper_bound)
|
||||
# Get the lower and upper bounds by subtracting the variations from the estimate
|
||||
lower_bound = estimate_value - sorted_bootstrap_variations[lower_bound_index]
|
||||
upper_bound = estimate_value - sorted_bootstrap_variations[upper_bound_index]
|
||||
return lower_bound, upper_bound
|
||||
|
||||
def _estimate_confidence_intervals(self, confidence_level, method=None,
|
||||
def _estimate_confidence_intervals(self, confidence_level=None, method=None,
|
||||
**kwargs):
|
||||
'''
|
||||
This method is to be overriden by the child classes, so that they
|
||||
can run a confidence interval estimation method suited to the specific
|
||||
causal estimator.
|
||||
'''
|
||||
raise NotImplementedError(("This method for estimating confidence intervals is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to estimate confidence intervals.").format(self.__class__))
|
||||
raise NotImplementedError((
|
||||
"This method for estimating confidence intervals is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to estimate confidence intervals.").format(
|
||||
self.__class__))
|
||||
|
||||
def estimate_confidence_intervals(self, confidence_level=None, method=None,
|
||||
def estimate_confidence_intervals(self, estimate_value, confidence_level=None, method=None,
|
||||
**kwargs):
|
||||
''' Find the confidence intervals corresponding to any estimator
|
||||
By default, this is done with the help of bootstrapped confidence intervals
|
||||
|
@ -364,6 +376,7 @@ class CausalEstimator:
|
|||
|
||||
If the method provided is not bootstrap, this function calls the implementation of the specific estimator.
|
||||
|
||||
:param estimate_value: obtained estimate's value
|
||||
:param method: Method for estimating confidence intervals.
|
||||
:param confidence_level: The confidence level of the confidence intervals of the estimate.
|
||||
:param kwargs: Other optional args to be passed to the CI method.
|
||||
|
@ -382,11 +395,11 @@ class CausalEstimator:
|
|||
confidence_intervals = self._estimate_confidence_intervals(
|
||||
confidence_level, method=method, **kwargs)
|
||||
except NotImplementedError:
|
||||
confidence_intervals = self._estimate_confidence_intervals_with_bootstrap(
|
||||
confidence_intervals = self._estimate_confidence_intervals_with_bootstrap(estimate_value,
|
||||
confidence_level, **kwargs)
|
||||
else:
|
||||
if method == "bootstrap":
|
||||
confidence_intervals = self._estimate_confidence_intervals_with_bootstrap(
|
||||
confidence_intervals = self._estimate_confidence_intervals_with_bootstrap(estimate_value,
|
||||
confidence_level, **kwargs)
|
||||
else:
|
||||
confidence_intervals = self._estimate_confidence_intervals(
|
||||
|
@ -426,7 +439,9 @@ class CausalEstimator:
|
|||
can run a standard error estimation method suited to the specific
|
||||
causal estimator.
|
||||
'''
|
||||
raise NotImplementedError(("This method for estimating standard errors is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to estimate standard errors.").format(self.__class__))
|
||||
raise NotImplementedError((
|
||||
"This method for estimating standard errors is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to estimate standard errors.").format(
|
||||
self.__class__))
|
||||
|
||||
def estimate_std_error(self, method=None, **kwargs):
|
||||
""" Compute standard error of an obtained causal estimate.
|
||||
|
@ -463,7 +478,8 @@ class CausalEstimator:
|
|||
# Use existing params, if new user defined params are not present
|
||||
if num_null_simulations is None:
|
||||
num_null_simulations = self.num_null_simulations
|
||||
do_retest = self._bootstrap_null_estimates is None or CausalEstimator.is_bootstrap_parameter_changed(self._bootstrap_null_estimates.params, locals())
|
||||
do_retest = self._bootstrap_null_estimates is None or CausalEstimator.is_bootstrap_parameter_changed(
|
||||
self._bootstrap_null_estimates.params, locals())
|
||||
if do_retest:
|
||||
null_estimates = np.zeros(num_null_simulations)
|
||||
for i in range(num_null_simulations):
|
||||
|
@ -517,7 +533,9 @@ class CausalEstimator:
|
|||
can run a significance test suited to the specific
|
||||
causal estimator.
|
||||
'''
|
||||
raise NotImplementedError(("This method for testing statistical significance is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to test statistical significance.").format(self.__class__))
|
||||
raise NotImplementedError((
|
||||
"This method for testing statistical significance is " + CausalEstimator.DEFAULT_NOTIMPLEMENTEDERROR_MSG + " Meanwhile, you can try the bootstrap method (method='bootstrap') to test statistical significance.").format(
|
||||
self.__class__))
|
||||
|
||||
def test_significance(self, estimate_value, method=None, **kwargs):
|
||||
"""Test statistical significance of obtained estimate.
|
||||
|
@ -621,6 +639,7 @@ class CausalEstimator:
|
|||
s += "{0}".format(pval)
|
||||
return s
|
||||
|
||||
|
||||
class CausalEstimate:
|
||||
"""Class for the estimate object that every causal estimator returns
|
||||
|
||||
|
@ -667,6 +686,7 @@ class CausalEstimate:
|
|||
:returns: The obtained confidence interval.
|
||||
"""
|
||||
confidence_intervals = self.estimator.estimate_confidence_intervals(
|
||||
estimate_value=self.value,
|
||||
confidence_level=confidence_level,
|
||||
method=method,
|
||||
**kwargs)
|
||||
|
@ -751,7 +771,8 @@ class CausalEstimate:
|
|||
if self.estimator._significance_test:
|
||||
s += "p-value: {0}\n".format(self.estimator.signif_results_tostr(self.test_stat_significance()))
|
||||
if self.estimator._confidence_intervals:
|
||||
s += "{0}% confidence interval: {1}\n".format(100 * self.estimator.confidence_level, self.get_confidence_intervals())
|
||||
s += "{0}% confidence interval: {1}\n".format(100 * self.estimator.confidence_level,
|
||||
self.get_confidence_intervals())
|
||||
if self.conditional_estimates is not None:
|
||||
s += "### Conditional Estimates\n"
|
||||
s += str(self.conditional_estimates)
|
||||
|
@ -789,4 +810,3 @@ class RealizedEstimand(object):
|
|||
s += "Estimand assumption {0}, {1}: {2}\n".format(j, ass_name, ass_str)
|
||||
j += 1
|
||||
return s
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче