add genai metrics endpoint in UI for model overview metrics (#2517) (#2520)

This commit is contained in:
Ilya Matiach 2024-01-30 14:00:41 -05:00 коммит произвёл GitHub
Родитель 7aa72fbac0
Коммит 84428aa63f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
15 изменённых файлов: 367 добавлений и 28 удалений

Просмотреть файл

@ -71,6 +71,20 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
abortSignal
);
};
callBack.requestGenerativeTextMetrics = async (
selectionIndexes: number[][],
generativeTextCache: Map<string, Map<string, number>>,
abortSignal: AbortSignal
): Promise<any[]> => {
const parameters = [selectionIndexes, generativeTextCache];
return connectToFlaskServiceWithBackupCall(
this.props.config,
parameters,
"handle_generative_text_json",
"/get_generative_text_metrics",
abortSignal
);
};
callBack.requestMatrix = async (
data: any[]
): Promise<IErrorAnalysisMatrix> => {

Просмотреть файл

@ -16,6 +16,7 @@ export interface IModelAssessmentProps {
export type CallbackType = Pick<
IModelAssessmentDashboardProps,
| "requestExp"
| "requestGenerativeTextMetrics"
| "requestObjectDetectionMetrics"
| "requestPredictions"
| "requestQuestionAnsweringMetrics"

Просмотреть файл

@ -56,6 +56,7 @@ export * from "./lib/util/getFilterBoundsArgs";
export * from "./lib/util/calculateBoxData";
export * from "./lib/util/calculateConfusionMatrixData";
export * from "./lib/util/calculateLineData";
export * from "./lib/util/GenerativeTextStatisticsUtils";
export * from "./lib/util/MultilabelStatisticsUtils";
export * from "./lib/util/ObjectDetectionStatisticsUtils";
export * from "./lib/util/QuestionAnsweringStatisticsUtils";

Просмотреть файл

@ -140,6 +140,13 @@ export interface IModelAssessmentContext {
requestExp?:
| ((index: number | number[], abortSignal: AbortSignal) => Promise<any[]>)
| undefined;
requestGenerativeTextMetrics?:
| ((
selectionIndexes: number[][],
generativeTextCache: Map<string, Map<string, number>>,
abortSignal: AbortSignal
) => Promise<any[]>)
| undefined;
requestObjectDetectionMetrics?:
| ((
selectionIndexes: number[][],

Просмотреть файл

@ -8,6 +8,7 @@ import { JointDataset } from "../util/JointDataset";
export enum ModelTypes {
Regression = "regression",
Binary = "binary",
GenerativeText = "generativetext",
Multiclass = "multiclass",
ImageBinary = "imagebinary",
ImageMulticlass = "imagemulticlass",

Просмотреть файл

@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { localization } from "@responsible-ai/localization";
import {
ILabeledStatistic,
TotalCohortSamples
} from "../Interfaces/IStatistic";
import { QuestionAnsweringMetrics } from "./QuestionAnsweringStatisticsUtils";
export enum GenerativeTextMetrics {
Coherence = "coherence",
Fluency = "fluency",
Equivalence = "equivalence",
Groundedness = "groundedness",
Relevance = "relevance"
}
export const generateGenerativeTextStats: (
selectionIndexes: number[][],
generativeTextCache: Map<string, Map<string, number>>
) => ILabeledStatistic[][] = (
selectionIndexes: number[][],
generativeTextCache: Map<string, Map<string, number>>
): ILabeledStatistic[][] => {
return selectionIndexes.map((selectionArray) => {
const count = selectionArray.length;
const value = generativeTextCache.get(selectionArray.toString());
const stat: Map<string, number> = value ? value : new Map<string, number>();
const stats = [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: count
}
];
for (const [key, value] of stat.entries()) {
let label = "";
switch (key) {
case GenerativeTextMetrics.Coherence:
label = localization.Interpret.Statistics.coherence;
break;
case GenerativeTextMetrics.Fluency:
label = localization.Interpret.Statistics.fluency;
break;
case GenerativeTextMetrics.Equivalence:
label = localization.Interpret.Statistics.equivalence;
break;
case GenerativeTextMetrics.Groundedness:
label = localization.Interpret.Statistics.groundedness;
break;
case GenerativeTextMetrics.Relevance:
label = localization.Interpret.Statistics.relevance;
break;
case QuestionAnsweringMetrics.ExactMatchRatio:
label = localization.Interpret.Statistics.exactMatchRatio;
break;
case QuestionAnsweringMetrics.F1Score:
label = localization.Interpret.Statistics.f1Score;
break;
case QuestionAnsweringMetrics.MeteorScore:
label = localization.Interpret.Statistics.meteorScore;
break;
case QuestionAnsweringMetrics.BleuScore:
label = localization.Interpret.Statistics.bleuScore;
break;
case QuestionAnsweringMetrics.BertScore:
label = localization.Interpret.Statistics.bertScore;
break;
case QuestionAnsweringMetrics.RougeScore:
label = localization.Interpret.Statistics.rougeScore;
break;
default:
break;
}
stats.push({
key,
label,
stat: value
});
}
return stats;
});
};

Просмотреть файл

@ -10,6 +10,7 @@ import {
} from "../Interfaces/IStatistic";
import { IsBinary } from "../util/ExplanationUtils";
import { generateGenerativeTextStats } from "./GenerativeTextStatisticsUtils";
import { JointDataset } from "./JointDataset";
import { ClassificationEnum } from "./JointDatasetUtils";
import { generateMulticlassStats } from "./MulticlassStatisticsUtils";
@ -156,7 +157,8 @@ export const generateMetrics: (
modelType: ModelTypes,
objectDetectionCache?: Map<string, [number, number, number]>,
objectDetectionInputs?: [string, string, number],
questionAnsweringCache?: QuestionAnsweringCacheType
questionAnsweringCache?: QuestionAnsweringCacheType,
generativeTextCache?: Map<string, Map<string, number>>
): ILabeledStatistic[][] => {
if (
modelType === ModelTypes.ImageMultilabel ||
@ -192,6 +194,9 @@ export const generateMetrics: (
objectDetectionInputs
);
}
if (modelType === ModelTypes.GenerativeText && generativeTextCache) {
return generateGenerativeTextStats(selectionIndexes, generativeTextCache);
}
const outcomes = jointDataset.unwrap(JointDataset.ClassificationError);
if (IsBinary(modelType)) {
return selectionIndexes.map((selectionArray) => {

Просмотреть файл

@ -58,9 +58,12 @@ export class MetricSelector extends React.Component<IMetricSelectorProps> {
options.push(this.addDropdownOption(Metrics.AccuracyScore));
} else if (
IsMultilabel(modelType) ||
modelType === ModelTypes.ObjectDetection
modelType === ModelTypes.ObjectDetection ||
modelType === ModelTypes.QuestionAnswering
) {
options.push(this.addDropdownOption(Metrics.ErrorRate));
} else if (modelType === ModelTypes.GenerativeText) {
options.push(this.addDropdownOption(Metrics.MeanSquaredError));
}
return (
<Dropdown

Просмотреть файл

@ -1221,12 +1221,16 @@
"_rSquared.comment": "the coefficient of determination, see https://en.wikipedia.org/wiki/Coefficient_of_determination",
"_recall.comment": "computed recall of model, see https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers",
"accuracy": "Accuracy: {0}",
"coherence": "Coherence: {0}",
"bleuScore": "Bleu score: {0}",
"bertScore": "Bert score: {0}",
"exactMatchRatio": "Exact match ratio: {0}",
"equivalence": "Equivalence: {0}",
"rougeScore": "Rouge Score: {0}",
"fluency": "Fluency: {0}",
"fnr": "False negative rate: {0}",
"fpr": "False positive rate: {0}",
"groundedness": "Groundedness: {0}",
"hammingScore": "Hamming score: {0}",
"meanPrediction": "Mean prediction {0}",
"meteorScore": "Meteor Score: {0}",
@ -1234,6 +1238,7 @@
"precision": "Precision: {0}",
"rSquared": "R²: {0}",
"recall": "Recall: {0}",
"relevance": "Relevance: {0}",
"selectionRate": "Selection rate: {0}",
"mae": "Mean absolute error: {0}",
"f1Score": "F1 score: {0}",
@ -1766,10 +1771,26 @@
"name": "Accuracy score",
"description": "The fraction of data points classified correctly."
},
"coherence": {
"name": "Coherence",
"description": "Coherence of an answer is measured by how well all the sentences fit together and sound naturally as a whole."
},
"fluency": {
"name": "Fluency",
"description": "Fluency measures the quality of individual sentences in the answer, and whether they are well-written and grammatically correct."
},
"equivalence": {
"name": "Equivalence",
"description": "Equivalence, as a metric, measures the similarity between the predicted answer and the correct answer."
},
"exactMatchRatio": {
"name": "Exact match ratio",
"description": "The ratio of instances classified correctly for every label."
},
"groundedness": {
"name": "Groundedness",
"description": "Groundedness measures whether the answer follows logically from the information in the context."
},
"meteorScore": {
"name": "Meteor Score",
"description": "METEOR Score is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision in question answering task."
@ -1782,6 +1803,10 @@
"name": "Bert Score",
"description": "BERTScore focuses on computing semantic similarity between tokens of reference and machine generated text in question answering task."
},
"relevance": {
"name": "Relevance",
"description": "Relevance measures how well the answer addresses the main aspects of the question, based on the context"
},
"rougeScore": {
"name": "Rouge Score",
"description": "Rouge Score measures the ratio of words (and/or n-grams) in the reference text that appeared in the machine generated text in question answering task."

Просмотреть файл

@ -34,6 +34,7 @@ import {
TelemetryEventName,
DatasetTaskType,
QuestionAnsweringMetrics,
GenerativeTextMetrics,
TotalCohortSamples
} from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
@ -53,20 +54,6 @@ import { getSelectableMetrics } from "./StatsTableUtils";
interface IModelOverviewProps {
telemetryHook?: (message: ITelemetryEvent) => void;
requestObjectDetectionMetrics?: (
selectionIndexes: number[][],
aggregateMethod: string,
className: string,
iouThreshold: number,
objectDetectionCache: Map<string, [number, number, number]>
) => Promise<any[]>;
requestQuestionAnsweringMetrics?: (
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>
) => Promise<any[]>;
}
interface IModelOverviewState {
@ -88,6 +75,7 @@ interface IModelOverviewState {
featureBasedCohortLabeledStatistics: ILabeledStatistic[][];
featureBasedCohorts: ErrorCohort[];
iouThreshold: number;
generativeTextAbortController: AbortController | undefined;
objectDetectionAbortController: AbortController | undefined;
questionAnsweringAbortController: AbortController | undefined;
}
@ -100,6 +88,7 @@ export class ModelOverview extends React.Component<
IModelOverviewState
> {
public static contextType = ModelAssessmentContext;
public generativeTextCache: Map<string, Map<string, number>> = new Map();
public questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
@ -125,6 +114,7 @@ export class ModelOverview extends React.Component<
featureBasedCohortLabeledStatistics: [],
featureBasedCohorts: [],
featureConfigurationIsVisible: false,
generativeTextAbortController: undefined,
iouThreshold: 70,
metricConfigurationIsVisible: false,
objectDetectionAbortController: undefined,
@ -184,6 +174,14 @@ export class ModelOverview extends React.Component<
QuestionAnsweringMetrics.F1Score,
QuestionAnsweringMetrics.BertScore
];
} else if (
this.context.dataset.task_type === DatasetTaskType.GenerativeText
) {
defaultSelectedMetrics = [
GenerativeTextMetrics.Fluency,
GenerativeTextMetrics.Coherence,
GenerativeTextMetrics.Relevance
];
} else {
// task_type === "regression"
defaultSelectedMetrics = [
@ -633,6 +631,10 @@ export class ModelOverview extends React.Component<
this.context.modelMetadata.modelType === ModelTypes.QuestionAnswering
) {
this.updateQuestionAnsweringMetrics(selectionIndexes, true);
} else if (
this.context.modelMetadata.modelType === ModelTypes.GenerativeText
) {
this.updateGenerativeTextMetrics(selectionIndexes, true);
}
};
@ -838,6 +840,108 @@ export class ModelOverview extends React.Component<
}
}
private updateGenerativeTextMetrics(
selectionIndexes: number[][],
isDatasetCohort: boolean
): void {
if (this.state.generativeTextAbortController !== undefined) {
this.state.generativeTextAbortController.abort();
}
const newAbortController = new AbortController();
this.setState({ generativeTextAbortController: newAbortController });
if (
this.context.requestGenerativeTextMetrics &&
selectionIndexes.length > 0
) {
this.context
.requestGenerativeTextMetrics(
selectionIndexes,
this.generativeTextCache,
newAbortController.signal
)
.then((result) => {
// Assumption: the lengths of `result` and `selectionIndexes` are the same.
const updatedMetricStats: ILabeledStatistic[][] = [];
for (const [cohortIndex, metrics] of result.entries()) {
const count = selectionIndexes[cohortIndex].length;
const metricsMap = new Map<string, number>(Object.entries(metrics));
if (
!this.generativeTextCache.has(
selectionIndexes[cohortIndex].toString()
)
) {
this.generativeTextCache.set(
selectionIndexes[cohortIndex].toString(),
metricsMap
);
}
const updatedCohortMetricStats = [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: count
}
];
for (const [key, value] of metricsMap.entries()) {
let label = "";
switch (key) {
case GenerativeTextMetrics.Coherence:
label = localization.Interpret.Statistics.coherence;
break;
case GenerativeTextMetrics.Fluency:
label = localization.Interpret.Statistics.fluency;
break;
case GenerativeTextMetrics.Equivalence:
label = localization.Interpret.Statistics.equivalence;
break;
case GenerativeTextMetrics.Groundedness:
label = localization.Interpret.Statistics.groundedness;
break;
case GenerativeTextMetrics.Relevance:
label = localization.Interpret.Statistics.relevance;
break;
case QuestionAnsweringMetrics.ExactMatchRatio:
label = localization.Interpret.Statistics.exactMatchRatio;
break;
case QuestionAnsweringMetrics.F1Score:
label = localization.Interpret.Statistics.f1Score;
break;
case QuestionAnsweringMetrics.MeteorScore:
label = localization.Interpret.Statistics.meteorScore;
break;
case QuestionAnsweringMetrics.BleuScore:
label = localization.Interpret.Statistics.bleuScore;
break;
case QuestionAnsweringMetrics.BertScore:
label = localization.Interpret.Statistics.bertScore;
break;
case QuestionAnsweringMetrics.RougeScore:
label = localization.Interpret.Statistics.rougeScore;
break;
default:
break;
}
updatedCohortMetricStats.push({
key,
label,
stat: value
});
}
updatedMetricStats.push(updatedCohortMetricStats);
}
isDatasetCohort
? this.updateDatasetCohortState(updatedMetricStats)
: this.updateFeatureCohortState(updatedMetricStats);
});
}
}
private updateDatasetCohortState(
cohortMetricStats: ILabeledStatistic[][]
): void {
@ -884,6 +988,10 @@ export class ModelOverview extends React.Component<
this.context.modelMetadata.modelType === ModelTypes.QuestionAnswering
) {
this.updateQuestionAnsweringMetrics(selectionIndexes, false);
} else if (
this.context.modelMetadata.modelType === ModelTypes.GenerativeText
) {
this.updateGenerativeTextMetrics(selectionIndexes, false);
}
};
@ -998,6 +1106,8 @@ export class ModelOverview extends React.Component<
abortController = this.state.objectDetectionAbortController;
} else if (taskType === DatasetTaskType.QuestionAnswering) {
abortController = this.state.questionAnsweringAbortController;
} else if (taskType === DatasetTaskType.GenerativeText) {
abortController = this.state.generativeTextAbortController;
}
if (abortController !== undefined) {
abortController.abort();

Просмотреть файл

@ -6,6 +6,7 @@ import {
BinaryClassificationMetrics,
DatasetTaskType,
ErrorCohort,
GenerativeTextMetrics,
HighchartsNull,
ILabeledStatistic,
ModelTypes,
@ -154,7 +155,8 @@ export function generateCohortsStatsTable(
const colorConfig =
useTexturedBackgroundForNaN &&
modelType !== ModelTypes.ObjectDetection &&
modelType !== ModelTypes.QuestionAnswering
modelType !== ModelTypes.QuestionAnswering &&
modelType !== ModelTypes.GenerativeText
? {
color: {
pattern: {
@ -458,6 +460,90 @@ export function getSelectableMetrics(
text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.bertScore
.description,
key: QuestionAnsweringMetrics.BertScore,
text: localization.ModelAssessment.ModelOverview.metrics.bertScore.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.rougeScore
.description,
key: QuestionAnsweringMetrics.RougeScore,
text: localization.ModelAssessment.ModelOverview.metrics.rougeScore.name
}
);
} else if (taskType === DatasetTaskType.GenerativeText) {
selectableMetrics.push(
{
description:
localization.ModelAssessment.ModelOverview.metrics.coherence
.description,
key: GenerativeTextMetrics.Coherence,
text: localization.ModelAssessment.ModelOverview.metrics.coherence.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.fluency
.description,
key: GenerativeTextMetrics.Fluency,
text: localization.ModelAssessment.ModelOverview.metrics.fluency.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.equivalence
.description,
key: GenerativeTextMetrics.Equivalence,
text: localization.ModelAssessment.ModelOverview.metrics.equivalence
.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.groundedness
.description,
key: GenerativeTextMetrics.Groundedness,
text: localization.ModelAssessment.ModelOverview.metrics.groundedness
.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.relevance
.description,
key: GenerativeTextMetrics.Relevance,
text: localization.ModelAssessment.ModelOverview.metrics.relevance.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio
.description,
key: QuestionAnsweringMetrics.ExactMatchRatio,
text: localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio
.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.meteorScore
.description,
key: QuestionAnsweringMetrics.MeteorScore,
text: localization.ModelAssessment.ModelOverview.metrics.meteorScore
.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.f1Score
.description,
key: QuestionAnsweringMetrics.F1Score,
text: localization.ModelAssessment.ModelOverview.metrics.f1Score.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.bleuScore
.description,
key: QuestionAnsweringMetrics.BleuScore,
text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name
},
{
description:
localization.ModelAssessment.ModelOverview.metrics.bertScore

Просмотреть файл

@ -57,14 +57,6 @@ export interface ITabsViewProps {
request: any[],
abortSignal: AbortSignal
) => Promise<any[]>;
requestQuestionAnsweringMetrics?: (
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>,
abortSignal: AbortSignal
) => Promise<any[]>;
requestDebugML?: (
request: any[],
abortSignal: AbortSignal

Просмотреть файл

@ -89,6 +89,7 @@ export class ModelAssessmentDashboard extends CohortBasedComponent<
this.props.requestDatasetAnalysisBoxChart,
requestExp: this.props.requestExp,
requestForecast: this.props.requestForecast,
requestGenerativeTextMetrics: this.props.requestGenerativeTextMetrics,
requestGlobalCausalEffects: this.props.requestGlobalCausalEffects,
requestGlobalCausalPolicy: this.props.requestGlobalCausalPolicy,
requestGlobalExplanations: this.props.requestGlobalExplanations,
@ -143,9 +144,6 @@ export class ModelAssessmentDashboard extends CohortBasedComponent<
this.props.requestObjectDetectionMetrics
}
requestPredictions={this.props.requestPredictions}
requestQuestionAnsweringMetrics={
this.props.requestQuestionAnsweringMetrics
}
requestDebugML={this.props.requestDebugML}
requestImportances={this.props.requestImportances}
requestMatrix={this.props.requestMatrix}

Просмотреть файл

@ -115,6 +115,11 @@ export interface IModelAssessmentDashboardProps
index: number | number[],
abortSignal: AbortSignal
) => Promise<any[]>;
requestGenerativeTextMetrics?: (
selectionIndexes: number[][],
generativeTextCache: Map<string, Map<string, number>>,
abortSignal: AbortSignal
) => Promise<any>;
requestObjectDetectionMetrics?: (
selectionIndexes: number[][],
aggregateMethod: string,

Просмотреть файл

@ -53,5 +53,8 @@ export function getModelTypeFromProps(
if (taskType === DatasetTaskType.QuestionAnswering) {
return ModelTypes.QuestionAnswering;
}
if (taskType === DatasetTaskType.GenerativeText) {
return ModelTypes.GenerativeText;
}
return modelType;
}