heavily commented batch-normalization code, including several bugs;

new interface IParameterNode for identifying LearnableParameters;
first implementation of CloneFunctionConfigLambda (except for returning the result)
This commit is contained in:
Frank Seide 2016-07-21 17:37:44 -07:00
Родитель 614762c03b
Коммит 3d70ff34e0
8 изменённых файлов: 232 добавлений и 80 удалений

Просмотреть файл

@ -106,13 +106,13 @@ void ComputationNetwork::FormRecurrentLoops(const ComputationNodeBasePtr& rootNo
assert(node->m_numNonDelayedParentsInLoop == 0); // (in PurgeStateForFormingRecurrentLoops())
}
for (let& node : nestedNodes)
{
{
for (auto& input : node->GetInputs())
{
{
if (input->m_loopId == node->m_loopId && GetRecurrenceSteppingDirection(node) == 0/*not a Delay node*/)
input->m_numNonDelayedParentsInLoop++; // cound #parents of 'input' that are not delay nodes
}
}
}
// re-traverse the graph for all nestedNodes, starting with the first
// Then update m_nestedNodes with the re-traversed order.

Просмотреть файл

@ -390,7 +390,7 @@ ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<ComputationNetworkWithEd
// - Input() nodes not listed as `inputNodes` are always shared
// - the source network may be a different network, e.g. loaded with BS.Network.Load()
// - a deep copy can be read-only (parameters="constant")
// - multiple uses of the lambda will share read-only parameters
// - Note: multiple uses of the lambda will not share read-only parameters. This is trickier to implement that one might expect.
// - example use cases:
// - adaptation (KL): a frozen read-only copy of the starting model is used as a KL-regularizer
// - adaptation (DLR): an injected input transform is trained while the network is fixed
@ -481,6 +481,49 @@ public:
else if (parametersOption == L"constant") parameterTreatment = ParameterTreatment::constant;
else if (parametersOption == L"shared") parameterTreatment = ParameterTreatment::shared;
else InvalidArgument("CloneFunction: 'parameters' option must be 'learnable', 'constant', or 'shared'.");
// determine which nodes must be cloned
// - intersection of:
// - all indirect inputs of the specified outputs
// - all dependents of leaves
// - where leaves are:
// - specified inputs
// - unless parameters="shared": all parameters
// determine all indirect inputs of the specified outputs
vector<ComputationNodeBasePtr> roots;
for (let& outputNodeKV : outputNodes)
roots.push_back(outputNodeKV.second);
let allInputs = ComputationNodeBase::EnumerateNodes(roots);
// determine all leaves and their dependents
dependentSet = set<ComputationNodeBasePtr>(inputNodes.begin(), inputNodes.end()); // start with the specified inputs
for (let& node : allInputs)
{
// add parameters that are to be cloned to dependent set
if (parameterTreatment != ParameterTreatment::shared && node->Is<IParameterNode>())
dependentSet.insert(node);
// if at least one input is in the dependent set then this node is, too
else
for (let& input : node->GetInputs())
if (dependentSet.find(input) != dependentSet.end())
dependentSet.insert(node);
}
#if 1
for (let& node : dependentSet)
fprintf(stderr, "CloneFunction: cloning %ls\n", node->NodeDescription().c_str());
#endif
// ensure none of the specified inputs reference back into the cloned set
// The function we extract must be separable.
for (let& input : inputNodes)
for (let& node : ComputationNodeBase::EnumerateNodes(vector<ComputationNodeBasePtr>{input})) // check all indirect inputs of each specified input
{
let iter = dependentSet.find(input);
if (iter != dependentSet.end() && *iter != input)
InvalidArgument("CloneFunction: specified function input %ls recursively depends on %ls inside the function.", input->NodeDescription().c_str(), node->NodeDescription().c_str());
}
}
private:
@ -509,9 +552,49 @@ private:
// This will clone all nodes that the outputNodes depend on, and rewire inputs matching inputNodes to inputArgs.
ConfigValuePtr DoClone(const vector<ConfigValuePtr>& inputValues, const std::wstring& exprName)
{
// resolve the input arguments
vector<ComputationNodeBasePtr> inputs;
for (let& inputValue : inputValues)
inputs.push_back(inputValue.ResolveValue()); // .AsPtr<ComputationNodeBase>());
inputs.push_back(inputValue.ResolveValue());
assert(inputValues.size() == inputNodes.size()); // (this should have been checked by BrainScript)
// clone everything in the dependent set
// - specified inputs get mapped to actual parameters
// - all others get duplicated
// Note that at this point, the "shared" option has already been considered,
// and is reflected in whether parameters are included or not in 'dependentSet'.
map<ComputationNodeBasePtr, ComputationNodeBasePtr> clonedNodes;
for (size_t i = 0; i < inputNodes.size(); i++)
clonedNodes[inputNodes[i]] = inputs[i];
for (let& node : dependentSet)
{
// if already there then it's an input that we just mapped above
if (clonedNodes.find(node) != clonedNodes.end())
continue;
// clone
ComputationNodeBasePtr newNode;
let newName = exprName + L"." + node->GetName();
newNode = node->Duplicate(newName, CopyNodeFlags::copyNodeAll);
// make it read-only if desired
if (parameterTreatment == ParameterTreatment::constant)
newNode->SetLearningRateMultiplier(0);
// and that's our cloned node
clonedNodes[node] = newNode;
}
// all cloned nodes' inputs must be redirected if they reference a node that has been cloned as well
for (let& clonedNodesKV : clonedNodes)
{
let& node = clonedNodesKV.second;
let& inputs = node->GetInputs();
for (size_t i = 0; i < inputs.size(); i++)
{
let iter = clonedNodes.find(inputs[i]);
if (iter != clonedNodes.end()) // input is also a cloned node
node->SetInput(i, iter->second);
}
}
return ConfigValuePtr();
}
@ -521,7 +604,7 @@ private:
map<wstring, ComputationNodeBasePtr> outputNodes;
ParameterTreatment parameterTreatment;
// other
map<ComputationNodeBasePtr, ComputationNodeBasePtr> clonedReadOnlyParameters; // if we clone 'constant' multiple times, we can share readOnly parameters
set<ComputationNodeBasePtr> dependentSet; // set of nodes that outputNodes depend on
};
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<CloneFunctionConfigLambda> registerCloneFunctionConfigLambda(L"CloneFunctionConfigLambda");

