add support for single-dimensional batch-norm; fixes #18

This commit is contained in:
Michal Moskal 2024-05-27 12:06:53 -07:00
Родитель 19cb3389bc
Коммит c507f38bbf
3 изменённых файлов: 42 добавлений и 4 удалений

22
.vscode/launch.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,22 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "node",
"request": "launch",
"name": "Launch Program",
"skipFiles": [
"<node_internals>/**"
],
"program": "${workspaceFolder}/built/cli.cjs",
"args": ["--debug", "tmp/models/1/test-model.json"],
// "preLaunchTask": "tsc: build - tsconfig.json",
"outFiles": [
"${workspaceFolder}/built/**/*.js"
]
}
]
}

Просмотреть файл

@ -557,13 +557,17 @@ function compileBatchNorm(info: LayerInfo) {
const flashRegs = numFPRegs - 2
const flashReg0 = Reg.S0 + 2
if (info.inputShape.length != 4)
unsupported("inputShape: " + info.inputShape.length)
let inpShape = info.inputShape
if (inpShape.length == 2)
inpShape = [inpShape[0], 1, 1, inpShape[1]]
if (inpShape.length != 4)
unsupported("inputShape: " + inpShape.length)
if (config.dtype && config.dtype != "float32")
unsupported("dtype: " + config.dtype)
const [_null, outh, outw, numch] = info.inputShape
const [_null, outh, outw, numch] = inpShape
function readVar(name: string) {
const r = info.layer.weights.find(w => w.originalName.endsWith("/" + name)).read().arraySync() as number[]

Просмотреть файл

@ -280,7 +280,19 @@ function getSampleModels(): SMap<tf.layers.Layer[]> {
batch2: [
tf.layers.inputLayer({ inputShape: [213, 1, 100] }),
tf.layers.batchNormalization({})
]
],
singleDimBatchNorm: [
tf.layers.inputLayer({ inputShape: [24] }),
tf.layers.batchNormalization({}),
tf.layers.dense({
units: 16,
activation: "relu"
}),
tf.layers.dense({
units: 3,
activation: "softmax"
})
],
}
}