From 1e26221f28bc776dd8374c020f9c60b2cf0fbfb0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 15 Oct 2020 13:54:36 +0200 Subject: [PATCH] Add --eval option --- cli/src/cli.ts | 14 ++++++++++++- src/testing.ts | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 30063b6..627d788 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -5,6 +5,7 @@ import * as child_process from 'child_process' import { program as commander } from "commander" import { compileModel, compileModelAndFullValidate, + evalModel, Options, sampleModel, testAllModels, testFloatConv } from '../..' @@ -18,6 +19,7 @@ interface CmdOptions { testAll?: boolean optimize?: boolean float16?: boolean + eval?: string } let options: CmdOptions @@ -160,13 +162,19 @@ async function processModelFile(modelFile: string) { write(".js", cres.js) write(".bin", cres.machineCode) + if (options.eval) { + const ev = evalModel(cres, JSON.parse(fs.readFileSync(options.eval, "utf8"))) + console.log(`\n*** ${built("model.bin")}\n${ev}`) + } + console.log(cres.memInfo) console.log(cres.timeInfo) function write(ext: string, buf: string | Uint8Array) { const fn = built("model" + ext) const binbuf = typeof buf == "string" ? Buffer.from(buf, "utf8") : buf - console.log(`write ${fn} (${binbuf.length} bytes)`) + if (!options.eval) + console.log(`write ${fn} (${binbuf.length} bytes)`) fs.writeFileSync(fn, binbuf) } } @@ -174,6 +182,9 @@ async function processModelFile(modelFile: string) { export async function mainCli() { // require('@tensorflow/tfjs-node'); + // shut up warning + (tf.backend() as any).firstUse = false; + const pkg = require("../../package.json") commander .version(pkg.version) @@ -184,6 +195,7 @@ export async function mainCli() { .option("-t, --test-data", "include test data in binary model") .option("-T, --test-all", "test all included sample models") .option("-s, --sample-model ", "use an included sample model") + .option("-e, --eval ", "evaluate model on given test data") .option("-o, --output ", "path to store compilation results (default: 'built')") .arguments("") .parse(process.argv) diff --git a/src/testing.ts b/src/testing.ts index 672ce82..f7eb675 100644 --- a/src/testing.ts +++ b/src/testing.ts @@ -200,4 +200,61 @@ export async function testAllModels(opts: Options) { await compileModelAndFullValidate(m, opts) } console.log(`\n*** All OK (${Date.now() - t0}ms)\n`) +} + +export type EvalSample = number | number[] | number[][] | number[][][] +export interface EvalData { + x: EvalSample[] + y: number[][] +} + +function flattenSample(s: EvalSample) { + const res: number[] = [] + const rec = (v: any) => { + if (Array.isArray(v)) + v.forEach(rec) + else if (typeof v == "number") + res.push(v) + else + throw new Error("invalid input") + } + rec(s) + return res +} + +function argmax(r: ArrayLike) { + let maxI = 0 + let max = r[0] + for (let i = 1; i < r.length; ++i) { + if (r[i] > max) { + max = r[i] + maxI = i + } + } + return maxI +} + +export function evalModel(cres: CompileResult, data: EvalData) { + let numOK = 0 + const dim = data.y[0].length + const confusion = U.range(dim).map(_ => U.range(dim).map(_ => 0)) + for (let i = 0; i < data.x.length; ++i) { + const predProb = cres.execute(flattenSample(data.x[i])) + const pred = argmax(predProb) + const ok = argmax(data.y[i]) + confusion[pred][ok]++ + if (pred == ok) numOK++ + } + + let r = "" + + r += `Accuracy: ${(numOK / data.x.length).toFixed(4)}\n` + for (let i = 0; i < dim; i++) { + for (let j = 0; j < dim; j++) { + r += (" " + confusion[i][j]).slice(-5) + } + r += "\n" + } + + return r } \ No newline at end of file