Просмотреть файл

@ -184,7 +184,7 @@ protected: // TODO: should be fully encapsulated here
bool m_needsGradient; // true if this node or any children need a gradient to be computed (for own consumption or propagation to somewhere in the child tree)
bool m_valueSharable; // a flag is needed for memory share.
// If it is false (e.g., learnableParameters/InputValue and those nodes are solely induced by learnableParameters),
// If it is false (e.g., LearnableParameters/InputValue and those nodes are solely induced by LearnableParameters),
// it will never be released to memory pool
private:
bool m_isPartOfLoop; // true if this loop is part of a recurrent loop
@ -1891,6 +1891,13 @@ public:
struct IRecurrentNode { virtual int GetRecurrenceSteppingDirection() const = 0; };
// =======================================================================
// IParameterNode -- interface implemented by ComputationNodes that are parameters
// Note: There is possibly code that identifies parameters by the type name instead. Should be unified.
// =======================================================================
struct IParameterNode { virtual ~IParameterNode() { } };
// =======================================================================
// PreComputedNodeBase -- interface implemented by ComputationNodes that precompute
// TODO: We can use this interface in more places.

Просмотреть файл

@ -21,7 +21,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
template <class ElemType>
class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>
class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>, public IParameterNode
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"LearnableParameter"; }

Просмотреть файл

