diff --git a/responsibleai/responsibleai/rai_insights/rai_insights.py b/responsibleai/responsibleai/rai_insights/rai_insights.py index 68c2557eb..89f14f220 100644 --- a/responsibleai/responsibleai/rai_insights/rai_insights.py +++ b/responsibleai/responsibleai/rai_insights/rai_insights.py @@ -260,6 +260,14 @@ class RAIInsights(RAIBaseInsights): self._initialize_managers() self._try_add_data_balance() + def get_categorical_features_after_drop(self): + dropped_features = self._feature_metadata.dropped_features + if dropped_features is None: + return self.categorical_features + else: + return list(set(self.categorical_features) - + set(dropped_features)) + def get_train_data(self): """Returns the training dataset after dropping features if any were configured to be dropped. @@ -372,7 +380,8 @@ class RAIInsights(RAIBaseInsights): dropped_features = self._feature_metadata.dropped_features self._causal_manager = CausalManager( self.get_train_data(), self.get_test_data(), self.target_column, - self.task_type, self.categorical_features, self._feature_metadata) + self.task_type, self.get_categorical_features_after_drop(), + self._feature_metadata) self._counterfactual_manager = CounterfactualManager( model=self.model, train=self.get_train_data(), @@ -394,7 +403,7 @@ class RAIInsights(RAIBaseInsights): self.model, self.get_train_data(), self.get_test_data(), self.target_column, self._classes, - categorical_features=self.categorical_features) + self.get_categorical_features_after_drop()) self._managers = [self._causal_manager, self._counterfactual_manager, @@ -881,10 +890,11 @@ class RAIInsights(RAIBaseInsights): true_y = self.test[self.target_column] X_test = test_data.drop(columns=[self.target_column]) + X_test_after_drop = self.get_test_data(X_test) filter_data_with_cohort = FilterDataWithCohortFilters( model=self.model, - dataset=X_test, - features=X_test.columns, + dataset=X_test_after_drop, + features=X_test_after_drop.columns, categorical_features=self.categorical_features, categories=self._categories, true_y=true_y, diff --git a/responsibleai/tests/rai_insights/test_rai_insights.py b/responsibleai/tests/rai_insights/test_rai_insights.py index 8bacf34e0..5b355dd45 100644 --- a/responsibleai/tests/rai_insights/test_rai_insights.py +++ b/responsibleai/tests/rai_insights/test_rai_insights.py @@ -132,19 +132,26 @@ class TestRAIInsights(object): categorical_features = categorical_features + ['is_adult'] X_train = data_train.drop([target_name], axis=1) + dropped_feature = 'education' + X_train_after_drop = X_train.drop([dropped_feature], axis=1) + categorical_features_after_drop = categorical_features.copy() + categorical_features_after_drop.remove(dropped_feature) model = create_complex_classification_pipeline( - X_train, y_train, continuous_features, categorical_features) + X_train_after_drop, y_train, continuous_features, + categorical_features_after_drop) manager_args = { ManagerParams.TREATMENT_FEATURES: ['age', 'hours_per_week'], ManagerParams.DESIRED_CLASS: 'opposite', ManagerParams.FEATURE_IMPORTANCE: False } + feature_metadata = FeatureMetadata( + dropped_features=[dropped_feature]) run_rai_insights(model, data_train, data_test, target_name, categorical_features, manager_type, manager_args, - classes=classes) + classes=classes, feature_metadata=feature_metadata) @pytest.mark.parametrize('manager_type', [ManagerNames.CAUSAL, ManagerNames.ERROR_ANALYSIS, @@ -236,8 +243,8 @@ def run_rai_insights(model, train_data, test_data, target_column, elif manager_type == ManagerNames.ERROR_ANALYSIS: setup_error_analysis(rai_insights) - validate_rai_insights(rai_insights, train_data, test_data, - target_column, task_type, categorical_features) + validate_rai_insights(rai_insights, train_data, test_data, target_column, + task_type, categorical_features, feature_metadata) if manager_type == ManagerNames.CAUSAL: treatment_features = manager_args.get(ManagerParams.TREATMENT_FEATURES) @@ -275,7 +282,8 @@ def run_rai_insights(model, train_data, test_data, target_column, validate_rai_insights( rai_insights, train_data, test_data, - target_column, task_type, categorical_features) + target_column, task_type, categorical_features, + feature_metadata) if manager_type == ManagerNames.ERROR_ANALYSIS: validate_error_analysis(rai_insights) @@ -359,7 +367,8 @@ def validate_rai_insights( test_data, target_column, task_type, - categorical_features + categorical_features, + feature_metadata=None ): pd.testing.assert_frame_equal(rai_insights.train, train_data) @@ -377,6 +386,19 @@ def validate_rai_insights( for ind_data in rai_insights._string_ind_data: assert len(ind_data) == expected_length + if feature_metadata is not None: + if feature_metadata.dropped_features is not None and \ + categorical_features is not None: + assert len(rai_insights.get_categorical_features_after_drop()) == \ + len(categorical_features) - \ + len(feature_metadata.dropped_features) + else: + if categorical_features is not None: + assert len(rai_insights.get_categorical_features_after_drop()) == \ + len(categorical_features) + else: + assert len(rai_insights.get_categorical_features_after_drop()) == 0 + if rai_insights.model is None: assert rai_insights._predict_output is None assert rai_insights._predict_proba_output is None