Add new statistical method to merge p-values

This uses an improved version of the "twice the average" rule following recent results from M. Gasparini, R. Wang, and A. Ramdas, *Combining exchangeable p-values*, arXiv 2404.03484, 2024.

This new method is now used by default when merging p-values. Accordingly, the quantile based method was renamed to be more consistent with the naming pattern.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Patrick Bloebaum 2024-04-12 10:25:32 -07:00 коммит произвёл Patrick Blöbaum
Родитель c6c11933b6
Коммит f0ca30925e
5 изменённых файлов: 86 добавлений и 31 удалений

Просмотреть файл

@ -8,7 +8,7 @@ from sklearn.preprocessing import scale
import dowhy.gcm.config as config import dowhy.gcm.config as config
from dowhy.gcm.independence_test.kernel_operation import approximate_rbf_kernel_features from dowhy.gcm.independence_test.kernel_operation import approximate_rbf_kernel_features
from dowhy.gcm.stats import quantile_based_fwer from dowhy.gcm.stats import merge_p_values_average
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
@ -20,7 +20,7 @@ def kernel_based(
bootstrap_num_runs: int = 10, bootstrap_num_runs: int = 10,
max_num_samples_run: int = 2000, max_num_samples_run: int = 2000,
bootstrap_n_jobs: Optional[int] = None, bootstrap_n_jobs: Optional[int] = None,
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer, p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
**kwargs, **kwargs,
) -> float: ) -> float:
"""Prepares the data and uses kernel (conditional) independence test. The independence test estimates a p-value """Prepares the data and uses kernel (conditional) independence test. The independence test estimates a p-value
@ -124,7 +124,7 @@ def approx_kernel_based(
bootstrap_num_runs: int = 10, bootstrap_num_runs: int = 10,
bootstrap_num_samples: int = 1000, bootstrap_num_samples: int = 1000,
bootstrap_n_jobs: Optional[int] = None, bootstrap_n_jobs: Optional[int] = None,
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer, p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
) -> float: ) -> float:
"""Implementation of the Randomized Conditional Independence Test. The independence test estimates a p-value """Implementation of the Randomized Conditional Independence Test. The independence test estimates a p-value
for the null hypothesis that X and Y are independent (given Z). Depending whether Z is given, a conditional or for the null hypothesis that X and Y are independent (given Z). Depending whether Z is given, a conditional or

Просмотреть файл

@ -11,7 +11,7 @@ from sklearn.model_selection import KFold
from sklearn.preprocessing import scale from sklearn.preprocessing import scale
import dowhy.gcm.config as config import dowhy.gcm.config as config
from dowhy.gcm.stats import estimate_ftest_pvalue, quantile_based_fwer from dowhy.gcm.stats import estimate_ftest_pvalue, merge_p_values_average
from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d from dowhy.gcm.util.general import auto_apply_encoders, auto_fit_encoders, set_random_seed, shape_into_2d
@ -21,7 +21,7 @@ def regression_based(
Z: Optional[np.ndarray] = None, Z: Optional[np.ndarray] = None,
max_num_components_all_inputs: int = 40, max_num_components_all_inputs: int = 40,
k_folds: int = 3, k_folds: int = 3,
p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = quantile_based_fwer, p_value_adjust_func: Callable[[Union[np.ndarray, List[float]]], float] = merge_p_values_average,
max_samples_per_fold: int = -1, max_samples_per_fold: int = -1,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
) -> float: ) -> float:

Просмотреть файл

@ -43,7 +43,7 @@ from dowhy.gcm.ml.classification import (
create_polynom_logistic_regression_classifier, create_polynom_logistic_regression_classifier,
) )
from dowhy.gcm.ml.regression import create_ada_boost_regressor, create_extra_trees_regressor, create_polynom_regressor from dowhy.gcm.ml.regression import create_ada_boost_regressor, create_extra_trees_regressor, create_polynom_regressor
from dowhy.gcm.stats import quantile_based_fwer from dowhy.gcm.stats import merge_p_values_average
from dowhy.gcm.util.general import is_categorical, set_random_seed, shape_into_2d from dowhy.gcm.util.general import is_categorical, set_random_seed, shape_into_2d
from dowhy.graph import get_ordered_predecessors, is_root_node from dowhy.graph import get_ordered_predecessors, is_root_node
@ -599,7 +599,7 @@ def _evaluate_invertibility_assumptions(
parent_samples[random_indices], parent_samples[random_indices],
) )
) )
all_pnl_p_values[node] = quantile_based_fwer(tmp_p_values) all_pnl_p_values[node] = merge_p_values_average(tmp_p_values)
if len(all_pnl_p_values) == 0: if len(all_pnl_p_values) == 0:
return all_pnl_p_values return all_pnl_p_values