@ -1553,51 +1553,48 @@ template class DropoutNode<double>;
// where gamma and beta are trainable parameters(represented as LearnableParameter).
//
// * input is the input of the batch normalization node
// * scale is a LearnableParameter that stores scale vector(gamma term in the equation above).
// * bias is a LearnableParameter that stores bias vector(beta term). scale and bias must have the same dimensions which must be equal
// * scale is a LearnableParameter that stores scale vector (gamma term in the equation above).
// * bias is a LearnableParameter that stores bias vector (beta term). scale and bias must have the same dimensions which must be equal
// to the input dimensions in case of spatial = false or number of output convolution feature maps in case of spatial = true.
// * runMean is the running mean which is used during evaluation phase and might be used during training as well.
// It is represented as a LearnableParameter with the same dimensions as scale and bias.
// * runInvStdDev is the running inverse square root of variance(so InvStdDev = 1 / sqrt(var + epsilon)).
// It is represented as a LearnableParameter with the same dimensions as scale and bias.
// * spatial is a flag that specifies whether to compute mean / var for each feature in a mininbatch independently or, in case of convolutional layers, per feature map.
// TODO: This must be configured in a generic fashion where tensor axes are chosen along which parameters are tied.
// * normalizationTimeConstant is the time constant which is used to compute running average of mean and variance.
// Value 0 (default) means there will be no exponential smoothing and running mean / variance will always have values computed for the last seen mininbatch.
// Value 1#INF (infinity)means running values are "frozen" (i.e.will not be updated).
// Value 0 (default) means there will be no exponential smoothing and running mean/variance will always have values computed for the last seen mininbatch.
// Value 1#INF (infinity) means running values are "frozen" (i.e.will not be updated).
// * blendTimeConstant is the time constant which allows to specify how much of running mean / var should be "blended" into mean / var of the current minibatch.
// Value 0 (default) means no blending will happen and only the current minibatch statistics will be used.
// Value 1#INF (infinity)means only running mean / var will be used(this is used, for example, in evaluation phase).
// Value 1#INF (infinity) means only running mean / var will be used(this is used, for example, in evaluation phase).
// * epsilon is a conditioner constant used in computing InvStdDev
// * useCntkEngine is a boolean flag that specifies which batch normalization implementation to use : CNTK or cuDNN - based.
// * imageLayout is the image layout.Only cudnn is supported.
// * useCntkEngine is a boolean flag that specifies which batch normalization implementation to use : CNTK or cuDNN-based.
// * imageLayout is the image layout. Only cudnn is supported at present.
// -----------------------------------------------------------------------
template <class ElemType>
class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInputs<5>
{
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"BatchNormalization";
}
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"BatchNormalization"; }
public:
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
m_mbCount(0), m_imageLayoutKind(ImageLayoutKind::CHW)
{
}
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind)
: Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_mbCount(0)
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_mbCount(0)
{
}
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp)
: BatchNormalizationNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"spatial"),
configp->Get(L"normalizationTimeConstant"), configp->Get(L"blendTimeConstant"),
configp->Get(L"epsilon"), configp->Get(L"useCntkEngine"),
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
BatchNormalizationNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"spatial"),
configp->Get(L"normalizationTimeConstant"), configp->Get(L"blendTimeConstant"),
configp->Get(L"epsilon"), configp->Get(L"useCntkEngine"),
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
@ -1689,6 +1686,9 @@ public:
}
}
// Note: This function assumes that inputIndex=0 is called before the others.
// BUGBUG: The node should not make assumptions in which order the inputs' derivates are computed. It currently assumes to start with 0.
// BUGBUG: If the input has no learnables (e.g. using BN instead of corpus mean/var norm), this will not be called for inputIndex=0 at all.
void BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
if (inputIndex == 0) // derivative with respect to the input.
@ -1702,30 +1702,31 @@ public:
m_dScale->Resize(scale);
m_dBias->Resize(bias);
// Compute all derivatives in one step. Save derivatives with respect to scale and bias in temp matrices.
m_bnEng->Backward(sliceInputValue, sliceOutputGrad, sliceInputGrad, scale,
*m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias);
m_bnEng->Backward(sliceInputValue, sliceOutputGrad, // (in) input from below, gradient from above
sliceInputGrad, // (out) gradient for data input goes here
scale, // (in) scaling is needed in gradient propagation
*m_saveMean, *m_saveInvStdDev, // (in) actual interpolated mean/stddev values from ForwardProp(). Note: unused/uninitialized for blendFactor=1.
// BUGBUG: ^^ For blendFactor=1, saveMean/saveInvStdDev are uninitialized; and the running mean/stddev should be passed instead
*m_dScale, *m_dBias); // (out) gradients for scale and bias
}
else if (inputIndex == 1) // derivative with respect to the scale
{
// Derivative with respect to the scale was precomputed during input derivative computation.
Matrix<ElemType>& grad = Input(1)->Gradient();
grad.SetValue(grad.GetNumRows(), grad.GetNumCols(), grad.GetDeviceId(), m_dScale->Data());
// BUGBUG: ^^ This should add the gradient, not overwrite it.
}
else if (inputIndex == 2) // derivative with respect to the bias
{
// Derivative with respect to the bias was precomputed during input derivative computation.
Matrix<ElemType>& grad = Input(2)->Gradient();
grad.SetValue(grad.GetNumRows(), grad.GetNumCols(), grad.GetDeviceId(), m_dBias->Data());
// BUGBUG: ^^ Also here, this should add the gradient, not overwrite it.
}
// No derivatives with respect to running mean and InvStdDev.
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
// The BatchNormalizationNode does not require its output value for computing
// the gradients of its input nodes
return false;
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
void ForwardProp(const FrameRange& fr) override
{
@ -1744,15 +1745,22 @@ public:
Matrix<ElemType> sliceOutputValue = ValueFor(fr);
// are we training or in inference mode?
// In inference mode, running estimates are used as the sole estimates, while the MB values are not used at all.
if (Input(3)->IsParameterUpdateRequired() ^ Input(4)->IsParameterUpdateRequired())
InvalidArgument("BatchNormalization: Either both or none of %ls and %ls must be enabled for model update.",
Input(3)->NodeDescription().c_str(), Input(4)->NodeDescription().c_str());
bool inferenceMode =
!Environment().IsTraining() || // we are actually inferring
!Input(3)->IsParameterUpdateRequired(); // we are training, but this piece of network has been frozen (e.g. as a fixed feature extractor)
// determine the factors from the time constants
double expAvgFactor;
double blendFactor;
if (!Environment().IsTraining())
if (inferenceMode) // in inference mode, only use long-term mean and do not update running estimates
{
expAvgFactor = 0;
blendFactor = 1.0;
m_saveMean->Resize(0, 0);
m_saveInvStdDev->Resize(0, 0);
expAvgFactor = 0; // no new contribution from current minibatch
blendFactor = 1.0; // estimate is taken 100% from the long-term running estimate
}
else
{
@ -1773,13 +1781,28 @@ public:
blendFactor = 1.0;
else
blendFactor = m_blendTimeConst > 0 ? (m_blendTimeConst / (m_blendTimeConst + numSamples)) : 0;
}
// TODO: These Resize() operations belong INSIDE Forward().
// Specifically, for blendFactor=1, they must come back resized to (0,0). This is how Backward() will know & use running ones instead.
// I am not fixing this now because I don't know how to identify all variants of Forward(), across engines, CPU/GPU etc.
if (blendFactor == 1.0)
{
m_saveMean->Resize(0, 0);
m_saveInvStdDev->Resize(0, 0);
}
else
{
m_saveMean->Resize(runMean);
m_saveInvStdDev->Resize(runMean);
}
m_bnEng->Forward(sliceInputValue, scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev,
sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev);
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
expAvgFactor, blendFactor,
runMean, runInvStdDev, // (in/out) running estimates, updated from the current MB mean/stddev
/*out=*/ sliceOutputValue, // (out) batch-normalized output value
m_epsilon,
*m_saveMean, *m_saveInvStdDev); // (out) actual interpolated mean/stddev values. Note: unused/untouched for blendFactor==1
m_mbCount++;
}
@ -1820,25 +1843,25 @@ public:
void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_saveMean, matrixPool);
RequestMatrixFromPool(m_saveInvStdDev, matrixPool);
}
RequestMatrixFromPool(m_saveMean, matrixPool);
RequestMatrixFromPool(m_saveInvStdDev, matrixPool);
}
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
Base::RequestMatricesBeforeBackprop(matrixPool);
RequestMatrixFromPool(m_dScale, matrixPool);
RequestMatrixFromPool(m_dBias, matrixPool);
}
RequestMatrixFromPool(m_dScale, matrixPool);
RequestMatrixFromPool(m_dBias, matrixPool);
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_saveMean, matrixPool);
ReleaseMatrixToPool(m_saveInvStdDev, matrixPool);
ReleaseMatrixToPool(m_dScale, matrixPool);
ReleaseMatrixToPool(m_dBias, matrixPool);
}
ReleaseMatrixToPool(m_saveMean, matrixPool);
ReleaseMatrixToPool(m_saveInvStdDev, matrixPool);
ReleaseMatrixToPool(m_dScale, matrixPool);
ReleaseMatrixToPool(m_dBias, matrixPool);
}
void SetNormalizationTimeConstants(double normalizationTimeConstant, double prevNormalizationTimeConstant,
double blendTimeConstant, double prevBlendTimeConstant)
@ -1888,13 +1911,12 @@ private:
// Minibatch count, used to compute cumulative moving average.
size_t m_mbCount;
// Stores pre-computed on forward pass mean values that are used in gradient computation.
// Interpolated actual mean/stddev values. Pre-computed on forward pass, also used in gradient computation.
shared_ptr<Matrix<ElemType>> m_saveMean;
// Stores pre-computed on forward pass InvStdDev values that are used in gradient computation.
shared_ptr<Matrix<ElemType>> m_saveInvStdDev;
// Stores scale derivatives
// Temp buffer for scale and bias derivatives. Only used in BackpropTo(), carrying info from first call to subsequent calls.
// Not used for blendFactor=1.
shared_ptr<Matrix<ElemType>> m_dScale;
// Stores bias derivatives.
shared_ptr<Matrix<ElemType>> m_dBias;
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;

