This commit is contained in:
Michal Moskal 2020-10-15 13:54:36 +02:00
Родитель 74c222e6cf
Коммит 1e26221f28
2 изменённых файлов: 70 добавлений и 1 удалений

Просмотреть файл

@ -5,6 +5,7 @@ import * as child_process from 'child_process'
import { program as commander } from "commander"
import {
compileModel, compileModelAndFullValidate,
evalModel,
Options, sampleModel, testAllModels,
testFloatConv
} from '../..'
@ -18,6 +19,7 @@ interface CmdOptions {
testAll?: boolean
optimize?: boolean
float16?: boolean
eval?: string
}
let options: CmdOptions
@ -160,13 +162,19 @@ async function processModelFile(modelFile: string) {
write(".js", cres.js)
write(".bin", cres.machineCode)
if (options.eval) {
const ev = evalModel(cres, JSON.parse(fs.readFileSync(options.eval, "utf8")))
console.log(`\n*** ${built("model.bin")}\n${ev}`)
}
console.log(cres.memInfo)
console.log(cres.timeInfo)
function write(ext: string, buf: string | Uint8Array) {
const fn = built("model" + ext)
const binbuf = typeof buf == "string" ? Buffer.from(buf, "utf8") : buf
console.log(`write ${fn} (${binbuf.length} bytes)`)
if (!options.eval)
console.log(`write ${fn} (${binbuf.length} bytes)`)
fs.writeFileSync(fn, binbuf)
}
}
@ -174,6 +182,9 @@ async function processModelFile(modelFile: string) {
export async function mainCli() {
// require('@tensorflow/tfjs-node');
// shut up warning
(tf.backend() as any).firstUse = false;
const pkg = require("../../package.json")
commander
.version(pkg.version)
@ -184,6 +195,7 @@ export async function mainCli() {
.option("-t, --test-data", "include test data in binary model")
.option("-T, --test-all", "test all included sample models")
.option("-s, --sample-model <name>", "use an included sample model")
.option("-e, --eval <file.json>", "evaluate model on given test data")
.option("-o, --output <folder>", "path to store compilation results (default: 'built')")
.arguments("<model>")
.parse(process.argv)

Просмотреть файл

@ -200,4 +200,61 @@ export async function testAllModels(opts: Options) {
await compileModelAndFullValidate(m, opts)
}
console.log(`\n*** All OK (${Date.now() - t0}ms)\n`)
}
export type EvalSample = number | number[] | number[][] | number[][][]
export interface EvalData {
x: EvalSample[]
y: number[][]
}
function flattenSample(s: EvalSample) {
const res: number[] = []
const rec = (v: any) => {
if (Array.isArray(v))
v.forEach(rec)
else if (typeof v == "number")
res.push(v)
else
throw new Error("invalid input")
}
rec(s)
return res
}
function argmax(r: ArrayLike<number>) {
let maxI = 0
let max = r[0]
for (let i = 1; i < r.length; ++i) {
if (r[i] > max) {
max = r[i]
maxI = i
}
}
return maxI
}
export function evalModel(cres: CompileResult, data: EvalData) {
let numOK = 0
const dim = data.y[0].length
const confusion = U.range(dim).map(_ => U.range(dim).map(_ => 0))
for (let i = 0; i < data.x.length; ++i) {
const predProb = cres.execute(flattenSample(data.x[i]))
const pred = argmax(predProb)
const ok = argmax(data.y[i])
confusion[pred][ok]++
if (pred == ok) numOK++
}
let r = ""
r += `Accuracy: ${(numOK / data.x.length).toFixed(4)}\n`
for (let i = 0; i < dim; i++) {
for (let j = 0; j < dim; j++) {
r += (" " + confusion[i][j]).slice(-5)
}
r += "\n"
}
return r
}