Просмотреть файл

@ -9,14 +9,51 @@ from dowhy.gcm.constant import EPS
from dowhy.gcm.util.general import shape_into_2d from dowhy.gcm.util.general import shape_into_2d
def quantile_based_fwer( def merge_p_values_average(p_values: Union[np.ndarray, List[float]], randomization: bool = False) -> float:
"""A statistically sound method to merge multiple potentially dependent p-values into one. This is a statistically
improved (i.e., more powerful) version of the "twice the average" rule, following Theorem 5.3
(second equation, F_UA) in
M. Gasparini, R. Wang, and A. Ramdas, *Combining exchangeable p-values*, arXiv 2404.03484, 2024
Note, if randomization is False, we have u = 1 here. Generally, randomization requires fewer assumptions but leads
to non-deterministic behavior.
:param p_values: A list or array of p-values.
:param randomization: If True, u is taken uniformly randomly from [0, 1] (non-deterministic). If False, u is set
to 1 (deterministic). Randomization is generally more powerful but provides non-deterministic results.
:return: A single p-value based on the given p-values.
"""
if len(p_values) == 0:
raise ValueError("Given list of p-values is empty!")
if np.all(np.isnan(p_values)):
return float(np.nan)
if randomization:
u = float(np.random.uniform(0, 1))
else:
u = 1
p_values = np.array(p_values)
p_values = p_values[~np.isnan(p_values)]
p_values.sort()
K = len(p_values)
return min(
1.0, float(np.min([2 * np.mean(p_values[:m]) / (2 - (K * u / m)) for m in range(1, K + 1) if (K * u / m) < 2]))
)
def merge_p_values_quantile(
p_values: Union[np.ndarray, List[float]], p_values_scaling: Optional[np.ndarray] = None, quantile: float = 0.5 p_values: Union[np.ndarray, List[float]], p_values_scaling: Optional[np.ndarray] = None, quantile: float = 0.5
) -> float: ) -> float:
"""Applies a quantile based family wise error rate (FWER) control to the given p-values. This is based on the """Applies a quantile based approach to merge multiple potentially dependent p-values to one. This is based on the
approach described in: approach described in:
Meinshausen, N., Meier, L. and Buehlmann, P. (2009). Meinshausen, N., Meier, L. and Buehlmann, P., *p-values for high-dimensional regression*,
p-values for high-dimensional regression. J. Amer. Statist. Assoc.104 16711681 J. Amer. Statist. Assoc.104 16711681, 2009
:param p_values: A list or array of p-values. :param p_values: A list or array of p-values.
:param p_values_scaling: An optional list of scaling factors for each p-value. :param p_values_scaling: An optional list of scaling factors for each p-value.

Просмотреть файл

