Dev harness style fixes (#63)
* Fix CSS to align venn diagram box properly. Structure dev harness HTML to resemble Jupyter notebok. * Remove some unnecessary experimental code * Add image loading to dev harness backend * Rename get_instance_data_by_id to get_instance_image_by_id Co-authored-by: Nicholas King <v-nicki@microsoft.com>
This commit is contained in:
Родитель
42b7c18413
Коммит
e43fc5b296
|
@ -3,6 +3,8 @@
|
|||
|
||||
import os
|
||||
import copy
|
||||
import io
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
@ -13,6 +15,8 @@ from backwardcompatibilityml import loss as bcloss
|
|||
from backwardcompatibilityml.helpers import training
|
||||
from backwardcompatibilityml.helpers.models import LogisticRegression, MLPClassifier
|
||||
from backwardcompatibilityml.widget.compatibility_analysis import CompatibilityAnalysis
|
||||
from flask import send_file
|
||||
from PIL import Image
|
||||
from rai_core_flask.flask_helper import FlaskHelper
|
||||
|
||||
|
||||
|
@ -175,13 +179,34 @@ def mnist_sweep():
|
|||
|
||||
h2 = copy.deepcopy(h1)
|
||||
|
||||
def unnormalize(img):
|
||||
img = img / 2 + 0.5
|
||||
return img
|
||||
|
||||
def get_instance_image(instance_id):
|
||||
img_bytes = io.BytesIO()
|
||||
data = np.reshape(
|
||||
np.uint8(np.transpose((unnormalize(dataset[instance_id][1])), (1, 2, 0)).numpy() * 255),
|
||||
(28, 28))
|
||||
img = Image.fromarray(data)
|
||||
img.save(img_bytes, format="PNG")
|
||||
img_bytes.seek(0)
|
||||
return send_file(img_bytes, mimetype='image/png')
|
||||
|
||||
def get_instance_label(instance_id):
|
||||
label = data_loader[instance_id][2].item()
|
||||
return {"label": label}
|
||||
|
||||
CompatibilityAnalysis(sweeps_folder, n_epochs, h1, h2, train_loader, test_loader,
|
||||
batch_size_train, batch_size_test,
|
||||
OptimizerClass=optim.SGD,
|
||||
optimizer_kwargs={"lr": learning_rate, "momentum": momentum},
|
||||
NewErrorLossClass=bcloss.BCCrossEntropyLoss,
|
||||
StrictImitationLossClass=bcloss.StrictImitationCrossEntropyLoss,
|
||||
lambda_c_stepsize=0.25, device="cuda")
|
||||
lambda_c_stepsize=0.25,
|
||||
get_instance_image_by_id=get_instance_image,
|
||||
get_instance_metadata=get_instance_label,
|
||||
device="cuda")
|
||||
|
||||
|
||||
mnist_sweep()
|
||||
|
|
|
@ -6,36 +6,50 @@
|
|||
<html>
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<link rel="stylesheet" href="css/bootstrap-tour.min.css">
|
||||
<link rel="stylesheet" href="css/codemirror.css">
|
||||
<link rel="stylesheet" href="css/custom.css">
|
||||
<link rel="stylesheet" href="css/jquery.typeahead.min.css">
|
||||
<link rel="stylesheet" href="css/jquery-ui.min.css">
|
||||
<link rel="stylesheet" href="css/override.css">
|
||||
<link rel="stylesheet" href="css/style.min.css">
|
||||
<title>BCML widget</title>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<link rel="stylesheet" href="css/bootstrap-tour.min.css">
|
||||
<link rel="stylesheet" href="css/codemirror.css">
|
||||
<link rel="stylesheet" href="css/custom.css">
|
||||
<link rel="stylesheet" href="css/jquery.typeahead.min.css">
|
||||
<link rel="stylesheet" href="css/jquery-ui.min.css">
|
||||
<link rel="stylesheet" href="css/override.css">
|
||||
<link rel="stylesheet" href="css/style.min.css">
|
||||
<title>BCML widget</title>
|
||||
</head>
|
||||
|
||||
<div class="container">
|
||||
<div id="widget">
|
||||
</div>
|
||||
<script>
|
||||
var data = null;
|
||||
window.API_SERVICE_ENVIRONMENT = {environment_type: "local", base_url: "", port: 5000};
|
||||
window.WIDGET_STATE = {
|
||||
data: data,
|
||||
sweepStatus: null,
|
||||
selectedDataPoint: null,
|
||||
training: true,
|
||||
testing: true,
|
||||
newError: true,
|
||||
strictImitation: true,
|
||||
error: null,
|
||||
loading: false
|
||||
};
|
||||
</script>
|
||||
</div>
|
||||
<body class="notebook_app command_mode ms-Fabric--isFocusHidden">
|
||||
<div id="notebook">
|
||||
<div class="container" id="notebook-container">
|
||||
<div class="cell code_cell rendered selected">
|
||||
<div class="output_wrapper">
|
||||
<div class="output">
|
||||
<div class="output_area">
|
||||
<div class="prompt"></div>
|
||||
<div class="output_subarea output_html rendered_html">
|
||||
<div id="widget">
|
||||
</div>
|
||||
<script>
|
||||
var data = null;
|
||||
window.API_SERVICE_ENVIRONMENT = { environment_type: "local", base_url: "", port: 5000 };
|
||||
window.WIDGET_STATE = {
|
||||
data: data,
|
||||
sweepStatus: null,
|
||||
selectedDataPoint: null,
|
||||
training: true,
|
||||
testing: true,
|
||||
newError: true,
|
||||
strictImitation: true,
|
||||
error: null,
|
||||
loading: false
|
||||
};
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
|
@ -64,7 +64,7 @@ label {
|
|||
padding-right:20px;
|
||||
padding-top: 20px;
|
||||
padding-bottom: 39px;
|
||||
bottom: 27px;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
.plot-distribution {
|
||||
|
@ -225,4 +225,4 @@ label {
|
|||
.highlighted-bar {
|
||||
stroke-width: 2;
|
||||
stroke: rgb(0,0,0);
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче