Add new statistical method to merge p-values
This uses an improved version of the "twice the average" rule following recent results from M. Gasparini, R. Wang, and A. Ramdas, *Combining exchangeable p-values*, arXiv 2404.03484, 2024. This new method is now used by default when merging p-values. Accordingly, the quantile based method was renamed to be more consistent with the naming pattern. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
c6c11933b6
Коммит
f0ca30925e
|
@ -8,7 +8,7 @@ from sklearn.preprocessing import scale
|
||||||
|
|
||||||
import dowhy.gcm.config as config
|
import dowhy.gcm.config as config
|
||||||
from dowhy.gcm.independence_test.kernel_operation import approximate_rbf_kernel_features
|
from dowhy.gcm.independence_test.kernel_operation import approximate_rbf_kernel_features
|
||||||
from dowhy.gcm.stats import quantile_based_fwer
|
from dowhy.gcm.stats import merge_p_values_average
|
||||||
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
|
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def kernel_based(
|
||||||
bootstrap_num_runs: int = 10,
|
bootstrap_num_runs: int = 10,
|
||||||
max_num_samples_run: int = 2000,
|
max_num_samples_run: int = 2000,
|
||||||
bootstrap_n_jobs: Optional[int] = None,
|
bootstrap_n_jobs: Optional[int] = None,
|
||||||
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer,
|
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Prepares the data and uses kernel (conditional) independence test. The independence test estimates a p-value
|
"""Prepares the data and uses kernel (conditional) independence test. The independence test estimates a p-value
|
||||||
|
@ -124,7 +124,7 @@ def approx_kernel_based(
|
||||||
bootstrap_num_runs: int = 10,
|
bootstrap_num_runs: int = 10,
|
||||||
bootstrap_num_samples: int = 1000,
|
bootstrap_num_samples: int = 1000,
|
||||||
bootstrap_n_jobs: Optional[int] = None,
|
bootstrap_n_jobs: Optional[int] = None,
|
||||||
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer,
|
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Implementation of the Randomized Conditional Independence Test. The independence test estimates a p-value
|
"""Implementation of the Randomized Conditional Independence Test. The independence test estimates a p-value
|
||||||
for the null hypothesis that X and Y are independent (given Z). Depending whether Z is given, a conditional or
|
for the null hypothesis that X and Y are independent (given Z). Depending whether Z is given, a conditional or
|
||||||
|
|
|
@ -11,7 +11,7 @@ from sklearn.model_selection import KFold
|
||||||
from sklearn.preprocessing import scale
|
from sklearn.preprocessing import scale
|
||||||
|
|
||||||
import dowhy.gcm.config as config
|
import dowhy.gcm.config as config
|
||||||
from dowhy.gcm.stats import estimate_ftest_pvalue, quantile_based_fwer
|
from dowhy.gcm.stats import estimate_ftest_pvalue, merge_p_values_average
|
||||||
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
|
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def regression_based(
|
||||||
Z: Optional[np.ndarray] = None,
|
Z: Optional[np.ndarray] = None,
|
||||||
max_num_components_all_inputs: int = 40,
|
max_num_components_all_inputs: int = 40,
|
||||||
k_folds: int = 3,
|
k_folds: int = 3,
|
||||||
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer,
|
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
|
||||||
max_samples_per_fold: int = -1,
|
max_samples_per_fold: int = -1,
|
||||||
n_jobs: Optional[int] = None,
|
n_jobs: Optional[int] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
|
|
@ -43,7 +43,7 @@ from dowhy.gcm.ml.classification import (
|
||||||
create_polynom_logistic_regression_classifier,
|
create_polynom_logistic_regression_classifier,
|
||||||
)
|
)
|
||||||
from dowhy.gcm.ml.regression import create_ada_boost_regressor, create_extra_trees_regressor, create_polynom_regressor
|
from dowhy.gcm.ml.regression import create_ada_boost_regressor, create_extra_trees_regressor, create_polynom_regressor
|
||||||
from dowhy.gcm.stats import quantile_based_fwer
|
from dowhy.gcm.stats import merge_p_values_average
|
||||||
from dowhy.gcm.util.general import is_categorical, set_random_seed, shape_into_2d
|
from dowhy.gcm.util.general import is_categorical, set_random_seed, shape_into_2d
|
||||||
from dowhy.graph import get_ordered_predecessors, is_root_node
|
from dowhy.graph import get_ordered_predecessors, is_root_node
|
||||||
|
|
||||||
|
@ -599,7 +599,7 @@ def _evaluate_invertibility_assumptions(
|
||||||
parent_samples[random_indices],
|
parent_samples[random_indices],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
all_pnl_p_values[node] = quantile_based_fwer(tmp_p_values)
|
all_pnl_p_values[node] = merge_p_values_average(tmp_p_values)
|
||||||
|
|
||||||
if len(all_pnl_p_values) == 0:
|
if len(all_pnl_p_values) == 0:
|
||||||
return all_pnl_p_values
|
return all_pnl_p_values
|
||||||
|
|
|
@ -9,14 +9,51 @@ from dowhy.gcm.constant import EPS
|
||||||
from dowhy.gcm.util.general import shape_into_2d
|
from dowhy.gcm.util.general import shape_into_2d
|
||||||
|
|
||||||
|
|
||||||
def quantile_based_fwer(
|
def merge_p_values_average(p_values: Union[np.ndarray, List[float]], randomization: bool = False) -> float:
|
||||||
|
"""A statistically sound method to merge multiple potentially dependent p-values into one. This is a statistically
|
||||||
|
improved (i.e., more powerful) version of the "twice the average" rule, following Theorem 5.3
|
||||||
|
(second equation, F_UA) in
|
||||||
|
|
||||||
|
M. Gasparini, R. Wang, and A. Ramdas, *Combining exchangeable p-values*, arXiv 2404.03484, 2024
|
||||||
|
|
||||||
|
Note, if randomization is False, we have u = 1 here. Generally, randomization requires fewer assumptions but leads
|
||||||
|
to non-deterministic behavior.
|
||||||
|
|
||||||
|
:param p_values: A list or array of p-values.
|
||||||
|
:param randomization: If True, u is taken uniformly randomly from [0, 1] (non-deterministic). If False, u is set
|
||||||
|
to 1 (deterministic). Randomization is generally more powerful but provides non-deterministic results.
|
||||||
|
:return: A single p-value based on the given p-values.
|
||||||
|
"""
|
||||||
|
if len(p_values) == 0:
|
||||||
|
raise ValueError("Given list of p-values is empty!")
|
||||||
|
|
||||||
|
if np.all(np.isnan(p_values)):
|
||||||
|
return float(np.nan)
|
||||||
|
|
||||||
|
if randomization:
|
||||||
|
u = float(np.random.uniform(0, 1))
|
||||||
|
else:
|
||||||
|
u = 1
|
||||||
|
|
||||||
|
p_values = np.array(p_values)
|
||||||
|
p_values = p_values[~np.isnan(p_values)]
|
||||||
|
p_values.sort()
|
||||||
|
|
||||||
|
K = len(p_values)
|
||||||
|
|
||||||
|
return min(
|
||||||
|
1.0, float(np.min([2 * np.mean(p_values[:m]) / (2 - (K * u / m)) for m in range(1, K + 1) if (K * u / m) < 2]))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_p_values_quantile(
|
||||||
p_values: Union[np.ndarray, List[float]], p_values_scaling: Optional[np.ndarray] = None, quantile: float = 0.5
|
p_values: Union[np.ndarray, List[float]], p_values_scaling: Optional[np.ndarray] = None, quantile: float = 0.5
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Applies a quantile based family wise error rate (FWER) control to the given p-values. This is based on the
|
"""Applies a quantile based approach to merge multiple potentially dependent p-values to one. This is based on the
|
||||||
approach described in:
|
approach described in:
|
||||||
|
|
||||||
Meinshausen, N., Meier, L. and Buehlmann, P. (2009).
|
Meinshausen, N., Meier, L. and Buehlmann, P., *p-values for high-dimensional regression*,
|
||||||
p-values for high-dimensional regression. J. Amer. Statist. Assoc.104 1671–1681
|
J. Amer. Statist. Assoc.104 1671–1681, 2009
|
||||||
|
|
||||||
:param p_values: A list or array of p-values.
|
:param p_values: A list or array of p-values.
|
||||||
:param p_values_scaling: An optional list of scaling factors for each p-value.
|
:param p_values_scaling: An optional list of scaling factors for each p-value.
|
||||||
|
|
|
@ -10,7 +10,7 @@ from dowhy.gcm.ml import (
|
||||||
create_linear_regressor,
|
create_linear_regressor,
|
||||||
create_logistic_regression_classifier,
|
create_logistic_regression_classifier,
|
||||||
)
|
)
|
||||||
from dowhy.gcm.stats import estimate_ftest_pvalue, marginal_expectation, quantile_based_fwer
|
from dowhy.gcm.stats import estimate_ftest_pvalue, marginal_expectation, merge_p_values_average, merge_p_values_quantile
|
||||||
from dowhy.gcm.util.general import geometric_median
|
from dowhy.gcm.util.general import geometric_median
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,47 +27,65 @@ def test_when_estimate_geometric_median_then_returns_correct_median_vector():
|
||||||
assert gm[1] == approx(-5, abs=0.5)
|
assert gm[1] == approx(-5, abs=0.5)
|
||||||
|
|
||||||
|
|
||||||
def test_when_apply_quantile_based_fwer_control_then_returns_single_adjusted_pvalue():
|
def test_when_merge_p_values_quantile_then_returns_single_adjusted_pvalue():
|
||||||
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
|
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
|
||||||
assert quantile_based_fwer(p_values, quantile=0.5) == 0.055 / 0.5
|
assert merge_p_values_quantile(p_values, quantile=0.5) == 0.055 / 0.5
|
||||||
assert quantile_based_fwer(p_values, quantile=0.25) == 0.0325 / 0.25
|
assert merge_p_values_quantile(p_values, quantile=0.25) == 0.0325 / 0.25
|
||||||
assert quantile_based_fwer(p_values, quantile=0.75) == 0.0775 / 0.75
|
assert merge_p_values_quantile(p_values, quantile=0.75) == 0.0775 / 0.75
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
quantile_based_fwer(np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 1]), quantile=0.5)
|
merge_p_values_quantile(np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 1]), quantile=0.5)
|
||||||
== 0.06 / 0.5
|
== 0.06 / 0.5
|
||||||
)
|
)
|
||||||
assert quantile_based_fwer(np.array([0.9, 0.95, 1]), quantile=0.5) == 1
|
assert merge_p_values_quantile(np.array([0.9, 0.95, 1]), quantile=0.5) == 1
|
||||||
assert quantile_based_fwer(np.array([0, 0, 0]), quantile=0.5) == 0
|
assert merge_p_values_quantile(np.array([0, 0, 0]), quantile=0.5) == 0
|
||||||
assert quantile_based_fwer(np.array([0.33]), quantile=0.5) == 0.33
|
assert merge_p_values_quantile(np.array([0.33]), quantile=0.5) == 0.33
|
||||||
|
|
||||||
|
|
||||||
def test_given_p_values_with_nans_when_using_quantile_based_fwer_then_ignores_the_nan_values():
|
def test_given_p_values_with_nans_when_merge_p_values_quantile_then_ignores_the_nan_values():
|
||||||
p_values = np.array([0.01, np.nan, 0.02, 0.03, 0.04, 0.05, np.nan, 0.06, 0.07, 0.08, 0.09, 0.1])
|
p_values = np.array([0.01, np.nan, 0.02, 0.03, 0.04, 0.05, np.nan, 0.06, 0.07, 0.08, 0.09, 0.1])
|
||||||
assert quantile_based_fwer(p_values, quantile=0.5) == 0.055 / 0.5
|
assert merge_p_values_quantile(p_values, quantile=0.5) == 0.055 / 0.5
|
||||||
|
|
||||||
|
|
||||||
def test_given_p_values_with_scaling_when_apply_quantile_based_fwer_control_then_returns_single_adjusted_pvalue():
|
def test_given_p_values_with_scaling_when_merge_p_values_quantile_then_returns_single_adjusted_pvalue():
|
||||||
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
|
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
|
||||||
p_values_scaling = np.array([2, 2, 1, 2, 1, 3, 1, 2, 4, 1])
|
p_values_scaling = np.array([2, 2, 1, 2, 1, 3, 1, 2, 4, 1])
|
||||||
|
|
||||||
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.5) == approx(0.15)
|
assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.5) == approx(0.15)
|
||||||
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.25) == approx(0.17)
|
assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.25) == approx(0.17)
|
||||||
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.75) == approx(0.193, abs=0.001)
|
assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.75) == approx(0.193, abs=0.001)
|
||||||
|
|
||||||
|
|
||||||
def test_given_invalid_inputs_when_apply_quantile_based_fwer_control_then_raises_error():
|
def test_given_invalid_inputs_when_merge_p_values_quantile_then_raises_error():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=0)
|
assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=0)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), np.array([1, 2]), quantile=0.1)
|
assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), np.array([1, 2]), quantile=0.1)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=1.1)
|
assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=1.1)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=-0.5)
|
assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=-0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_when_merge_p_values_average_without_randomization_then_returns_expected_results():
|
||||||
|
assert merge_p_values_average([0]) == 0
|
||||||
|
assert merge_p_values_average([1]) == 1
|
||||||
|
assert merge_p_values_average([0, 1]) == approx(1.0)
|
||||||
|
assert merge_p_values_average([0, 0, 1]) == 0
|
||||||
|
assert merge_p_values_average([0, 0.5, 0.5, np.nan, 1, np.nan]) == approx(1.0)
|
||||||
|
assert merge_p_values_average([0, 0, 1, 1, 1]) == approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@flaky(max_runs=3)
|
||||||
|
def test_when_merge_p_values_average_with_randomization_then_returns_expected_results():
|
||||||
|
assert merge_p_values_average([0], randomization=True) == 0
|
||||||
|
assert merge_p_values_average([1], randomization=True) == 1
|
||||||
|
assert merge_p_values_average([0, 1], randomization=True) == approx(0.0, abs=0.01)
|
||||||
|
assert merge_p_values_average([0, 0, 1], randomization=True) == approx(0.0, abs=0.01)
|
||||||
|
assert merge_p_values_average([0, np.nan, 0, np.nan, 1, 1], randomization=True) == approx(0.0, abs=0.01)
|
||||||
|
|
||||||
|
|
||||||
def test_when_evaluate_marginal_expectation_without_averaging_result_then_returned_results_have_correct_format():
|
def test_when_evaluate_marginal_expectation_without_averaging_result_then_returned_results_have_correct_format():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче