From 6f721e5c1b64ffedebfc4718b14b205e6fdaa88a Mon Sep 17 00:00:00 2001 From: Andres Morales Date: Mon, 31 Oct 2022 10:13:16 -0600 Subject: [PATCH] Check random_state type on placebo_treatment_refuter Fix #719 Signed-off-by: Andres Morales --- dowhy/causal_refuters/placebo_treatment_refuter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dowhy/causal_refuters/placebo_treatment_refuter.py b/dowhy/causal_refuters/placebo_treatment_refuter.py index 15767ff11..dd06a4a1a 100755 --- a/dowhy/causal_refuters/placebo_treatment_refuter.py +++ b/dowhy/causal_refuters/placebo_treatment_refuter.py @@ -88,7 +88,7 @@ def _refute_once( treatment_names: List[str], type_dict: Dict, placebo_type: PlaceboType = PlaceboType.DEFAULT, - random_state: Optional[Union[int, np.random.RandomState]] = None, + random_state: Optional[np.random.RandomState] = None, ): if placebo_type == PlaceboType.PERMUTE: permuted_idx = None @@ -179,6 +179,10 @@ def refute_placebo_treatment( :param n_jobs: The maximum number of concurrently running jobs. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all (this is the default). :param verbose: The verbosity level: if non zero, progress messages are printed. Above 50, the output is sent to stdout. The frequency of the messages increases with the verbosity level. If it more than 10, all iterations are reported. The default is 0. """ + + if isinstance(random_state, int): + random_state = np.random.RandomState(random_state) + # only permute is supported for iv methods if target_estimand.identifier_method.startswith("iv"): if placebo_type != PlaceboType.PERMUTE: