diff --git a/Source/CNTKv2LibraryDll/PrimitiveFunction.h b/Source/CNTKv2LibraryDll/PrimitiveFunction.h index 0f5f3df5a..68666491f 100644 --- a/Source/CNTKv2LibraryDll/PrimitiveFunction.h +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.h @@ -762,7 +762,7 @@ namespace CNTK if (i < operands.size() - 1) { - if (inferDimensions && ((paramShape.Rank() == 1) && paramShape.HasInferredDimension()) && !mainOperandShape.HasUnboundDimension()) + if (inferDimensions && ((paramShape.Rank() == 1) && paramShape.HasInferredDimension()) && (!mainOperandShape.HasUnboundDimension() || (spatial && mainOperandShape[mainOperandShape.Rank() - 1] != NDShape::FreeDimension))) { size_t total = spatial ? mainOperandShape[mainOperandShape.Rank() - 1] : mainOperandShape.TotalSize(); paramShape[0] = total; diff --git a/bindings/python/cntk/ops/tests/non_linear_test.py b/bindings/python/cntk/ops/tests/non_linear_test.py index c8282a5b7..1dc2b38c7 100644 --- a/bindings/python/cntk/ops/tests/non_linear_test.py +++ b/bindings/python/cntk/ops/tests/non_linear_test.py @@ -612,7 +612,7 @@ def test_op_batch_normalization(use_cudnn, sample, device_id, precision): forward_input = {a: t} unittest_helper(op_node, forward_input, expected_forward, expected_backward=None, device_id=device_id, precision=precision) - + @pytest.mark.parametrize("shape", [(1,), (16,), (16,32,), (16,32,32,)]) @pytest.mark.parametrize("spatial", [True, False]) def test_op_batch_normalization_numpy(shape, spatial, device_id, precision): @@ -666,7 +666,7 @@ def test_op_batch_normalization_numpy(shape, spatial, device_id, precision): var_out = var * reduced_count / (reduced_count - 1) var_b = np.asarray([[np.ones(reduced_shape)*x for x in var]]*batch_size) x_hat = (x - mean_b) / np.sqrt(var_b + epsilon) - y = init_scale * x_hat + init_bias; + y = init_scale * x_hat + init_bias d_scale = np.sum(x_hat, reduce_dims) d_bias = np.sum(np.ones_like(x_hat), reduce_dims) @@ -678,6 +678,37 @@ def test_op_batch_normalization_numpy(shape, spatial, device_id, precision): assert(np.allclose(init_mean * (1-exp_avg) + mean.reshape(param_shape) * exp_avg, run_mean.value)) assert(run_count.value == init_count + batch_size) +@pytest.mark.parametrize("channels", [1, 16]) +@pytest.mark.parametrize("input_size", [32, C.FreeDimension, C.InferredDimension]) +def test_op_batch_normalization_spatial_shape_inference(channels, input_size, device_id, precision): + dtype = PRECISION_TO_TYPE[precision] + dev = cntk_device(device_id) + + spatial = True + epsilon = 0.01 + + init_scale = 1 + init_bias = 2 + init_mean = 3 + init_var = 4 + init_count = 2 + + shape = (channels, input_size, input_size) + param_shape = (C.InferredDimension,) + + i = C.input_variable(shape, dtype=dtype) + scale = C.parameter(param_shape, init=init_scale, dtype=dtype, device=dev) + bias = C.parameter(param_shape, init=init_bias, dtype=dtype, device=dev) + run_mean = C.constant(init_mean, shape=param_shape, dtype=dtype, device=dev) + run_var = C.constant(init_var, shape=param_shape, dtype=dtype, device=dev) + run_count = C.constant(init_count, shape=(), dtype=dtype, device=dev) + + bn = C.batch_normalization(i, scale, bias, run_mean, run_var, spatial, normalization_time_constant=-1, epsilon=epsilon, running_count = run_count) + + for param in [scale, bias, run_mean, run_var]: + assert(param.shape == (channels,)) + + def test_local_response_normalization(device_id, precision): dtype = PRECISION_TO_TYPE[precision] dev = cntk_device(device_id)