Fix dimension mismatch error when dropped_features contains cat_features (#2099)

* fix error when dropped_features contains cat_features

* fix lint

* add tests

* add tests

* fix UT error
This commit is contained in:
tongy-msft 2023-06-08 19:20:48 -07:00 коммит произвёл GitHub
Родитель 53f2c72912
Коммит 069c0a44b4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 42 добавлений и 10 удалений

Просмотреть файл

@ -260,6 +260,14 @@ class RAIInsights(RAIBaseInsights):
self._initialize_managers()
self._try_add_data_balance()
def get_categorical_features_after_drop(self):
dropped_features = self._feature_metadata.dropped_features
if dropped_features is None:
return self.categorical_features
else:
return list(set(self.categorical_features) -
set(dropped_features))
def get_train_data(self):
"""Returns the training dataset after dropping
features if any were configured to be dropped.
@ -372,7 +380,8 @@ class RAIInsights(RAIBaseInsights):
dropped_features = self._feature_metadata.dropped_features
self._causal_manager = CausalManager(
self.get_train_data(), self.get_test_data(), self.target_column,
self.task_type, self.categorical_features, self._feature_metadata)
self.task_type, self.get_categorical_features_after_drop(),
self._feature_metadata)
self._counterfactual_manager = CounterfactualManager(
model=self.model, train=self.get_train_data(),
@ -394,7 +403,7 @@ class RAIInsights(RAIBaseInsights):
self.model, self.get_train_data(), self.get_test_data(),
self.target_column,
self._classes,
categorical_features=self.categorical_features)
self.get_categorical_features_after_drop())
self._managers = [self._causal_manager,
self._counterfactual_manager,
@ -881,10 +890,11 @@ class RAIInsights(RAIBaseInsights):
true_y = self.test[self.target_column]
X_test = test_data.drop(columns=[self.target_column])
X_test_after_drop = self.get_test_data(X_test)
filter_data_with_cohort = FilterDataWithCohortFilters(
model=self.model,
dataset=X_test,
features=X_test.columns,
dataset=X_test_after_drop,
features=X_test_after_drop.columns,
categorical_features=self.categorical_features,
categories=self._categories,
true_y=true_y,

Просмотреть файл

@ -132,19 +132,26 @@ class TestRAIInsights(object):
categorical_features = categorical_features + ['is_adult']
X_train = data_train.drop([target_name], axis=1)
dropped_feature = 'education'
X_train_after_drop = X_train.drop([dropped_feature], axis=1)
categorical_features_after_drop = categorical_features.copy()
categorical_features_after_drop.remove(dropped_feature)
model = create_complex_classification_pipeline(
X_train, y_train, continuous_features, categorical_features)
X_train_after_drop, y_train, continuous_features,
categorical_features_after_drop)
manager_args = {
ManagerParams.TREATMENT_FEATURES: ['age', 'hours_per_week'],
ManagerParams.DESIRED_CLASS: 'opposite',
ManagerParams.FEATURE_IMPORTANCE: False
}
feature_metadata = FeatureMetadata(
dropped_features=[dropped_feature])
run_rai_insights(model, data_train, data_test, target_name,
categorical_features,
manager_type, manager_args,
classes=classes)
classes=classes, feature_metadata=feature_metadata)
@pytest.mark.parametrize('manager_type', [ManagerNames.CAUSAL,
ManagerNames.ERROR_ANALYSIS,
@ -236,8 +243,8 @@ def run_rai_insights(model, train_data, test_data, target_column,
elif manager_type == ManagerNames.ERROR_ANALYSIS:
setup_error_analysis(rai_insights)
validate_rai_insights(rai_insights, train_data, test_data,
target_column, task_type, categorical_features)
validate_rai_insights(rai_insights, train_data, test_data, target_column,
task_type, categorical_features, feature_metadata)
if manager_type == ManagerNames.CAUSAL:
treatment_features = manager_args.get(ManagerParams.TREATMENT_FEATURES)
@ -275,7 +282,8 @@ def run_rai_insights(model, train_data, test_data, target_column,
validate_rai_insights(
rai_insights, train_data, test_data,
target_column, task_type, categorical_features)
target_column, task_type, categorical_features,
feature_metadata)
if manager_type == ManagerNames.ERROR_ANALYSIS:
validate_error_analysis(rai_insights)
@ -359,7 +367,8 @@ def validate_rai_insights(
test_data,
target_column,
task_type,
categorical_features
categorical_features,
feature_metadata=None
):
pd.testing.assert_frame_equal(rai_insights.train, train_data)
@ -377,6 +386,19 @@ def validate_rai_insights(
for ind_data in rai_insights._string_ind_data:
assert len(ind_data) == expected_length
if feature_metadata is not None:
if feature_metadata.dropped_features is not None and \
categorical_features is not None:
assert len(rai_insights.get_categorical_features_after_drop()) == \
len(categorical_features) - \
len(feature_metadata.dropped_features)
else:
if categorical_features is not None:
assert len(rai_insights.get_categorical_features_after_drop()) == \
len(categorical_features)
else:
assert len(rai_insights.get_categorical_features_after_drop()) == 0
if rai_insights.model is None:
assert rai_insights._predict_output is None
assert rai_insights._predict_proba_output is None