Use the common pxt-ml library
This commit is contained in:
Родитель
9a92b39ee6
Коммит
32760e40c6
122
binstore.cpp
122
binstore.cpp
|
@ -1,122 +0,0 @@
|
|||
#include "pxt.h"
|
||||
#include "Flash.h"
|
||||
|
||||
// TODO move this to separate package
|
||||
|
||||
namespace settings {
|
||||
uintptr_t largeStoreStart();
|
||||
size_t largeStoreSize();
|
||||
CODAL_FLASH *largeStoreFlash();
|
||||
} // namespace settings
|
||||
|
||||
namespace binstore {
|
||||
|
||||
/**
|
||||
* Returns the maximum allowed size of binstore buffers.
|
||||
*/
|
||||
//%
|
||||
uint32_t totalSize() {
|
||||
return settings::largeStoreStart() ? settings::largeStoreSize() : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear storage.
|
||||
*/
|
||||
//%
|
||||
int erase() {
|
||||
size_t sz = settings::largeStoreSize();
|
||||
uintptr_t beg = settings::largeStoreStart();
|
||||
if (!beg)
|
||||
return -1;
|
||||
auto flash = settings::largeStoreFlash();
|
||||
|
||||
uintptr_t p = beg;
|
||||
uintptr_t end = beg + sz;
|
||||
while (p < end) {
|
||||
DMESG("erase at %p", p);
|
||||
if (flash->erasePage(p))
|
||||
return -2;
|
||||
p += flash->pageSize(p);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//%
|
||||
RefCollection *buffers() {
|
||||
auto res = Array_::mk();
|
||||
registerGCObj(res);
|
||||
uintptr_t p = settings::largeStoreStart();
|
||||
if (p) {
|
||||
uintptr_t end = p + settings::largeStoreSize();
|
||||
for (;;) {
|
||||
BoxedBuffer *buf = (BoxedBuffer *)p;
|
||||
if (buf->vtable != (uint32_t)&pxt::buffer_vt)
|
||||
break;
|
||||
res->head.push((TValue)buf);
|
||||
p += (8 + buf->length + 7) & ~7;
|
||||
if (p >= end)
|
||||
break;
|
||||
}
|
||||
}
|
||||
unregisterGCObj(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a buffer of given size to binstore.
|
||||
*/
|
||||
//%
|
||||
Buffer addBuffer(uint32_t size) {
|
||||
uintptr_t p = settings::largeStoreStart();
|
||||
if (!p)
|
||||
return NULL;
|
||||
|
||||
BoxedBuffer *buf;
|
||||
uintptr_t end = p + settings::largeStoreSize();
|
||||
for (;;) {
|
||||
buf = (BoxedBuffer *)p;
|
||||
if (buf->vtable != (uint32_t)&pxt::buffer_vt)
|
||||
break;
|
||||
p += (8 + buf->length + 7) & ~7;
|
||||
if (p >= end)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (buf->vtable + 1 || buf->length + 1)
|
||||
return NULL;
|
||||
|
||||
auto flash = settings::largeStoreFlash();
|
||||
uint32_t header[] = {(uint32_t)&pxt::buffer_vt, size};
|
||||
if (flash->writeBytes(p, header, sizeof(header)))
|
||||
return NULL;
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
PXT_DEF_STRING(sNotAligned, "binstore: not aligned")
|
||||
PXT_DEF_STRING(sOOR, "binstore: out of range")
|
||||
PXT_DEF_STRING(sWriteError, "binstore: write failure")
|
||||
PXT_DEF_STRING(sNotErased, "binstore: not erased")
|
||||
|
||||
/**
|
||||
* Write bytes in a binstore buffer.
|
||||
*/
|
||||
//%
|
||||
void write(Buffer dst, int dstOffset, Buffer src) {
|
||||
if (dstOffset & 7)
|
||||
pxt::throwValue((TValue)sNotAligned);
|
||||
if (dstOffset < 0 || dstOffset + src->length > dst->length)
|
||||
pxt::throwValue((TValue)sOOR);
|
||||
|
||||
auto flash = settings::largeStoreFlash();
|
||||
uint32_t len = (src->length + 7) & ~7;
|
||||
for (unsigned i = 0; i < len; ++i)
|
||||
if (dst->data[dstOffset + i] != 0xff)
|
||||
pxt::throwValue((TValue)sNotErased);
|
||||
if (flash->writeBytes((uintptr_t)dst->data + dstOffset, src->data, len))
|
||||
pxt::throwValue((TValue)sWriteError);
|
||||
}
|
||||
|
||||
} // namespace binstore
|
|
@ -1,7 +0,0 @@
|
|||
namespace binstore {
|
||||
/**
|
||||
* Returns all buffers currently allocated in binstore.
|
||||
*/
|
||||
//% shim=binstore::buffers
|
||||
export declare function buffers(): Buffer[];
|
||||
}
|
3
pxt.json
3
pxt.json
|
@ -3,6 +3,7 @@
|
|||
"version": "0.0.1",
|
||||
"description": "Run Tensorflow Lite models in MakeCode - beta",
|
||||
"dependencies": {
|
||||
"ml": "github:microsoft/pxt-ml#v0.0.1",
|
||||
"core": "*"
|
||||
},
|
||||
"files": [
|
||||
|
@ -11,8 +12,6 @@
|
|||
"tf.ts",
|
||||
"shims.d.ts",
|
||||
"tfjacdac.ts",
|
||||
"binstore.cpp",
|
||||
"binstore.ts",
|
||||
"ns.ts"
|
||||
],
|
||||
"codal": {
|
||||
|
|
|
@ -19,31 +19,5 @@ declare namespace tf {
|
|||
//% shim=tf::arenaBytes
|
||||
function arenaBytes(): uint32;
|
||||
}
|
||||
declare namespace binstore {
|
||||
|
||||
/**
|
||||
* Returns the maximum allowed size of binstore buffers.
|
||||
*/
|
||||
//% shim=binstore::totalSize
|
||||
function totalSize(): uint32;
|
||||
|
||||
/**
|
||||
* Clear storage.
|
||||
*/
|
||||
//% shim=binstore::erase
|
||||
function erase(): int32;
|
||||
|
||||
/**
|
||||
* Add a buffer of given size to binstore.
|
||||
*/
|
||||
//% shim=binstore::addBuffer
|
||||
function addBuffer(size: uint32): Buffer;
|
||||
|
||||
/**
|
||||
* Write bytes in a binstore buffer.
|
||||
*/
|
||||
//% shim=binstore::write
|
||||
function write(dst: Buffer, dstOffset: int32, src: Buffer): void;
|
||||
}
|
||||
|
||||
// Auto-generated. Do not edit. Really.
|
||||
|
|
179
tfjacdac.ts
179
tfjacdac.ts
|
@ -3,104 +3,15 @@ namespace jd_class {
|
|||
}
|
||||
|
||||
namespace jacdac {
|
||||
export enum TFLiteCmd {
|
||||
/**
|
||||
* Argument: model_size bytes uint32_t. Open pipe for streaming in the model. The size of the model has to be declared upfront.
|
||||
* The model is streamed over regular pipe data packets, in the `.tflite` flatbuffer format.
|
||||
* When the pipe is closed, the model is written all into flash, and the device running the service may reset.
|
||||
*/
|
||||
SetModel = 0x80,
|
||||
|
||||
/**
|
||||
* Argument: outputs pipe (bytes). Open channel that can be used to manually invoke the model. When enough data is sent over the `inputs` pipe, the model is invoked,
|
||||
* and results are send over the `outputs` pipe.
|
||||
*/
|
||||
Predict = 0x81,
|
||||
}
|
||||
|
||||
export enum TFLiteReg {
|
||||
/**
|
||||
* Read-write uint16_t. When register contains `N > 0`, run the model automatically every time new `N` samples are collected.
|
||||
* Model may be run less often if it takes longer to run than `N * sampling_interval`.
|
||||
* The `outputs` register will stream its value after each run.
|
||||
* This register is not stored in flash.
|
||||
*/
|
||||
AutoInvokeEvery = 0x80,
|
||||
|
||||
/** Read-only bytes. Results of last model invocation as `float32` array. */
|
||||
Outputs = 0x101,
|
||||
|
||||
/** Read-only dimension uint16_t. The shape of the input tensor. */
|
||||
InputShape = 0x180,
|
||||
|
||||
/** Read-only dimension uint16_t. The shape of the output tensor. */
|
||||
OutputShape = 0x181,
|
||||
|
||||
/** Read-only μs uint32_t. The time consumed in last model execution. */
|
||||
LastRunTime = 0x182,
|
||||
|
||||
/** Read-only bytes uint32_t. Number of RAM bytes allocated for model execution. */
|
||||
AllocatedArenaSize = 0x183,
|
||||
|
||||
/** Read-only bytes uint32_t. The size of `.tflite` model in bytes. */
|
||||
ModelSize = 0x184,
|
||||
|
||||
/** Read-only string (bytes). Textual description of last error when running or loading model (if any). */
|
||||
LastError = 0x185,
|
||||
}
|
||||
|
||||
function packArray(arr: number[], fmt: NumberFormat) {
|
||||
const sz = Buffer.sizeOfNumberFormat(fmt)
|
||||
const res = Buffer.create(arr.length * sz)
|
||||
for (let i = 0; i < arr.length; ++i)
|
||||
res.setNumber(fmt, i * sz, arr[i])
|
||||
return res
|
||||
}
|
||||
|
||||
|
||||
const arenaSizeSettingsKey = "#jd-tflite-arenaSize"
|
||||
|
||||
|
||||
export class TFLiteHost extends Host {
|
||||
private autoInvokeSamples = 0
|
||||
private execTime = 0
|
||||
private outputs = Buffer.create(0)
|
||||
private lastError: string
|
||||
private lastRunNumSamples = 0
|
||||
|
||||
constructor(private agg: SensorAggregatorHost) {
|
||||
super("tflite", jd_class.TFLITE);
|
||||
agg.newDataCallback = () => {
|
||||
if (this.autoInvokeSamples && this.lastRunNumSamples >= 0 &&
|
||||
this.numSamples - this.lastRunNumSamples >= this.autoInvokeSamples) {
|
||||
this.lastRunNumSamples = -1
|
||||
control.runInBackground(() => this.runModel())
|
||||
}
|
||||
}
|
||||
export class TFLiteHost extends MLHost {
|
||||
constructor(agg: SensorAggregatorHost) {
|
||||
super("tflite", jd_class.TFLITE, agg);
|
||||
}
|
||||
|
||||
get numSamples() {
|
||||
return this.agg.numSamples
|
||||
}
|
||||
|
||||
get modelBuffer() {
|
||||
const bufs = binstore.buffers()
|
||||
if (!bufs || !bufs[0]) return null
|
||||
if (bufs[0].getNumber(NumberFormat.Int32LE, 0) == -1)
|
||||
return null
|
||||
return bufs[0]
|
||||
}
|
||||
|
||||
get modelSize() {
|
||||
const m = this.modelBuffer
|
||||
if (m) return m.length
|
||||
else return 0
|
||||
}
|
||||
|
||||
private runModel() {
|
||||
if (this.lastError) return
|
||||
const numSamples = this.numSamples
|
||||
const t0 = control.micros()
|
||||
protected invokeModel() {
|
||||
try {
|
||||
const res = tf.invokeModelF([this.agg.samplesBuffer])
|
||||
this.outputs = packArray(res[0], NumberFormat.Float32LE)
|
||||
|
@ -109,29 +20,15 @@ namespace jacdac {
|
|||
this.lastError = e
|
||||
control.dmesgValue(e)
|
||||
}
|
||||
this.execTime = control.micros() - t0
|
||||
this.lastRunNumSamples = numSamples
|
||||
this.sendReport(JDPacket.from(CMD_GET_REG | TFLiteReg.Outputs, this.outputs))
|
||||
}
|
||||
|
||||
start() {
|
||||
super.start()
|
||||
this.agg.start()
|
||||
this.loadModel()
|
||||
}
|
||||
|
||||
private eraseModel() {
|
||||
protected eraseModel() {
|
||||
tf.freeModel()
|
||||
binstore.erase()
|
||||
settings.remove(arenaSizeSettingsKey)
|
||||
}
|
||||
|
||||
private loadModel() {
|
||||
this.lastError = null
|
||||
if (!this.modelBuffer) {
|
||||
this.lastError = "no model"
|
||||
return
|
||||
}
|
||||
protected loadModelImpl() {
|
||||
try {
|
||||
const sizeHint = settings.readNumber(arenaSizeSettingsKey)
|
||||
tf.loadModel(this.modelBuffer, sizeHint)
|
||||
|
@ -144,68 +41,12 @@ namespace jacdac {
|
|||
}
|
||||
}
|
||||
|
||||
private readModel(packet: JDPacket) {
|
||||
const sz = packet.intData
|
||||
console.log(`model ${sz} bytes (of ${binstore.totalSize()})`)
|
||||
if (sz > binstore.totalSize() - 8)
|
||||
return
|
||||
this.eraseModel()
|
||||
const flash = binstore.addBuffer(sz)
|
||||
const pipe = new InPipe()
|
||||
this.sendReport(JDPacket.packed(packet.service_command, "H", [pipe.port]))
|
||||
console.log(`pipe ${pipe.port}`)
|
||||
let off = 0
|
||||
const headBuffer = Buffer.create(8)
|
||||
while (true) {
|
||||
const buf = pipe.read()
|
||||
if (!buf)
|
||||
return
|
||||
if (off == 0) {
|
||||
// don't write the header before we finish
|
||||
headBuffer.write(0, buf)
|
||||
binstore.write(flash, 8, buf.slice(8))
|
||||
} else {
|
||||
binstore.write(flash, off, buf)
|
||||
}
|
||||
off += buf.length
|
||||
if (off >= sz) {
|
||||
// now that we're done, write the header
|
||||
binstore.write(flash, 0, headBuffer)
|
||||
// and reset, so we're sure the GC heap is not fragmented when we allocate new arena
|
||||
//control.reset()
|
||||
break
|
||||
}
|
||||
if (off & 7)
|
||||
throw "invalid model stream size"
|
||||
}
|
||||
pipe.close()
|
||||
this.loadModel()
|
||||
get inputShape(): number[] {
|
||||
return tf.inputShape(0)
|
||||
}
|
||||
|
||||
handlePacket(packet: JDPacket) {
|
||||
this.handleRegInt(packet, TFLiteReg.AllocatedArenaSize, tf.arenaBytes())
|
||||
this.handleRegInt(packet, TFLiteReg.LastRunTime, this.execTime)
|
||||
this.handleRegInt(packet, TFLiteReg.ModelSize, this.modelSize)
|
||||
this.handleRegBuffer(packet, TFLiteReg.Outputs, this.outputs)
|
||||
this.autoInvokeSamples = this.handleRegInt(packet, TFLiteReg.AutoInvokeEvery, this.autoInvokeSamples)
|
||||
|
||||
let arr: number[]
|
||||
switch (packet.service_command) {
|
||||
case TFLiteCmd.SetModel:
|
||||
control.runInBackground(() => this.readModel(packet))
|
||||
break
|
||||
case TFLiteReg.OutputShape | CMD_GET_REG:
|
||||
arr = tf.outputShape(0)
|
||||
case TFLiteReg.InputShape | CMD_GET_REG:
|
||||
arr = arr || tf.inputShape(0)
|
||||
this.sendReport(JDPacket.from(packet.service_command, packArray(arr, NumberFormat.UInt16LE)))
|
||||
break;
|
||||
case TFLiteReg.LastError | CMD_GET_REG:
|
||||
this.sendReport(JDPacket.from(packet.service_command, Buffer.fromUTF8(this.lastError || "")))
|
||||
break
|
||||
default:
|
||||
break;
|
||||
}
|
||||
get outputShape(): number[] {
|
||||
return tf.outputShape(0)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче