Add multiclass classification dataset & set up basic model-assessment tests (#1191)

* add wine dataset for multiclass classification

* add basic tests for model overview section for binary and multiclass classification as well as regression

* lintfix

* undo changes to notebook

* update tests to reflect changes on main (addition of new metrics and string changes), simplify test setup

* lintfix
This commit is contained in:
Roman Lutz 2022-02-04 11:33:28 -05:00 коммит произвёл GitHub
Родитель ea2e14f39d
Коммит f0f9c7f7b4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 4109 добавлений и 28 удалений

Просмотреть файл

@ -23,35 +23,29 @@ export function describeModelPerformanceStats(dataShape: IInterpretData): void {
cy.get('#OverallMetricChart div[class*="statsBox"]').should("exist");
});
it("should have some stats", () => {
let expectedMetrics: string[];
if (dataShape.isClassification) {
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"Accuracy"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"Precision"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains("Recall");
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"False positive rate"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"False negative rate"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
expectedMetrics = [
"Accuracy",
"Precision",
"F1 score",
"False positive rate",
"False negative rate",
"Selection rate"
);
];
} else {
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"Mean squared error"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
"Mean absolute error"
);
cy.get('#OverallMetricChart div[class*="statsBox"]').contains("R²");
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
expectedMetrics = [
"Mean squared error",
"Mean absolute error",
"R²",
"Mean prediction"
);
];
}
expectedMetrics.forEach((metricName) => {
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
metricName
);
});
});
});
}

Просмотреть файл

@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
export interface IModelAssessmentData {
isClassification?: boolean;
isMulticlass?: boolean;
}

Просмотреть файл

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { IModelAssessmentData } from "./IModelAssessmentData";
const modelAssessmentDatasets = {
adultCensusIncomeData: {
isClassification: true,
isMulticlass: false
},
bostonData: {
isClassification: false
},
wineData: {
isClassification: true,
isMulticlass: true
}
};
const withType: {
[key in keyof typeof modelAssessmentDatasets]: IModelAssessmentData;
} = modelAssessmentDatasets;
export { withType as modelAssessmentDatasets };

Просмотреть файл

@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
export function describeAxisConfigDialog(): void {
describe("Axis settings dialog", () => {
describe("Y Axis settings dialog", () => {
it("should display settings dialog", () => {
cy.get(
'#OverallMetricChart div[class*="rotatedVerticalBox"] button'
).click();
cy.get("#AxisConfigPanel div.ms-Panel-main").should("exist");
});
it("should be able to hide settings", () => {
cy.get("#AxisConfigPanel button.ms-Panel-closeButton").click();
cy.get("#AxisConfigPanel div.ms-Panel-main").should("not.exist");
});
it("should change to different y-axis title", () => {
cy.get(
'#OverallMetricChart div[class*="rotatedVerticalBox"] button'
).click();
cy.get("#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:eq(1)")
.invoke("text")
.then((text1) => {
cy.get(`#AxisConfigPanel label:contains(${text1})`).click();
cy.get("#AxisConfigPanel")
.find("button")
.contains("Select")
.click();
cy.get(
'#OverallMetricChart div[class*="rotatedVerticalBox"] button:eq(0)'
).contains(text1);
});
});
});
describe("X Axis settings dialog", () => {
it("should display settings dialog", () => {
cy.get(
'#OverallMetricChart div[class*="horizontalAxis"] button'
).click();
cy.get("#AxisConfigPanel div.ms-Panel-main").should("exist");
});
it("should be able to hide settings", () => {
cy.get("#AxisConfigPanel button.ms-Panel-closeButton").click();
cy.get("#AxisConfigPanel div.ms-Panel-main").should("not.exist");
});
it("should change to different x-axis title", () => {
cy.get(
'#OverallMetricChart div[class*="horizontalAxis"] button'
).click();
cy.get("#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:eq(1)")
.invoke("text")
.then((text1) => {
cy.get(`#AxisConfigPanel label:contains(${text1})`).click();
cy.get("#AxisConfigPanel")
.find("button")
.contains("Select")
.click();
cy.get(
'#OverallMetricChart div[class*="horizontalAxis"] button:eq(0)'
).contains(text1);
});
});
});
});
}

