зеркало из https://github.com/microsoft/ml4f.git
Add --eval option
This commit is contained in:
Родитель
74c222e6cf
Коммит
1e26221f28
|
@ -5,6 +5,7 @@ import * as child_process from 'child_process'
|
|||
import { program as commander } from "commander"
|
||||
import {
|
||||
compileModel, compileModelAndFullValidate,
|
||||
evalModel,
|
||||
Options, sampleModel, testAllModels,
|
||||
testFloatConv
|
||||
} from '../..'
|
||||
|
@ -18,6 +19,7 @@ interface CmdOptions {
|
|||
testAll?: boolean
|
||||
optimize?: boolean
|
||||
float16?: boolean
|
||||
eval?: string
|
||||
}
|
||||
|
||||
let options: CmdOptions
|
||||
|
@ -160,13 +162,19 @@ async function processModelFile(modelFile: string) {
|
|||
write(".js", cres.js)
|
||||
write(".bin", cres.machineCode)
|
||||
|
||||
if (options.eval) {
|
||||
const ev = evalModel(cres, JSON.parse(fs.readFileSync(options.eval, "utf8")))
|
||||
console.log(`\n*** ${built("model.bin")}\n${ev}`)
|
||||
}
|
||||
|
||||
console.log(cres.memInfo)
|
||||
console.log(cres.timeInfo)
|
||||
|
||||
function write(ext: string, buf: string | Uint8Array) {
|
||||
const fn = built("model" + ext)
|
||||
const binbuf = typeof buf == "string" ? Buffer.from(buf, "utf8") : buf
|
||||
console.log(`write ${fn} (${binbuf.length} bytes)`)
|
||||
if (!options.eval)
|
||||
console.log(`write ${fn} (${binbuf.length} bytes)`)
|
||||
fs.writeFileSync(fn, binbuf)
|
||||
}
|
||||
}
|
||||
|
@ -174,6 +182,9 @@ async function processModelFile(modelFile: string) {
|
|||
export async function mainCli() {
|
||||
// require('@tensorflow/tfjs-node');
|
||||
|
||||
// shut up warning
|
||||
(tf.backend() as any).firstUse = false;
|
||||
|
||||
const pkg = require("../../package.json")
|
||||
commander
|
||||
.version(pkg.version)
|
||||
|
@ -184,6 +195,7 @@ export async function mainCli() {
|
|||
.option("-t, --test-data", "include test data in binary model")
|
||||
.option("-T, --test-all", "test all included sample models")
|
||||
.option("-s, --sample-model <name>", "use an included sample model")
|
||||
.option("-e, --eval <file.json>", "evaluate model on given test data")
|
||||
.option("-o, --output <folder>", "path to store compilation results (default: 'built')")
|
||||
.arguments("<model>")
|
||||
.parse(process.argv)
|
||||
|
|
|
@ -200,4 +200,61 @@ export async function testAllModels(opts: Options) {
|
|||
await compileModelAndFullValidate(m, opts)
|
||||
}
|
||||
console.log(`\n*** All OK (${Date.now() - t0}ms)\n`)
|
||||
}
|
||||
|
||||
export type EvalSample = number | number[] | number[][] | number[][][]
|
||||
export interface EvalData {
|
||||
x: EvalSample[]
|
||||
y: number[][]
|
||||
}
|
||||
|
||||
function flattenSample(s: EvalSample) {
|
||||
const res: number[] = []
|
||||
const rec = (v: any) => {
|
||||
if (Array.isArray(v))
|
||||
v.forEach(rec)
|
||||
else if (typeof v == "number")
|
||||
res.push(v)
|
||||
else
|
||||
throw new Error("invalid input")
|
||||
}
|
||||
rec(s)
|
||||
return res
|
||||
}
|
||||
|
||||
function argmax(r: ArrayLike<number>) {
|
||||
let maxI = 0
|
||||
let max = r[0]
|
||||
for (let i = 1; i < r.length; ++i) {
|
||||
if (r[i] > max) {
|
||||
max = r[i]
|
||||
maxI = i
|
||||
}
|
||||
}
|
||||
return maxI
|
||||
}
|
||||
|
||||
export function evalModel(cres: CompileResult, data: EvalData) {
|
||||
let numOK = 0
|
||||
const dim = data.y[0].length
|
||||
const confusion = U.range(dim).map(_ => U.range(dim).map(_ => 0))
|
||||
for (let i = 0; i < data.x.length; ++i) {
|
||||
const predProb = cres.execute(flattenSample(data.x[i]))
|
||||
const pred = argmax(predProb)
|
||||
const ok = argmax(data.y[i])
|
||||
confusion[pred][ok]++
|
||||
if (pred == ok) numOK++
|
||||
}
|
||||
|
||||
let r = ""
|
||||
|
||||
r += `Accuracy: ${(numOK / data.x.length).toFixed(4)}\n`
|
||||
for (let i = 0; i < dim; i++) {
|
||||
for (let j = 0; j < dim; j++) {
|
||||
r += (" " + confusion[i][j]).slice(-5)
|
||||
}
|
||||
r += "\n"
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
Загрузка…
Ссылка в новой задаче