зеркало из https://github.com/microsoft/ml4f.git
add support for single-dimensional batch-norm; fixes #18
This commit is contained in:
Родитель
19cb3389bc
Коммит
c507f38bbf
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"name": "Launch Program",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
],
|
||||
"program": "${workspaceFolder}/built/cli.cjs",
|
||||
"args": ["--debug", "tmp/models/1/test-model.json"],
|
||||
// "preLaunchTask": "tsc: build - tsconfig.json",
|
||||
"outFiles": [
|
||||
"${workspaceFolder}/built/**/*.js"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -557,13 +557,17 @@ function compileBatchNorm(info: LayerInfo) {
|
|||
const flashRegs = numFPRegs - 2
|
||||
const flashReg0 = Reg.S0 + 2
|
||||
|
||||
if (info.inputShape.length != 4)
|
||||
unsupported("inputShape: " + info.inputShape.length)
|
||||
let inpShape = info.inputShape
|
||||
if (inpShape.length == 2)
|
||||
inpShape = [inpShape[0], 1, 1, inpShape[1]]
|
||||
|
||||
if (inpShape.length != 4)
|
||||
unsupported("inputShape: " + inpShape.length)
|
||||
|
||||
if (config.dtype && config.dtype != "float32")
|
||||
unsupported("dtype: " + config.dtype)
|
||||
|
||||
const [_null, outh, outw, numch] = info.inputShape
|
||||
const [_null, outh, outw, numch] = inpShape
|
||||
|
||||
function readVar(name: string) {
|
||||
const r = info.layer.weights.find(w => w.originalName.endsWith("/" + name)).read().arraySync() as number[]
|
||||
|
|
|
@ -280,7 +280,19 @@ function getSampleModels(): SMap<tf.layers.Layer[]> {
|
|||
batch2: [
|
||||
tf.layers.inputLayer({ inputShape: [213, 1, 100] }),
|
||||
tf.layers.batchNormalization({})
|
||||
]
|
||||
],
|
||||
singleDimBatchNorm: [
|
||||
tf.layers.inputLayer({ inputShape: [24] }),
|
||||
tf.layers.batchNormalization({}),
|
||||
tf.layers.dense({
|
||||
units: 16,
|
||||
activation: "relu"
|
||||
}),
|
||||
tf.layers.dense({
|
||||
units: 3,
|
||||
activation: "softmax"
|
||||
})
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче