Packing matries in V1 Aggregation

Packing aggregated matrixes into continuous buffer so less MPI operations are used.

Usage:
Introduce a configurable threshold size (units = KB) and only pack
gradients into a continous buffer whose total size is below the threshold.
This is to avoid possible memory overflow.

The default threshold size is 32KB. To change this value, specify
"packThresholdsizeInKB=value" (e.g. 2048 for 2MB) in train/eval config
file.

Notes:
acquire continous buffer only if NCCL is not supported

Fallback to normal execution if not enough extra continous memory available

change reduce all to unblocking reduce all

SimpleDistAggregate: Broadcast the aggregated headers to all nodes

Configurable threshold size to pack gradients in a buffer

simpleAggregator: pack the gradient into continous buffer if its size is less than PackThresholdSize

Add all constant definitions to a header file in Common
This commit is contained in:
Junjie Qian 2016-12-13 13:58:26 -08:00
Родитель 9375555343
Коммит a6e13ce969
8 изменённых файлов: 204 добавлений и 72 удалений

Просмотреть файл

@ -0,0 +1,13 @@
// Constants.h -- the constants used by CNTK
//
#pragma once
#ifndef _CONSTANTS_H_
#define _CONSTANTS_H_
// Constants used in aggregation
const size_t DEFAULT_PACK_THRESHOLD_SIZE_IN_KB = 32;
const size_t DEFAULT_PACK_THRESHOLD_SIZE_IN_BYTES = DEFAULT_PACK_THRESHOLD_SIZE_IN_KB * 1024;
#endif

Просмотреть файл

@ -6,6 +6,7 @@
#include "Basics.h"
#include "ComputationNode.h"
#include "Constants.h"
#include "Matrix.h"
#include "TensorView.h"
#include <unordered_set>
@ -1448,7 +1449,8 @@ void AggregateAccumulatorValuesAndUpdateEvaluation(
shared_ptr<ComputationNetwork> net,
set<shared_ptr<ComputationNodeBase>> evalNodesWhichAccumulateResult,
shared_ptr<DistGradHeader> gradHeader,
shared_ptr<MPIWrapper> mpi);
shared_ptr<MPIWrapper> mpi,
size_t packThresholdSizeInBytes = (size_t)DEFAULT_PACK_THRESHOLD_SIZE_IN_BYTES);
// -----------------------------------------------------------------------
// EpochAccumulatorNode calculates mean values of all samples used in forward pass.
@ -1499,7 +1501,8 @@ protected:
shared_ptr<ComputationNetwork> net,
set<shared_ptr<ComputationNodeBase>> evalNodesWhichAccumulateResult,
shared_ptr<DistGradHeader> gradHeader,
shared_ptr<MPIWrapper> mpi);
shared_ptr<MPIWrapper> mpi,
size_t packThresholdSize);
void Reset();

Просмотреть файл

@ -93,6 +93,23 @@ void NcclComm::AllReduceImpl(void* buffer, size_t count, DataType dtype)
RuntimeError("NcclComm ncclAllReduce failed: %s", ncclGetErrorString(res));
}
void NcclComm::BroadcastImpl(void* buffer, size_t count, MPI_Datatype dtype, int root)
{
ncclResult_t res;
if (dtype == MPI_CHAR)
{
res = ncclBcast(buffer, count, ncclChar, root, m_ncclComm, m_stream);
}
else
{
RuntimeError("NcclComm Broadcast supports Char type only");
}
if (res != ncclSuccess)
{
RuntimeError("NcclComm ncclBcast failed: %s", ncclGetErrorString(res));
}
}
void NcclComm::Sync()
{
cudaStreamSynchronize(m_stream) || "NcclComm: cudaStreamSynchronize failed";

Просмотреть файл

@ -23,6 +23,7 @@ class NcclComm
private:
enum class DataType : int {FLOAT, DOUBLE};
void AllReduceImpl(void* buffer, size_t count, DataType dtype);
void BroadcastImpl(void* buffer, size_t count, MPI_Datatype dtype, int root);
cudaStream_t m_stream;
ncclComm_t m_ncclComm;
#endif
@ -53,6 +54,20 @@ public:
RuntimeError("NcclComm: CNTK was built without NCCL support.");
#endif
}
#pragma warning( push )
#pragma warning ( disable : 4100 ) // Disable warning 4100 in Broadcast function
void Broadcast(void* buffer, size_t count, MPI_Datatype dtype, int root)
{
#ifdef USE_NCCL
BroadcastImpl(buffer, count, dtype, root);
#else
RuntimeError("NcclComm: CNTK was built without NCCL support.");
#endif
}
};
#pragma warning( pop )
}}}

