Individual feature importance interpret QA (#2186)

* add

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* update

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* cleanup

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* lintfix

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* lintfix

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* lintfix

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* fix row change err

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

* address comments

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>

---------

Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>
This commit is contained in:
Vinutha Karanth 2023-07-24 10:32:44 -07:00 коммит произвёл GitHub
Родитель 37c340c8d6
Коммит 29bfb51675
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
20 изменённых файлов: 598 добавлений и 240 удалений

Просмотреть файл

@ -57,6 +57,7 @@ export interface IPrecomputedExplanations {
export interface ITextFeatureImportance {
text: string[];
localExplanations: number[][];
baseValues?: number[][];
}
export interface IEBMGlobalExplanation {

Просмотреть файл

@ -6,4 +6,7 @@ export interface ITextExplanationDashboardData {
localExplanations: number[][];
prediction: number[];
text: string[];
baseValues?: number[][];
predictedY?: number[] | number[][] | string[] | string | number;
trueY?: number[] | number[][] | string[] | string | number;
}

Просмотреть файл

@ -78,6 +78,14 @@ export class Utils {
return sortedList;
}
public static addItem(value: number, radio: string | undefined): boolean {
return (
radio === RadioKeys.All ||
(radio === RadioKeys.Neg && value <= 0) ||
(radio === RadioKeys.Pos && value >= 0)
);
}
public static takeTopK(list: number[], k: number): number[] {
/*
* Returns a list after splicing and taking the top K

Просмотреть файл

@ -2,17 +2,27 @@
// Licensed under the MIT License.
import { ITheme } from "@fluentui/react";
import {
IHighchartsConfig,
getPrimaryChartColor,
getPrimaryBackgroundChartColor
} from "@responsible-ai/core-ui";
import { IHighchartsConfig } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import { SeriesOptionsType } from "highcharts";
import _ from "lodash";
import { Utils } from "../../CommonUtils";
import { IChartProps } from "../../Interfaces/IChartProps";
function findNearestIndex(
array: number[],
target?: number
): number | undefined {
if (!target) {
return array.length;
}
const nearestElement = _.minBy(array, (element) =>
Math.abs(element - target)
);
return _.indexOf(array, nearestElement);
}
export function getTokenImportancesChartOptions(
props: IChartProps,
theme: ITheme
@ -20,6 +30,11 @@ export function getTokenImportancesChartOptions(
const importances = props.localExplanations;
const k = props.topK;
const sortedList = Utils.sortedTopK(importances, k, props.radio);
const outputFeatureImportanceLabel = `f ${
props.text[props.selectedTokenIndex || 0]
} (inputs)`;
const baseValueLabel = "base value";
const [x, y, ylabel, tooltip]: [number[], number[], string[], string[]] = [
[],
[],
@ -46,6 +61,36 @@ export function getTokenImportancesChartOptions(
ylabel.push(props.text[idx]);
tooltip.push(str);
});
// add output feature importance
if (props.outputFeatureValue && props.baseValue) {
const outputFeatureValueIndex = findNearestIndex(
x,
props.outputFeatureValue
);
const baseValueFeatureValueIndex = findNearestIndex(x, props.baseValue);
if (outputFeatureValueIndex && baseValueFeatureValueIndex) {
if (Utils.addItem(props.outputFeatureValue, props.radio)) {
addItem(
x,
props.outputFeatureValue,
ylabel,
outputFeatureImportanceLabel,
outputFeatureValueIndex
);
}
if (Utils.addItem(props.baseValue, props.radio)) {
addItem(
x,
props.baseValue,
ylabel,
baseValueLabel,
baseValueFeatureValueIndex
);
}
}
}
// Put most significant word at the top by reversing order
tooltip.reverse();
ylabel.reverse();
@ -54,11 +99,10 @@ export function getTokenImportancesChartOptions(
const data: any[] = [];
x.forEach((p, index) => {
const temp = {
borderColor: getPrimaryChartColor(theme),
color:
(p || 0) >= 0
? getPrimaryChartColor(theme)
: getPrimaryBackgroundChartColor(theme),
? theme.semanticColors.errorText
: theme.semanticColors.link,
x: index,
y: p
};
@ -68,6 +112,15 @@ export function getTokenImportancesChartOptions(
const series: SeriesOptionsType[] = [
{
data,
dataLabels: {
align: "center",
color: theme.semanticColors.bodyBackground,
enabled: true,
formatter(): string | number | undefined {
return this.x; // Display the Y-axis value inside the bar
},
inside: true
},
name: "",
showInLegend: false,
type: "bar"
@ -80,11 +133,12 @@ export function getTokenImportancesChartOptions(
},
plotOptions: {
bar: {
minPointLength: 10,
tooltip: {
pointFormatter(): string {
return `${tooltip[this.x || 0]}: ${this.y || 0}`;
}
}
} // Set the minimum pixel width for bars
}
},
series,
@ -98,3 +152,14 @@ export function getTokenImportancesChartOptions(
}
};
}
function addItem(
x: any[],
xValue: any,
yLabel: any[],
yLabelValue: any,
index: number
): void {
x.splice(index, 0, xValue);
yLabel.splice(index, 0, yLabelValue);
}

Просмотреть файл

@ -13,12 +13,13 @@ export interface ITextExplanationViewState {
maxK: number;
topK: number;
radio: string;
// qaRadio?: string;
qaRadio?: string;
importances: number[];
singleTokenImportances: number[];
selectedToken: number;
tokenIndexes: number[];
text: string[];
outputFeatureImportances: number[][];
}
export const options: IChoiceGroupOption[] = [

Просмотреть файл

@ -38,6 +38,9 @@ export interface ISidePanelOfChartProps {
selectedWeightVector: WeightVectorOption;
weightOptions: WeightVectorOption[];
weightLabels: any;
baseValue?: number;
outputFeatureValue?: number;
selectedTokenIndex?: number;
changeRadioButton: (
_event?: React.FormEvent,
item?: IChoiceGroupOption
@ -63,6 +66,9 @@ export class SidePanelOfChart extends React.PureComponent<ISidePanelOfChartProps
localExplanations={this.props.importances}
topK={this.props.topK}
radio={this.props.radio}
baseValue={this.props.baseValue}
outputFeatureValue={this.props.outputFeatureValue}
selectedTokenIndex={this.props.selectedTokenIndex}
/>
</Stack.Item>
<Stack.Item grow className={classNames.chartRight}>

Просмотреть файл

@ -11,16 +11,25 @@ import {
export interface ITextExplanationDashboardStyles {
chartRight: IStyle;
textHighlighting: IStyle;
predictedAnswer: IStyle;
boldText: IStyle;
}
export const textExplanationDashboardStyles: () => IProcessedStyleSet<ITextExplanationDashboardStyles> =
() => {
const theme = getTheme();
return mergeStyleSets<ITextExplanationDashboardStyles>({
boldText: {
fontWeight: "bold"
},
chartRight: {
maxWidth: "230px",
minWidth: "230px"
},
predictedAnswer: {
fontWeight: "bold",
paddingBottom: "14px"
},
textHighlighting: {
borderColor: theme.semanticColors.variantBorder,
borderRadius: "1px",

Просмотреть файл

@ -2,24 +2,29 @@
// Licensed under the MIT License.
import { IChoiceGroupOption, Stack, Text } from "@fluentui/react";
import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
import { WeightVectorOption } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import React from "react";
import { RadioKeys, Utils } from "../../CommonUtils";
import { QAExplanationType, RadioKeys } from "../../CommonUtils";
import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps";
import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
import {
ITextExplanationViewState,
MaxImportantWords,
componentStackTokens
} from "./ITextExplanationViewSpec";
import { SidePanelOfChart } from "./SidePanelOfChart";
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
import {
calculateMaxKImportances,
calculateTopKImportances,
computeImportancesForAllTokens,
computeImportancesForWeightVector,
getOutputFeatureImportances
} from "./TextExplanationViewUtils";
import { TextInputOutputAreaWithLegend } from "./TextInputOutputAreaWithLegend";
import { TrueAndPredictedAnswerView } from "./TrueAndPredictedAnswerView";
export class TextExplanationView extends React.PureComponent<
export class TextExplanationView extends React.Component<
ITextExplanationViewProps,
ITextExplanationViewState
> {
@ -31,23 +36,32 @@ export class TextExplanationView extends React.PureComponent<
const weightVector = this.props.selectedWeightVector;
const importances = this.props.isQA
? this.computeImportancesForAllTokens(
this.props.dataSummary.localExplanations
? computeImportancesForAllTokens(
this.props.dataSummary.localExplanations,
true
)
: this.computeImportancesForWeightVector(
: computeImportancesForWeightVector(
this.props.dataSummary.localExplanations,
weightVector
);
const maxK = this.calculateMaxKImportances(importances);
const topK = this.calculateTopKImportances(importances);
const maxK = calculateMaxKImportances(importances);
const topK = calculateTopKImportances(importances);
this.state = {
importances,
maxK,
// qaRadio: QAExplanationType.Start,
outputFeatureImportances: getOutputFeatureImportances(
this.props.dataSummary.localExplanations,
this.props.dataSummary.baseValues
),
qaRadio: QAExplanationType.Start,
radio: RadioKeys.All,
selectedToken: 0, // default to the first token
singleTokenImportances: this.getImportanceForSingleToken(0), // get importance for first token
selectedToken: 0,
// default to the first token
singleTokenImportances: this.props.dataSummary.localExplanations[0].map(
(row) => row[0]
),
// get importance for first token
text: this.props.dataSummary.text,
tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
topK
@ -60,28 +74,19 @@ export class TextExplanationView extends React.PureComponent<
this.props.dataSummary.localExplanations !==
prevProps.dataSummary.localExplanations
) {
if (this.props.isQA) {
this.setState(
{
selectedToken: 0,
//update token dropdown
tokenIndexes: [...this.props.dataSummary.text].map(
(_, index) => index
)
},
() => {
this.updateTokenImportances();
this.updateSingleTokenImportances();
}
);
} else {
this.updateImportances(this.props.selectedWeightVector);
}
this.updateState();
}
}
public render(): React.ReactNode {
const classNames = textExplanationDashboardStyles();
const outputLocalExplanations =
this.state.qaRadio === QAExplanationType.Start
? this.state.outputFeatureImportances[0]
: this.state.outputFeatureImportances[1];
const inputLocalExplanations = this.props.isQA
? this.state.singleTokenImportances
: this.state.importances;
const baseValue = this.props.isQA ? this.getBaseValue() : undefined;
return (
<Stack>
@ -93,9 +98,30 @@ export class TextExplanationView extends React.PureComponent<
)}
</Stack>
<Stack tokens={componentStackTokens} horizontal>
{this.props.isQA && (
<TrueAndPredictedAnswerView
predictedY={this.props.dataSummary.predictedY}
trueY={this.props.dataSummary.trueY}
/>
)}
</Stack>
<TextInputOutputAreaWithLegend
topK={this.state.topK}
radio={this.state.radio}
selectedToken={this.state.selectedToken}
text={this.state.text}
outputLocalExplanations={outputLocalExplanations}
inputLocalExplanations={inputLocalExplanations}
isQA={this.props.isQA}
getSelectedWord={this.getSelectedWord}
onSelectedTokenChange={this.onSelectedTokenChange}
/>
<SidePanelOfChart
text={this.state.text}
importances={this.state.importances}
importances={inputLocalExplanations}
topK={this.state.topK}
radio={this.state.radio}
isQA={this.props.isQA}
@ -111,153 +137,93 @@ export class TextExplanationView extends React.PureComponent<
setTopK={this.setTopK}
onWeightVectorChange={this.onWeightVectorChange}
onSelectedTokenChange={this.onSelectedTokenChange}
outputFeatureValue={outputLocalExplanations[this.state.selectedToken]}
baseValue={baseValue}
selectedTokenIndex={this.state.selectedToken}
/>
<Stack tokens={componentStackTokens} horizontal>
<Stack.Item
align="stretch"
grow
disableShrink
className={classNames.textHighlighting}
>
<TextHighlighting
text={this.state.text}
localExplanations={this.state.importances}
topK={this.state.topK}
radio={this.state.radio}
/>
</Stack.Item>
{this.props.isQA && (
<Stack.Item
align="stretch"
grow
disableShrink
className={classNames.textHighlighting}
>
<TextHighlighting
text={this.state.text}
localExplanations={this.state.singleTokenImportances}
topK={
// keep all importances for single token(set topK to length)
this.state.singleTokenImportances.length
}
radio={this.state.radio}
/>
</Stack.Item>
)}
<Stack.Item align="end">
<TextFeatureLegend />
</Stack.Item>
</Stack>
</Stack>
);
}
private updateState(): void {
const importances = this.props.isQA
? this.getTokenImportances()
: this.getImportances(this.props.selectedWeightVector);
const [topK, maxK] = this.getTopKMaxK(importances);
this.setState({
importances,
maxK,
outputFeatureImportances: getOutputFeatureImportances(
this.props.dataSummary.localExplanations,
this.props.dataSummary.baseValues
),
selectedToken: 0,
singleTokenImportances: this.getImportanceForSingleToken(
this.state.selectedToken
),
text: this.props.dataSummary.text,
tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
topK
});
}
private onWeightVectorChange = (weightOption: WeightVectorOption): void => {
this.updateImportances(weightOption);
const importances = this.getImportances(weightOption);
const [topK, maxK] = this.getTopKMaxK(importances);
this.setState({ importances, maxK, topK });
this.props.onWeightChange(weightOption);
};
private onSelectedTokenChange = (newIndex: number): void => {
this.setState({ selectedToken: newIndex }, () => {
this.updateSingleTokenImportances();
const singleTokenImportances = this.getImportanceForSingleToken(newIndex);
this.setState({
selectedToken: newIndex,
singleTokenImportances
});
};
private updateImportances(weightOption: WeightVectorOption): void {
const importances = this.computeImportancesForWeightVector(
private getSelectedWord = (): string => {
return this.props.dataSummary.text[this.state.selectedToken];
};
private getTopKMaxK(importances: number[]): [number, number] {
const topK = calculateTopKImportances(importances);
const maxK = calculateMaxKImportances(importances);
return [topK, maxK];
}
private getImportances(weightOption: WeightVectorOption): number[] {
return computeImportancesForWeightVector(
this.props.dataSummary.localExplanations,
weightOption
);
const topK = this.calculateTopKImportances(importances);
const maxK = this.calculateMaxKImportances(importances);
this.setState({
importances,
maxK,
text: this.props.dataSummary.text,
topK
});
}
// for QA
private updateTokenImportances(): void {
const importances = this.computeImportancesForAllTokens(
private getTokenImportances(): number[] {
return computeImportancesForAllTokens(
this.props.dataSummary.localExplanations
);
const topK = this.calculateTopKImportances(importances);
const maxK = this.calculateMaxKImportances(importances);
this.setState({
importances,
maxK,
text: this.props.dataSummary.text,
topK
});
}
private updateSingleTokenImportances(): void {
const singleTokenImportances = this.getImportanceForSingleToken(
this.state.selectedToken
);
this.setState({ singleTokenImportances });
}
private calculateTopKImportances(importances: number[]): number {
return Math.min(
MaxImportantWords,
Math.ceil(Utils.countNonzeros(importances) / 2)
);
}
private calculateMaxKImportances(importances: number[]): number {
return Math.min(
MaxImportantWords,
Math.ceil(Utils.countNonzeros(importances))
);
}
private computeImportancesForWeightVector(
importances: number[][],
weightVector: WeightVectorOption
): number[] {
if (weightVector === WeightVectors.AbsAvg) {
// Sum the multidimensional array to one dimension across rows for each token
const numClasses = importances[0].length;
const sumImportances = importances.map((row) =>
row.reduce((a, b): number => {
return (a + Math.abs(b)) / numClasses;
}, 0)
);
return sumImportances;
}
return importances.map(
(perClassImportances) => perClassImportances[weightVector as number]
);
}
private computeImportancesForAllTokens(importances: number[][]): number[] {
/*
* sum the tokens importance
* TODO: add base values?
*/
const sumImportances = importances[0].map((_, index) =>
importances.reduce((sum, row) => sum + row[index], 0)
);
return sumImportances;
}
private getImportanceForSingleToken(index: number): number[] {
return this.props.dataSummary.localExplanations.map((row) => row[index]);
const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
return this.props.dataSummary.localExplanations[expIndex].map(
(row) => row[index]
);
}
private getBaseValue(): number {
if (this.props.dataSummary.baseValues) {
const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
return this.props.dataSummary.baseValues?.[expIndex][
this.state.selectedToken
];
}
return 0;
}
private setTopK = (newNumber: number): void => {
/*
* Changes the state of K
*/
this.setState({ topK: newNumber });
};
@ -265,23 +231,23 @@ export class TextExplanationView extends React.PureComponent<
_event?: React.FormEvent,
item?: IChoiceGroupOption
): void => {
/*
* Changes the state of the radio button
*/
if (item?.key !== undefined) {
if (item?.key) {
this.setState({ radio: item.key });
}
};
private switchQAPrediction = (): // _event?: React.FormEvent,
// _item?: IChoiceGroupOption
void => {
/*
* switch to the target predictions(starting or ending)
* TODO: add logic for switching explanation data
*/
// if (item?.key !== undefined) {
// this.setState({ qaRadio: item.key });
// }
private switchQAPrediction = (
_event?: React.FormEvent,
item?: IChoiceGroupOption
): void => {
if (item?.key) {
const singleTokenImportances = this.getImportanceForSingleToken(
this.state.selectedToken
);
this.setState({
qaRadio: item.key,
singleTokenImportances
});
}
};
}

Просмотреть файл

@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
import { QAExplanationType, Utils } from "../../CommonUtils";
import { MaxImportantWords } from "./ITextExplanationViewSpec";
export function getOutputFeatureImportances(
localExplanations: number[][],
baseValues?: number[][]
): number[][] {
const startSumOfFeatureImportances = getSumOfFeatureImportances(
localExplanations[0]
);
const endSumOfFeatureImportances = getSumOfFeatureImportances(
localExplanations[1]
);
const startOutputFeatureImportances = getOutputFeatureImportancesIntl(
startSumOfFeatureImportances,
baseValues?.[0]
);
const endOutputFeatureImportances = getOutputFeatureImportancesIntl(
endSumOfFeatureImportances,
baseValues?.[1]
);
return [
startOutputFeatureImportances || [],
endOutputFeatureImportances || []
];
}
export function getSumOfFeatureImportances(importances: number[]): number[] {
return importances.map((_, index) =>
importances.reduce((sum, row) => sum + row[index], 0)
);
}
export function getOutputFeatureImportancesIntl(
sumOfFeatureImportances: number[],
baseValues?: number[]
): number[] | undefined {
return baseValues?.map(
(bValue, index) => sumOfFeatureImportances[index] + bValue
);
}
export function calculateTopKImportances(importances: number[]): number {
return Math.min(
MaxImportantWords,
Math.ceil(Utils.countNonzeros(importances) / 2)
);
}
export function calculateMaxKImportances(importances: number[]): number {
return Math.min(
MaxImportantWords,
Math.ceil(Utils.countNonzeros(importances))
);
}
export function computeImportancesForWeightVector(
importances: number[][],
weightVector: WeightVectorOption
): number[] {
if (weightVector === WeightVectors.AbsAvg) {
// Sum the multidimensional array to one dimension across rows for each token
const numClasses = importances[0].length;
const sumImportances = importances.map((row) =>
row.reduce((a, b): number => {
return (a + Math.abs(b)) / numClasses;
}, 0)
);
return sumImportances;
}
return importances.map(
(perClassImportances) => perClassImportances[weightVector as number]
);
}
export function computeImportancesForAllTokens(
importances: number[][],
isInitialState?: boolean,
qaRadio?: string
): number[] {
const startSumImportances = importances[0].map((_, index) =>
importances.reduce((sum, row) => sum + row[index], 0)
);
const endSumImportances = importances[1].map((_, index) =>
importances.reduce((sum, row) => sum + row[index], 0)
);
if (isInitialState) {
return startSumImportances;
}
return qaRadio === QAExplanationType.Start
? startSumImportances
: endSumImportances;
}

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { Stack, Text } from "@fluentui/react";
import { localization } from "@responsible-ai/localization";
import React from "react";
import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
import { componentStackTokens } from "./ITextExplanationViewSpec";
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
interface ITextInputOutputAreaWithLegendProps {
topK: number;
radio: string;
selectedToken: number;
text: string[];
outputLocalExplanations: number[];
inputLocalExplanations: number[];
isQA?: boolean;
getSelectedWord: () => string;
onSelectedTokenChange: (newIndex: number) => void;
}
export class TextInputOutputAreaWithLegend extends React.Component<ITextInputOutputAreaWithLegendProps> {
public render(): React.ReactNode {
const classNames = textExplanationDashboardStyles();
return (
<Stack tokens={componentStackTokens} horizontal>
{this.props.isQA && (
<Stack.Item grow>
<Stack horizontal={false}>
<Stack.Item>
<Text className={classNames.boldText}>
{localization.InterpretText.View.outputs}
</Text>
</Stack.Item>
<Stack.Item
align="stretch"
grow
disableShrink
className={classNames.textHighlighting}
>
<TextHighlighting
text={this.props.text}
localExplanations={this.props.outputLocalExplanations}
topK={
// keep all importances for single token(set topK to length)
this.props.outputLocalExplanations.length
}
radio={this.props.radio}
isInput={false}
onSelectedTokenChange={this.props.onSelectedTokenChange}
selectedTokenIndex={this.props.selectedToken}
/>
</Stack.Item>
</Stack>
</Stack.Item>
)}
<Stack.Item grow>
<Stack horizontal={false}>
<Stack.Item>
<Text className={classNames.boldText}>
{localization.InterpretText.View.inputs}
</Text>
</Stack.Item>
<Stack.Item
align="stretch"
grow
disableShrink
className={classNames.textHighlighting}
>
<TextHighlighting
text={this.props.text}
localExplanations={this.props.inputLocalExplanations}
topK={this.props.topK}
radio={this.props.radio}
isInput
/>
</Stack.Item>
</Stack>
</Stack.Item>
<Stack.Item grow className={classNames.chartRight}>
<TextFeatureLegend
selectedWord={this.props.getSelectedWord()}
isQA={this.props.isQA}
/>
</Stack.Item>
</Stack>
);
}
}

Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { Stack, Text } from "@fluentui/react";
import { localization } from "@responsible-ai/localization";
import React from "react";
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
interface ITrueAndPredictedAnswerViewProps {
predictedY: string | number | number[] | string[] | number[][] | undefined;
trueY: string | number | number[] | string[] | number[][] | undefined;
}
export class TrueAndPredictedAnswerView extends React.Component<ITrueAndPredictedAnswerViewProps> {
public render(): React.ReactNode {
const classNames = textExplanationDashboardStyles();
return (
<Stack horizontal={false}>
<Stack horizontal>
<Stack.Item>
<Text className={classNames.predictedAnswer}>
{localization.InterpretText.View.predictedAnswer} &nbsp;
</Text>
</Stack.Item>
<Stack.Item>
<Text className={classNames.predictedAnswer}>
{this.props.predictedY}
</Text>
</Stack.Item>
</Stack>
<Stack horizontal>
<Stack.Item>
<Text className={classNames.boldText}>
{localization.InterpretText.View.trueAnswer} &nbsp;
</Text>
</Stack.Item>
<Stack.Item>
<Text className={classNames.boldText}>{this.props.trueY}</Text>
</Stack.Item>
</Stack>
</Stack>
);
}
}

Просмотреть файл

@ -7,10 +7,6 @@ import {
IProcessedStyleSet,
getTheme
} from "@fluentui/react";
import {
getPrimaryBackgroundChartColor,
getPrimaryChartColor
} from "@responsible-ai/core-ui";
export interface ITextFeatureLegendStyles {
legend: IStyle;
@ -26,12 +22,12 @@ export const textFeatureLegendStyles: () => IProcessedStyleSet<ITextFeatureLegen
color: theme.semanticColors.disabledText
},
negFeatureImportance: {
color: getPrimaryChartColor(theme),
textDecorationLine: "underline"
backgroundColor: theme.semanticColors.link,
color: theme.semanticColors.bodyBackground
},
posFeatureImportance: {
backgroundColor: getPrimaryChartColor(theme),
color: getPrimaryBackgroundChartColor(theme)
backgroundColor: theme.semanticColors.errorText,
color: theme.semanticColors.bodyBackground
}
});
};

