ml4f/pxt/extension.ts

474 строки
14 KiB
TypeScript

import { compileModelAndFullValidate, loadFlatJSONModel } from "../src/driver"
import { setBackend, loadLayersModel, SymbolicTensor } from "@tensorflow/tfjs"
export function inIFrame() {
try {
return typeof window !== "undefined" && window.self !== window.top
} catch (e) {
return typeof window !== "undefined"
}
}
const CHANGE = "change"
export const READ = "read"
export const MESSAGE_PACKET = "messagepacket"
const HIDDEN = "hidden"
const SHOWN = "shown"
//const SENDER = "jacdac-editor-extension"
const CONNECT = "connect"
export interface ReadResponse {
code?: string
json?: string
jres?: string
}
export interface SMap<T> {
[index: string]: T
}
const fakeSample = `
export function _sample() {
while (true) {
basic.showString("_sample() missing")
}
return [@samples@]
}`
const accelSample = `
export function _sample() {
return [
input.acceleration(Dimension.X) / 1024,
input.acceleration(Dimension.Y) / 1024,
input.acceleration(Dimension.Z) / 1024
]
}`
const buttonSample = `
let _button = Button.A
export function _sample() {
return [input.buttonIsPressed(_button) ? 1 : 0]
}
//% block="set ml button %button" blockId="ml_set_button"
export function setButton(button: Button) {
_button = button
}
`
export class MakeCodeEditorExtensionClient {
private readonly pendingCommands: {
[key: string]: {
action: string
resolve: (resp: any) => void
reject: (e: any) => void
}
} = {}
private readonly extensionId: string = inIFrame()
? window.location.hash.substr(1)
: undefined
private _target: any // full apptarget
private _connected = false
private _visible = false
constructor() {
this.handleMessage = this.handleMessage.bind(this)
window.addEventListener("message", this.handleMessage, false)
// notify parent that we're ready
this.init()
}
emit(id: string, arg?: any) {
console.log("EMIT", id, { arg })
}
log(msg: string) {
console.log(`ML4F-PXT: ${msg}`)
}
get target() {
return this._target
}
get connected() {
return this._connected
}
get visible() {
return this._visible
}
private setVisible(vis: boolean) {
if (this._visible !== vis) {
this._visible = vis
this.emit(CHANGE)
}
}
private nextRequestId = 1
private mkRequest(
resolve: (resp: any) => void,
reject: (e: any) => void,
action: string,
body?: any
): any {
const id = "ml_" + this.nextRequestId++
this.pendingCommands[id] = { action, resolve, reject }
return {
type: "pxtpkgext",
action,
extId: this.extensionId,
response: true,
id,
body,
}
}
private sendRequest<T>(action: string, body?: any): Promise<T> {
this.log(`send ${action}`)
if (!this.extensionId) return Promise.resolve(undefined)
return new Promise((resolve, reject) => {
const msg = this.mkRequest(resolve, reject, action, body)
window.parent.postMessage(msg, "*")
})
}
private handleMessage(ev: any) {
const msg = ev.data
if (msg?.type !== "pxtpkgext") return
if (!msg.id) {
switch (msg.event) {
case "extinit":
this.log(`init`)
this._target = msg.target
this._connected = true
this.emit(CONNECT)
this.emit(CHANGE)
break
case "extloaded":
this.log(`loaded`)
break
case "extshown":
this.setVisible(true)
this.refresh()
this.emit(SHOWN)
this.emit(CHANGE)
break
case "exthidden":
this.setVisible(false)
this.emit(HIDDEN)
this.emit(CHANGE)
break
case "extdatastream":
this.emit("datastream", true)
break
case "extconsole":
this.emit("console", msg.body)
break
case "extmessagepacket":
this.emit(MESSAGE_PACKET, msg.body)
break
default:
console.debug("Unhandled event", msg)
}
} else {
const { action, resolve, reject } =
this.pendingCommands[msg.id] || {}
delete this.pendingCommands[msg.id]
if (msg.success && resolve) resolve(msg.resp)
else if (!msg.success && reject) reject(msg.resp)
// raise event as well
switch (action) {
case "extinit":
this._connected = true
this.emit("CONNECT")
this.emit(CHANGE)
break
case "extusercode":
// Loaded, set the target
this.emit("readuser", msg.resp)
this.emit(CHANGE)
break
case "extreadcode":
// Loaded, set the target
this.emit(READ, msg.resp)
this.emit(CHANGE)
break
case "extwritecode":
this.emit("written", undefined)
break
}
}
}
private async init() {
this.log(`initializing`)
await this.sendRequest<void>("extinit")
this.log(`connected`)
await this.refresh()
}
private async refresh() {
this.log(`refresh`)
const r = await this.read()
}
async read(): Promise<ReadResponse> {
if (!this.extensionId) {
const r: ReadResponse = {}
this.emit(READ, r)
return r
} else {
const resp: ReadResponse = await this.sendRequest("extreadcode")
return resp
}
}
async readUser() {
await this.sendRequest("extusercode")
}
async write(
code: string,
json?: string,
jres?: string,
dependencies?: SMap<string>
): Promise<void> {
if (!this.extensionId) {
// Write to local storage instead
this.emit("written", undefined)
} else {
await this.sendRequest<void>("extwritecode", {
code: code || undefined,
json: json || undefined,
jres: jres || undefined,
dependencies,
})
}
}
async queryPermission() {
await this.sendRequest("extquerypermission")
}
async requestPermission(console: boolean) {
await this.sendRequest("extrequestpermission", {
console,
})
}
async dataStreamConsole(console: boolean) {
await this.sendRequest("extdatastream", {
console,
})
}
async dataStreamMessages(messages: boolean) {
await this.sendRequest("extdatastream", {
messages,
})
}
}
export interface FlatJSONModel {
name: string
inputTypes: string[] // ["x","y","z"]; ["pressure"]
labels: string[]
modelJSON: unknown
inputInterval: number // ms
weights: number[] // UInt32Array (little endian)
}
export async function start() {
setBackend("cpu")
const options: SMap<boolean> = {
f16: true,
}
const pxtClient = new MakeCodeEditorExtensionClient()
const maindiv = document.createElement("div")
maindiv.style.background = "white"
document.body.appendChild(maindiv)
const status = div("")
maindiv.append(status)
setStatus("waiting for model file")
const d = div("Drop TF.JS model file here")
d.style.padding = "2em"
d.style.margin = "1em 0em"
d.style.border = "1px dotted gray"
maindiv.append(d)
addCheckbox("f16", "Use float16 type")
const dropbox = maindiv
dropbox.addEventListener("dragenter", stopEv, false)
dropbox.addEventListener("dragover", stopEv, false)
dropbox.addEventListener(
"drop",
e => {
setStatus("reading model")
stopEv(e)
const file = e.dataTransfer.files.item(0)
const reader = new FileReader()
reader.onload = async e => {
try {
const mod: FlatJSONModel = JSON.parse(
e.target.result as string
)
await compileModel(mod, file.name)
} catch (e) {
console.error(e.stack)
setError(e.message)
}
}
reader.readAsText(file)
},
false
)
function shapeElements(shape: number[]) {
let res = 1
for (const s of shape) if (s != null) res *= s
return res
}
function toCamelCase(name: string) {
return name.replace(/(^|( +))(.)/g, (_0, _1, _2, l) => l.toUpperCase())
}
async function compileModel(mod: FlatJSONModel, fileName: string) {
const name = mod.name || fileName
const ma = loadFlatJSONModel(mod)
const m = await loadLayersModel({ load: () => Promise.resolve(ma) })
const inpTen = m.getInputAt(0) as SymbolicTensor
const numClasses = shapeElements(
(m.getOutputAt(0) as SymbolicTensor).shape
)
const labels = (mod.labels || []).slice()
while (labels.length > numClasses) labels.pop()
while (labels.length < numClasses) labels.push("class " + labels.length)
const inputShape = inpTen.shape
const samplingPeriod = mod.inputInterval || 100
setStatus("compiling...") // can't see that...
const res = await compileModelAndFullValidate(m, {
verbose: false,
includeTest: true,
float16weights: options.f16,
optimize: true,
})
setStatus("compiled!")
const shape2 = inputShape.filter(v => v != null)
const samplesInWindow = shape2.shift()
const elementsInSample = shapeElements(shape2)
let code =
`// model: ${name}; input: ${JSON.stringify(
inputShape
)}; sampling at: ${samplingPeriod}ms\n` +
`// ${res.memInfo}\n` +
`// ${res.timeInfo}\n`
code += "const enum MLEvent {\n"
let idx = 0
for (let lbl of labels) {
lbl = lbl.replace(/_/g, " ")
code += ` //% block="${lbl}"\n`
code += ` ${toCamelCase(lbl)} = ${idx},\n`
idx++
}
code += `}\n\n`
code += `namespace ml {\n`
code += `
let _classifier: Classifier
export function classifier() {
if (_classifier) return _classifier
_classifier = new Classifier(input => _model.invoke(input), _sample)
_classifier.detectionThreshold = 0.7
_classifier.samplingInterval = ${Math.round(
samplingPeriod
)} // ms
_classifier.samplesOverlap = ${Math.max(
samplesInWindow >> 2,
1
)}
_classifier.samplesInWindow = ${samplesInWindow}
_classifier.elementsInSample = ${elementsInSample}
_classifier.noiseClassNo = -1 // disable
_classifier.noiseSuppressionTime = 500 // ms
_classifier.start()
return _classifier
}
/**
* Run some code when a particular ML event is detected.
*/
//% blockId=ml_on_event block="on ml event %condition"
//% blockGap=12
export function onEvent(mlevent: MLEvent, handler: () => void) {
classifier().onEvent(mlevent, handler)
}
`
let sample = fakeSample
if (elementsInSample == 1) sample = buttonSample
else if (elementsInSample == 3) sample = accelSample
const exampleSample = []
for (let i = 0; i < elementsInSample; ++i) exampleSample.push(i)
code +=
"\n" +
sample.replace("@sample@", JSON.stringify(exampleSample)) +
"\n"
code += `export const _model = new ml4f.Model(\n` + "hex`"
for (let i = 0; i < res.machineCode.length; ++i) {
code += ("0" + res.machineCode[i].toString(16)).slice(-2)
if ((i + 3) % 32 == 0) code += "\n"
}
code += "`);\n"
code += "\n} // namespace ml\n"
console.log(code.replace(/([a-f0-9]{64}\n)+/, "..."))
await pxtClient.write(code)
setStatus("done; you can Go back now")
}
function stopEv(e: Event) {
e.stopPropagation()
e.preventDefault()
}
function div(text: string): HTMLDivElement {
const d = document.createElement("div")
d.textContent = text
return d
}
function setError(msg: string) {
status.style.color = "red"
status.textContent = "Error: " + msg
}
function setStatus(msg: string) {
status.style.color = "green"
status.textContent = msg
}
function addCheckbox(field: string, name: string) {
const lbl = document.createElement("label")
lbl.textContent = name
const box = document.createElement("input")
lbl.prepend(box)
box.type = "checkbox"
box.checked = !!options[field]
box.addEventListener("change", () => {
if (box.checked) options[field] = !!box.checked
})
maindiv.appendChild(lbl)
}
}