Просмотреть файл

@ -26,7 +26,8 @@ void AggregateAccumulatorValuesAndUpdateEvaluation(
std::shared_ptr<ComputationNetwork> net,
std::set<std::shared_ptr<ComputationNodeBase>> evalNodesWhichAccumulateResult,
std::shared_ptr<DistGradHeader> gradHeader,
std::shared_ptr<MPIWrapper> mpi)
std::shared_ptr<MPIWrapper> mpi,
size_t packThresholdSizeInBytes)
{
// Accumulator stores mean value and number of samples. Aggregation performs simple summation of values,
// so we transfer sum instead of mean, and calculate mean after aggregation is finished.
@ -58,7 +59,8 @@ void AggregateAccumulatorValuesAndUpdateEvaluation(
mpi,
false /*useAsyncAggregation*/,
net->GetDeviceId(),
0 /*syncStatsTrace*/);
0 /*syncStatsTrace*/,
packThresholdSizeInBytes);
// Prepare header.
const size_t c_evalNodes = 1;
@ -127,10 +129,11 @@ void AggregateAccumulatorValuesAndUpdateEpochEvaluation(
std::vector<EpochCriterion>& epochEvalErrors,
const std::vector<ComputationNodeBasePtr>& evaluationNodes,
CriterionAccumulator<ElemType> localEpochEvalErrors,
std::function<bool(ComputationNodeBasePtr)> containsAccumulatedResult)
std::function<bool(ComputationNodeBasePtr)> containsAccumulatedResult,
size_t packThresholdSizeInBytes = DEFAULT_PACK_THRESHOLD_SIZE_IN_BYTES)
{
// Each node contains accumulated values for part of the data set, we have to aggregate accumulated values.
AggregateAccumulatorValuesAndUpdateEvaluation<ElemType>(net, evalNodesWhichAccumulateResult, gradHeader, mpi);
AggregateAccumulatorValuesAndUpdateEvaluation<ElemType>(net, evalNodesWhichAccumulateResult, gradHeader, mpi, packThresholdSizeInBytes);
// After values of accumulators have been aggregated accross nodes, we have to update evaluation results for
// evaluation nodes that accumulate results.

Просмотреть файл

@ -1511,7 +1511,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// and recalculate evaluation errors based on accumulators.
AggregateAccumulatorValuesAndUpdateEpochEvaluation<ElemType>(
net, evaluationNodesWhichAccumulateResult, m_gradHeader, m_mpi, epochEvalErrors, evaluationNodes,
localEpochEvalErrors, ContainsAccumulatedResult);
localEpochEvalErrors, ContainsAccumulatedResult, m_packThresholdSizeInBytes);
}
return totalEpochSamples;
@ -2111,7 +2111,7 @@ void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int numGradientBits, int d
if (Globals::UseV2Aggregator()) // Currently used to check V2 against baselines.
m_distGradAgg = std::make_shared<V2SimpleDistGradAggregator<ElemType>>(m_mpi, m_bufferedAsyncGradientAggregation, deviceId, m_syncStatsTrace, ::CNTK::MPICommunicator());
else
m_distGradAgg = std::make_shared<SimpleDistGradAggregator<ElemType>>(m_mpi, m_bufferedAsyncGradientAggregation, deviceId, m_syncStatsTrace);
m_distGradAgg = std::make_shared<SimpleDistGradAggregator<ElemType>>(m_mpi, m_bufferedAsyncGradientAggregation, deviceId, m_syncStatsTrace, m_packThresholdSizeInBytes);
}
m_gradHeader.reset(DistGradHeader::Create(numEvalNodes), [](DistGradHeader* ptr) { DistGradHeader::Destroy(ptr); });
@ -2701,6 +2701,8 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_maxSamplesInRAM = configSGD(L"maxSamplesInRAM", (size_t) SIZE_MAX);
m_numSubminiBatches = configSGD(L"numSubminibatches", (size_t) 1);
m_packThresholdSizeInBytes = configSGD(L"packThresholdSizeInKB", DEFAULT_PACK_THRESHOLD_SIZE_IN_KB) * 1024;
if (configAALR.Exists(L"numMiniBatch4LRSearch"))
{
LOGPRINTF(stderr, "WARNING: 'numMiniBatch4LRSearch' is deprecated, please remove it and use 'numSamples4Search' instead.\n");

Просмотреть файл

@ -200,6 +200,9 @@ protected:
intargvector m_numSamples4Search;
size_t m_numBestSearchEpoch;
// Threshold size in bytes for single gradient to do packing
size_t m_packThresholdSizeInBytes;
LearningRateSearchAlgorithm m_autoLearnRateSearchType;
AdaptationRegType m_adaptationRegType;

Просмотреть файл

@ -6,6 +6,7 @@
#pragma once
#include "Constants.h"
#include "IDistGradAggregator.h"
#include "CUDAPageLockedMemAllocator.h"
#include "NcclComm.h"
@ -22,8 +23,9 @@ class SimpleDistGradAggregator : public IDistGradAggregator<ElemType>
UsingIDistGradAggregatorMembers;
public:
SimpleDistGradAggregator(const MPIWrapperPtr& mpi, bool useAsyncAggregation, int deviceId, int syncStatsTrace)
: IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_initialized(false), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace), m_iterationCount(0), m_nccl(deviceId, mpi)
SimpleDistGradAggregator(const MPIWrapperPtr& mpi, bool useAsyncAggregation, int deviceId, int syncStatsTrace, size_t packThresholdSizeInBytes = DEFAULT_PACK_THRESHOLD_SIZE_IN_BYTES)
: IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_initialized(false), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace),
m_iterationCount(0), m_nccl(deviceId, mpi), m_packThresholdSizeInBytes(packThresholdSizeInBytes)
{}
~SimpleDistGradAggregator()
@ -144,25 +146,65 @@ private:
m_initialized = true;
int deviceId = gradients[0]->GetDeviceId();
if (!m_nccl.IsSupported() && deviceId != CPUDEVICE)
if (!m_nccl.IsSupported() && (deviceId != CPUDEVICE))
m_allocator.reset(new CUDAPageLockedMemAllocator(deviceId));
size_t packedGradientsSizeInElements = 0;
for (size_t i = 0; i < gradients.size(); i++)
{
if (!m_useAsyncAggregation && sizeof(ElemType) * gradients[i]->GetNumElements() <= m_packThresholdSizeInBytes)
{
packedGradientsSizeInElements += gradients[i]->GetNumElements();
m_packedGradientsIndex.push_back(i);
}
else
{
m_gradientIndexToAggregate.push_back(i);
}
// Make sure none of the gradient matrixes are sparse - we currently do not support aggregation of sparse gradient matrices
if (gradients[i]->GetMatrixType() != DENSE)
RuntimeError("Gradient aggregation for sparse gradient matrices is currently unsupported!");
if (!m_nccl.IsSupported() && deviceId != CPUDEVICE)
{
m_gpuDataTransferers.push_back(std::make_unique<GPUDataTransferer>(deviceId, m_useAsyncAggregation));
m_intermediateCPUBuffers.push_back(AllocateIntermediateBuffer(deviceId, gradients[i]->GetNumElements()));
}
if (m_useAsyncAggregation)
m_bufferedGradients[gradients[i]].reset(new Matrix<ElemType>(gradients[i]->GetNumRows(), gradients[i]->GetNumCols(), deviceId));
}
// Packing matrices into continous buffer if not doing async aggregation
m_aggregationBuffer.reset();
if (packedGradientsSizeInElements > 0)
{
m_aggregationBuffer.reset(new (std::nothrow) Matrix<ElemType>(1, packedGradientsSizeInElements, deviceId));
}
// If no extra continous buffer allocated or using async aggregation
if (m_aggregationBuffer == nullptr)
{
m_gradientIndexToAggregate.clear();
m_packedGradientsIndex.clear();
packedGradientsSizeInElements = 0;
// Reuse "@param m_gradientIndexToAggregate" for following code, if no continous buffer allocated
for (size_t i = 0; i < gradients.size(); i++)
{
m_gradientIndexToAggregate.push_back(i);
}
}
else
{
// First element is reserved for continous buffer
m_gradientIndexToAggregate.insert(m_gradientIndexToAggregate.begin(), 1, (size_t)-1);
}
// If running on GPU and NCCL not supported, initialize GPU and CPU data transfer
if (!m_nccl.IsSupported() && (deviceId != CPUDEVICE))
{
for (size_t i : m_gradientIndexToAggregate)
{
m_gpuDataTransferers.push_back(std::make_unique<GPUDataTransferer>(deviceId, m_useAsyncAggregation));
m_intermediateCPUBuffers.push_back(AllocateIntermediateBuffer(deviceId,
(i == -1) ? packedGradientsSizeInElements : gradients[i]->GetNumElements()));
}
}
if (m_useAsyncAggregation)
{
m_bufferedGradHeader = DistGradHeader::Create(numEvalNodes);
@ -223,11 +265,33 @@ private:
}
}
// Initiate transfer of the gradient matrices to the CPU if needed
if (!m_nccl.IsSupported() && deviceId >= 0)
// Copy all gradient data into a single contiguous buffer, if additional continous buffer allocated
size_t offset = 0;
for (size_t i : m_packedGradientsIndex)
{
for (size_t i = 0; i < numGradMatrices; ++i)
m_gpuDataTransferers[i]->CopyGPUToCPUAsync(gradients[i]->Data(), gradients[i]->GetNumElements(), m_intermediateCPUBuffers[i].get());
m_aggregationBuffer->ColumnSlice(offset, gradients[i]->GetNumElements()).AssignValuesOf(gradients[i]->Reshaped(1, gradients[i]->GetNumElements()));
offset += gradients[i]->GetNumElements();
}
// Initiate transfer of the bufferred data to the CPU if needed
if (!m_nccl.IsSupported() && deviceId != CPUDEVICE)
{
size_t gpuDataTransfersIdx = 0;
Matrix<ElemType>* gpuCopyBuffer = m_aggregationBuffer.get();
for (size_t i : m_gradientIndexToAggregate)
{
if (i != -1)
{
gpuCopyBuffer = gradients[i];
}
else
{
// i == -1, first element is for packed gradients, which should not be with AsyncAggregation
assert(m_useAsyncAggregation == false);
}
m_gpuDataTransferers[gpuDataTransfersIdx]->CopyGPUToCPUAsync(gpuCopyBuffer->Data(), gpuCopyBuffer->GetNumElements(), m_intermediateCPUBuffers[gpuDataTransfersIdx].get());
gpuDataTransfersIdx++;
}
}
// Initiate receive of the header on the main node
@ -248,26 +312,35 @@ private:
m_mpi->Isend(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices, &sendHeaderRequest) || MpiFail("MPI_Isend");
// Perform async allreduce on the gradient data
std::vector<MPI_Request> allReduceRequests(numGradMatrices);
std::vector<MPI_Request> allReduceRequests;
if (!m_nccl.IsSupported())
{
for (size_t i = 0; i < numGradMatrices; ++i)
size_t allReduceIndex = 0;
ElemType* reductionBuffer;
for (size_t i : m_gradientIndexToAggregate)
{
ElemType* reductionBuffer = gradients[i]->Data();
if (deviceId >= 0)
allReduceRequests.push_back(MPI_Request());
reductionBuffer = (i == -1)? m_aggregationBuffer->Data() : gradients[i]->Data();
if (deviceId != CPUDEVICE)
{
m_gpuDataTransferers[i]->WaitForCopyGPUToCPUAsync();
reductionBuffer = m_intermediateCPUBuffers[i].get();
m_gpuDataTransferers[allReduceIndex]->WaitForCopyGPUToCPUAsync();
reductionBuffer = m_intermediateCPUBuffers[allReduceIndex].get();
}
// On Windows this async MPI_Iallreduce call requires MS MPI v7 or higher to be installed
m_mpi->Iallreduce(MPI_IN_PLACE, reductionBuffer, gradients[i]->GetNumElements(),
MPIWrapper::GetDataType(reductionBuffer), MPI_SUM,
&allReduceRequests[i]) || MpiFail("MPI_Iallreduce");
m_mpi->Iallreduce(MPI_IN_PLACE, reductionBuffer, (i == -1) ? m_aggregationBuffer->GetNumElements() : gradients[i]->GetNumElements(),
MPIWrapper::GetDataType(reductionBuffer), MPI_SUM, &allReduceRequests.back()) || MpiFail("MPI_Iallreduce");
allReduceIndex++;
}
}
}
else
m_nccl.AllReduce(gradients);
{
std::vector<Matrix<ElemType>*> ncclReduceGradients;
for (size_t i : m_gradientIndexToAggregate)
{
ncclReduceGradients.push_back((i == -1) ? m_aggregationBuffer.get() : gradients[i]);
}
m_nccl.AllReduce(ncclReduceGradients);
}
// On the main node wait for the headers to arrive and aggregate
if (m_mpi->IsMainNode())
@ -290,52 +363,48 @@ private:
assert(numNodesHeadersReceivedFrom == (NumProc() - 1));
}
// Initiate receive of the aggregate header
MPI_Request recvAggHeaderRequest;
if (!m_mpi->IsMainNode())
m_mpi->Irecv(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices + 1 + numGradMatrices, &recvAggHeaderRequest) || MpiFail("MPI_Irecv");
// Broadcast the aggregated header to all nodes
m_mpi->Bcast(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank());
// Intiate send of the aggregate header from main node
std::vector<MPI_Request> sendAggHeaderRequests(NumProc() - 1);
if (m_mpi->IsMainNode())
{
for (size_t j = 0; j < NumProc() - 1; ++j)
{
int dest = (j >= MyRank()) ? (j + 1) : j;
// TODO: Should we use MPI_Bcast instead for better performance
m_mpi->Isend(headerCPU, headerCPU->Size(), MPI_CHAR, dest, numGradMatrices + 1 + numGradMatrices, &(sendAggHeaderRequests[j])) || MpiFail("MPI_Isend");
}
}
// Wait for the allreduce operations to finish and initiate transfer back to the GPU if needed
if (!m_nccl.IsSupported())
{
for (size_t i = 0; i < numGradMatrices; ++i)
{
m_mpi->Wait(&allReduceRequests[i], MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
if (deviceId >= 0)
m_gpuDataTransferers[i]->CopyCPUToGPUAsync(m_intermediateCPUBuffers[i].get(), gradients[i]->GetNumElements(), gradients[i]->Data());
}
}
// Wait to receive aggregate header
if (!m_mpi->IsMainNode())
m_mpi->Wait(&recvAggHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
// Wait for all the transfers to finish
if (m_nccl.IsSupported())
m_nccl.Sync();
else if (deviceId >= 0)
{
for (size_t i = 0; i < numGradMatrices; ++i)
m_gpuDataTransferers[i]->WaitForCopyCPUToGPUAsync();
m_nccl.Sync();
}
else
{
// Wait for the allreduce operations to finish and initiate transfer back to the GPU if needed
size_t gpuDataTransfersIdx = 0; // Index of allReduceRequest for each un-packed gradient
for (size_t i : m_gradientIndexToAggregate)
{
m_mpi->Wait(&allReduceRequests[gpuDataTransfersIdx], MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
if (deviceId != CPUDEVICE)
{
m_gpuDataTransferers[gpuDataTransfersIdx]->CopyCPUToGPUAsync(m_intermediateCPUBuffers[gpuDataTransfersIdx].get(),
(i == -1) ? m_aggregationBuffer->GetNumElements() : gradients[i]->GetNumElements(),
(i == -1) ? m_aggregationBuffer->Data() : gradients[i]->Data());
}
gpuDataTransfersIdx++;
}
// Wait for copy data from CPU to GPU, if not running on CPU and not NCCL enabled
if (deviceId != CPUDEVICE)
{
for (size_t i = 0; i < m_gradientIndexToAggregate.size(); i++)
m_gpuDataTransferers[i]->WaitForCopyCPUToGPUAsync();
}
}
// Copy data back to the packed gradients from the continous buffer
offset = 0;
for (size_t i : m_packedGradientsIndex)
{
gradients[i]->AssignValuesOf(m_aggregationBuffer->ColumnSlice(offset, gradients[i]->GetNumElements()).Reshaped(gradients[i]->GetNumRows(), gradients[i]->GetNumCols()));
offset += gradients[i]->GetNumElements();
}
// Wait for completion of the async send requests
if (!m_mpi->IsMainNode())
m_mpi->Wait(&sendHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
else
m_mpi->Waitall(sendAggHeaderRequests.size(), sendAggHeaderRequests.data(), MPI_STATUSES_IGNORE) || MpiFail("MPI_Waitall");
if (showSyncPerfStats)
{
@ -347,8 +416,8 @@ private:
private:
std::unique_ptr<CUDAPageLockedMemAllocator> m_allocator;
std::vector<std::shared_ptr<ElemType>> m_intermediateCPUBuffers;
std::vector<std::shared_ptr<ElemType>> m_intermediateCPUBuffers;
std::vector<std::unique_ptr<GPUDataTransferer>> m_gpuDataTransferers;
std::vector<DistGradHeader*> m_recvHeaders;
@ -363,6 +432,13 @@ private:
std::unordered_map<Matrix<ElemType>*, std::unique_ptr<Matrix<ElemType>>> m_bufferedGradients;
DistGradHeader* m_bufferedGradHeader;
// Packing small gradients (size not larger than threshold size) into a continous buffer to reduce MPI calls.
// Threshold size to pack a gradient into the continous buffer, default 32KB (tunable by define "packThresholdSizeInKB=[value]")
const size_t m_packThresholdSizeInBytes;
std::unique_ptr<Matrix<ElemType>> m_aggregationBuffer;
std::vector<size_t> m_packedGradientsIndex;
std::vector<size_t> m_gradientIndexToAggregate;
int m_syncStatsTrace;
// Only used for controlling frequency of measuring/showing gradient aggregation perf stats