Individual feature importance interpret QA (#2186)
* add Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * update Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * cleanup Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * lintfix Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * lintfix Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * lintfix Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * fix row change err Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> * address comments Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com> --------- Signed-off-by: vinutha karanth <vinutha.karanth@gmail.com>
This commit is contained in:
Родитель
37c340c8d6
Коммит
29bfb51675
|
@ -57,6 +57,7 @@ export interface IPrecomputedExplanations {
|
|||
export interface ITextFeatureImportance {
|
||||
text: string[];
|
||||
localExplanations: number[][];
|
||||
baseValues?: number[][];
|
||||
}
|
||||
|
||||
export interface IEBMGlobalExplanation {
|
||||
|
|
|
@ -6,4 +6,7 @@ export interface ITextExplanationDashboardData {
|
|||
localExplanations: number[][];
|
||||
prediction: number[];
|
||||
text: string[];
|
||||
baseValues?: number[][];
|
||||
predictedY?: number[] | number[][] | string[] | string | number;
|
||||
trueY?: number[] | number[][] | string[] | string | number;
|
||||
}
|
||||
|
|
|
@ -78,6 +78,14 @@ export class Utils {
|
|||
return sortedList;
|
||||
}
|
||||
|
||||
public static addItem(value: number, radio: string | undefined): boolean {
|
||||
return (
|
||||
radio === RadioKeys.All ||
|
||||
(radio === RadioKeys.Neg && value <= 0) ||
|
||||
(radio === RadioKeys.Pos && value >= 0)
|
||||
);
|
||||
}
|
||||
|
||||
public static takeTopK(list: number[], k: number): number[] {
|
||||
/*
|
||||
* Returns a list after splicing and taking the top K
|
||||
|
|
|
@ -2,17 +2,27 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import { ITheme } from "@fluentui/react";
|
||||
import {
|
||||
IHighchartsConfig,
|
||||
getPrimaryChartColor,
|
||||
getPrimaryBackgroundChartColor
|
||||
} from "@responsible-ai/core-ui";
|
||||
import { IHighchartsConfig } from "@responsible-ai/core-ui";
|
||||
import { localization } from "@responsible-ai/localization";
|
||||
import { SeriesOptionsType } from "highcharts";
|
||||
import _ from "lodash";
|
||||
|
||||
import { Utils } from "../../CommonUtils";
|
||||
import { IChartProps } from "../../Interfaces/IChartProps";
|
||||
|
||||
function findNearestIndex(
|
||||
array: number[],
|
||||
target?: number
|
||||
): number | undefined {
|
||||
if (!target) {
|
||||
return array.length;
|
||||
}
|
||||
const nearestElement = _.minBy(array, (element) =>
|
||||
Math.abs(element - target)
|
||||
);
|
||||
return _.indexOf(array, nearestElement);
|
||||
}
|
||||
|
||||
export function getTokenImportancesChartOptions(
|
||||
props: IChartProps,
|
||||
theme: ITheme
|
||||
|
@ -20,6 +30,11 @@ export function getTokenImportancesChartOptions(
|
|||
const importances = props.localExplanations;
|
||||
const k = props.topK;
|
||||
const sortedList = Utils.sortedTopK(importances, k, props.radio);
|
||||
|
||||
const outputFeatureImportanceLabel = `f ${
|
||||
props.text[props.selectedTokenIndex || 0]
|
||||
} (inputs)`;
|
||||
const baseValueLabel = "base value";
|
||||
const [x, y, ylabel, tooltip]: [number[], number[], string[], string[]] = [
|
||||
[],
|
||||
[],
|
||||
|
@ -46,6 +61,36 @@ export function getTokenImportancesChartOptions(
|
|||
ylabel.push(props.text[idx]);
|
||||
tooltip.push(str);
|
||||
});
|
||||
|
||||
// add output feature importance
|
||||
if (props.outputFeatureValue && props.baseValue) {
|
||||
const outputFeatureValueIndex = findNearestIndex(
|
||||
x,
|
||||
props.outputFeatureValue
|
||||
);
|
||||
const baseValueFeatureValueIndex = findNearestIndex(x, props.baseValue);
|
||||
if (outputFeatureValueIndex && baseValueFeatureValueIndex) {
|
||||
if (Utils.addItem(props.outputFeatureValue, props.radio)) {
|
||||
addItem(
|
||||
x,
|
||||
props.outputFeatureValue,
|
||||
ylabel,
|
||||
outputFeatureImportanceLabel,
|
||||
outputFeatureValueIndex
|
||||
);
|
||||
}
|
||||
if (Utils.addItem(props.baseValue, props.radio)) {
|
||||
addItem(
|
||||
x,
|
||||
props.baseValue,
|
||||
ylabel,
|
||||
baseValueLabel,
|
||||
baseValueFeatureValueIndex
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put most significant word at the top by reversing order
|
||||
tooltip.reverse();
|
||||
ylabel.reverse();
|
||||
|
@ -54,11 +99,10 @@ export function getTokenImportancesChartOptions(
|
|||
const data: any[] = [];
|
||||
x.forEach((p, index) => {
|
||||
const temp = {
|
||||
borderColor: getPrimaryChartColor(theme),
|
||||
color:
|
||||
(p || 0) >= 0
|
||||
? getPrimaryChartColor(theme)
|
||||
: getPrimaryBackgroundChartColor(theme),
|
||||
? theme.semanticColors.errorText
|
||||
: theme.semanticColors.link,
|
||||
x: index,
|
||||
y: p
|
||||
};
|
||||
|
@ -68,6 +112,15 @@ export function getTokenImportancesChartOptions(
|
|||
const series: SeriesOptionsType[] = [
|
||||
{
|
||||
data,
|
||||
dataLabels: {
|
||||
align: "center",
|
||||
color: theme.semanticColors.bodyBackground,
|
||||
enabled: true,
|
||||
formatter(): string | number | undefined {
|
||||
return this.x; // Display the Y-axis value inside the bar
|
||||
},
|
||||
inside: true
|
||||
},
|
||||
name: "",
|
||||
showInLegend: false,
|
||||
type: "bar"
|
||||
|
@ -80,11 +133,12 @@ export function getTokenImportancesChartOptions(
|
|||
},
|
||||
plotOptions: {
|
||||
bar: {
|
||||
minPointLength: 10,
|
||||
tooltip: {
|
||||
pointFormatter(): string {
|
||||
return `${tooltip[this.x || 0]}: ${this.y || 0}`;
|
||||
}
|
||||
}
|
||||
} // Set the minimum pixel width for bars
|
||||
}
|
||||
},
|
||||
series,
|
||||
|
@ -98,3 +152,14 @@ export function getTokenImportancesChartOptions(
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
function addItem(
|
||||
x: any[],
|
||||
xValue: any,
|
||||
yLabel: any[],
|
||||
yLabelValue: any,
|
||||
index: number
|
||||
): void {
|
||||
x.splice(index, 0, xValue);
|
||||
yLabel.splice(index, 0, yLabelValue);
|
||||
}
|
||||
|
|
|
@ -13,12 +13,13 @@ export interface ITextExplanationViewState {
|
|||
maxK: number;
|
||||
topK: number;
|
||||
radio: string;
|
||||
// qaRadio?: string;
|
||||
qaRadio?: string;
|
||||
importances: number[];
|
||||
singleTokenImportances: number[];
|
||||
selectedToken: number;
|
||||
tokenIndexes: number[];
|
||||
text: string[];
|
||||
outputFeatureImportances: number[][];
|
||||
}
|
||||
|
||||
export const options: IChoiceGroupOption[] = [
|
||||
|
|
|
@ -38,6 +38,9 @@ export interface ISidePanelOfChartProps {
|
|||
selectedWeightVector: WeightVectorOption;
|
||||
weightOptions: WeightVectorOption[];
|
||||
weightLabels: any;
|
||||
baseValue?: number;
|
||||
outputFeatureValue?: number;
|
||||
selectedTokenIndex?: number;
|
||||
changeRadioButton: (
|
||||
_event?: React.FormEvent,
|
||||
item?: IChoiceGroupOption
|
||||
|
@ -63,6 +66,9 @@ export class SidePanelOfChart extends React.PureComponent<ISidePanelOfChartProps
|
|||
localExplanations={this.props.importances}
|
||||
topK={this.props.topK}
|
||||
radio={this.props.radio}
|
||||
baseValue={this.props.baseValue}
|
||||
outputFeatureValue={this.props.outputFeatureValue}
|
||||
selectedTokenIndex={this.props.selectedTokenIndex}
|
||||
/>
|
||||
</Stack.Item>
|
||||
<Stack.Item grow className={classNames.chartRight}>
|
||||
|
|
|
@ -11,16 +11,25 @@ import {
|
|||
export interface ITextExplanationDashboardStyles {
|
||||
chartRight: IStyle;
|
||||
textHighlighting: IStyle;
|
||||
predictedAnswer: IStyle;
|
||||
boldText: IStyle;
|
||||
}
|
||||
|
||||
export const textExplanationDashboardStyles: () => IProcessedStyleSet<ITextExplanationDashboardStyles> =
|
||||
() => {
|
||||
const theme = getTheme();
|
||||
return mergeStyleSets<ITextExplanationDashboardStyles>({
|
||||
boldText: {
|
||||
fontWeight: "bold"
|
||||
},
|
||||
chartRight: {
|
||||
maxWidth: "230px",
|
||||
minWidth: "230px"
|
||||
},
|
||||
predictedAnswer: {
|
||||
fontWeight: "bold",
|
||||
paddingBottom: "14px"
|
||||
},
|
||||
textHighlighting: {
|
||||
borderColor: theme.semanticColors.variantBorder,
|
||||
borderRadius: "1px",
|
||||
|
|
|
@ -2,24 +2,29 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import { IChoiceGroupOption, Stack, Text } from "@fluentui/react";
|
||||
import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
|
||||
import { WeightVectorOption } from "@responsible-ai/core-ui";
|
||||
import { localization } from "@responsible-ai/localization";
|
||||
import React from "react";
|
||||
|
||||
import { RadioKeys, Utils } from "../../CommonUtils";
|
||||
import { QAExplanationType, RadioKeys } from "../../CommonUtils";
|
||||
import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps";
|
||||
import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
|
||||
import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
|
||||
|
||||
import {
|
||||
ITextExplanationViewState,
|
||||
MaxImportantWords,
|
||||
componentStackTokens
|
||||
} from "./ITextExplanationViewSpec";
|
||||
import { SidePanelOfChart } from "./SidePanelOfChart";
|
||||
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
|
||||
import {
|
||||
calculateMaxKImportances,
|
||||
calculateTopKImportances,
|
||||
computeImportancesForAllTokens,
|
||||
computeImportancesForWeightVector,
|
||||
getOutputFeatureImportances
|
||||
} from "./TextExplanationViewUtils";
|
||||
import { TextInputOutputAreaWithLegend } from "./TextInputOutputAreaWithLegend";
|
||||
import { TrueAndPredictedAnswerView } from "./TrueAndPredictedAnswerView";
|
||||
|
||||
export class TextExplanationView extends React.PureComponent<
|
||||
export class TextExplanationView extends React.Component<
|
||||
ITextExplanationViewProps,
|
||||
ITextExplanationViewState
|
||||
> {
|
||||
|
@ -31,23 +36,32 @@ export class TextExplanationView extends React.PureComponent<
|
|||
|
||||
const weightVector = this.props.selectedWeightVector;
|
||||
const importances = this.props.isQA
|
||||
? this.computeImportancesForAllTokens(
|
||||
this.props.dataSummary.localExplanations
|
||||
? computeImportancesForAllTokens(
|
||||
this.props.dataSummary.localExplanations,
|
||||
true
|
||||
)
|
||||
: this.computeImportancesForWeightVector(
|
||||
: computeImportancesForWeightVector(
|
||||
this.props.dataSummary.localExplanations,
|
||||
weightVector
|
||||
);
|
||||
|
||||
const maxK = this.calculateMaxKImportances(importances);
|
||||
const topK = this.calculateTopKImportances(importances);
|
||||
const maxK = calculateMaxKImportances(importances);
|
||||
const topK = calculateTopKImportances(importances);
|
||||
this.state = {
|
||||
importances,
|
||||
maxK,
|
||||
// qaRadio: QAExplanationType.Start,
|
||||
outputFeatureImportances: getOutputFeatureImportances(
|
||||
this.props.dataSummary.localExplanations,
|
||||
this.props.dataSummary.baseValues
|
||||
),
|
||||
qaRadio: QAExplanationType.Start,
|
||||
radio: RadioKeys.All,
|
||||
selectedToken: 0, // default to the first token
|
||||
singleTokenImportances: this.getImportanceForSingleToken(0), // get importance for first token
|
||||
selectedToken: 0,
|
||||
// default to the first token
|
||||
singleTokenImportances: this.props.dataSummary.localExplanations[0].map(
|
||||
(row) => row[0]
|
||||
),
|
||||
// get importance for first token
|
||||
text: this.props.dataSummary.text,
|
||||
tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
|
||||
topK
|
||||
|
@ -60,28 +74,19 @@ export class TextExplanationView extends React.PureComponent<
|
|||
this.props.dataSummary.localExplanations !==
|
||||
prevProps.dataSummary.localExplanations
|
||||
) {
|
||||
if (this.props.isQA) {
|
||||
this.setState(
|
||||
{
|
||||
selectedToken: 0,
|
||||
//update token dropdown
|
||||
tokenIndexes: [...this.props.dataSummary.text].map(
|
||||
(_, index) => index
|
||||
)
|
||||
},
|
||||
() => {
|
||||
this.updateTokenImportances();
|
||||
this.updateSingleTokenImportances();
|
||||
}
|
||||
);
|
||||
} else {
|
||||
this.updateImportances(this.props.selectedWeightVector);
|
||||
}
|
||||
this.updateState();
|
||||
}
|
||||
}
|
||||
|
||||
public render(): React.ReactNode {
|
||||
const classNames = textExplanationDashboardStyles();
|
||||
const outputLocalExplanations =
|
||||
this.state.qaRadio === QAExplanationType.Start
|
||||
? this.state.outputFeatureImportances[0]
|
||||
: this.state.outputFeatureImportances[1];
|
||||
const inputLocalExplanations = this.props.isQA
|
||||
? this.state.singleTokenImportances
|
||||
: this.state.importances;
|
||||
const baseValue = this.props.isQA ? this.getBaseValue() : undefined;
|
||||
|
||||
return (
|
||||
<Stack>
|
||||
|
@ -93,9 +98,30 @@ export class TextExplanationView extends React.PureComponent<
|
|||
)}
|
||||
</Stack>
|
||||
|
||||
<Stack tokens={componentStackTokens} horizontal>
|
||||
{this.props.isQA && (
|
||||
<TrueAndPredictedAnswerView
|
||||
predictedY={this.props.dataSummary.predictedY}
|
||||
trueY={this.props.dataSummary.trueY}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
<TextInputOutputAreaWithLegend
|
||||
topK={this.state.topK}
|
||||
radio={this.state.radio}
|
||||
selectedToken={this.state.selectedToken}
|
||||
text={this.state.text}
|
||||
outputLocalExplanations={outputLocalExplanations}
|
||||
inputLocalExplanations={inputLocalExplanations}
|
||||
isQA={this.props.isQA}
|
||||
getSelectedWord={this.getSelectedWord}
|
||||
onSelectedTokenChange={this.onSelectedTokenChange}
|
||||
/>
|
||||
|
||||
<SidePanelOfChart
|
||||
text={this.state.text}
|
||||
importances={this.state.importances}
|
||||
importances={inputLocalExplanations}
|
||||
topK={this.state.topK}
|
||||
radio={this.state.radio}
|
||||
isQA={this.props.isQA}
|
||||
|
@ -111,153 +137,93 @@ export class TextExplanationView extends React.PureComponent<
|
|||
setTopK={this.setTopK}
|
||||
onWeightVectorChange={this.onWeightVectorChange}
|
||||
onSelectedTokenChange={this.onSelectedTokenChange}
|
||||
outputFeatureValue={outputLocalExplanations[this.state.selectedToken]}
|
||||
baseValue={baseValue}
|
||||
selectedTokenIndex={this.state.selectedToken}
|
||||
/>
|
||||
|
||||
<Stack tokens={componentStackTokens} horizontal>
|
||||
<Stack.Item
|
||||
align="stretch"
|
||||
grow
|
||||
disableShrink
|
||||
className={classNames.textHighlighting}
|
||||
>
|
||||
<TextHighlighting
|
||||
text={this.state.text}
|
||||
localExplanations={this.state.importances}
|
||||
topK={this.state.topK}
|
||||
radio={this.state.radio}
|
||||
/>
|
||||
</Stack.Item>
|
||||
|
||||
{this.props.isQA && (
|
||||
<Stack.Item
|
||||
align="stretch"
|
||||
grow
|
||||
disableShrink
|
||||
className={classNames.textHighlighting}
|
||||
>
|
||||
<TextHighlighting
|
||||
text={this.state.text}
|
||||
localExplanations={this.state.singleTokenImportances}
|
||||
topK={
|
||||
// keep all importances for single token(set topK to length)
|
||||
this.state.singleTokenImportances.length
|
||||
}
|
||||
radio={this.state.radio}
|
||||
/>
|
||||
</Stack.Item>
|
||||
)}
|
||||
|
||||
<Stack.Item align="end">
|
||||
<TextFeatureLegend />
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
|
||||
private updateState(): void {
|
||||
const importances = this.props.isQA
|
||||
? this.getTokenImportances()
|
||||
: this.getImportances(this.props.selectedWeightVector);
|
||||
const [topK, maxK] = this.getTopKMaxK(importances);
|
||||
this.setState({
|
||||
importances,
|
||||
maxK,
|
||||
outputFeatureImportances: getOutputFeatureImportances(
|
||||
this.props.dataSummary.localExplanations,
|
||||
this.props.dataSummary.baseValues
|
||||
),
|
||||
selectedToken: 0,
|
||||
singleTokenImportances: this.getImportanceForSingleToken(
|
||||
this.state.selectedToken
|
||||
),
|
||||
text: this.props.dataSummary.text,
|
||||
tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
|
||||
topK
|
||||
});
|
||||
}
|
||||
|
||||
private onWeightVectorChange = (weightOption: WeightVectorOption): void => {
|
||||
this.updateImportances(weightOption);
|
||||
const importances = this.getImportances(weightOption);
|
||||
const [topK, maxK] = this.getTopKMaxK(importances);
|
||||
this.setState({ importances, maxK, topK });
|
||||
this.props.onWeightChange(weightOption);
|
||||
};
|
||||
|
||||
private onSelectedTokenChange = (newIndex: number): void => {
|
||||
this.setState({ selectedToken: newIndex }, () => {
|
||||
this.updateSingleTokenImportances();
|
||||
const singleTokenImportances = this.getImportanceForSingleToken(newIndex);
|
||||
this.setState({
|
||||
selectedToken: newIndex,
|
||||
singleTokenImportances
|
||||
});
|
||||
};
|
||||
|
||||
private updateImportances(weightOption: WeightVectorOption): void {
|
||||
const importances = this.computeImportancesForWeightVector(
|
||||
private getSelectedWord = (): string => {
|
||||
return this.props.dataSummary.text[this.state.selectedToken];
|
||||
};
|
||||
|
||||
private getTopKMaxK(importances: number[]): [number, number] {
|
||||
const topK = calculateTopKImportances(importances);
|
||||
const maxK = calculateMaxKImportances(importances);
|
||||
return [topK, maxK];
|
||||
}
|
||||
|
||||
private getImportances(weightOption: WeightVectorOption): number[] {
|
||||
return computeImportancesForWeightVector(
|
||||
this.props.dataSummary.localExplanations,
|
||||
weightOption
|
||||
);
|
||||
|
||||
const topK = this.calculateTopKImportances(importances);
|
||||
const maxK = this.calculateMaxKImportances(importances);
|
||||
this.setState({
|
||||
importances,
|
||||
maxK,
|
||||
text: this.props.dataSummary.text,
|
||||
topK
|
||||
});
|
||||
}
|
||||
|
||||
// for QA
|
||||
private updateTokenImportances(): void {
|
||||
const importances = this.computeImportancesForAllTokens(
|
||||
private getTokenImportances(): number[] {
|
||||
return computeImportancesForAllTokens(
|
||||
this.props.dataSummary.localExplanations
|
||||
);
|
||||
const topK = this.calculateTopKImportances(importances);
|
||||
const maxK = this.calculateMaxKImportances(importances);
|
||||
this.setState({
|
||||
importances,
|
||||
maxK,
|
||||
text: this.props.dataSummary.text,
|
||||
topK
|
||||
});
|
||||
}
|
||||
|
||||
private updateSingleTokenImportances(): void {
|
||||
const singleTokenImportances = this.getImportanceForSingleToken(
|
||||
this.state.selectedToken
|
||||
);
|
||||
this.setState({ singleTokenImportances });
|
||||
}
|
||||
|
||||
private calculateTopKImportances(importances: number[]): number {
|
||||
return Math.min(
|
||||
MaxImportantWords,
|
||||
Math.ceil(Utils.countNonzeros(importances) / 2)
|
||||
);
|
||||
}
|
||||
|
||||
private calculateMaxKImportances(importances: number[]): number {
|
||||
return Math.min(
|
||||
MaxImportantWords,
|
||||
Math.ceil(Utils.countNonzeros(importances))
|
||||
);
|
||||
}
|
||||
|
||||
private computeImportancesForWeightVector(
|
||||
importances: number[][],
|
||||
weightVector: WeightVectorOption
|
||||
): number[] {
|
||||
if (weightVector === WeightVectors.AbsAvg) {
|
||||
// Sum the multidimensional array to one dimension across rows for each token
|
||||
const numClasses = importances[0].length;
|
||||
const sumImportances = importances.map((row) =>
|
||||
row.reduce((a, b): number => {
|
||||
return (a + Math.abs(b)) / numClasses;
|
||||
}, 0)
|
||||
);
|
||||
return sumImportances;
|
||||
}
|
||||
return importances.map(
|
||||
(perClassImportances) => perClassImportances[weightVector as number]
|
||||
);
|
||||
}
|
||||
|
||||
private computeImportancesForAllTokens(importances: number[][]): number[] {
|
||||
/*
|
||||
* sum the tokens importance
|
||||
* TODO: add base values?
|
||||
*/
|
||||
|
||||
const sumImportances = importances[0].map((_, index) =>
|
||||
importances.reduce((sum, row) => sum + row[index], 0)
|
||||
);
|
||||
|
||||
return sumImportances;
|
||||
}
|
||||
|
||||
private getImportanceForSingleToken(index: number): number[] {
|
||||
return this.props.dataSummary.localExplanations.map((row) => row[index]);
|
||||
const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
|
||||
return this.props.dataSummary.localExplanations[expIndex].map(
|
||||
(row) => row[index]
|
||||
);
|
||||
}
|
||||
|
||||
private getBaseValue(): number {
|
||||
if (this.props.dataSummary.baseValues) {
|
||||
const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
|
||||
return this.props.dataSummary.baseValues?.[expIndex][
|
||||
this.state.selectedToken
|
||||
];
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private setTopK = (newNumber: number): void => {
|
||||
/*
|
||||
* Changes the state of K
|
||||
*/
|
||||
this.setState({ topK: newNumber });
|
||||
};
|
||||
|
||||
|
@ -265,23 +231,23 @@ export class TextExplanationView extends React.PureComponent<
|
|||
_event?: React.FormEvent,
|
||||
item?: IChoiceGroupOption
|
||||
): void => {
|
||||
/*
|
||||
* Changes the state of the radio button
|
||||
*/
|
||||
if (item?.key !== undefined) {
|
||||
if (item?.key) {
|
||||
this.setState({ radio: item.key });
|
||||
}
|
||||
};
|
||||
|
||||
private switchQAPrediction = (): // _event?: React.FormEvent,
|
||||
// _item?: IChoiceGroupOption
|
||||
void => {
|
||||
/*
|
||||
* switch to the target predictions(starting or ending)
|
||||
* TODO: add logic for switching explanation data
|
||||
*/
|
||||
// if (item?.key !== undefined) {
|
||||
// this.setState({ qaRadio: item.key });
|
||||
// }
|
||||
private switchQAPrediction = (
|
||||
_event?: React.FormEvent,
|
||||
item?: IChoiceGroupOption
|
||||
): void => {
|
||||
if (item?.key) {
|
||||
const singleTokenImportances = this.getImportanceForSingleToken(
|
||||
this.state.selectedToken
|
||||
);
|
||||
this.setState({
|
||||
qaRadio: item.key,
|
||||
singleTokenImportances
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
|
||||
|
||||
import { QAExplanationType, Utils } from "../../CommonUtils";
|
||||
|
||||
import { MaxImportantWords } from "./ITextExplanationViewSpec";
|
||||
|
||||
export function getOutputFeatureImportances(
|
||||
localExplanations: number[][],
|
||||
baseValues?: number[][]
|
||||
): number[][] {
|
||||
const startSumOfFeatureImportances = getSumOfFeatureImportances(
|
||||
localExplanations[0]
|
||||
);
|
||||
const endSumOfFeatureImportances = getSumOfFeatureImportances(
|
||||
localExplanations[1]
|
||||
);
|
||||
const startOutputFeatureImportances = getOutputFeatureImportancesIntl(
|
||||
startSumOfFeatureImportances,
|
||||
baseValues?.[0]
|
||||
);
|
||||
const endOutputFeatureImportances = getOutputFeatureImportancesIntl(
|
||||
endSumOfFeatureImportances,
|
||||
baseValues?.[1]
|
||||
);
|
||||
return [
|
||||
startOutputFeatureImportances || [],
|
||||
endOutputFeatureImportances || []
|
||||
];
|
||||
}
|
||||
|
||||
export function getSumOfFeatureImportances(importances: number[]): number[] {
|
||||
return importances.map((_, index) =>
|
||||
importances.reduce((sum, row) => sum + row[index], 0)
|
||||
);
|
||||
}
|
||||
|
||||
export function getOutputFeatureImportancesIntl(
|
||||
sumOfFeatureImportances: number[],
|
||||
baseValues?: number[]
|
||||
): number[] | undefined {
|
||||
return baseValues?.map(
|
||||
(bValue, index) => sumOfFeatureImportances[index] + bValue
|
||||
);
|
||||
}
|
||||
|
||||
export function calculateTopKImportances(importances: number[]): number {
|
||||
return Math.min(
|
||||
MaxImportantWords,
|
||||
Math.ceil(Utils.countNonzeros(importances) / 2)
|
||||
);
|
||||
}
|
||||
|
||||
export function calculateMaxKImportances(importances: number[]): number {
|
||||
return Math.min(
|
||||
MaxImportantWords,
|
||||
Math.ceil(Utils.countNonzeros(importances))
|
||||
);
|
||||
}
|
||||
|
||||
export function computeImportancesForWeightVector(
|
||||
importances: number[][],
|
||||
weightVector: WeightVectorOption
|
||||
): number[] {
|
||||
if (weightVector === WeightVectors.AbsAvg) {
|
||||
// Sum the multidimensional array to one dimension across rows for each token
|
||||
const numClasses = importances[0].length;
|
||||
const sumImportances = importances.map((row) =>
|
||||
row.reduce((a, b): number => {
|
||||
return (a + Math.abs(b)) / numClasses;
|
||||
}, 0)
|
||||
);
|
||||
return sumImportances;
|
||||
}
|
||||
return importances.map(
|
||||
(perClassImportances) => perClassImportances[weightVector as number]
|
||||
);
|
||||
}
|
||||
|
||||
export function computeImportancesForAllTokens(
|
||||
importances: number[][],
|
||||
isInitialState?: boolean,
|
||||
qaRadio?: string
|
||||
): number[] {
|
||||
const startSumImportances = importances[0].map((_, index) =>
|
||||
importances.reduce((sum, row) => sum + row[index], 0)
|
||||
);
|
||||
const endSumImportances = importances[1].map((_, index) =>
|
||||
importances.reduce((sum, row) => sum + row[index], 0)
|
||||
);
|
||||
if (isInitialState) {
|
||||
return startSumImportances;
|
||||
}
|
||||
|
||||
return qaRadio === QAExplanationType.Start
|
||||
? startSumImportances
|
||||
: endSumImportances;
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { Stack, Text } from "@fluentui/react";
|
||||
import { localization } from "@responsible-ai/localization";
|
||||
import React from "react";
|
||||
|
||||
import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
|
||||
import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
|
||||
|
||||
import { componentStackTokens } from "./ITextExplanationViewSpec";
|
||||
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
|
||||
|
||||
interface ITextInputOutputAreaWithLegendProps {
|
||||
topK: number;
|
||||
radio: string;
|
||||
selectedToken: number;
|
||||
text: string[];
|
||||
outputLocalExplanations: number[];
|
||||
inputLocalExplanations: number[];
|
||||
isQA?: boolean;
|
||||
getSelectedWord: () => string;
|
||||
onSelectedTokenChange: (newIndex: number) => void;
|
||||
}
|
||||
|
||||
export class TextInputOutputAreaWithLegend extends React.Component<ITextInputOutputAreaWithLegendProps> {
|
||||
public render(): React.ReactNode {
|
||||
const classNames = textExplanationDashboardStyles();
|
||||
|
||||
return (
|
||||
<Stack tokens={componentStackTokens} horizontal>
|
||||
{this.props.isQA && (
|
||||
<Stack.Item grow>
|
||||
<Stack horizontal={false}>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.boldText}>
|
||||
{localization.InterpretText.View.outputs}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item
|
||||
align="stretch"
|
||||
grow
|
||||
disableShrink
|
||||
className={classNames.textHighlighting}
|
||||
>
|
||||
<TextHighlighting
|
||||
text={this.props.text}
|
||||
localExplanations={this.props.outputLocalExplanations}
|
||||
topK={
|
||||
// keep all importances for single token(set topK to length)
|
||||
this.props.outputLocalExplanations.length
|
||||
}
|
||||
radio={this.props.radio}
|
||||
isInput={false}
|
||||
onSelectedTokenChange={this.props.onSelectedTokenChange}
|
||||
selectedTokenIndex={this.props.selectedToken}
|
||||
/>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack.Item>
|
||||
)}
|
||||
<Stack.Item grow>
|
||||
<Stack horizontal={false}>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.boldText}>
|
||||
{localization.InterpretText.View.inputs}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item
|
||||
align="stretch"
|
||||
grow
|
||||
disableShrink
|
||||
className={classNames.textHighlighting}
|
||||
>
|
||||
<TextHighlighting
|
||||
text={this.props.text}
|
||||
localExplanations={this.props.inputLocalExplanations}
|
||||
topK={this.props.topK}
|
||||
radio={this.props.radio}
|
||||
isInput
|
||||
/>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack.Item>
|
||||
<Stack.Item grow className={classNames.chartRight}>
|
||||
<TextFeatureLegend
|
||||
selectedWord={this.props.getSelectedWord()}
|
||||
isQA={this.props.isQA}
|
||||
/>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { Stack, Text } from "@fluentui/react";
|
||||
import { localization } from "@responsible-ai/localization";
|
||||
import React from "react";
|
||||
|
||||
import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
|
||||
|
||||
interface ITrueAndPredictedAnswerViewProps {
|
||||
predictedY: string | number | number[] | string[] | number[][] | undefined;
|
||||
trueY: string | number | number[] | string[] | number[][] | undefined;
|
||||
}
|
||||
|
||||
export class TrueAndPredictedAnswerView extends React.Component<ITrueAndPredictedAnswerViewProps> {
|
||||
public render(): React.ReactNode {
|
||||
const classNames = textExplanationDashboardStyles();
|
||||
|
||||
return (
|
||||
<Stack horizontal={false}>
|
||||
<Stack horizontal>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.predictedAnswer}>
|
||||
{localization.InterpretText.View.predictedAnswer}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.predictedAnswer}>
|
||||
{this.props.predictedY}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
<Stack horizontal>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.boldText}>
|
||||
{localization.InterpretText.View.trueAnswer}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item>
|
||||
<Text className={classNames.boldText}>{this.props.trueY}</Text>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
}
|
|
@ -7,10 +7,6 @@ import {
|
|||
IProcessedStyleSet,
|
||||
getTheme
|
||||
} from "@fluentui/react";
|
||||
import {
|
||||
getPrimaryBackgroundChartColor,
|
||||
getPrimaryChartColor
|
||||
} from "@responsible-ai/core-ui";
|
||||
|
||||
export interface ITextFeatureLegendStyles {
|
||||
legend: IStyle;
|
||||
|
@ -26,12 +22,12 @@ export const textFeatureLegendStyles: () => IProcessedStyleSet<ITextFeatureLegen
|
|||
color: theme.semanticColors.disabledText
|
||||
},
|
||||
negFeatureImportance: {
|
||||
color: getPrimaryChartColor(theme),
|
||||
textDecorationLine: "underline"
|
||||
backgroundColor: theme.semanticColors.link,
|
||||
color: theme.semanticColors.bodyBackground
|
||||
},
|
||||
posFeatureImportance: {
|
||||
backgroundColor: getPrimaryChartColor(theme),
|
||||
color: getPrimaryBackgroundChartColor(theme)
|
||||
backgroundColor: theme.semanticColors.errorText,
|
||||
color: theme.semanticColors.bodyBackground
|
||||
}
|
||||
});
|
||||
};
|
||||
|
|
|
@ -17,7 +17,12 @@ const legendStackTokens: IStackTokens = {
|
|||
padding: "s"
|
||||
};
|
||||
|
||||
export class TextFeatureLegend extends React.Component {
|
||||
interface ITextFeatureLegendProps {
|
||||
selectedWord: string;
|
||||
isQA?: boolean;
|
||||
}
|
||||
|
||||
export class TextFeatureLegend extends React.Component<ITextFeatureLegendProps> {
|
||||
public render(): React.ReactNode {
|
||||
const classNames = textFeatureLegendStyles();
|
||||
return (
|
||||
|
@ -51,6 +56,28 @@ export class TextFeatureLegend extends React.Component {
|
|||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack.Item>
|
||||
{this.props.isQA && (
|
||||
<Stack tokens={componentStackTokens}>
|
||||
<Stack.Item>
|
||||
<Text>{localization.InterpretText.Legend.cls}</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item>
|
||||
<Text>{localization.InterpretText.Legend.sep}</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item>
|
||||
<Stack horizontal>
|
||||
<Stack.Item>
|
||||
<Text>
|
||||
{localization.InterpretText.Legend.selectedWord}
|
||||
</Text>
|
||||
</Stack.Item>
|
||||
<Stack.Item>
|
||||
<Text>{this.props.selectedWord}</Text>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
</Stack.Item>
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -2,14 +2,13 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {
|
||||
IStyle,
|
||||
mergeStyles,
|
||||
mergeStyleSets,
|
||||
IProcessedStyleSet,
|
||||
IStackStyles,
|
||||
getTheme
|
||||
IStyle,
|
||||
getTheme,
|
||||
mergeStyleSets,
|
||||
mergeStyles
|
||||
} from "@fluentui/react";
|
||||
import { getPrimaryChartColor } from "@responsible-ai/core-ui";
|
||||
|
||||
export const textStackStyles: IStackStyles = {
|
||||
root: {
|
||||
|
@ -31,30 +30,41 @@ export interface ITextHighlightingStyles {
|
|||
boldunderline: IStyle;
|
||||
}
|
||||
|
||||
export const textHighlightingStyles: () => IProcessedStyleSet<ITextHighlightingStyles> =
|
||||
() => {
|
||||
const theme = getTheme();
|
||||
const normal = {
|
||||
color: theme.semanticColors.bodyText
|
||||
};
|
||||
return mergeStyleSets<ITextHighlightingStyles>({
|
||||
boldunderline: mergeStyles([
|
||||
normal,
|
||||
{
|
||||
color: getPrimaryChartColor(theme),
|
||||
fontSize: theme.fonts.large.fontSize,
|
||||
margin: "2px",
|
||||
padding: 0,
|
||||
textDecorationLine: "underline"
|
||||
}
|
||||
]),
|
||||
highlighted: mergeStyles([
|
||||
normal,
|
||||
{
|
||||
backgroundColor: getPrimaryChartColor(theme),
|
||||
color: theme.semanticColors.bodyBackground
|
||||
}
|
||||
]),
|
||||
normal
|
||||
});
|
||||
export const textHighlightingStyles: (
|
||||
isTextSelected: boolean
|
||||
) => IProcessedStyleSet<ITextHighlightingStyles> = (isTextSelected) => {
|
||||
const theme = getTheme();
|
||||
const normal = {
|
||||
color: theme.semanticColors.bodyText
|
||||
};
|
||||
const selectedTextStyle = isTextSelected
|
||||
? {
|
||||
textDecorationColor: "black",
|
||||
textDecorationLine: "underline",
|
||||
textDecorationStyle: "solid",
|
||||
textDecorationThickness: "4px"
|
||||
}
|
||||
: {};
|
||||
return mergeStyleSets<ITextHighlightingStyles>({
|
||||
boldunderline: mergeStyles([
|
||||
normal,
|
||||
{
|
||||
backgroundColor: theme.semanticColors.link,
|
||||
color: theme.semanticColors.bodyBackground,
|
||||
fontSize: theme.fonts.large.fontSize,
|
||||
margin: "2px",
|
||||
padding: 0
|
||||
},
|
||||
selectedTextStyle
|
||||
]),
|
||||
highlighted: mergeStyles([
|
||||
normal,
|
||||
selectedTextStyle,
|
||||
{
|
||||
backgroundColor: theme.semanticColors.errorText,
|
||||
color: theme.semanticColors.bodyBackground
|
||||
}
|
||||
]),
|
||||
normal: mergeStyles([normal, selectedTextStyle])
|
||||
});
|
||||
};
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {
|
||||
Label,
|
||||
Text,
|
||||
Stack,
|
||||
IStackTokens,
|
||||
|
@ -25,12 +24,11 @@ const textStackTokens: IStackTokens = {
|
|||
padding: "s2"
|
||||
};
|
||||
|
||||
export class TextHighlighting extends React.PureComponent<IChartProps> {
|
||||
export class TextHighlighting extends React.Component<IChartProps> {
|
||||
/*
|
||||
* Presents the document in an accessible manner with text highlighting
|
||||
*/
|
||||
public render(): React.ReactNode {
|
||||
const classNames = textHighlightingStyles();
|
||||
const text = this.props.text;
|
||||
const importances = this.props.localExplanations;
|
||||
const k = this.props.topK;
|
||||
|
@ -47,30 +45,22 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
|
|||
styles={textStackStyles}
|
||||
>
|
||||
{text.map((word, wordIndex) => {
|
||||
const isWordSelected =
|
||||
(this.props.selectedTokenIndex &&
|
||||
wordIndex === this.props.selectedTokenIndex) ||
|
||||
false;
|
||||
const classNames = textHighlightingStyles(isWordSelected);
|
||||
let styleType = classNames.normal;
|
||||
const score = importances[wordIndex];
|
||||
let isBold = false;
|
||||
if (sortedList.includes(wordIndex)) {
|
||||
if (score > 0) {
|
||||
styleType = classNames.highlighted;
|
||||
} else if (score < 0) {
|
||||
styleType = classNames.boldunderline;
|
||||
isBold = true;
|
||||
} else {
|
||||
styleType = classNames.normal;
|
||||
}
|
||||
}
|
||||
if (isBold) {
|
||||
return (
|
||||
<Label
|
||||
key={wordIndex}
|
||||
className={styleType}
|
||||
title={score.toString()}
|
||||
>
|
||||
{word}
|
||||
</Label>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Text
|
||||
|
@ -78,6 +68,7 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
|
|||
key={wordIndex}
|
||||
className={styleType}
|
||||
title={score.toString()}
|
||||
onClick={(): void => this.handleClick(wordIndex)}
|
||||
>
|
||||
{word}
|
||||
</Text>
|
||||
|
@ -88,4 +79,13 @@ export class TextHighlighting extends React.PureComponent<IChartProps> {
|
|||
</Stack>
|
||||
);
|
||||
}
|
||||
|
||||
private readonly handleClick = (wordIndex: number): void => {
|
||||
if (this.props.isInput) {
|
||||
return;
|
||||
}
|
||||
if (this.props.onSelectedTokenChange) {
|
||||
this.props.onSelectedTokenChange(wordIndex);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -9,4 +9,9 @@ export interface IChartProps {
|
|||
localExplanations: number[];
|
||||
topK?: number;
|
||||
radio?: string;
|
||||
isInput?: boolean;
|
||||
baseValue?: number;
|
||||
outputFeatureValue?: number;
|
||||
selectedTokenIndex?: number;
|
||||
onSelectedTokenChange?: (newIndex: number) => void;
|
||||
}
|
||||
|
|
|
@ -19,5 +19,8 @@ export interface IDatasetSummary {
|
|||
text: string[];
|
||||
classNames?: string[];
|
||||
localExplanations: number[][];
|
||||
baseValues?: number[][];
|
||||
prediction?: number[];
|
||||
predictedY?: number[] | number[][] | string[] | string | number;
|
||||
trueY?: number[] | number[][] | string[] | string | number;
|
||||
}
|
||||
|
|
|
@ -1374,12 +1374,19 @@
|
|||
"label": "Label",
|
||||
"colon": ": ",
|
||||
"startingPosition": "STARTING POSITION",
|
||||
"endingPosition": "ENDING POSITION"
|
||||
"endingPosition": "ENDING POSITION",
|
||||
"predictedAnswer": "Predicted answer: ",
|
||||
"trueAnswer": "True answer: ",
|
||||
"inputs": "Inputs",
|
||||
"outputs": "Outputs"
|
||||
},
|
||||
"Legend": {
|
||||
"featureLegend": "TEXT FEATURE LEGEND",
|
||||
"posFeatureImportance": "POSITIVE FEATURE IMPORTANCE",
|
||||
"negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE"
|
||||
"negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE",
|
||||
"cls": "CLS: start of the sentence",
|
||||
"sep": "SEP: end of the sentence",
|
||||
"selectedWord": "Selected word: "
|
||||
},
|
||||
"BarChart": {
|
||||
"featureImportance": "FEATURE IMPORTANCE"
|
||||
|
|
|
@ -75,7 +75,6 @@ export class FeatureImportancesTab extends React.PureComponent<
|
|||
return React.Fragment;
|
||||
}
|
||||
const classNames = featureImportanceTabStyles();
|
||||
|
||||
return (
|
||||
<Stack className={classNames.container}>
|
||||
<Pivot
|
||||
|
|
|
@ -29,7 +29,10 @@ export interface ITextLocalImportancePlotsProps {
|
|||
export interface ITextFeatureImportances {
|
||||
text: string[];
|
||||
importances: number[][];
|
||||
baseValues?: number[][];
|
||||
prediction: number[];
|
||||
predictedY?: number[] | number[][] | string[] | string | number;
|
||||
trueY?: number[] | number[][] | string[] | string | number;
|
||||
}
|
||||
|
||||
export class TextLocalImportancePlots extends React.Component<ITextLocalImportancePlotsProps> {
|
||||
|
@ -44,10 +47,13 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
|
|||
}
|
||||
const classNames = this.props.jointDataset.getModelClasses();
|
||||
const textExplanationDashboardData: ITextExplanationDashboardData = {
|
||||
baseValues: textFeatureImportances.baseValues,
|
||||
classNames,
|
||||
localExplanations: textFeatureImportances.importances,
|
||||
predictedY: textFeatureImportances.predictedY,
|
||||
prediction: textFeatureImportances.prediction,
|
||||
text: textFeatureImportances.text
|
||||
text: textFeatureImportances.text,
|
||||
trueY: textFeatureImportances.trueY
|
||||
};
|
||||
const dashboardProp: ITextExplanationViewProps = {
|
||||
dataSummary: textExplanationDashboardData,
|
||||
|
@ -67,7 +73,7 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
|
|||
this.context.modelExplanationData?.precomputedExplanations
|
||||
?.textFeatureImportance?.[row[0]];
|
||||
if (!textFeatureImportance) {
|
||||
return { importances: [], prediction: [], text: [] };
|
||||
return { baseValues: [], importances: [], prediction: [], text: [] };
|
||||
}
|
||||
const text = textFeatureImportance?.text;
|
||||
const rowDict = this.props.jointDataset.getRow(row[0]);
|
||||
|
@ -78,10 +84,16 @@ export class TextLocalImportancePlots extends React.Component<ITextLocalImportan
|
|||
return rowDict[key];
|
||||
});
|
||||
const importances: number[][] = textFeatureImportance?.localExplanations;
|
||||
const baseValues = textFeatureImportance?.baseValues;
|
||||
const trueY = this.context.dataset.true_y[row[0]];
|
||||
const predictedY = this.context.dataset.predicted_y?.[row[0]];
|
||||
return {
|
||||
baseValues,
|
||||
importances,
|
||||
predictedY,
|
||||
prediction,
|
||||
text
|
||||
text,
|
||||
trueY
|
||||
};
|
||||
});
|
||||
return featureImportances;
|
||||
|
|
Загрузка…
Ссылка в новой задаче