зеркало из https://github.com/microsoft/ml4f.git
Fixes for loading
This commit is contained in:
Родитель
bf75e54910
Коммит
1c68157612
23
cli/cli.ts
23
cli/cli.ts
|
@ -19,6 +19,19 @@ function loadJSONModel(modelPath: string) {
|
|||
throw new Error("model not in JSON format")
|
||||
|
||||
const modelJSON = JSON.parse(modelBuf.toString("utf8")) as tf.io.ModelJSON
|
||||
|
||||
// remove regularizers, as we're not going to train the model, and unknown regularizers
|
||||
// cause it to fail to load
|
||||
const cfg = (modelJSON.modelTopology as any)?.model_config?.config
|
||||
for (const layer of cfg?.layers || []) {
|
||||
const layerConfig = layer?.config
|
||||
if (layerConfig) {
|
||||
layerConfig.bias_regularizer = null
|
||||
layerConfig.activity_regularizer = null
|
||||
layerConfig.bias_constraint = null
|
||||
}
|
||||
}
|
||||
|
||||
const model: tf.io.ModelArtifacts = {
|
||||
modelTopology: modelJSON.modelTopology,
|
||||
format: modelJSON.format,
|
||||
|
@ -132,6 +145,12 @@ export async function mainCli() {
|
|||
}
|
||||
|
||||
const modelFile = commander.args[0]
|
||||
const m = await tf.loadLayersModel({ load: () => loadModel(modelFile) })
|
||||
m.summary()
|
||||
|
||||
try {
|
||||
const model = loadModel(modelFile)
|
||||
const m = await tf.loadLayersModel({ load: () => model })
|
||||
m.summary()
|
||||
} catch (e) {
|
||||
console.error(e.stack)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"description": "",
|
||||
"scripts": {
|
||||
"build": "node node_modules/typescript/bin/tsc",
|
||||
"watch-cli": "node node_modules/typescript/bin/tsc --watch",
|
||||
"watch": "node node_modules/rollup/dist/bin/rollup -c rollup.config.ts -w"
|
||||
},
|
||||
"author": "",
|
||||
|
|
Загрузка…
Ссылка в новой задаче