Add dataset metadata to `IDataset` (#2071)
* Add dataset metadata to `IDataset` Signed-off-by: Gaurav Gupta <gaugup@microsoft.com> * Fix tests Signed-off-by: Gaurav Gupta <gaugup@microsoft.com> --------- Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
This commit is contained in:
Родитель
e398f5a396
Коммит
d6d88694ab
|
@ -46,6 +46,7 @@
|
|||
"ci_upper",
|
||||
"class_names",
|
||||
"cohort_filter_list",
|
||||
"column_name",
|
||||
"comparison_value",
|
||||
"control_treatment",
|
||||
"data_balance_measures",
|
||||
|
@ -74,6 +75,7 @@
|
|||
"feature_metadata",
|
||||
"feature_names",
|
||||
"feature_names_including_target",
|
||||
"feature_ranges",
|
||||
"feature_value",
|
||||
"global_effects",
|
||||
"identity_feature_name",
|
||||
|
@ -85,15 +87,18 @@
|
|||
"lower_bounds",
|
||||
"matrix_features",
|
||||
"max_error",
|
||||
"max_value",
|
||||
"mean_absolute_error",
|
||||
"mean_absolute_error_min",
|
||||
"mean_squared_error",
|
||||
"mean_squared_error_min",
|
||||
"mean_squared_log_error",
|
||||
"median_absolute_error",
|
||||
"min_value",
|
||||
"miss_rate",
|
||||
"model_type",
|
||||
"n_samples",
|
||||
"num_rows",
|
||||
"p_value",
|
||||
"plot_bgcolor",
|
||||
"point_estimate",
|
||||
|
@ -104,6 +109,7 @@
|
|||
"probability_y",
|
||||
"r2_score",
|
||||
"r2_score_min",
|
||||
"range_type",
|
||||
"recall_score",
|
||||
"recall_score_min",
|
||||
"recommended_policy_gains",
|
||||
|
@ -120,6 +126,7 @@
|
|||
"shap_tree",
|
||||
"specificity_score",
|
||||
"summary_importance",
|
||||
"tabular_dataset_metadata",
|
||||
"target_column",
|
||||
"task_type",
|
||||
"test_data",
|
||||
|
@ -136,6 +143,7 @@
|
|||
"use_entire_test_data",
|
||||
"zero_one_loss",
|
||||
"index_series",
|
||||
"unique_values",
|
||||
"x_series",
|
||||
"y_series",
|
||||
"x_map",
|
||||
|
|
|
@ -16916,6 +16916,86 @@ exports[`FairnessWizardV1 should render successfully 1`] = `
|
|||
0.3195856063209964,
|
||||
],
|
||||
],
|
||||
"tabular_dataset_metadata": Object {
|
||||
"feature_ranges": Array [
|
||||
Object {
|
||||
"column_name": "age",
|
||||
"max_value": 74,
|
||||
"min_value": 17,
|
||||
"range_type": "integer",
|
||||
},
|
||||
Object {
|
||||
"column_name": "workclass",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"Private",
|
||||
"Government",
|
||||
"Other/Unknown",
|
||||
"Self-Employed",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "education",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"HS-grad",
|
||||
"Masters",
|
||||
"Some-college",
|
||||
"Assoc",
|
||||
"Bachelors",
|
||||
"School",
|
||||
"Doctorate",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "marital_status",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"Married",
|
||||
"Divorced",
|
||||
"Single",
|
||||
"Widowed",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "occupation",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"Blue-Collar",
|
||||
"Sales",
|
||||
"White-Collar",
|
||||
"Service",
|
||||
"Other/Unknown",
|
||||
"Professional",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "race",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"White",
|
||||
"Other",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "gender",
|
||||
"range_type": "categorical",
|
||||
"unique_values": Array [
|
||||
"Male",
|
||||
"Female",
|
||||
],
|
||||
},
|
||||
Object {
|
||||
"column_name": "hours_per_week",
|
||||
"max_value": 70,
|
||||
"min_value": 10,
|
||||
"range_type": "integer",
|
||||
},
|
||||
],
|
||||
"is_large_data_scenario": false,
|
||||
"num_rows": 48,
|
||||
"use_entire_test_data": false,
|
||||
},
|
||||
"target_column": "income",
|
||||
"task_type": "classification",
|
||||
"true_y": Array [
|
||||
|
|
|
@ -2978,6 +2978,75 @@ export const adultCensusWithFairnessDataset: IDataset = {
|
|||
[0.7355988315726243, 0.2644011684273757],
|
||||
[0.6804143936790036, 0.3195856063209964]
|
||||
],
|
||||
tabular_dataset_metadata: {
|
||||
feature_ranges: [
|
||||
{
|
||||
column_name: "age",
|
||||
max_value: 74,
|
||||
min_value: 17,
|
||||
range_type: "integer"
|
||||
},
|
||||
{
|
||||
column_name: "workclass",
|
||||
range_type: "categorical",
|
||||
unique_values: [
|
||||
"Private",
|
||||
"Government",
|
||||
"Other/Unknown",
|
||||
"Self-Employed"
|
||||
]
|
||||
},
|
||||
{
|
||||
column_name: "education",
|
||||
range_type: "categorical",
|
||||
unique_values: [
|
||||
"HS-grad",
|
||||
"Masters",
|
||||
"Some-college",
|
||||
"Assoc",
|
||||
"Bachelors",
|
||||
"School",
|
||||
"Doctorate"
|
||||
]
|
||||
},
|
||||
{
|
||||
column_name: "marital_status",
|
||||
range_type: "categorical",
|
||||
unique_values: ["Married", "Divorced", "Single", "Widowed"]
|
||||
},
|
||||
{
|
||||
column_name: "occupation",
|
||||
range_type: "categorical",
|
||||
unique_values: [
|
||||
"Blue-Collar",
|
||||
"Sales",
|
||||
"White-Collar",
|
||||
"Service",
|
||||
"Other/Unknown",
|
||||
"Professional"
|
||||
]
|
||||
},
|
||||
{
|
||||
column_name: "race",
|
||||
range_type: "categorical",
|
||||
unique_values: ["White", "Other"]
|
||||
},
|
||||
{
|
||||
column_name: "gender",
|
||||
range_type: "categorical",
|
||||
unique_values: ["Male", "Female"]
|
||||
},
|
||||
{
|
||||
column_name: "hours_per_week",
|
||||
max_value: 70,
|
||||
min_value: 10,
|
||||
range_type: "integer"
|
||||
}
|
||||
],
|
||||
is_large_data_scenario: false,
|
||||
num_rows: 48,
|
||||
use_entire_test_data: false
|
||||
},
|
||||
target_column: "income",
|
||||
task_type: DatasetTaskType.Classification,
|
||||
true_y: [
|
||||
|
|
|
@ -16,6 +16,13 @@ export enum DatasetTaskType {
|
|||
QuestionAnswering = "question_answering"
|
||||
}
|
||||
|
||||
export interface ITabularDatasetMetadata {
|
||||
is_large_data_scenario: boolean;
|
||||
use_entire_test_data: boolean;
|
||||
num_rows: number;
|
||||
feature_ranges: Array<{ [key: string]: any }>;
|
||||
}
|
||||
|
||||
export interface IDataset {
|
||||
task_type: DatasetTaskType;
|
||||
true_y: number[] | number[][] | string[];
|
||||
|
@ -30,6 +37,7 @@ export interface IDataset {
|
|||
target_column?: string | string[];
|
||||
data_balance_measures?: IDataBalanceMeasures;
|
||||
feature_metadata?: IFeatureMetaData;
|
||||
tabular_dataset_metadata?: ITabularDatasetMetadata;
|
||||
images?: string[];
|
||||
index?: string[];
|
||||
object_detection_true_y?: number[][][];
|
||||
|
|
|
@ -11,6 +11,13 @@ class TaskType(str, Enum):
|
|||
FORECASTING = 'forecasting'
|
||||
|
||||
|
||||
class TabularDatasetMetadata:
|
||||
is_large_data_scenario: bool
|
||||
use_entire_test_data: bool
|
||||
feature_ranges: List[Dict[str, Any]]
|
||||
num_rows: int
|
||||
|
||||
|
||||
class Dataset:
|
||||
task_type: TaskType
|
||||
predicted_y: List
|
||||
|
@ -24,6 +31,7 @@ class Dataset:
|
|||
is_large_data_scenario: bool
|
||||
use_entire_test_data: bool
|
||||
feature_metadata: Optional[Dict[str, Any]]
|
||||
tabular_dataset_metadata: Optional[TabularDatasetMetadata]
|
||||
data_balance_measures: Dict[str, Any]
|
||||
images: Optional[List[str]]
|
||||
index: Optional[List[str]]
|
||||
|
|
|
@ -18,7 +18,8 @@ from erroranalysis._internal.process_categoricals import process_categoricals
|
|||
from raiutils.data_processing import convert_to_list
|
||||
from raiutils.exceptions import UserConfigValidationException
|
||||
from raiutils.models import Forecasting, ModelTask, SKLearn
|
||||
from responsibleai._interfaces import Dataset, RAIInsightsData
|
||||
from responsibleai._interfaces import (Dataset, RAIInsightsData,
|
||||
TabularDatasetMetadata)
|
||||
from responsibleai._internal._forecasting_wrappers import _wrap_model
|
||||
from responsibleai._internal.constants import (FileFormats, ManagerNames,
|
||||
Metadata,
|
||||
|
@ -234,7 +235,7 @@ class RAIInsights(RAIBaseInsights):
|
|||
self._feature_columns = \
|
||||
test.drop(columns=[target_column]).columns.tolist()
|
||||
self._feature_ranges = RAIInsights._get_feature_ranges(
|
||||
test=test,
|
||||
test=(self._large_test if self._large_test is not None else test),
|
||||
categorical_features=self.categorical_features,
|
||||
feature_columns=self._feature_columns,
|
||||
datetime_features=self._feature_metadata.datetime_features)
|
||||
|
@ -923,6 +924,16 @@ class RAIInsights(RAIBaseInsights):
|
|||
True if self._large_test is not None else False
|
||||
dashboard_dataset.use_entire_test_data = False
|
||||
|
||||
dashboard_dataset.tabular_dataset_metadata = TabularDatasetMetadata()
|
||||
dashboard_dataset.tabular_dataset_metadata.is_large_data_scenario = \
|
||||
True if self._large_test is not None else False
|
||||
dashboard_dataset.tabular_dataset_metadata.use_entire_test_data = False
|
||||
dashboard_dataset.tabular_dataset_metadata.num_rows = \
|
||||
len(self._large_test) \
|
||||
if self._large_test is not None else len(self.test)
|
||||
dashboard_dataset.tabular_dataset_metadata.feature_ranges = \
|
||||
self._feature_ranges
|
||||
|
||||
if self._feature_metadata is not None:
|
||||
dashboard_dataset.feature_metadata = \
|
||||
self._feature_metadata.to_dict()
|
||||
|
|
|
@ -14,6 +14,7 @@ from rai_test_utils.models.sklearn import (
|
|||
create_sklearn_random_forest_classifier,
|
||||
create_sklearn_random_forest_regressor)
|
||||
from responsibleai import RAIInsights
|
||||
from responsibleai._interfaces import Dataset, TabularDatasetMetadata
|
||||
|
||||
LABELS = 'labels'
|
||||
|
||||
|
@ -33,9 +34,19 @@ class TestRAIInsightsLargeData(object):
|
|||
rai_insights._large_predict_proba_output)
|
||||
|
||||
dataset = rai_insights._get_dataset()
|
||||
assert isinstance(dataset, Dataset)
|
||||
assert dataset.is_large_data_scenario
|
||||
assert not dataset.use_entire_test_data
|
||||
|
||||
assert isinstance(
|
||||
dataset.tabular_dataset_metadata, TabularDatasetMetadata)
|
||||
assert dataset.tabular_dataset_metadata is not None
|
||||
assert dataset.tabular_dataset_metadata.is_large_data_scenario
|
||||
assert not dataset.tabular_dataset_metadata.use_entire_test_data
|
||||
assert dataset.tabular_dataset_metadata.num_rows == \
|
||||
len(rai_insights.test) + 1
|
||||
assert dataset.tabular_dataset_metadata.feature_ranges is not None
|
||||
|
||||
filtered_small_data = rai_insights.get_filtered_test_data(
|
||||
[], [], use_entire_test_data=False)
|
||||
assert len(filtered_small_data) == len(rai_insights.test)
|
||||
|
|
Загрузка…
Ссылка в новой задаче