@ -10,7 +10,7 @@ from dowhy.gcm.ml import (
create_linear_regressor, create_linear_regressor,
create_logistic_regression_classifier, create_logistic_regression_classifier,
) )
from dowhy.gcm.stats import estimate_ftest_pvalue, marginal_expectation, quantile_based_fwer from dowhy.gcm.stats import estimate_ftest_pvalue, marginal_expectation, merge_p_values_average, merge_p_values_quantile
from dowhy.gcm.util.general import geometric_median from dowhy.gcm.util.general import geometric_median
@ -27,47 +27,65 @@ def test_when_estimate_geometric_median_then_returns_correct_median_vector():
assert gm[1] == approx(-5, abs=0.5) assert gm[1] == approx(-5, abs=0.5)
def test_when_apply_quantile_based_fwer_control_then_returns_single_adjusted_pvalue(): def test_when_merge_p_values_quantile_then_returns_single_adjusted_pvalue():
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]) p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
assert quantile_based_fwer(p_values, quantile=0.5) == 0.055 / 0.5 assert merge_p_values_quantile(p_values, quantile=0.5) == 0.055 / 0.5
assert quantile_based_fwer(p_values, quantile=0.25) == 0.0325 / 0.25 assert merge_p_values_quantile(p_values, quantile=0.25) == 0.0325 / 0.25
assert quantile_based_fwer(p_values, quantile=0.75) == 0.0775 / 0.75 assert merge_p_values_quantile(p_values, quantile=0.75) == 0.0775 / 0.75
assert ( assert (
quantile_based_fwer(np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 1]), quantile=0.5) merge_p_values_quantile(np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 1]), quantile=0.5)
== 0.06 / 0.5 == 0.06 / 0.5
) )
assert quantile_based_fwer(np.array([0.9, 0.95, 1]), quantile=0.5) == 1 assert merge_p_values_quantile(np.array([0.9, 0.95, 1]), quantile=0.5) == 1
assert quantile_based_fwer(np.array([0, 0, 0]), quantile=0.5) == 0 assert merge_p_values_quantile(np.array([0, 0, 0]), quantile=0.5) == 0
assert quantile_based_fwer(np.array([0.33]), quantile=0.5) == 0.33 assert merge_p_values_quantile(np.array([0.33]), quantile=0.5) == 0.33
def test_given_p_values_with_nans_when_using_quantile_based_fwer_then_ignores_the_nan_values(): def test_given_p_values_with_nans_when_merge_p_values_quantile_then_ignores_the_nan_values():
p_values = np.array([0.01, np.nan, 0.02, 0.03, 0.04, 0.05, np.nan, 0.06, 0.07, 0.08, 0.09, 0.1]) p_values = np.array([0.01, np.nan, 0.02, 0.03, 0.04, 0.05, np.nan, 0.06, 0.07, 0.08, 0.09, 0.1])
assert quantile_based_fwer(p_values, quantile=0.5) == 0.055 / 0.5 assert merge_p_values_quantile(p_values, quantile=0.5) == 0.055 / 0.5
def test_given_p_values_with_scaling_when_apply_quantile_based_fwer_control_then_returns_single_adjusted_pvalue(): def test_given_p_values_with_scaling_when_merge_p_values_quantile_then_returns_single_adjusted_pvalue():
p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]) p_values = np.array([0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1])
p_values_scaling = np.array([2, 2, 1, 2, 1, 3, 1, 2, 4, 1]) p_values_scaling = np.array([2, 2, 1, 2, 1, 3, 1, 2, 4, 1])
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.5) == approx(0.15) assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.5) == approx(0.15)
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.25) == approx(0.17) assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.25) == approx(0.17)
assert quantile_based_fwer(p_values, p_values_scaling, quantile=0.75) == approx(0.193, abs=0.001) assert merge_p_values_quantile(p_values, p_values_scaling, quantile=0.75) == approx(0.193, abs=0.001)
def test_given_invalid_inputs_when_apply_quantile_based_fwer_control_then_raises_error(): def test_given_invalid_inputs_when_merge_p_values_quantile_then_raises_error():
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=0) assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), np.array([1, 2]), quantile=0.1) assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), np.array([1, 2]), quantile=0.1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=1.1) assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=1.1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert quantile_based_fwer(np.array([0.1, 0.5, 1]), quantile=-0.5) assert merge_p_values_quantile(np.array([0.1, 0.5, 1]), quantile=-0.5)
def test_when_merge_p_values_average_without_randomization_then_returns_expected_results():
assert merge_p_values_average([0]) == 0
assert merge_p_values_average([1]) == 1
assert merge_p_values_average([0, 1]) == approx(1.0)
assert merge_p_values_average([0, 0, 1]) == 0
assert merge_p_values_average([0, 0.5, 0.5, np.nan, 1, np.nan]) == approx(1.0)
assert merge_p_values_average([0, 0, 1, 1, 1]) == approx(1.0)
@flaky(max_runs=3)
def test_when_merge_p_values_average_with_randomization_then_returns_expected_results():
assert merge_p_values_average([0], randomization=True) == 0
assert merge_p_values_average([1], randomization=True) == 1
assert merge_p_values_average([0, 1], randomization=True) == approx(0.0, abs=0.01)
assert merge_p_values_average([0, 0, 1], randomization=True) == approx(0.0, abs=0.01)
assert merge_p_values_average([0, np.nan, 0, np.nan, 1, 1], randomization=True) == approx(0.0, abs=0.01)
def test_when_evaluate_marginal_expectation_without_averaging_result_then_returned_results_have_correct_format(): def test_when_evaluate_marginal_expectation_without_averaging_result_then_returned_results_have_correct_format():