Просмотреть файл

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { modelAssessmentDatasets } from "../modelAssessmentDatasets";
import { describeAxisConfigDialog } from "./describeAxisConfigDialog";
import { describeModelPerformanceStats } from "./describeModelPerformanceStats";
const testName = "Model overview";
export function describeModelOverview(
name: keyof typeof modelAssessmentDatasets
): void {
describe(testName, () => {
before(() => {
cy.visit(`#/modelAssessment/${name}/light/english/Version-2`);
});
it("Model overview title", () => {
cy.get("#modelStatisticsHeader").contains("Model overview");
});
describe("Model performance Chart", () => {
describeAxisConfigDialog();
});
describe("Model performance stats", () => {
describeModelPerformanceStats(modelAssessmentDatasets[name]);
});
});
}

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { IModelAssessmentData } from "../IModelAssessmentData";
export function describeModelPerformanceStats(
modelAssessmentData: IModelAssessmentData
): void {
describe("performance stats", () => {
it("should have legend", () => {
cy.get('#OverallMetricChart g[class*="infolayer"]').should("exist");
});
// stats box currently not available for multiclass
if (!modelAssessmentData.isMulticlass) {
it("should have stats box", () => {
cy.get('#OverallMetricChart div[class*="statsBox"]').should("exist");
});
it("should have some stats", () => {
let expectedMetrics: string[];
if (modelAssessmentData.isClassification) {
expectedMetrics = [
"Accuracy",
"Precision",
"F1 score",
"False positive rate",
"False negative rate",
"Selection rate"
];
} else {
expectedMetrics = [
"Mean squared error",
"Mean absolute error",
"R²",
"Mean prediction"
];
}
expectedMetrics.forEach((metricName) => {
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
metricName
);
});
});
}
});
}

Просмотреть файл

@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
describeModelOverview("adultCensusIncomeData");

Просмотреть файл

@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
describeModelOverview("bostonData");

Просмотреть файл

@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
describeModelOverview("wineData");

Просмотреть файл

