fix notebook by inline all scripts and data (#129)

inline scripts and data so it will for in jupyter and azure databricks.
extract common dashboard code to a separate dashboard class
This commit is contained in:
xuke444 2020-11-02 07:59:49 -08:00 коммит произвёл GitHub
Родитель 6c0d390985
Коммит 494e13a34c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
17 изменённых файлов: 349 добавлений и 192 удалений

Просмотреть файл

@ -7,8 +7,6 @@
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="icon" type="image/x-icon" href="favicon.ico" />
<script src="https://cdn.plot.ly/plotly-latest.js"></script>
<style type="text/css">
html,
body,

Просмотреть файл

@ -2,34 +2,17 @@
// Licensed under the MIT License.
import React from "react";
import { Route, Switch, RouteComponentProps } from "react-router-dom";
import { config } from "./config";
import { Fairness } from "./Fairness";
import { IAppConfig } from "./IAppConfig";
import { IFairnessRouteProps } from "./IFairnessRouteProps";
export interface IAppState {
config: IAppConfig;
}
export class App extends React.Component<unknown, IAppState> {
export class App extends React.Component {
public render(): React.ReactNode {
return (
<Switch>
<Route path={Fairness.route} render={this.renderFairness} exact />
</Switch>
);
}
public async componentDidMount(): Promise<void> {
const res = await (await fetch(new Request("/getconfig"))).json();
this.setState({ config: res });
}
private readonly renderFairness = (
props: RouteComponentProps<IFairnessRouteProps>
): React.ReactNode => {
if (!this.state?.config) {
return "Loading";
switch (config.dashboardType) {
case "Fairness":
return <Fairness />;
default:
return "Not Found";
}
return <Fairness {...props.match.params} {...this.state.config} />;
};
}
}

Просмотреть файл

@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
import { generateRoute } from "@responsible-ai/core-ui";
import {
FairnessWizardV2,
IMetricRequest,
@ -9,62 +8,44 @@ import {
} from "@responsible-ai/fairness";
import React from "react";
import { IAppConfig } from "./IAppConfig";
import { IFairnessRouteProps, routeKey } from "./IFairnessRouteProps";
import { config } from "./config";
import { modelData } from "./modelData";
export interface IFairnessState {
fairnessConfig: any | undefined;
}
export type IFairnessProps = IFairnessRouteProps & IAppConfig;
export class Fairness extends React.Component<IFairnessProps, IFairnessState> {
public static route = `/fairness/model${generateRoute(routeKey)}`;
public async componentDidMount(): Promise<void> {
const res = await (
await fetch(new Request(`/fairness/getmodel/${this.props.model}`))
).json();
this.setState({ fairnessConfig: res });
}
export class Fairness extends React.Component {
public render(): React.ReactNode {
if (this.state?.fairnessConfig) {
return (
<FairnessWizardV2
dataSummary={{
classNames: this.state.fairnessConfig.classes,
featureNames: this.state.fairnessConfig.features
}}
testData={this.state.fairnessConfig.dataset}
predictedY={this.state.fairnessConfig.predicted_ys}
trueY={this.state.fairnessConfig.true_y}
modelNames={this.state.fairnessConfig.model_names}
precomputedMetrics={this.state.fairnessConfig.precomputedMetrics}
precomputedFeatureBins={
this.state.fairnessConfig.precomputedFeatureBins
}
customMetrics={this.state.fairnessConfig.customMetrics}
predictionType={this.state.fairnessConfig.predictionType}
supportedBinaryClassificationPerformanceKeys={
this.state.fairnessConfig.classification_methods
}
supportedRegressionPerformanceKeys={
this.state.fairnessConfig.regression_methods
}
supportedProbabilityPerformanceKeys={
this.state.fairnessConfig.probability_methods
}
locale={this.state.fairnessConfig.locale}
requestMetrics={this.requestMetrics}
/>
);
}
return "Loading";
return (
<FairnessWizardV2
dataSummary={{
classNames: modelData.classes,
featureNames: modelData.features
}}
testData={modelData.dataset}
predictedY={modelData.predicted_ys}
trueY={modelData.true_y}
modelNames={modelData.model_names}
precomputedMetrics={modelData.precomputedMetrics}
precomputedFeatureBins={modelData.precomputedFeatureBins}
customMetrics={modelData.customMetrics}
predictionType={modelData.predictionType}
supportedBinaryClassificationPerformanceKeys={
modelData.classification_methods
}
supportedRegressionPerformanceKeys={modelData.regression_methods}
supportedProbabilityPerformanceKeys={modelData.probability_methods}
locale={modelData.locale}
requestMetrics={config.hasCallback ? this.requestMetrics : undefined}
/>
);
}
private readonly requestMetrics = (
postData: IMetricRequest
): Promise<IMetricResponse> => {
return fetch(this.state.fairnessConfig.metricsUrl, {
return fetch(config.baseUrl + `/fairness/model/${config.id}/metrics`, {
body: JSON.stringify(postData),
headers: {
"Content-Type": "application/json"

Просмотреть файл

@ -1,7 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
export const routeKey = ["model"] as const;
export type IFairnessRouteProps = {
[key in typeof routeKey[number]]?: string;
};

Просмотреть файл

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
export interface IAppConfig {
dashboardType: "Fairness";
id: string;
baseUrl: string;
hasCallback: boolean;
withCredentials: boolean;
}
export const config: IAppConfig = JSON.parse(`__rai_config__${""}`);

Просмотреть файл

@ -1,6 +1,4 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
export interface IAppConfig {
localUrl: string;
}
export const modelData = JSON.parse(`__rai_model_data__${""}`);

Просмотреть файл

@ -6,10 +6,8 @@
<base href="/" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="icon" type="image/x-icon" href="favicon.ico" />
<script src="https://cdn.plot.ly/plotly-latest.js"></script>
</head>
<body>
<div id="root"></div>
<div id="__rai_app_id__"></div>
</body>
</html>

Просмотреть файл

@ -3,15 +3,12 @@
import React from "react";
import ReactDOM from "react-dom";
import { BrowserRouter } from "react-router-dom";
import { App } from "./app/App";
ReactDOM.render(
<React.StrictMode>
<BrowserRouter>
<App />
</BrowserRouter>
<App />
</React.StrictMode>,
document.querySelector("#root")
document.querySelector(`#${"__rai_app_id__"}`)
);

Просмотреть файл

@ -23,6 +23,7 @@
"format:check": "nx format:check",
"format:write": "nx format:write",
"help": "nx help",
"postinstall": "node ./scripts/postInstall.js",
"kill": "taskkill /IM:node.exe /F",
"lint": "cross-env NODE_OPTIONS=--max_old_space_size=4096 nx lint",
"lintall": "nx workspace-lint && cross-env NODE_OPTIONS=--max_old_space_size=4096 nx run-many --target=lint --all && yarn prettier . --check",

Просмотреть файл

@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.
"""Defines the dashboard class."""
# TODO: use environment_detector
# https://github.com/microsoft/responsible-ai-widgets/issues/92
from rai_core_flask import FlaskHelper # , environment_detector
from flask import Response
from IPython.display import display, HTML
import json
import os
from html.parser import HTMLParser
class InLineScript(HTMLParser):
def __init__(self, id):
HTMLParser.__init__(self)
self.content = ""
self.id = id
def handle_starttag(self, tag, attrs):
if tag == "script":
src = None
for att in attrs:
if att[0] == "src":
src = att[1]
break
if src is not None:
content = Dashboard.load_widget_file(src, self.id)
self.content += f'<script>\r\n{content}\r\n'
return
self.content += self.get_starttag_text()
def handle_endtag(self, tag):
self.content += f'</{tag}>'
pass
def handle_data(self, data):
self.content += data
pass
class Dashboard(object):
"""The dashboard class, wraps the dashboard component.
:param sensitive_features: A matrix of feature vector examples
(# examples x # features), these can be from the initial dataset,
or reserved from training.
:type sensitive_features: numpy.array or list[][] or pandas.DataFrame
or pandas.Series
:param y_true: The true labels or values for the provided dataset.
:type y_true: numpy.array or list[]
:param y_pred: Array of output predictions from models to be evaluated.
Can be a single array of predictions, or a 2D list over multiple
models. Can be a dictionary of named model predictions.
:type y_pred: numpy.array or list[][] or list[] or dict {string: list[]}
:param sensitive_feature_names: Feature names
:type sensitive_feature_names: numpy.array or list[]
"""
model_data = {}
config = {}
model_count = 0
_service = None
@FlaskHelper.app.route('/')
def list():
return ','.join(Dashboard.config.keys())
@FlaskHelper.app.route('/<int:id>')
def visual(id):
if str(id) in Dashboard.config:
return Dashboard.load_index(str(id))
else:
return Response("Unknown model id.", status=404)
def __init__(
self, *,
dashboard_type,
model_data,
port=None):
"""Initialize the Dashboard."""
if model_data is None or type is None:
raise ValueError("Required parameters not provided")
if Dashboard._service is None:
try:
Dashboard._service = FlaskHelper(port=port)
except Exception as e:
Dashboard._service = None
raise e
Dashboard.model_count += 1
self.id = str(Dashboard.model_count)
Dashboard.config[self.id] = {
"dashboardType": dashboard_type,
"id": self.id,
"baseUrl": Dashboard._service.env.base_url,
'withCredentials': False,
'hasCallback': True
}
Dashboard.model_data[self.id] = model_data
html = Dashboard.load_index(self.id)
# TODO https://github.com/microsoft/responsible-ai-widgets/issues/92
# FairnessDashboard._service.env.display(html)
display(HTML(html))
def get_widget_path(path):
script_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(script_path, "widget", path)
def load_index(id):
index = Dashboard.load_widget_file("index.html", id)
parser = InLineScript(id)
parser.feed(index)
return parser.content
def load_widget_file(path, id):
js_path = Dashboard.get_widget_path(path)
with open(js_path, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace("__rai_app_id__", f'rai_widget_{id}')
content = content.replace(
"__rai_config__", json.dumps(Dashboard.config[id]))
content = content.replace(
"__rai_model_data__", json.dumps(Dashboard.model_data[id]))
return content

Просмотреть файл

@ -5,20 +5,17 @@
# TODO: use environment_detector
# https://github.com/microsoft/responsible-ai-widgets/issues/92
from .dashboard import Dashboard
from rai_core_flask import FlaskHelper # , environment_detector
from .fairness_metric_calculation import FairnessMetricModule
from flask import jsonify, request, Response
from IPython.display import display, HTML
from jinja2 import Environment, PackageLoader
import json
from flask import jsonify, request
import numpy as np
import os
import pandas as pd
from scipy.sparse import issparse
class FairnessDashboard(object):
class FairnessDashboard(Dashboard):
"""The dashboard class, wraps the dashboard component.
:param sensitive_features: A matrix of feature vector examples
@ -35,59 +32,21 @@ class FairnessDashboard(object):
:param sensitive_feature_names: Feature names
:type sensitive_feature_names: numpy.array or list[]
"""
env = Environment(loader=PackageLoader(__name__, 'widget'))
_dashboard_js = None
fairness_inputs = {}
model_count = 0
_service = None
@FlaskHelper.app.route('/widget/<path:path>')
def widget_static(path):
mimetypes = {
".css": "text/css",
".html": "text/html",
".js": "application/javascript",
}
ext = os.path.splitext(path)[1]
mimetype = mimetypes.get(ext, "application/octet-stream")
return Response(load_widget_file(path), mimetype=mimetype)
@FlaskHelper.app.route('/getconfig')
def get_config():
burl = FairnessDashboard._service.env.base_url
ct = FairnessDashboard.model_count
return {
"local_url": f"{burl}/fairness/model/{ct}"
}
@FlaskHelper.app.route('/fairness')
def list():
return "No global list view supported at this time."
@FlaskHelper.app.route('/fairness/model/<id>')
def fairness_visual(id):
return load_widget_file("index.html")
@FlaskHelper.app.route('/fairness/getmodel/<id>')
def fairness_get_model(id):
if id in FairnessDashboard.fairness_inputs:
model_data = json.dumps(FairnessDashboard.fairness_inputs[id])
return model_data
else:
return Response("Unknown model id.", status=404)
fairness_metrics_module = {}
@FlaskHelper.app.route('/fairness/model/<id>/metrics', methods=['POST'])
def fairness_metrics_calculation(id):
try:
data = request.get_json(force=True)
if id in FairnessDashboard.fairness_inputs:
data.update(FairnessDashboard.fairness_inputs[id])
if id in FairnessDashboard.model_data:
data.update(FairnessDashboard.model_data[id])
if type(data["binVector"][0]) == np.int32:
data['binVector'] = [
str(bin_) for bin_ in data['binVector']]
method = FairnessDashboard.fairness_metrics_module. \
method = FairnessDashboard.fairness_metrics_module[id]. \
_metric_methods.get(data["metricKey"]).get("function")
prediction = method(
data['true_y'],
@ -121,7 +80,7 @@ class FairnessDashboard(object):
fairness_metric_mapping=None):
"""Initialize the fairness Dashboard."""
FairnessDashboard.fairness_metrics_module = FairnessMetricModule(
metrics_module = FairnessMetricModule(
module_name=fairness_metric_module,
mapping=fairness_metric_mapping)
@ -154,14 +113,11 @@ class FairnessDashboard(object):
"predicted_ys": self._y_pred,
"dataset": dataset,
"classification_methods":
FairnessDashboard.fairness_metrics_module.
classification_methods,
metrics_module.classification_methods,
"regression_methods":
FairnessDashboard.fairness_metrics_module.
regression_methods,
metrics_module.regression_methods,
"probability_methods":
FairnessDashboard.fairness_metrics_module.
probability_methods,
metrics_module.probability_methods,
}
if model_names is not None:
@ -178,32 +134,11 @@ class FairnessDashboard(object):
"ignoring")
fairness_input["features"] = sensitive_feature_names
if FairnessDashboard._service is None:
try:
FairnessDashboard._service = FlaskHelper(port=port)
except Exception as e:
FairnessDashboard._service = None
raise e
Dashboard.__init__(self, dashboard_type="Fairness",
model_data=fairness_input,
port=port)
FairnessDashboard.model_count += 1
model_count = FairnessDashboard.model_count
burl = FairnessDashboard._service.env.base_url
local_url = f"{burl}/fairness/model/{model_count}"
metrics_url = f"{local_url}/metrics"
fairness_input['metricsUrl'] = metrics_url
# TODO
fairness_input['withCredentials'] = False
FairnessDashboard.fairness_inputs[str(model_count)] = fairness_input
html = load_widget_file("index.html")
# TODO https://github.com/microsoft/responsible-ai-widgets/issues/92
# FairnessDashboard._service.env.display(html)
display(HTML(html))
FairnessDashboard.fairness_metrics_module[self.id] = metrics_module
def _sanitize_data_shape(self, dataset):
result = self._convert_to_list(dataset)
@ -224,14 +159,3 @@ class FairnessDashboard(object):
if (isinstance(array, np.ndarray)):
return array.tolist()
return array
def get_widget_path(path):
script_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(script_path, "widget", path)
def load_widget_file(path):
js_path = get_widget_path(path)
with open(js_path, "r", encoding="utf-8") as f:
return f.read()

Просмотреть файл

@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from raiwidgets import FairnessDashboard\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.linear_model import LogisticRegression\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_openml\n",
"data = fetch_openml(data_id=1590, as_frame=True)\n",
"X_raw = data.data\n",
"Y = (data.target == '>50K') * 1\n",
"X_raw"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"A = X_raw[\"sex\"]\n",
"X = X_raw.drop(labels=['sex'], axis=1)\n",
"X = pd.get_dummies(X)\n",
"\n",
"sc = StandardScaler()\n",
"X_scaled = sc.fit_transform(X)\n",
"X_scaled = pd.DataFrame(X_scaled, columns=X.columns)\n",
"\n",
"le = LabelEncoder()\n",
"Y = le.fit_transform(Y)\n",
"\n",
"\n",
"X_train,\\\n",
" X_test,\\\n",
" Y_train,\\\n",
" Y_test,\\\n",
" A_train,\\\n",
" A_test = train_test_split(X_scaled,\n",
" Y,\n",
" A,\n",
" test_size=0.2,\n",
" random_state=0,\n",
" stratify=Y)\n",
"\n",
"\n",
"X_train = X_train.reset_index(drop=True)\n",
"A_train = A_train.reset_index(drop=True)\n",
"X_test = X_test.reset_index(drop=True)\n",
"A_test = A_test.reset_index(drop=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unmitigated_predictor = LogisticRegression(\n",
" solver='liblinear', fit_intercept=True)\n",
"\n",
"unmitigated_predictor.fit(X_train, Y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FairnessDashboard(sensitive_features=A_test, sensitive_feature_names=['sex'],\n",
" y_true=Y_test,\n",
" y_pred={\n",
" \"unmitigated\": unmitigated_predictor.predict(X_test)\n",
"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

23
scripts/fixPlotly.js Normal file
Просмотреть файл

@ -0,0 +1,23 @@
const plotlyFolder = "./node_modules/plotly.js/";
const jsPath = plotlyFolder + "dist/plotly.js";
const packagePath = plotlyFolder + "package.json";
const jsReg = /(define\(d3\).*)\n(\}\.apply\(self\)|\}\(\));/;
const packageReg = /\"main\":\s\"(.\/lib\/index\.js|\.\/dist\/plotly.js)",/;
const fs = require("fs");
function replaceFile(path, reg, to) {
if (!fs.existsSync(path)) {
throw new Error(`${path} does not exist.`);
}
let content = fs.readFileSync(path, { encoding: "utf-8" });
if (!reg.test(content)) {
throw new Error(`${path} has wrong content`);
}
content = content.replace(reg, to);
fs.writeFileSync(path, content, { encoding: "utf-8" });
}
module.exports = function () {
replaceFile(jsPath, jsReg, "$1\n}.apply(self);");
replaceFile(packagePath, packageReg, '"main": "./dist/plotly.js",');
};

2
scripts/postInstall.js Normal file
Просмотреть файл

@ -0,0 +1,2 @@
const fixPlotly = require("./fixPlotly");
fixPlotly();

Просмотреть файл

@ -13,8 +13,11 @@ module.exports = (config) => {
tls: "empty",
child_process: "empty"
};
config.externals = config.externals || {};
config.externals["plotly.js"] = "Plotly";
if (process.env.debug) {
require("fs-extra").writeJSONSync("./webpack.json", config, {
spaces: 2
});
}
return config;
};

Просмотреть файл

@ -5800,14 +5800,6 @@ doctrine@^3.0.0:
dependencies:
esutils "^2.0.2"
dom-helpers@^5.0.1:
version "5.2.0"
resolved "https://registry.yarnpkg.com/dom-helpers/-/dom-helpers-5.2.0.tgz#57fd054c5f8f34c52a3eeffdb7e7e93cd357d95b"
integrity sha512-Ru5o9+V8CpunKnz5LGgWXkmrH/20cGKwcHwS4m73zIvs54CN9epEmT/HLqFJW3kXpakAFkEdzgy1hzlJe3E4OQ==
dependencies:
"@babel/runtime" "^7.8.7"
csstype "^3.0.2"
document-register-element@1.13.1:
version "1.13.1"
resolved "https://registry.yarnpkg.com/document-register-element/-/document-register-element-1.13.1.tgz#dad8cb7be38e04ee3f56842e6cf81af46c1249ba"
@ -5820,6 +5812,14 @@ dom-accessibility-api@^0.5.1:
resolved "https://registry.yarnpkg.com/dom-accessibility-api/-/dom-accessibility-api-0.5.4.tgz#b06d059cdd4a4ad9a79275f9d414a5c126241166"
integrity sha512-TvrjBckDy2c6v6RLxPv5QXOnU+SmF9nBII5621Ve5fu6Z/BDrENurBEvlC1f44lKEUVqOpK4w9E5Idc5/EgkLQ==
dom-helpers@^5.0.1:
version "5.2.0"
resolved "https://registry.yarnpkg.com/dom-helpers/-/dom-helpers-5.2.0.tgz#57fd054c5f8f34c52a3eeffdb7e7e93cd357d95b"
integrity sha512-Ru5o9+V8CpunKnz5LGgWXkmrH/20cGKwcHwS4m73zIvs54CN9epEmT/HLqFJW3kXpakAFkEdzgy1hzlJe3E4OQ==
dependencies:
"@babel/runtime" "^7.8.7"
csstype "^3.0.2"
dom-serializer@0:
version "0.2.2"
resolved "https://registry.yarnpkg.com/dom-serializer/-/dom-serializer-0.2.2.tgz#1afb81f533717175d478655debc5e332d9f9bb51"