diff --git a/econml/_shap.py b/econml/_shap.py index 986ec152..cc417d71 100644 --- a/econml/_shap.py +++ b/econml/_shap.py @@ -82,7 +82,8 @@ def _shap_explain_cme(cme_model, X, d_t, d_y, data=shap_out.data, main_effects=shap_out.main_effects, feature_names=shap_out.feature_names) shap_outs[output_names[i]][treatment_names[0]] = shap_out_new - return shap_outs + # return plain dictionary so that erroneous accesses don't half work (see #708) + return dict(shap_outs) def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, feature_names=None, @@ -176,7 +177,8 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe else: shap_outs[output_names[0]][treatment_names[i]] = shap_out - return shap_outs + # return plain dictionary so that erroneous accesses don't half work (see #708) + return dict(shap_outs) def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_intercept, @@ -258,7 +260,8 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int feature_names=shap_out.feature_names) shap_outs[output_names[0]][treatment_names[i]] = shap_out_new - return shap_outs + # return plain dictionary so that erroneous accesses don't half work (see #708) + return dict(shap_outs) def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t, d_y, featurizer=None, @@ -352,7 +355,8 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t, shap_outs[output_names[j]][treatment_names[i]] = shap_out_new else: shap_outs[output_names[j]][treatment_names[0]] = shap_out - return shap_outs + # return plain dictionary so that erroneous accesses don't half work (see #708) + return dict(shap_outs) def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_names, featurizer): diff --git a/econml/utilities.py b/econml/utilities.py index b31843f8..a7becc0a 100644 --- a/econml/utilities.py +++ b/econml/utilities.py @@ -1360,7 +1360,9 @@ def transpose_dictionary(d): for key1, value in d.items(): for key2, val in value.items(): output[key2][key1] = val - return output + + # return plain dictionary so that erroneous accesses don't half work (see e.g. #708) + return dict(output) def reshape_arrays_2dim(length, *args):