diff --git a/Source/ComputationNetworkLib/ComputationNetworkAnalysis.cpp b/Source/ComputationNetworkLib/ComputationNetworkAnalysis.cpp index 168ecda06..cc7f65c08 100644 --- a/Source/ComputationNetworkLib/ComputationNetworkAnalysis.cpp +++ b/Source/ComputationNetworkLib/ComputationNetworkAnalysis.cpp @@ -106,13 +106,13 @@ void ComputationNetwork::FormRecurrentLoops(const ComputationNodeBasePtr& rootNo assert(node->m_numNonDelayedParentsInLoop == 0); // (in PurgeStateForFormingRecurrentLoops()) } for (let& node : nestedNodes) - { + { for (auto& input : node->GetInputs()) - { + { if (input->m_loopId == node->m_loopId && GetRecurrenceSteppingDirection(node) == 0/*not a Delay node*/) input->m_numNonDelayedParentsInLoop++; // cound #parents of 'input' that are not delay nodes - } } + } // re-traverse the graph for all nestedNodes, starting with the first // Then update m_nestedNodes with the re-traversed order. diff --git a/Source/ComputationNetworkLib/ComputationNetworkScripting.cpp b/Source/ComputationNetworkLib/ComputationNetworkScripting.cpp index 1ab29bd94..cf634c16a 100644 --- a/Source/ComputationNetworkLib/ComputationNetworkScripting.cpp +++ b/Source/ComputationNetworkLib/ComputationNetworkScripting.cpp @@ -390,7 +390,7 @@ ScriptableObjects::ConfigurableRuntimeTypeRegister::Add roots; + for (let& outputNodeKV : outputNodes) + roots.push_back(outputNodeKV.second); + let allInputs = ComputationNodeBase::EnumerateNodes(roots); + + // determine all leaves and their dependents + dependentSet = set(inputNodes.begin(), inputNodes.end()); // start with the specified inputs + for (let& node : allInputs) + { + // add parameters that are to be cloned to dependent set + if (parameterTreatment != ParameterTreatment::shared && node->Is()) + dependentSet.insert(node); + // if at least one input is in the dependent set then this node is, too + else + for (let& input : node->GetInputs()) + if (dependentSet.find(input) != dependentSet.end()) + dependentSet.insert(node); + } + +#if 1 + for (let& node : dependentSet) + fprintf(stderr, "CloneFunction: cloning %ls\n", node->NodeDescription().c_str()); +#endif + + // ensure none of the specified inputs reference back into the cloned set + // The function we extract must be separable. + for (let& input : inputNodes) + for (let& node : ComputationNodeBase::EnumerateNodes(vector{input})) // check all indirect inputs of each specified input + { + let iter = dependentSet.find(input); + if (iter != dependentSet.end() && *iter != input) + InvalidArgument("CloneFunction: specified function input %ls recursively depends on %ls inside the function.", input->NodeDescription().c_str(), node->NodeDescription().c_str()); + } } private: @@ -509,9 +552,49 @@ private: // This will clone all nodes that the outputNodes depend on, and rewire inputs matching inputNodes to inputArgs. ConfigValuePtr DoClone(const vector& inputValues, const std::wstring& exprName) { + // resolve the input arguments vector inputs; for (let& inputValue : inputValues) - inputs.push_back(inputValue.ResolveValue()); // .AsPtr()); + inputs.push_back(inputValue.ResolveValue()); + assert(inputValues.size() == inputNodes.size()); // (this should have been checked by BrainScript) + + // clone everything in the dependent set + // - specified inputs get mapped to actual parameters + // - all others get duplicated + // Note that at this point, the "shared" option has already been considered, + // and is reflected in whether parameters are included or not in 'dependentSet'. + map clonedNodes; + for (size_t i = 0; i < inputNodes.size(); i++) + clonedNodes[inputNodes[i]] = inputs[i]; + for (let& node : dependentSet) + { + // if already there then it's an input that we just mapped above + if (clonedNodes.find(node) != clonedNodes.end()) + continue; + // clone + ComputationNodeBasePtr newNode; + let newName = exprName + L"." + node->GetName(); + newNode = node->Duplicate(newName, CopyNodeFlags::copyNodeAll); + // make it read-only if desired + if (parameterTreatment == ParameterTreatment::constant) + newNode->SetLearningRateMultiplier(0); + // and that's our cloned node + clonedNodes[node] = newNode; + } + + // all cloned nodes' inputs must be redirected if they reference a node that has been cloned as well + for (let& clonedNodesKV : clonedNodes) + { + let& node = clonedNodesKV.second; + let& inputs = node->GetInputs(); + for (size_t i = 0; i < inputs.size(); i++) + { + let iter = clonedNodes.find(inputs[i]); + if (iter != clonedNodes.end()) // input is also a cloned node + node->SetInput(i, iter->second); + } + } + return ConfigValuePtr(); } @@ -521,7 +604,7 @@ private: map outputNodes; ParameterTreatment parameterTreatment; // other - map clonedReadOnlyParameters; // if we clone 'constant' multiple times, we can share readOnly parameters + set dependentSet; // set of nodes that outputNodes depend on }; ScriptableObjects::ConfigurableRuntimeTypeRegister::Add registerCloneFunctionConfigLambda(L"CloneFunctionConfigLambda"); diff --git a/Source/ComputationNetworkLib/ComputationNode.h b/Source/ComputationNetworkLib/ComputationNode.h index aa5cf658a..ee651751c 100644 --- a/Source/ComputationNetworkLib/ComputationNode.h +++ b/Source/ComputationNetworkLib/ComputationNode.h @@ -184,7 +184,7 @@ protected: // TODO: should be fully encapsulated here bool m_needsGradient; // true if this node or any children need a gradient to be computed (for own consumption or propagation to somewhere in the child tree) bool m_valueSharable; // a flag is needed for memory share. - // If it is false (e.g., learnableParameters/InputValue and those nodes are solely induced by learnableParameters), + // If it is false (e.g., LearnableParameters/InputValue and those nodes are solely induced by LearnableParameters), // it will never be released to memory pool private: bool m_isPartOfLoop; // true if this loop is part of a recurrent loop @@ -1891,6 +1891,13 @@ public: struct IRecurrentNode { virtual int GetRecurrenceSteppingDirection() const = 0; }; +// ======================================================================= +// IParameterNode -- interface implemented by ComputationNodes that are parameters +// Note: There is possibly code that identifies parameters by the type name instead. Should be unified. +// ======================================================================= + +struct IParameterNode { virtual ~IParameterNode() { } }; + // ======================================================================= // PreComputedNodeBase -- interface implemented by ComputationNodes that precompute // TODO: We can use this interface in more places. diff --git a/Source/ComputationNetworkLib/InputAndParamNodes.h b/Source/ComputationNetworkLib/InputAndParamNodes.h index aaccb75ff..7f4e3049e 100644 --- a/Source/ComputationNetworkLib/InputAndParamNodes.h +++ b/Source/ComputationNetworkLib/InputAndParamNodes.h @@ -21,7 +21,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { // ----------------------------------------------------------------------- template -class LearnableParameter : public ComputationNode, public NumInputs<0> +class LearnableParameter : public ComputationNode, public NumInputs<0>, public IParameterNode { typedef ComputationNode Base; UsingComputationNodeMembersBoilerplate; static const std::wstring TypeName() { return L"LearnableParameter"; } diff --git a/Source/ComputationNetworkLib/TrainingNodes.h b/Source/ComputationNetworkLib/TrainingNodes.h index 1580de11b..ac51acb1f 100644 --- a/Source/ComputationNetworkLib/TrainingNodes.h +++ b/Source/ComputationNetworkLib/TrainingNodes.h @@ -1553,51 +1553,48 @@ template class DropoutNode; // where gamma and beta are trainable parameters(represented as LearnableParameter). // // * input is the input of the batch normalization node -// * scale is a LearnableParameter that stores scale vector(gamma term in the equation above). -// * bias is a LearnableParameter that stores bias vector(beta term). scale and bias must have the same dimensions which must be equal +// * scale is a LearnableParameter that stores scale vector (gamma term in the equation above). +// * bias is a LearnableParameter that stores bias vector (beta term). scale and bias must have the same dimensions which must be equal // to the input dimensions in case of spatial = false or number of output convolution feature maps in case of spatial = true. // * runMean is the running mean which is used during evaluation phase and might be used during training as well. // It is represented as a LearnableParameter with the same dimensions as scale and bias. // * runInvStdDev is the running inverse square root of variance(so InvStdDev = 1 / sqrt(var + epsilon)). // It is represented as a LearnableParameter with the same dimensions as scale and bias. // * spatial is a flag that specifies whether to compute mean / var for each feature in a mininbatch independently or, in case of convolutional layers, per feature map. +// TODO: This must be configured in a generic fashion where tensor axes are chosen along which parameters are tied. // * normalizationTimeConstant is the time constant which is used to compute running average of mean and variance. -// Value 0 (default) means there will be no exponential smoothing and running mean / variance will always have values computed for the last seen mininbatch. -// Value 1#INF (infinity)means running values are "frozen" (i.e.will not be updated). +// Value 0 (default) means there will be no exponential smoothing and running mean/variance will always have values computed for the last seen mininbatch. +// Value 1#INF (infinity) means running values are "frozen" (i.e.will not be updated). // * blendTimeConstant is the time constant which allows to specify how much of running mean / var should be "blended" into mean / var of the current minibatch. // Value 0 (default) means no blending will happen and only the current minibatch statistics will be used. -// Value 1#INF (infinity)means only running mean / var will be used(this is used, for example, in evaluation phase). +// Value 1#INF (infinity) means only running mean / var will be used(this is used, for example, in evaluation phase). // * epsilon is a conditioner constant used in computing InvStdDev -// * useCntkEngine is a boolean flag that specifies which batch normalization implementation to use : CNTK or cuDNN - based. -// * imageLayout is the image layout.Only cudnn is supported. +// * useCntkEngine is a boolean flag that specifies which batch normalization implementation to use : CNTK or cuDNN-based. +// * imageLayout is the image layout. Only cudnn is supported at present. // ----------------------------------------------------------------------- template class BatchNormalizationNode : public ComputationNode, public NumInputs<5> { - typedef ComputationNode Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"BatchNormalization"; - } + typedef ComputationNode Base; UsingComputationNodeMembersBoilerplate; + static const std::wstring TypeName() { return L"BatchNormalization"; } public: - BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true), + BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) : + Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true), m_mbCount(0), m_imageLayoutKind(ImageLayoutKind::CHW) { } BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant, - double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) - : Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant), - m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_mbCount(0) + double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) : + Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant), + m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_mbCount(0) { } - BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) - : BatchNormalizationNode(configp->Get(L"deviceId"), L"", configp->Get(L"spatial"), - configp->Get(L"normalizationTimeConstant"), configp->Get(L"blendTimeConstant"), - configp->Get(L"epsilon"), configp->Get(L"useCntkEngine"), - ImageLayoutKindFrom(configp->Get(L"imageLayout"))) + BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) : + BatchNormalizationNode(configp->Get(L"deviceId"), L"", configp->Get(L"spatial"), + configp->Get(L"normalizationTimeConstant"), configp->Get(L"blendTimeConstant"), + configp->Get(L"epsilon"), configp->Get(L"useCntkEngine"), + ImageLayoutKindFrom(configp->Get(L"imageLayout"))) { AttachInputsFromConfig(configp, this->GetExpectedNumInputs()); } @@ -1689,6 +1686,9 @@ public: } } + // Note: This function assumes that inputIndex=0 is called before the others. + // BUGBUG: The node should not make assumptions in which order the inputs' derivates are computed. It currently assumes to start with 0. + // BUGBUG: If the input has no learnables (e.g. using BN instead of corpus mean/var norm), this will not be called for inputIndex=0 at all. void BackpropTo(const size_t inputIndex, const FrameRange& fr) override { if (inputIndex == 0) // derivative with respect to the input. @@ -1702,30 +1702,31 @@ public: m_dScale->Resize(scale); m_dBias->Resize(bias); // Compute all derivatives in one step. Save derivatives with respect to scale and bias in temp matrices. - m_bnEng->Backward(sliceInputValue, sliceOutputGrad, sliceInputGrad, scale, - *m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias); + m_bnEng->Backward(sliceInputValue, sliceOutputGrad, // (in) input from below, gradient from above + sliceInputGrad, // (out) gradient for data input goes here + scale, // (in) scaling is needed in gradient propagation + *m_saveMean, *m_saveInvStdDev, // (in) actual interpolated mean/stddev values from ForwardProp(). Note: unused/uninitialized for blendFactor=1. + // BUGBUG: ^^ For blendFactor=1, saveMean/saveInvStdDev are uninitialized; and the running mean/stddev should be passed instead + *m_dScale, *m_dBias); // (out) gradients for scale and bias } else if (inputIndex == 1) // derivative with respect to the scale { // Derivative with respect to the scale was precomputed during input derivative computation. Matrix& grad = Input(1)->Gradient(); grad.SetValue(grad.GetNumRows(), grad.GetNumCols(), grad.GetDeviceId(), m_dScale->Data()); + // BUGBUG: ^^ This should add the gradient, not overwrite it. } else if (inputIndex == 2) // derivative with respect to the bias { // Derivative with respect to the bias was precomputed during input derivative computation. Matrix& grad = Input(2)->Gradient(); grad.SetValue(grad.GetNumRows(), grad.GetNumCols(), grad.GetDeviceId(), m_dBias->Data()); + // BUGBUG: ^^ Also here, this should add the gradient, not overwrite it. } // No derivatives with respect to running mean and InvStdDev. } - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The BatchNormalizationNode does not require its output value for computing - // the gradients of its input nodes - return false; - } + virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; } void ForwardProp(const FrameRange& fr) override { @@ -1744,15 +1745,22 @@ public: Matrix sliceOutputValue = ValueFor(fr); + // are we training or in inference mode? + // In inference mode, running estimates are used as the sole estimates, while the MB values are not used at all. + if (Input(3)->IsParameterUpdateRequired() ^ Input(4)->IsParameterUpdateRequired()) + InvalidArgument("BatchNormalization: Either both or none of %ls and %ls must be enabled for model update.", + Input(3)->NodeDescription().c_str(), Input(4)->NodeDescription().c_str()); + bool inferenceMode = + !Environment().IsTraining() || // we are actually inferring + !Input(3)->IsParameterUpdateRequired(); // we are training, but this piece of network has been frozen (e.g. as a fixed feature extractor) + + // determine the factors from the time constants double expAvgFactor; double blendFactor; - if (!Environment().IsTraining()) + if (inferenceMode) // in inference mode, only use long-term mean and do not update running estimates { - expAvgFactor = 0; - blendFactor = 1.0; - - m_saveMean->Resize(0, 0); - m_saveInvStdDev->Resize(0, 0); + expAvgFactor = 0; // no new contribution from current minibatch + blendFactor = 1.0; // estimate is taken 100% from the long-term running estimate } else { @@ -1773,13 +1781,28 @@ public: blendFactor = 1.0; else blendFactor = m_blendTimeConst > 0 ? (m_blendTimeConst / (m_blendTimeConst + numSamples)) : 0; + } + // TODO: These Resize() operations belong INSIDE Forward(). + // Specifically, for blendFactor=1, they must come back resized to (0,0). This is how Backward() will know & use running ones instead. + // I am not fixing this now because I don't know how to identify all variants of Forward(), across engines, CPU/GPU etc. + if (blendFactor == 1.0) + { + m_saveMean->Resize(0, 0); + m_saveInvStdDev->Resize(0, 0); + } + else + { m_saveMean->Resize(runMean); m_saveInvStdDev->Resize(runMean); } - m_bnEng->Forward(sliceInputValue, scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev, - sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev); + m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in) + expAvgFactor, blendFactor, + runMean, runInvStdDev, // (in/out) running estimates, updated from the current MB mean/stddev + /*out=*/ sliceOutputValue, // (out) batch-normalized output value + m_epsilon, + *m_saveMean, *m_saveInvStdDev); // (out) actual interpolated mean/stddev values. Note: unused/untouched for blendFactor==1 m_mbCount++; } @@ -1820,25 +1843,25 @@ public: void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override { Base::RequestMatricesBeforeForwardProp(matrixPool); - RequestMatrixFromPool(m_saveMean, matrixPool); - RequestMatrixFromPool(m_saveInvStdDev, matrixPool); - } + RequestMatrixFromPool(m_saveMean, matrixPool); + RequestMatrixFromPool(m_saveInvStdDev, matrixPool); + } void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override { Base::RequestMatricesBeforeBackprop(matrixPool); - RequestMatrixFromPool(m_dScale, matrixPool); - RequestMatrixFromPool(m_dBias, matrixPool); - } + RequestMatrixFromPool(m_dScale, matrixPool); + RequestMatrixFromPool(m_dBias, matrixPool); + } void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override { Base::ReleaseMatricesAfterBackprop(matrixPool); - ReleaseMatrixToPool(m_saveMean, matrixPool); - ReleaseMatrixToPool(m_saveInvStdDev, matrixPool); - ReleaseMatrixToPool(m_dScale, matrixPool); - ReleaseMatrixToPool(m_dBias, matrixPool); - } + ReleaseMatrixToPool(m_saveMean, matrixPool); + ReleaseMatrixToPool(m_saveInvStdDev, matrixPool); + ReleaseMatrixToPool(m_dScale, matrixPool); + ReleaseMatrixToPool(m_dBias, matrixPool); + } void SetNormalizationTimeConstants(double normalizationTimeConstant, double prevNormalizationTimeConstant, double blendTimeConstant, double prevBlendTimeConstant) @@ -1888,13 +1911,12 @@ private: // Minibatch count, used to compute cumulative moving average. size_t m_mbCount; - // Stores pre-computed on forward pass mean values that are used in gradient computation. + // Interpolated actual mean/stddev values. Pre-computed on forward pass, also used in gradient computation. shared_ptr> m_saveMean; - // Stores pre-computed on forward pass InvStdDev values that are used in gradient computation. shared_ptr> m_saveInvStdDev; - // Stores scale derivatives + // Temp buffer for scale and bias derivatives. Only used in BackpropTo(), carrying info from first call to subsequent calls. + // Not used for blendFactor=1. shared_ptr> m_dScale; - // Stores bias derivatives. shared_ptr> m_dBias; std::unique_ptr> m_bnEng; diff --git a/Source/Math/BatchNormalizationEngine.h b/Source/Math/BatchNormalizationEngine.h index 89252f7d0..58e813033 100644 --- a/Source/Math/BatchNormalizationEngine.h +++ b/Source/Math/BatchNormalizationEngine.h @@ -55,6 +55,7 @@ protected: virtual void EnsureCompatible() = 0; + // saveMean/saveInvStdDev return the actual mean/stddev used for normalization, except for blendFactor=1, these are unused and untouched virtual void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) = 0; @@ -70,4 +71,4 @@ protected: #pragma warning(pop) -} } } +}}} diff --git a/Source/Math/CntkBatchNormalization.cuh b/Source/Math/CntkBatchNormalization.cuh index 5ca447537..c0719f3f6 100644 --- a/Source/Math/CntkBatchNormalization.cuh +++ b/Source/Math/CntkBatchNormalization.cuh @@ -162,8 +162,12 @@ void Call(size_t vectorSize, Targs... args) // As a result, each block has 2 * blockDim.x (mean and inverse stddev) values to write at the end. // template -__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, const ElemType* x, double expAvgFactor, ElemType* runMean, ElemType* runInvStdDev, - double epsilon, ElemType* xMean, ElemType* xInvStdDev) +__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, + const ElemType* x, // (in) input data + double expAvgFactor, + ElemType* runMean, ElemType* runInvStdDev, // (in/out) running mean/stddev, gets updated with current minibatch + double epsilon, + ElemType* xMean, ElemType* xInvStdDev) // (out) this minibatch's mean { static_assert(BlockDimX * U == CUB_PTX_WARP_THREADS, "BlockDimX * U must be equal to warp size (32)."); static_assert((BlockDimX * BlockDimY % CUB_PTX_WARP_THREADS) == 0, "Block size must be a multiple of warp size (32)."); @@ -181,9 +185,12 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con return; assert(irowSrcBase + U <= vectorSize); + // --- estimate this minibatch's mean/stddev + + // first estimate mean over all data for this thread int n = 0; - ElemType mean[U]; - ElemType m2[U]; + ElemType mean[U]; // this thread's part of the mean vector (stored as a normalized mean also during accumulation) + ElemType m2[U]; // likewise for stdev #pragma unroll for (int k = 0; k < U; k++) { @@ -206,12 +213,13 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con ElemType d = curVal[k] - mean[k]; // REVIEW alexeyk: we enabled fast CUDA math in CNTK so division below will be approximate, is this a problem? // Using precise math slows down the code by about 40%. - mean[k] += d / n; + mean[k] += d / n; // mean_n = [mean_{n-1} * (n-1) + curVal] / n = mean_{n-1} *n/n - mean_{n-1} / n + curVal / n m2[k] += d * (curVal[k] - mean[k]); } psrc += vectorSize * BlockDimY; } + // now reduce minibatch mean/stddev across threads const int tid = threadIdx.y * BlockDimX + threadIdx.x; const int laneId = tid & 0x1f; // First, reduce within warp using shuffle. @@ -258,6 +266,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con } __syncthreads(); + // --- final reduction and update of running mean/stddev + // Accumulate and write final results. // REVIEW alexeyk: see if atomicAdd can be used instead, do perf comparison. if (threadIdx.y == 0) @@ -282,7 +292,10 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con size_t idxDstBase = (blockIdx.x * BlockDimX + threadIdx.x) * U; // Store mean and running mean. StoreValues(mean, xMean + idxDstBase); - if (expAvgFactor == 1) + // at this point, minibatch mean has been saved into xMean[] + + // accumulate running mean + if (expAvgFactor == 1) // 100% comes from current minibatch, nothing from history StoreValues(mean, runMean + idxDstBase); else { @@ -293,6 +306,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con run[k] = expAvgFactor * mean[k] + (1.0 - expAvgFactor) * run[k]; StoreValues(run, runMean + idxDstBase); } + // at this point, runMean[] has been updated + // Store inv std dev and its running version. #pragma unroll for (int k = 0; k < U; k++) @@ -300,6 +315,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con m2[k] = Operations::RSqrt(static_cast(m2[k] / batchSize + epsilon)); } StoreValues(m2, xInvStdDev + idxDstBase); + // at this point, minibatch stddev has been saved into xInvStdDev[] + if (expAvgFactor == 1) StoreValues(m2, runInvStdDev + idxDstBase); else @@ -311,6 +328,7 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con run[k] = expAvgFactor * m2[k] + (1.0 - expAvgFactor) * run[k]; StoreValues(run, runInvStdDev + idxDstBase); } + // at this point, runInvStdDev[] has been updated } } @@ -466,8 +484,13 @@ template struct ComputeBatchMeanAndInvStdDev { template - static void Call(size_t vectorSize, size_t batchSize, const ElemType* x, double expAvgFactor, ElemType* runMean, ElemType* runInvStdDev, - double epsilon, ElemType* xMean, ElemType* xInvStdDev, cudaStream_t stream) + static void Call(size_t vectorSize, size_t batchSize, + const ElemType* x, // (in) input data + double expAvgFactor, + ElemType* runMean, ElemType* runInvStdDev, // (in/out) running mean/stddev, gets updated with current minibatch + double epsilon, + ElemType* xMean, ElemType* xInvStdDev, // (out) actual interpolated mean/stddev that are used to normalize. Returned since needed in backprop. + cudaStream_t stream) { assert((vectorSize % U) == 0); @@ -593,8 +616,11 @@ template struct NormalizeBatchTraining { template - static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, bool spatial, const ElemType* x, ElemType* y, - const ElemType* bnScale, const ElemType* bnBias, const ElemType* batchMean, const ElemType* batchInvStdDev, cudaStream_t stream) + static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, bool spatial, + const ElemType* x, ElemType* y, // (in, out) data to normalize -> normalized data + const ElemType* bnScale, const ElemType* bnBias, // (in) scale/bias to denormalize with + const ElemType* batchMean, const ElemType* batchInvStdDev, // (in) actual mean/stddev to normalize with + cudaStream_t stream) { assert((vectorSize % U) == 0); diff --git a/Source/Math/GPUMatrix.cu b/Source/Math/GPUMatrix.cu index 1330cc0e3..a6a78014b 100644 --- a/Source/Math/GPUMatrix.cu +++ b/Source/Math/GPUMatrix.cu @@ -3107,6 +3107,7 @@ void GPUMatrix::AveragePoolingBackward(const GPUMatrix& mpRowCol, Data(), (int)GetNumRows(), grad.Data(), (int)grad.GetNumRows()); } +// returns saveMean/saveInvStdDev which are the actual values used to perform the normalization, except for blendFactor 1, in which case they are unused and untouched template void GPUMatrix::BatchNormalizationForward(const GPUMatrix& scale, const GPUMatrix& bias, double expAvgFactor, double blendFactor, GPUMatrix& runMean, GPUMatrix& runInvStdDev, GPUMatrix& out, double epsilon, @@ -3122,6 +3123,7 @@ void GPUMatrix::BatchNormalizationForward(const GPUMatrix& s assert(0 < vectorSize && vectorSize <= std::numeric_limits::max()); assert(0 < batchSize && batchSize <= std::numeric_limits::max()); + // --- compute data mean/stddev (into saveMean/saveInvStdDev) and update running mean/stddev SyncGuard syncGuard; // If expAvgFactor == 0 && blendFactor == 1 then we don't need to compute current minibatch statistics. if (expAvgFactor > 0 || blendFactor < 1) @@ -3139,12 +3141,15 @@ void GPUMatrix::BatchNormalizationForward(const GPUMatrix& s saveMean.Data(), saveInvStdDev.Data(), GetStream()); } } + + // --- apply MAP estimates of mean/stddev (interpolation of data and running mean/stddev) to data // When: // blendFactor == 1 - use running mean/var instead of the current minibatch mean/var. // 0 < blendFactor < 1 - blend running mean/var with mean/var of the current minibatch: saveMean = (1 - blendFactor) * saveMean + blendFactor * runMean - // blendFactor == 0 - use mean/var of the current minibatch. + // blendFactor == 0 - use mean/var of the current minibatch. Note: saveMean/saveInvStdDev are NOT updated. if (blendFactor < 1) { + // non-zero blendFactor: interpolate minibatch mean/stddev in-place with running mean/stddev if (blendFactor > 0) { // REVIEW alexeyk: can be rolled into NormalizeBatchTraining to save bandwidth. @@ -3154,18 +3159,26 @@ void GPUMatrix::BatchNormalizationForward(const GPUMatrix& s Scale((ElemType)(1 - blendFactor), saveInvStdDev); ScaleAndAdd((ElemType)blendFactor, runInvStdDev, saveInvStdDev); } - Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, - spatial, Data(), out.Data(), scale.Data(), bias.Data(), - saveMean.Data(), saveInvStdDev.Data(), GetStream()); + // normalize + Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial, + Data(), out.Data(), // (in, out) data to be normalized -> normalized data + scale.Data(), bias.Data(), // (in) scale/bias to denormalize with + /*(in)*/saveMean.Data(), saveInvStdDev.Data(), // (in) actual mean/stddev to normalize with + GetStream()); } - else + else // blendFactor == 1: use running mean/stddev only { - Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, - spatial, Data(), out.Data(), scale.Data(), bias.Data(), + assert(saveMean.IsEmpty() && saveInvStdDev.IsEmpty()); // TODO: We should rather Resize() them in here. + // TODO: require saveMean/saveInvStdDev to be passed in as empty matrices, to clarify/enforce the semantics + Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial, + Data(), out.Data(), + scale.Data(), bias.Data(), runMean.Data(), runInvStdDev.Data(), GetStream()); } } +// saveMean/saveInvStdDev are the interpolated mean/stddev as used in ForwardProp(). +// BUGBUG (in call site): For blendFactor=1, they are uninitialized. Caller must pass running mean/stddev instead in that case. template void GPUMatrix::BatchNormalizationBackward(const GPUMatrix& in, GPUMatrix& grad, const GPUMatrix& scale, const GPUMatrix& saveMean, const GPUMatrix& saveInvStdDev,