@ -49,6 +49,11 @@ import {
bostonErrorAnalysisData,
bostonWithFairnessModelExplanationData
} from "../model-assessment/__mock_data__/bostonData";
import {
wineData as wineDataMAD,
wineErrorAnalysisData,
wineWithFairnessModelExplanationData
} from "../model-assessment/__mock_data__/wineData";
export interface IInterpretDataSet {
data: IExplanationDashboardData;
@ -195,7 +200,7 @@ export const applications: IApplications = <const>{
adultCensusIncomeNoModelData: {
classDimension: 2,
dataset: adultCensusWithFairnessDataset
},
} as IModelAssessmentDataSet,
bostonData: {
causalAnalysisData: [bostonCensusCausalAnalysisData],
classDimension: 1,
@ -203,6 +208,12 @@ export const applications: IApplications = <const>{
dataset: bostonDataMAD,
errorAnalysisData: [bostonErrorAnalysisData],
modelExplanationData: [bostonWithFairnessModelExplanationData]
} as IModelAssessmentDataSet,
wineData: {
classDimension: 3,
dataset: wineDataMAD,
errorAnalysisData: [wineErrorAnalysisData],
modelExplanationData: [wineWithFairnessModelExplanationData]
} as IModelAssessmentDataSet
},
versions: { "1": 1, "2:Static-View": 2 }

Просмотреть файл

@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { IErrorAnalysisTreeNode } from "@responsible-ai/core-ui";
export const dummyTreeWineData: IErrorAnalysisTreeNode[] = [
{
arg: undefined,
badFeaturesRowCount: 0,
condition: "",
error: 49,
id: 0,
isErrorMetric: true,
method: "",
metricName: "Error rate",
metricValue: 0.550561797752809,
nodeIndex: 0,
nodeName: "color_intensity",
parentId: undefined,
parentNodeName: "",
pathFromRoot: "",
size: 89,
sourceRowKeyHash: "hashkey",
success: 40
},
{
arg: 3.8700000000000006,
badFeaturesRowCount: 0,
condition: "color_intensity <= 3.87",
error: 2,
id: 2,
isErrorMetric: true,
method: "less and equal",
metricName: "Error rate",
metricValue: 0.05714285714285714,
nodeIndex: 2,
nodeName: "",
parentId: 0,
parentNodeName: "color_intensity",
pathFromRoot: "",
size: 35,
sourceRowKeyHash: "hashkey",
success: 33
},
{
arg: 3.8700000000000006,
badFeaturesRowCount: 0,
condition: "color_intensity > 3.87",
error: 47,
id: 1,
isErrorMetric: true,
method: "greater",
metricName: "Error rate",
metricValue: 0.8703703703703703,
nodeIndex: 1,
nodeName: "proline",
parentId: 0,
parentNodeName: "color_intensity",
pathFromRoot: "",
size: 54,
sourceRowKeyHash: "hashkey",
success: 7
},
{
arg: 666.0000000000001,
badFeaturesRowCount: 0,
condition: "proline <= 666.00",
error: 13,
id: 3,
isErrorMetric: true,
method: "less and equal",
metricName: "Error rate",
metricValue: 0.65,
nodeIndex: 3,
nodeName: "",
parentId: 1,
parentNodeName: "proline",
pathFromRoot: "",
size: 20,
sourceRowKeyHash: "hashkey",
success: 7
},
{
arg: 666.0000000000001,
badFeaturesRowCount: 0,
condition: "proline > 666.00",
error: 34,
id: 4,
isErrorMetric: true,
method: "greater",
metricName: "Error rate",
metricValue: 1,
nodeIndex: 4,
nodeName: "",
parentId: 1,
parentNodeName: "proline",
pathFromRoot: "",
size: 34,
sourceRowKeyHash: "hashkey",
success: 0
}
];

Просмотреть файл

@ -15,13 +15,15 @@ import { dummyTreeBostonData } from "./__mock_data__/dummyTreeBoston";
import { dummyTreeBreastCancerData } from "./__mock_data__/dummyTreeBreastCancer";
import { dummyTreeBreastCancerPrecisionData } from "./__mock_data__/dummyTreeBreastCancerPrecision";
import { dummyTreeBreastCancerRecallData } from "./__mock_data__/dummyTreeBreastCancerRecall";
import { dummyTreeWineData } from "./__mock_data__/dummyTreeWine";
export enum DatasetName {
AdultCensusIncome = 1,
BreastCancer,
Boston,
BreastCancerPrecision,
BreastCancerRecall
BreastCancerRecall,
Wine
}
export function getJsonMatrix(): any {
@ -81,6 +83,8 @@ export function getJsonTree(dataset: DatasetName): any {
return _.cloneDeep(dummyTreeBreastCancerPrecisionData);
} else if (dataset === DatasetName.BreastCancerRecall) {
return _.cloneDeep(dummyTreeBreastCancerRecallData);
} else if (dataset === DatasetName.Wine) {
return _.cloneDeep(dummyTreeWineData);
}
return _.cloneDeep(dummyTreeAdultCensusIncomeData);
}
@ -100,6 +104,8 @@ export function generateJsonTree(
resolve(_.cloneDeep(dummyTreeBreastCancerPrecisionData));
} else if (dataset === DatasetName.BreastCancerRecall) {
resolve(_.cloneDeep(dummyTreeBreastCancerRecallData));
} else if (dataset === DatasetName.Wine) {
resolve(_.cloneDeep(dummyTreeWineData));
} else {
resolve(_.cloneDeep(dummyTreeBostonData));
}
@ -148,6 +154,13 @@ export function generateJsonTreeBoston(
return generateJsonTree(_data, signal, DatasetName.Boston);
}
export function generateJsonTreeWine(
_data: any[],
signal: AbortSignal
): Promise<any> {
return generateJsonTree(_data, signal, DatasetName.Wine);
}
export function generateJsonMatrix(dataset: DatasetName) {
return (data: any[], signal: AbortSignal): Promise<any> => {
const promise = new Promise((resolve, reject) => {
@ -200,7 +213,8 @@ export function createJsonImportancesGenerator(
dataset === DatasetName.AdultCensusIncome ||
dataset === DatasetName.Boston ||
dataset === DatasetName.BreastCancerPrecision ||
dataset === DatasetName.BreastCancerRecall
dataset === DatasetName.BreastCancerRecall ||
dataset === DatasetName.Wine
);
resolve(featureNames.map(() => Math.random()));
}, 300);

Просмотреть файл

@ -19,6 +19,7 @@ import {
DatasetName,
generateJsonTreeBoston,
generateJsonTreeAdultCensusIncome,
generateJsonTreeWine,
getJsonMatrix,
getJsonTreeAdultCensusIncome
} from "../error-analysis/utils";
@ -69,7 +70,7 @@ export class App extends React.Component<IAppProps> {
this.props.dataset.feature_names,
DatasetName.Boston
);
} else {
} else if (this.props.classDimension === 2) {
// Adult
modelAssessmentDashboardProps.requestDebugML =
generateJsonTreeAdultCensusIncome;
@ -78,6 +79,14 @@ export class App extends React.Component<IAppProps> {
this.props.dataset.feature_names,
DatasetName.AdultCensusIncome
);
} else {
// Wine
modelAssessmentDashboardProps.requestDebugML = generateJsonTreeWine;
modelAssessmentDashboardProps.requestImportances =
createJsonImportancesGenerator(
this.props.dataset.feature_names,
DatasetName.Wine
);
}
} else {
const staticTree = getJsonTreeAdultCensusIncome(

Разница между файлами не показана из-за своего большого размера Загрузить разницу