Add dataset metadata to `IDataset` (#2071)

* Add dataset metadata to `IDataset`

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Fix tests

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

---------

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
This commit is contained in:
Gaurav Gupta 2023-05-22 14:18:55 -07:00 коммит произвёл GitHub
Родитель e398f5a396
Коммит d6d88694ab
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 197 добавлений и 2 удалений

Просмотреть файл

@ -46,6 +46,7 @@
"ci_upper",
"class_names",
"cohort_filter_list",
"column_name",
"comparison_value",
"control_treatment",
"data_balance_measures",
@ -74,6 +75,7 @@
"feature_metadata",
"feature_names",
"feature_names_including_target",
"feature_ranges",
"feature_value",
"global_effects",
"identity_feature_name",
@ -85,15 +87,18 @@
"lower_bounds",
"matrix_features",
"max_error",
"max_value",
"mean_absolute_error",
"mean_absolute_error_min",
"mean_squared_error",
"mean_squared_error_min",
"mean_squared_log_error",
"median_absolute_error",
"min_value",
"miss_rate",
"model_type",
"n_samples",
"num_rows",
"p_value",
"plot_bgcolor",
"point_estimate",
@ -104,6 +109,7 @@
"probability_y",
"r2_score",
"r2_score_min",
"range_type",
"recall_score",
"recall_score_min",
"recommended_policy_gains",
@ -120,6 +126,7 @@
"shap_tree",
"specificity_score",
"summary_importance",
"tabular_dataset_metadata",
"target_column",
"task_type",
"test_data",
@ -136,6 +143,7 @@
"use_entire_test_data",
"zero_one_loss",
"index_series",
"unique_values",
"x_series",
"y_series",
"x_map",

Просмотреть файл

@ -16916,6 +16916,86 @@ exports[`FairnessWizardV1 should render successfully 1`] = `
0.3195856063209964,
],
],
"tabular_dataset_metadata": Object {
"feature_ranges": Array [
Object {
"column_name": "age",
"max_value": 74,
"min_value": 17,
"range_type": "integer",
},
Object {
"column_name": "workclass",
"range_type": "categorical",
"unique_values": Array [
"Private",
"Government",
"Other/Unknown",
"Self-Employed",
],
},
Object {
"column_name": "education",
"range_type": "categorical",
"unique_values": Array [
"HS-grad",
"Masters",
"Some-college",
"Assoc",
"Bachelors",
"School",
"Doctorate",
],
},
Object {
"column_name": "marital_status",
"range_type": "categorical",
"unique_values": Array [
"Married",
"Divorced",
"Single",
"Widowed",
],
},
Object {
"column_name": "occupation",
"range_type": "categorical",
"unique_values": Array [
"Blue-Collar",
"Sales",
"White-Collar",
"Service",
"Other/Unknown",
"Professional",
],
},
Object {
"column_name": "race",
"range_type": "categorical",
"unique_values": Array [
"White",
"Other",
],
},
Object {
"column_name": "gender",
"range_type": "categorical",
"unique_values": Array [
"Male",
"Female",
],
},
Object {
"column_name": "hours_per_week",
"max_value": 70,
"min_value": 10,
"range_type": "integer",
},
],
"is_large_data_scenario": false,
"num_rows": 48,
"use_entire_test_data": false,
},
"target_column": "income",
"task_type": "classification",
"true_y": Array [

Просмотреть файл

@ -2978,6 +2978,75 @@ export const adultCensusWithFairnessDataset: IDataset = {
[0.7355988315726243, 0.2644011684273757],
[0.6804143936790036, 0.3195856063209964]
],
tabular_dataset_metadata: {
feature_ranges: [
{
column_name: "age",
max_value: 74,
min_value: 17,
range_type: "integer"
},
{
column_name: "workclass",
range_type: "categorical",
unique_values: [
"Private",
"Government",
"Other/Unknown",
"Self-Employed"
]
},
{
column_name: "education",
range_type: "categorical",
unique_values: [
"HS-grad",
"Masters",
"Some-college",
"Assoc",
"Bachelors",
"School",
"Doctorate"
]
},
{
column_name: "marital_status",
range_type: "categorical",
unique_values: ["Married", "Divorced", "Single", "Widowed"]
},
{
column_name: "occupation",
range_type: "categorical",
unique_values: [
"Blue-Collar",
"Sales",
"White-Collar",
"Service",
"Other/Unknown",
"Professional"
]
},
{
column_name: "race",
range_type: "categorical",
unique_values: ["White", "Other"]
},
{
column_name: "gender",
range_type: "categorical",
unique_values: ["Male", "Female"]
},
{
column_name: "hours_per_week",
max_value: 70,
min_value: 10,
range_type: "integer"
}
],
is_large_data_scenario: false,
num_rows: 48,
use_entire_test_data: false
},
target_column: "income",
task_type: DatasetTaskType.Classification,
true_y: [

Просмотреть файл

@ -16,6 +16,13 @@ export enum DatasetTaskType {
QuestionAnswering = "question_answering"
}
export interface ITabularDatasetMetadata {
is_large_data_scenario: boolean;
use_entire_test_data: boolean;
num_rows: number;
feature_ranges: Array<{ [key: string]: any }>;
}
export interface IDataset {
task_type: DatasetTaskType;
true_y: number[] | number[][] | string[];
@ -30,6 +37,7 @@ export interface IDataset {
target_column?: string | string[];
data_balance_measures?: IDataBalanceMeasures;
feature_metadata?: IFeatureMetaData;
tabular_dataset_metadata?: ITabularDatasetMetadata;
images?: string[];
index?: string[];
object_detection_true_y?: number[][][];

Просмотреть файл

@ -11,6 +11,13 @@ class TaskType(str, Enum):
FORECASTING = 'forecasting'
class TabularDatasetMetadata:
is_large_data_scenario: bool
use_entire_test_data: bool
feature_ranges: List[Dict[str, Any]]
num_rows: int
class Dataset:
task_type: TaskType
predicted_y: List
@ -24,6 +31,7 @@ class Dataset:
is_large_data_scenario: bool
use_entire_test_data: bool
feature_metadata: Optional[Dict[str, Any]]
tabular_dataset_metadata: Optional[TabularDatasetMetadata]
data_balance_measures: Dict[str, Any]
images: Optional[List[str]]
index: Optional[List[str]]

Просмотреть файл

@ -18,7 +18,8 @@ from erroranalysis._internal.process_categoricals import process_categoricals
from raiutils.data_processing import convert_to_list
from raiutils.exceptions import UserConfigValidationException
from raiutils.models import Forecasting, ModelTask, SKLearn
from responsibleai._interfaces import Dataset, RAIInsightsData
from responsibleai._interfaces import (Dataset, RAIInsightsData,
TabularDatasetMetadata)
from responsibleai._internal._forecasting_wrappers import _wrap_model
from responsibleai._internal.constants import (FileFormats, ManagerNames,
Metadata,
@ -234,7 +235,7 @@ class RAIInsights(RAIBaseInsights):
self._feature_columns = \
test.drop(columns=[target_column]).columns.tolist()
self._feature_ranges = RAIInsights._get_feature_ranges(
test=test,
test=(self._large_test if self._large_test is not None else test),
categorical_features=self.categorical_features,
feature_columns=self._feature_columns,
datetime_features=self._feature_metadata.datetime_features)
@ -923,6 +924,16 @@ class RAIInsights(RAIBaseInsights):
True if self._large_test is not None else False
dashboard_dataset.use_entire_test_data = False
dashboard_dataset.tabular_dataset_metadata = TabularDatasetMetadata()
dashboard_dataset.tabular_dataset_metadata.is_large_data_scenario = \
True if self._large_test is not None else False
dashboard_dataset.tabular_dataset_metadata.use_entire_test_data = False
dashboard_dataset.tabular_dataset_metadata.num_rows = \
len(self._large_test) \
if self._large_test is not None else len(self.test)
dashboard_dataset.tabular_dataset_metadata.feature_ranges = \
self._feature_ranges
if self._feature_metadata is not None:
dashboard_dataset.feature_metadata = \
self._feature_metadata.to_dict()

Просмотреть файл

@ -14,6 +14,7 @@ from rai_test_utils.models.sklearn import (
create_sklearn_random_forest_classifier,
create_sklearn_random_forest_regressor)
from responsibleai import RAIInsights
from responsibleai._interfaces import Dataset, TabularDatasetMetadata
LABELS = 'labels'
@ -33,9 +34,19 @@ class TestRAIInsightsLargeData(object):
rai_insights._large_predict_proba_output)
dataset = rai_insights._get_dataset()
assert isinstance(dataset, Dataset)
assert dataset.is_large_data_scenario
assert not dataset.use_entire_test_data
assert isinstance(
dataset.tabular_dataset_metadata, TabularDatasetMetadata)
assert dataset.tabular_dataset_metadata is not None
assert dataset.tabular_dataset_metadata.is_large_data_scenario
assert not dataset.tabular_dataset_metadata.use_entire_test_data
assert dataset.tabular_dataset_metadata.num_rows == \
len(rai_insights.test) + 1
assert dataset.tabular_dataset_metadata.feature_ranges is not None
filtered_small_data = rai_insights.get_filtered_test_data(
[], [], use_entire_test_data=False)
assert len(filtered_small_data) == len(rai_insights.test)