Просмотреть файл

@ -55,6 +55,7 @@ protected:
virtual void EnsureCompatible() = 0;
// saveMean/saveInvStdDev return the actual mean/stddev used for normalization, except for blendFactor=1, these are unused and untouched
virtual void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev,
Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) = 0;
@ -70,4 +71,4 @@ protected:
#pragma warning(pop)
} } }
}}}

Просмотреть файл

@ -162,8 +162,12 @@ void Call(size_t vectorSize, Targs... args)
// As a result, each block has 2 * blockDim.x (mean and inverse stddev) values to write at the end.
//
template <int BlockDimX, int BlockDimY, int U, typename ElemType>
__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, const ElemType* x, double expAvgFactor, ElemType* runMean, ElemType* runInvStdDev,
double epsilon, ElemType* xMean, ElemType* xInvStdDev)
__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize,
const ElemType* x, // (in) input data
double expAvgFactor,
ElemType* runMean, ElemType* runInvStdDev, // (in/out) running mean/stddev, gets updated with current minibatch
double epsilon,
ElemType* xMean, ElemType* xInvStdDev) // (out) this minibatch's mean
{
static_assert(BlockDimX * U == CUB_PTX_WARP_THREADS, "BlockDimX * U must be equal to warp size (32).");
static_assert((BlockDimX * BlockDimY % CUB_PTX_WARP_THREADS) == 0, "Block size must be a multiple of warp size (32).");
@ -181,9 +185,12 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
return;
assert(irowSrcBase + U <= vectorSize);
// --- estimate this minibatch's mean/stddev
// first estimate mean over all data for this thread
int n = 0;
ElemType mean[U];
ElemType m2[U];
ElemType mean[U]; // this thread's part of the mean vector (stored as a normalized mean also during accumulation)
ElemType m2[U]; // likewise for stdev
#pragma unroll
for (int k = 0; k < U; k++)
{
@ -206,12 +213,13 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
ElemType d = curVal[k] - mean[k];
// REVIEW alexeyk: we enabled fast CUDA math in CNTK so division below will be approximate, is this a problem?
// Using precise math slows down the code by about 40%.
mean[k] += d / n;
mean[k] += d / n; // mean_n = [mean_{n-1} * (n-1) + curVal] / n = mean_{n-1} *n/n - mean_{n-1} / n + curVal / n
m2[k] += d * (curVal[k] - mean[k]);
}
psrc += vectorSize * BlockDimY;
}
// now reduce minibatch mean/stddev across threads
const int tid = threadIdx.y * BlockDimX + threadIdx.x;
const int laneId = tid & 0x1f;
// First, reduce within warp using shuffle.
@ -258,6 +266,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
}
__syncthreads();
// --- final reduction and update of running mean/stddev
// Accumulate and write final results.
// REVIEW alexeyk: see if atomicAdd can be used instead, do perf comparison.
if (threadIdx.y == 0)
@ -282,7 +292,10 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
size_t idxDstBase = (blockIdx.x * BlockDimX + threadIdx.x) * U;
// Store mean and running mean.
StoreValues<U>(mean, xMean + idxDstBase);
if (expAvgFactor == 1)
// at this point, minibatch mean has been saved into xMean[]
// accumulate running mean
if (expAvgFactor == 1) // 100% comes from current minibatch, nothing from history
StoreValues<U>(mean, runMean + idxDstBase);
else
{
@ -293,6 +306,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
run[k] = expAvgFactor * mean[k] + (1.0 - expAvgFactor) * run[k];
StoreValues<U>(run, runMean + idxDstBase);
}
// at this point, runMean[] has been updated
// Store inv std dev and its running version.
#pragma unroll
for (int k = 0; k < U; k++)
@ -300,6 +315,8 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
m2[k] = Operations::RSqrt(static_cast<ElemType>(m2[k] / batchSize + epsilon));
}
StoreValues<U>(m2, xInvStdDev + idxDstBase);
// at this point, minibatch stddev has been saved into xInvStdDev[]
if (expAvgFactor == 1)
StoreValues<U>(m2, runInvStdDev + idxDstBase);
else
@ -311,6 +328,7 @@ __global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, con
run[k] = expAvgFactor * m2[k] + (1.0 - expAvgFactor) * run[k];
StoreValues<U>(run, runInvStdDev + idxDstBase);
}
// at this point, runInvStdDev[] has been updated
}
}
@ -466,8 +484,13 @@ template <int U>
struct ComputeBatchMeanAndInvStdDev
{
template <typename ElemType>
static void Call(size_t vectorSize, size_t batchSize, const ElemType* x, double expAvgFactor, ElemType* runMean, ElemType* runInvStdDev,
double epsilon, ElemType* xMean, ElemType* xInvStdDev, cudaStream_t stream)
static void Call(size_t vectorSize, size_t batchSize,
const ElemType* x, // (in) input data
double expAvgFactor,
ElemType* runMean, ElemType* runInvStdDev, // (in/out) running mean/stddev, gets updated with current minibatch
double epsilon,
ElemType* xMean, ElemType* xInvStdDev, // (out) actual interpolated mean/stddev that are used to normalize. Returned since needed in backprop.
cudaStream_t stream)
{
assert((vectorSize % U) == 0);
@ -593,8 +616,11 @@ template <int U>
struct NormalizeBatchTraining
{
template <typename ElemType>
static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, bool spatial, const ElemType* x, ElemType* y,
const ElemType* bnScale, const ElemType* bnBias, const ElemType* batchMean, const ElemType* batchInvStdDev, cudaStream_t stream)
static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, bool spatial,
const ElemType* x, ElemType* y, // (in, out) data to normalize -> normalized data
const ElemType* bnScale, const ElemType* bnBias, // (in) scale/bias to denormalize with
const ElemType* batchMean, const ElemType* batchInvStdDev, // (in) actual mean/stddev to normalize with
cudaStream_t stream)
{
assert((vectorSize % U) == 0);

Просмотреть файл

@ -3107,6 +3107,7 @@ void GPUMatrix<ElemType>::AveragePoolingBackward(const GPUMatrix<int>& mpRowCol,
Data(), (int)GetNumRows(), grad.Data(), (int)grad.GetNumRows());
}
// returns saveMean/saveInvStdDev which are the actual values used to perform the normalization, except for blendFactor 1, in which case they are unused and untouched
template <class ElemType>
void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& bias, double expAvgFactor, double blendFactor,
GPUMatrix<ElemType>& runMean, GPUMatrix<ElemType>& runInvStdDev, GPUMatrix<ElemType>& out, double epsilon,
@ -3122,6 +3123,7 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
assert(0 < vectorSize && vectorSize <= std::numeric_limits<int>::max());
assert(0 < batchSize && batchSize <= std::numeric_limits<int>::max());
// --- compute data mean/stddev (into saveMean/saveInvStdDev) and update running mean/stddev
SyncGuard syncGuard;
// If expAvgFactor == 0 && blendFactor == 1 then we don't need to compute current minibatch statistics.
if (expAvgFactor > 0 || blendFactor < 1)
@ -3139,12 +3141,15 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
saveMean.Data(), saveInvStdDev.Data(), GetStream());
}
}
// --- apply MAP estimates of mean/stddev (interpolation of data and running mean/stddev) to data
// When:
// blendFactor == 1 - use running mean/var instead of the current minibatch mean/var.
// 0 < blendFactor < 1 - blend running mean/var with mean/var of the current minibatch: saveMean = (1 - blendFactor) * saveMean + blendFactor * runMean
// blendFactor == 0 - use mean/var of the current minibatch.
// blendFactor == 0 - use mean/var of the current minibatch. Note: saveMean/saveInvStdDev are NOT updated.
if (blendFactor < 1)
{
// non-zero blendFactor: interpolate minibatch mean/stddev in-place with running mean/stddev
if (blendFactor > 0)
{
// REVIEW alexeyk: can be rolled into NormalizeBatchTraining to save bandwidth.
@ -3154,18 +3159,26 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
Scale((ElemType)(1 - blendFactor), saveInvStdDev);
ScaleAndAdd((ElemType)blendFactor, runInvStdDev, saveInvStdDev);
}
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize,
spatial, Data(), out.Data(), scale.Data(), bias.Data(),
saveMean.Data(), saveInvStdDev.Data(), GetStream());
// normalize
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial,
Data(), out.Data(), // (in, out) data to be normalized -> normalized data
scale.Data(), bias.Data(), // (in) scale/bias to denormalize with
/*(in)*/saveMean.Data(), saveInvStdDev.Data(), // (in) actual mean/stddev to normalize with
GetStream());
}
else
else // blendFactor == 1: use running mean/stddev only
{
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize,
spatial, Data(), out.Data(), scale.Data(), bias.Data(),
assert(saveMean.IsEmpty() && saveInvStdDev.IsEmpty()); // TODO: We should rather Resize() them in here.
// TODO: require saveMean/saveInvStdDev to be passed in as empty matrices, to clarify/enforce the semantics
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial,
Data(), out.Data(),
scale.Data(), bias.Data(),
runMean.Data(), runInvStdDev.Data(), GetStream());
}
}
// saveMean/saveInvStdDev are the interpolated mean/stddev as used in ForwardProp().
// BUGBUG (in call site): For blendFactor=1, they are uninitialized. Caller must pass running mean/stddev instead in that case.
template <class ElemType>
void GPUMatrix<ElemType>::BatchNormalizationBackward(const GPUMatrix<ElemType>& in, GPUMatrix<ElemType>& grad, const GPUMatrix<ElemType>& scale,
const GPUMatrix<ElemType>& saveMean, const GPUMatrix<ElemType>& saveInvStdDev,