Add multiclass classification dataset & set up basic model-assessment tests (#1191)
* add wine dataset for multiclass classification * add basic tests for model overview section for binary and multiclass classification as well as regression * lintfix * undo changes to notebook * update tests to reflect changes on main (addition of new metrics and string changes), simplify test setup * lintfix
This commit is contained in:
Родитель
ea2e14f39d
Коммит
f0f9c7f7b4
|
@ -23,35 +23,29 @@ export function describeModelPerformanceStats(dataShape: IInterpretData): void {
|
|||
cy.get('#OverallMetricChart div[class*="statsBox"]').should("exist");
|
||||
});
|
||||
it("should have some stats", () => {
|
||||
let expectedMetrics: string[];
|
||||
if (dataShape.isClassification) {
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"Accuracy"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"Precision"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains("Recall");
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"False positive rate"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"False negative rate"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
expectedMetrics = [
|
||||
"Accuracy",
|
||||
"Precision",
|
||||
"F1 score",
|
||||
"False positive rate",
|
||||
"False negative rate",
|
||||
"Selection rate"
|
||||
);
|
||||
];
|
||||
} else {
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"Mean squared error"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
"Mean absolute error"
|
||||
);
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains("R²");
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
expectedMetrics = [
|
||||
"Mean squared error",
|
||||
"Mean absolute error",
|
||||
"R²",
|
||||
"Mean prediction"
|
||||
);
|
||||
];
|
||||
}
|
||||
expectedMetrics.forEach((metricName) => {
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
metricName
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
export interface IModelAssessmentData {
|
||||
isClassification?: boolean;
|
||||
isMulticlass?: boolean;
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { IModelAssessmentData } from "./IModelAssessmentData";
|
||||
|
||||
const modelAssessmentDatasets = {
|
||||
adultCensusIncomeData: {
|
||||
isClassification: true,
|
||||
isMulticlass: false
|
||||
},
|
||||
bostonData: {
|
||||
isClassification: false
|
||||
},
|
||||
wineData: {
|
||||
isClassification: true,
|
||||
isMulticlass: true
|
||||
}
|
||||
};
|
||||
const withType: {
|
||||
[key in keyof typeof modelAssessmentDatasets]: IModelAssessmentData;
|
||||
} = modelAssessmentDatasets;
|
||||
|
||||
export { withType as modelAssessmentDatasets };
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
export function describeAxisConfigDialog(): void {
|
||||
describe("Axis settings dialog", () => {
|
||||
describe("Y Axis settings dialog", () => {
|
||||
it("should display settings dialog", () => {
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="rotatedVerticalBox"] button'
|
||||
).click();
|
||||
cy.get("#AxisConfigPanel div.ms-Panel-main").should("exist");
|
||||
});
|
||||
it("should be able to hide settings", () => {
|
||||
cy.get("#AxisConfigPanel button.ms-Panel-closeButton").click();
|
||||
cy.get("#AxisConfigPanel div.ms-Panel-main").should("not.exist");
|
||||
});
|
||||
it("should change to different y-axis title", () => {
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="rotatedVerticalBox"] button'
|
||||
).click();
|
||||
|
||||
cy.get("#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:eq(1)")
|
||||
.invoke("text")
|
||||
.then((text1) => {
|
||||
cy.get(`#AxisConfigPanel label:contains(${text1})`).click();
|
||||
cy.get("#AxisConfigPanel")
|
||||
.find("button")
|
||||
.contains("Select")
|
||||
.click();
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="rotatedVerticalBox"] button:eq(0)'
|
||||
).contains(text1);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("X Axis settings dialog", () => {
|
||||
it("should display settings dialog", () => {
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="horizontalAxis"] button'
|
||||
).click();
|
||||
cy.get("#AxisConfigPanel div.ms-Panel-main").should("exist");
|
||||
});
|
||||
it("should be able to hide settings", () => {
|
||||
cy.get("#AxisConfigPanel button.ms-Panel-closeButton").click();
|
||||
cy.get("#AxisConfigPanel div.ms-Panel-main").should("not.exist");
|
||||
});
|
||||
it("should change to different x-axis title", () => {
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="horizontalAxis"] button'
|
||||
).click();
|
||||
|
||||
cy.get("#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:eq(1)")
|
||||
.invoke("text")
|
||||
.then((text1) => {
|
||||
cy.get(`#AxisConfigPanel label:contains(${text1})`).click();
|
||||
cy.get("#AxisConfigPanel")
|
||||
.find("button")
|
||||
.contains("Select")
|
||||
.click();
|
||||
cy.get(
|
||||
'#OverallMetricChart div[class*="horizontalAxis"] button:eq(0)'
|
||||
).contains(text1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { modelAssessmentDatasets } from "../modelAssessmentDatasets";
|
||||
|
||||
import { describeAxisConfigDialog } from "./describeAxisConfigDialog";
|
||||
import { describeModelPerformanceStats } from "./describeModelPerformanceStats";
|
||||
|
||||
const testName = "Model overview";
|
||||
export function describeModelOverview(
|
||||
name: keyof typeof modelAssessmentDatasets
|
||||
): void {
|
||||
describe(testName, () => {
|
||||
before(() => {
|
||||
cy.visit(`#/modelAssessment/${name}/light/english/Version-2`);
|
||||
});
|
||||
it("Model overview title", () => {
|
||||
cy.get("#modelStatisticsHeader").contains("Model overview");
|
||||
});
|
||||
describe("Model performance Chart", () => {
|
||||
describeAxisConfigDialog();
|
||||
});
|
||||
describe("Model performance stats", () => {
|
||||
describeModelPerformanceStats(modelAssessmentDatasets[name]);
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { IModelAssessmentData } from "../IModelAssessmentData";
|
||||
|
||||
export function describeModelPerformanceStats(
|
||||
modelAssessmentData: IModelAssessmentData
|
||||
): void {
|
||||
describe("performance stats", () => {
|
||||
it("should have legend", () => {
|
||||
cy.get('#OverallMetricChart g[class*="infolayer"]').should("exist");
|
||||
});
|
||||
// stats box currently not available for multiclass
|
||||
if (!modelAssessmentData.isMulticlass) {
|
||||
it("should have stats box", () => {
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').should("exist");
|
||||
});
|
||||
it("should have some stats", () => {
|
||||
let expectedMetrics: string[];
|
||||
if (modelAssessmentData.isClassification) {
|
||||
expectedMetrics = [
|
||||
"Accuracy",
|
||||
"Precision",
|
||||
"F1 score",
|
||||
"False positive rate",
|
||||
"False negative rate",
|
||||
"Selection rate"
|
||||
];
|
||||
} else {
|
||||
expectedMetrics = [
|
||||
"Mean squared error",
|
||||
"Mean absolute error",
|
||||
"R²",
|
||||
"Mean prediction"
|
||||
];
|
||||
}
|
||||
expectedMetrics.forEach((metricName) => {
|
||||
cy.get('#OverallMetricChart div[class*="statsBox"]').contains(
|
||||
metricName
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
|
||||
|
||||
describeModelOverview("adultCensusIncomeData");
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
|
||||
|
||||
describeModelOverview("bostonData");
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { describeModelOverview } from "../../../describer/modelAssessment/modelOverview/describeModelOverview";
|
||||
|
||||
describeModelOverview("wineData");
|
|
@ -49,6 +49,11 @@ import {
|
|||
bostonErrorAnalysisData,
|
||||
bostonWithFairnessModelExplanationData
|
||||
} from "../model-assessment/__mock_data__/bostonData";
|
||||
import {
|
||||
wineData as wineDataMAD,
|
||||
wineErrorAnalysisData,
|
||||
wineWithFairnessModelExplanationData
|
||||
} from "../model-assessment/__mock_data__/wineData";
|
||||
|
||||
export interface IInterpretDataSet {
|
||||
data: IExplanationDashboardData;
|
||||
|
@ -195,7 +200,7 @@ export const applications: IApplications = <const>{
|
|||
adultCensusIncomeNoModelData: {
|
||||
classDimension: 2,
|
||||
dataset: adultCensusWithFairnessDataset
|
||||
},
|
||||
} as IModelAssessmentDataSet,
|
||||
bostonData: {
|
||||
causalAnalysisData: [bostonCensusCausalAnalysisData],
|
||||
classDimension: 1,
|
||||
|
@ -203,6 +208,12 @@ export const applications: IApplications = <const>{
|
|||
dataset: bostonDataMAD,
|
||||
errorAnalysisData: [bostonErrorAnalysisData],
|
||||
modelExplanationData: [bostonWithFairnessModelExplanationData]
|
||||
} as IModelAssessmentDataSet,
|
||||
wineData: {
|
||||
classDimension: 3,
|
||||
dataset: wineDataMAD,
|
||||
errorAnalysisData: [wineErrorAnalysisData],
|
||||
modelExplanationData: [wineWithFairnessModelExplanationData]
|
||||
} as IModelAssessmentDataSet
|
||||
},
|
||||
versions: { "1": 1, "2:Static-View": 2 }
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { IErrorAnalysisTreeNode } from "@responsible-ai/core-ui";
|
||||
|
||||
export const dummyTreeWineData: IErrorAnalysisTreeNode[] = [
|
||||
{
|
||||
arg: undefined,
|
||||
badFeaturesRowCount: 0,
|
||||
condition: "",
|
||||
error: 49,
|
||||
id: 0,
|
||||
isErrorMetric: true,
|
||||
method: "",
|
||||
metricName: "Error rate",
|
||||
metricValue: 0.550561797752809,
|
||||
nodeIndex: 0,
|
||||
nodeName: "color_intensity",
|
||||
parentId: undefined,
|
||||
parentNodeName: "",
|
||||
pathFromRoot: "",
|
||||
size: 89,
|
||||
sourceRowKeyHash: "hashkey",
|
||||
success: 40
|
||||
},
|
||||
{
|
||||
arg: 3.8700000000000006,
|
||||
badFeaturesRowCount: 0,
|
||||
condition: "color_intensity <= 3.87",
|
||||
error: 2,
|
||||
id: 2,
|
||||
isErrorMetric: true,
|
||||
method: "less and equal",
|
||||
metricName: "Error rate",
|
||||
metricValue: 0.05714285714285714,
|
||||
nodeIndex: 2,
|
||||
nodeName: "",
|
||||
parentId: 0,
|
||||
parentNodeName: "color_intensity",
|
||||
pathFromRoot: "",
|
||||
size: 35,
|
||||
sourceRowKeyHash: "hashkey",
|
||||
success: 33
|
||||
},
|
||||
{
|
||||
arg: 3.8700000000000006,
|
||||
badFeaturesRowCount: 0,
|
||||
condition: "color_intensity > 3.87",
|
||||
error: 47,
|
||||
id: 1,
|
||||
isErrorMetric: true,
|
||||
method: "greater",
|
||||
metricName: "Error rate",
|
||||
metricValue: 0.8703703703703703,
|
||||
nodeIndex: 1,
|
||||
nodeName: "proline",
|
||||
parentId: 0,
|
||||
parentNodeName: "color_intensity",
|
||||
pathFromRoot: "",
|
||||
size: 54,
|
||||
sourceRowKeyHash: "hashkey",
|
||||
success: 7
|
||||
},
|
||||
{
|
||||
arg: 666.0000000000001,
|
||||
badFeaturesRowCount: 0,
|
||||
condition: "proline <= 666.00",
|
||||
error: 13,
|
||||
id: 3,
|
||||
isErrorMetric: true,
|
||||
method: "less and equal",
|
||||
metricName: "Error rate",
|
||||
metricValue: 0.65,
|
||||
nodeIndex: 3,
|
||||
nodeName: "",
|
||||
parentId: 1,
|
||||
parentNodeName: "proline",
|
||||
pathFromRoot: "",
|
||||
size: 20,
|
||||
sourceRowKeyHash: "hashkey",
|
||||
success: 7
|
||||
},
|
||||
{
|
||||
arg: 666.0000000000001,
|
||||
badFeaturesRowCount: 0,
|
||||
condition: "proline > 666.00",
|
||||
error: 34,
|
||||
id: 4,
|
||||
isErrorMetric: true,
|
||||
method: "greater",
|
||||
metricName: "Error rate",
|
||||
metricValue: 1,
|
||||
nodeIndex: 4,
|
||||
nodeName: "",
|
||||
parentId: 1,
|
||||
parentNodeName: "proline",
|
||||
pathFromRoot: "",
|
||||
size: 34,
|
||||
sourceRowKeyHash: "hashkey",
|
||||
success: 0
|
||||
}
|
||||
];
|
|
@ -15,13 +15,15 @@ import { dummyTreeBostonData } from "./__mock_data__/dummyTreeBoston";
|
|||
import { dummyTreeBreastCancerData } from "./__mock_data__/dummyTreeBreastCancer";
|
||||
import { dummyTreeBreastCancerPrecisionData } from "./__mock_data__/dummyTreeBreastCancerPrecision";
|
||||
import { dummyTreeBreastCancerRecallData } from "./__mock_data__/dummyTreeBreastCancerRecall";
|
||||
import { dummyTreeWineData } from "./__mock_data__/dummyTreeWine";
|
||||
|
||||
export enum DatasetName {
|
||||
AdultCensusIncome = 1,
|
||||
BreastCancer,
|
||||
Boston,
|
||||
BreastCancerPrecision,
|
||||
BreastCancerRecall
|
||||
BreastCancerRecall,
|
||||
Wine
|
||||
}
|
||||
|
||||
export function getJsonMatrix(): any {
|
||||
|
@ -81,6 +83,8 @@ export function getJsonTree(dataset: DatasetName): any {
|
|||
return _.cloneDeep(dummyTreeBreastCancerPrecisionData);
|
||||
} else if (dataset === DatasetName.BreastCancerRecall) {
|
||||
return _.cloneDeep(dummyTreeBreastCancerRecallData);
|
||||
} else if (dataset === DatasetName.Wine) {
|
||||
return _.cloneDeep(dummyTreeWineData);
|
||||
}
|
||||
return _.cloneDeep(dummyTreeAdultCensusIncomeData);
|
||||
}
|
||||
|
@ -100,6 +104,8 @@ export function generateJsonTree(
|
|||
resolve(_.cloneDeep(dummyTreeBreastCancerPrecisionData));
|
||||
} else if (dataset === DatasetName.BreastCancerRecall) {
|
||||
resolve(_.cloneDeep(dummyTreeBreastCancerRecallData));
|
||||
} else if (dataset === DatasetName.Wine) {
|
||||
resolve(_.cloneDeep(dummyTreeWineData));
|
||||
} else {
|
||||
resolve(_.cloneDeep(dummyTreeBostonData));
|
||||
}
|
||||
|
@ -148,6 +154,13 @@ export function generateJsonTreeBoston(
|
|||
return generateJsonTree(_data, signal, DatasetName.Boston);
|
||||
}
|
||||
|
||||
export function generateJsonTreeWine(
|
||||
_data: any[],
|
||||
signal: AbortSignal
|
||||
): Promise<any> {
|
||||
return generateJsonTree(_data, signal, DatasetName.Wine);
|
||||
}
|
||||
|
||||
export function generateJsonMatrix(dataset: DatasetName) {
|
||||
return (data: any[], signal: AbortSignal): Promise<any> => {
|
||||
const promise = new Promise((resolve, reject) => {
|
||||
|
@ -200,7 +213,8 @@ export function createJsonImportancesGenerator(
|
|||
dataset === DatasetName.AdultCensusIncome ||
|
||||
dataset === DatasetName.Boston ||
|
||||
dataset === DatasetName.BreastCancerPrecision ||
|
||||
dataset === DatasetName.BreastCancerRecall
|
||||
dataset === DatasetName.BreastCancerRecall ||
|
||||
dataset === DatasetName.Wine
|
||||
);
|
||||
resolve(featureNames.map(() => Math.random()));
|
||||
}, 300);
|
||||
|
|
|
@ -19,6 +19,7 @@ import {
|
|||
DatasetName,
|
||||
generateJsonTreeBoston,
|
||||
generateJsonTreeAdultCensusIncome,
|
||||
generateJsonTreeWine,
|
||||
getJsonMatrix,
|
||||
getJsonTreeAdultCensusIncome
|
||||
} from "../error-analysis/utils";
|
||||
|
@ -69,7 +70,7 @@ export class App extends React.Component<IAppProps> {
|
|||
this.props.dataset.feature_names,
|
||||
DatasetName.Boston
|
||||
);
|
||||
} else {
|
||||
} else if (this.props.classDimension === 2) {
|
||||
// Adult
|
||||
modelAssessmentDashboardProps.requestDebugML =
|
||||
generateJsonTreeAdultCensusIncome;
|
||||
|
@ -78,6 +79,14 @@ export class App extends React.Component<IAppProps> {
|
|||
this.props.dataset.feature_names,
|
||||
DatasetName.AdultCensusIncome
|
||||
);
|
||||
} else {
|
||||
// Wine
|
||||
modelAssessmentDashboardProps.requestDebugML = generateJsonTreeWine;
|
||||
modelAssessmentDashboardProps.requestImportances =
|
||||
createJsonImportancesGenerator(
|
||||
this.props.dataset.feature_names,
|
||||
DatasetName.Wine
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const staticTree = getJsonTreeAdultCensusIncome(
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Загрузка…
Ссылка в новой задаче