Просмотреть файл

@ -17,7 +17,12 @@ const legendStackTokens: IStackTokens = {
padding: "s"
};
export class TextFeatureLegend extends React.Component {
interface ITextFeatureLegendProps {
selectedWord: string;
isQA?: boolean;
}
export class TextFeatureLegend extends React.Component<ITextFeatureLegendProps> {
public render(): React.ReactNode {
const classNames = textFeatureLegendStyles();
return (
@ -51,6 +56,28 @@ export class TextFeatureLegend extends React.Component {
</Stack.Item>
</Stack>
</Stack.Item>
{this.props.isQA && (
<Stack tokens={componentStackTokens}>
<Stack.Item>
<Text>{localization.InterpretText.Legend.cls}</Text>
</Stack.Item>
<Stack.Item>
<Text>{localization.InterpretText.Legend.sep}</Text>
</Stack.Item>
<Stack.Item>
<Stack horizontal>
<Stack.Item>
<Text>
{localization.InterpretText.Legend.selectedWord} &nbsp;
</Text>
</Stack.Item>
<Stack.Item>
<Text>{this.props.selectedWord}</Text>
</Stack.Item>
</Stack>
</Stack.Item>
</Stack>
)}
</Stack>
);
}

Просмотреть файл

@ -2,14 +2,13 @@
// Licensed under the MIT License.
import {
IStyle,
mergeStyles,
mergeStyleSets,
IProcessedStyleSet,
IStackStyles,
getTheme
IStyle,
getTheme,
mergeStyleSets,
mergeStyles
} from "@fluentui/react";
import { getPrimaryChartColor } from "@responsible-ai/core-ui";
export const textStackStyles: IStackStyles = {
root: {
@ -31,30 +30,41 @@ export interface ITextHighlightingStyles {
boldunderline: IStyle;
}
export const textHighlightingStyles: () => IProcessedStyleSet<ITextHighlightingStyles> =
() => {
const theme = getTheme();
const normal = {
color: theme.semanticColors.bodyText
};
return mergeStyleSets<ITextHighlightingStyles>({
boldunderline: mergeStyles([
normal,
{
color: getPrimaryChartColor(theme),
fontSize: theme.fonts.large.fontSize,
margin: "2px",
padding: 0,
textDecorationLine: "underline"
}
]),
highlighted: mergeStyles([
normal,
{
backgroundColor: getPrimaryChartColor(theme),
color: theme.semanticColors.bodyBackground
}
]),
normal
});
export const textHighlightingStyles: (
isTextSelected: boolean
) => IProcessedStyleSet<ITextHighlightingStyles> = (isTextSelected) => {
const theme = getTheme();
const normal = {
color: theme.semanticColors.bodyText
};
const selectedTextStyle = isTextSelected
? {
textDecorationColor: "black",
textDecorationLine: "underline",
textDecorationStyle: "solid",
textDecorationThickness: "4px"
}
: {};
return mergeStyleSets<ITextHighlightingStyles>({
boldunderline: mergeStyles([
normal,
{
backgroundColor: theme.semanticColors.link,
color: theme.semanticColors.bodyBackground,
fontSize: theme.fonts.large.fontSize,
margin: "2px",
padding: 0
},
selectedTextStyle
]),
highlighted: mergeStyles([
normal,
selectedTextStyle,
{
backgroundColor: theme.semanticColors.errorText,
color: theme.semanticColors.bodyBackground
}
]),
normal: mergeStyles([normal, selectedTextStyle])
});
};

Просмотреть файл

@ -2,7 +2,6 @@
// Licensed under the MIT License.
import {
Label,
Text,
Stack,
IStackTokens,
@ -25,12 +24,11 @@ const textStackTokens: IStackTokens = {
padding: "s2"
};
export class TextHighlighting extends React.PureComponent<IChartProps> {
export class TextHighlighting extends React.Component<IChartProps> {
/*
* Presents the document in an accessible manner with text highlighting
*/
public render(): React.ReactNode {
const classNames = textHighlightingStyles();
const text = this.props.text;
const importances = this.props.localExplanations;
const k = this.props.topK;
@ -47,30 +45,22 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
styles={textStackStyles}
>
{text.map((word, wordIndex) => {
const isWordSelected =
(this.props.selectedTokenIndex &&
wordIndex === this.props.selectedTokenIndex) ||
false;
const classNames = textHighlightingStyles(isWordSelected);
let styleType = classNames.normal;
const score = importances[wordIndex];
let isBold = false;
if (sortedList.includes(wordIndex)) {
if (score > 0) {
styleType = classNames.highlighted;
} else if (score < 0) {
styleType = classNames.boldunderline;
isBold = true;
} else {
styleType = classNames.normal;
}
}
if (isBold) {
return (
<Label
key={wordIndex}
className={styleType}
title={score.toString()}
>
{word}
</Label>
);
}
return (
<Text
@ -78,6 +68,7 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
key={wordIndex}
className={styleType}
title={score.toString()}
onClick={(): void => this.handleClick(wordIndex)}
>
{word}
</Text>
@ -88,4 +79,13 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
</Stack>
);
}
private readonly handleClick = (wordIndex: number): void => {
if (this.props.isInput) {
return;
}
if (this.props.onSelectedTokenChange) {
this.props.onSelectedTokenChange(wordIndex);
}
};
}

Просмотреть файл

@ -9,4 +9,9 @@ export interface IChartProps {
localExplanations: number[];
topK?: number;
radio?: string;
isInput?: boolean;
baseValue?: number;
outputFeatureValue?: number;
selectedTokenIndex?: number;
onSelectedTokenChange?: (newIndex: number) => void;
}

Просмотреть файл

@ -19,5 +19,8 @@ export interface IDatasetSummary {
text: string[];
classNames?: string[];
localExplanations: number[][];
baseValues?: number[][];
prediction?: number[];
predictedY?: number[] | number[][] | string[] | string | number;
trueY?: number[] | number[][] | string[] | string | number;
}

Просмотреть файл

@ -1374,12 +1374,19 @@
"label": "Label",
"colon": ": ",
"startingPosition": "STARTING POSITION",
"endingPosition": "ENDING POSITION"
"endingPosition": "ENDING POSITION",
"predictedAnswer": "Predicted answer: ",
"trueAnswer": "True answer: ",
"inputs": "Inputs",
"outputs": "Outputs"
},
"Legend": {
"featureLegend": "TEXT FEATURE LEGEND",
"posFeatureImportance": "POSITIVE FEATURE IMPORTANCE",
"negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE"
"negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE",
"cls": "CLS: start of the sentence",
"sep": "SEP: end of the sentence",
"selectedWord": "Selected word: "
},
"BarChart": {
"featureImportance": "FEATURE IMPORTANCE"

Просмотреть файл

@ -75,7 +75,6 @@ export class FeatureImportancesTab extends React.PureComponent<
return React.Fragment;
}
const classNames = featureImportanceTabStyles();
return (
<Stack className={classNames.container}>
<Pivot

Просмотреть файл

@ -29,7 +29,10 @@ export interface ITextLocalImportancePlotsProps {
export interface ITextFeatureImportances {
text: string[];
importances: number[][];
baseValues?: number[][];
prediction: number[];
predictedY?: number[] | number[][] | string[] | string | number;
trueY?: number[] | number[][] | string[] | string | number;
}
export class TextLocalImportancePlots extends React.Component<ITextLocalImportancePlotsProps> {
@ -44,10 +47,13 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
}
const classNames = this.props.jointDataset.getModelClasses();
const textExplanationDashboardData: ITextExplanationDashboardData = {
baseValues: textFeatureImportances.baseValues,
classNames,
localExplanations: textFeatureImportances.importances,
predictedY: textFeatureImportances.predictedY,
prediction: textFeatureImportances.prediction,
text: textFeatureImportances.text
text: textFeatureImportances.text,
trueY: textFeatureImportances.trueY
};
const dashboardProp: ITextExplanationViewProps = {
dataSummary: textExplanationDashboardData,
@ -67,7 +73,7 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
this.context.modelExplanationData?.precomputedExplanations
?.textFeatureImportance?.[row[0]];
if (!textFeatureImportance) {
return { importances: [], prediction: [], text: [] };
return { baseValues: [], importances: [], prediction: [], text: [] };
}
const text = textFeatureImportance?.text;
const rowDict = this.props.jointDataset.getRow(row[0]);
@ -78,10 +84,16 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
return rowDict[key];
});
const importances: number[][] = textFeatureImportance?.localExplanations;
const baseValues = textFeatureImportance?.baseValues;
const trueY = this.context.dataset.true_y[row[0]];
const predictedY = this.context.dataset.predicted_y?.[row[0]];
return {
baseValues,
importances,
predictedY,
prediction,
text
text,
trueY
};
});
return featureImportances;