зеркало из https://github.com/microsoft/msccl.git
Merge pull request #5 from parasailteam/channel_management
Channel management
This commit is contained in:
Коммит
e785648283
|
@ -18,26 +18,27 @@ template<int ALGO, int PROTO, class FUNC, typename T, int UNROLL>
|
|||
class ncclFunction<ncclFuncAllToAll, ALGO, PROTO, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct scklAlgorithm* scklAlgo = &comm->scklAlgo;
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int bid = blockIdx.x;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
const int scklNumBlocksPerChannel = args->scklNumBlocksPerChannel;
|
||||
const int channelId = bid/scklNumBlocksPerChannel;
|
||||
const int scklNBlocks = scklAlgo->nBlocks;
|
||||
const int rscklbid = bid % scklNBlocks; // bid within a sckl algo
|
||||
const int scklIndex = bid / scklNBlocks; // which instance of sckl algo
|
||||
const int nScklInstnaces = gridDim.x / scklAlgo->nBlocks; // number of sckl aglos
|
||||
struct scklThreadBlock* scklTB = &scklAlgo->scklTB[rscklbid];
|
||||
const int channelId = scklIndex * scklAlgo->nChannels + scklTB->channelId;
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
// relative bid to a channel
|
||||
int rbid = bid % scklNumBlocksPerChannel;
|
||||
struct scklAlgorithm* scklAlgo = &comm->scklAlgo;
|
||||
struct scklThreadBlock* sckltb = &scklAlgo->scklTB[rbid];
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * ALLTOALL_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const int nchunksPerLoopPerRank = scklAlgo->nchunksPerLoop/nranks;
|
||||
const int totalNChunksPerLoopPerRank = nScklInstnaces*nchunksPerLoopPerRank;
|
||||
const ssize_t loopSize = (ssize_t)chunkSize*nScklInstnaces;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int nChunks = scklAlgo->nChunks;
|
||||
// assume that size is divisible by nchunks
|
||||
const ssize_t sizePerChunk = size/nChunks;
|
||||
const ssize_t sizePerScklChunk = size/nchunksPerLoopPerRank;
|
||||
// sckl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags
|
||||
// this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions.
|
||||
const int workIndex = args->index+1;
|
||||
|
@ -47,26 +48,26 @@ class ncclFunction<ncclFuncAllToAll, ALGO, PROTO, FUNC, T, UNROLL> {
|
|||
T * thisOutput = (T*)args->recvbuff;
|
||||
int myRank = channel->ring.devUserRanks[0];
|
||||
int m1 = -1;
|
||||
int recvPeer = (sckltb->type == SCKL_RECV) ? sckltb->peer : m1;
|
||||
int sendPeer = (sckltb->type == SCKL_SEND) ? sckltb->peer : m1;
|
||||
int recvPeer = (scklTB->type == SCKL_RECV) ? scklTB->peer : m1;
|
||||
int sendPeer = (scklTB->type == SCKL_SEND) ? scklTB->peer : m1;
|
||||
|
||||
ncclPrimitives<UNROLL, ALLTOALL_CHUNKSTEPS/ALLTOALL_SLICESTEPS, ALLTOALL_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &recvPeer, &sendPeer, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerChunk; gridOffset += loopSize, iter++) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(sizePerChunk-gridOffset,nChannels));
|
||||
for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerScklChunk; gridOffset += loopSize, iter++) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(sizePerScklChunk-gridOffset,nScklInstnaces));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + channelId*realChunkSize;
|
||||
ssize_t chunkOffset = gridOffset + scklIndex*realChunkSize;
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, sizePerChunk-chunkOffset);
|
||||
for (int i = 0; i < sckltb->nsteps; i++){
|
||||
struct scklTransfer* sckltran = &sckltb->transfers[i];
|
||||
int nelem = min(realChunkSize, sizePerScklChunk-chunkOffset);
|
||||
for (int i = 0; i < scklTB->nsteps; i++){
|
||||
struct scklTransfer* sckltran = &scklTB->transfers[i];
|
||||
if (sckltran->offset == -1) continue;
|
||||
offset = chunkOffset + sckltran->offset * sizePerChunk;
|
||||
offset = chunkOffset + sckltran->offset * sizePerScklChunk;
|
||||
T* thisbuffer = (sckltran->buffer == SCKL_INPUT_BUFFER) ? thisInput : thisOutput;
|
||||
if (sckltb->type == SCKL_SEND){
|
||||
int8_t dependentBid = sckltran->dependentRbid + scklNumBlocksPerChannel * channelId;
|
||||
if (scklTB->type == SCKL_SEND){
|
||||
int8_t dependentBid = sckltran->dependentBid + scklIndex * scklNBlocks;
|
||||
int8_t dependentStep = sckltran->dependentStep;
|
||||
if (sckltran->dependentRbid >= 0){
|
||||
if (sckltran->dependentBid >= 0){
|
||||
if (tid == 0){
|
||||
uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep);
|
||||
while ((scklFlags + dependentBid)->flag < goalFlag){};
|
||||
|
@ -74,7 +75,7 @@ class ncclFunction<ncclFuncAllToAll, ALGO, PROTO, FUNC, T, UNROLL> {
|
|||
__syncthreads();
|
||||
}
|
||||
prims.directSend(thisbuffer + offset, offset, nelem);
|
||||
} else if (sckltb->type == SCKL_RECV) {
|
||||
} else if (scklTB->type == SCKL_RECV) {
|
||||
prims.directRecv(thisbuffer + offset, offset, nelem);
|
||||
if (tid == 0){
|
||||
__threadfence();
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "collectives.h"
|
||||
#include "devcomm.h"
|
||||
#include <stdio.h>
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#define COLL_UNROLL 8
|
||||
|
@ -40,13 +41,13 @@ static __device__ void load_parallel(void* dst, void* src, size_t size, int tid)
|
|||
int* s = (int*)src;
|
||||
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
|
||||
}
|
||||
static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork* hostWork, int tid, struct ncclDevComm* comm, int rbid) {
|
||||
static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork* hostWork, int tid, struct ncclDevComm* comm, int activeId) {
|
||||
__syncthreads();
|
||||
load_parallel(localWork, hostWork, sizeof(struct ncclWork), tid);
|
||||
// Check whether the last operation was aborted and make sure all threads exit
|
||||
int abort = tid == 0 ? *(comm->abortFlag) : 0;
|
||||
exitIfAbortBarrier(abort);
|
||||
if (tid == 0) hostWork->elems[0].active[rbid] = 0;
|
||||
if (tid == 0) hostWork->elems[0].active[activeId] = 0;
|
||||
}
|
||||
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
|
||||
|
@ -79,11 +80,13 @@ __device__ void ncclKernel(struct ncclWorkElem first) {
|
|||
auto f = ncclFunction<FUNCTION, ALGO, PROTO, REDOP, T, UNROLL>();
|
||||
|
||||
struct ncclDevComm* comm = first.comm;
|
||||
// SCKL: this needs to be changed such that a mixture of SCKL and NCCL can be handled
|
||||
const int scklNumBlocksPerChannel = first.scklNumBlocksPerChannel;
|
||||
|
||||
int channelId = bid / scklNumBlocksPerChannel;
|
||||
int rbid = bid % scklNumBlocksPerChannel;
|
||||
int channelId = bid;
|
||||
int activeId = 0;
|
||||
if (ALGO == NCCL_ALGO_SCKL){
|
||||
int rbid = bid % comm->scklAlgo.nBlocks;
|
||||
channelId = (bid / comm->scklAlgo.nBlocks) * comm->scklAlgo.nChannels + comm->scklAlgo.scklTB[rbid].channelId;
|
||||
activeId = comm->scklAlgo.scklTB[rbid].rid;
|
||||
}
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
struct ncclWorkElem* w = NULL;
|
||||
uint16_t index = first.index;
|
||||
|
@ -94,7 +97,7 @@ __device__ void ncclKernel(struct ncclWorkElem first) {
|
|||
while (1) {
|
||||
if (w == NULL) {
|
||||
w = shmem.localWork.elems;
|
||||
load_coll(&shmem.localWork, channel->workFifo+index, tid, comm, rbid);
|
||||
load_coll(&shmem.localWork, channel->workFifo+index, tid, comm, activeId);
|
||||
}
|
||||
if (tid < w->nThreads) {
|
||||
// SCKL uses w->index as an indicator for the progress this threadblock has made. in case index wraps around due to overflow, w->index is increament so that the progress invariant is still true
|
||||
|
@ -112,7 +115,7 @@ __device__ void ncclKernel(struct ncclWorkElem first) {
|
|||
}
|
||||
if (index == NCCL_MAX_OPS-1) wrappedAround = 1;
|
||||
index = (index+1) % NCCL_MAX_OPS;
|
||||
if (w->active[rbid] == 2) {
|
||||
if (w->active[activeId] == 2) {
|
||||
return;
|
||||
}
|
||||
w = NULL;
|
||||
|
|
|
@ -18,8 +18,8 @@ __device__ struct ncclShmemData* ncclShmem;
|
|||
#define NCCL_FUNC4(func, redop, type) \
|
||||
NCCL_FUNC5(func, TREE, redop, type), \
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type), \
|
||||
NCCL_FUNC5(func, SCKL, redop, type)
|
||||
NCCL_FUNC5(func, SCKL, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#define NCCL_FUNC4(func, redop, type) \
|
||||
(void*)NCCL_FUNC5(func, TREE, redop, type), \
|
||||
(void*)NCCL_FUNC5(func, RING, redop, type), \
|
||||
(void*)NCCL_FUNC5(func, COLLNET, redop, type), \
|
||||
(void*)NCCL_FUNC5(func, SCKL, redop, type)
|
||||
(void*)NCCL_FUNC5(func, SCKL, redop, type), \
|
||||
(void*)NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
|
@ -97,8 +97,7 @@ static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** wor
|
|||
int opIndex = channel->workFifoTail%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
struct ncclWorkElem* e = w->elems;
|
||||
// SCKL replicates active, make sure all of them are 0
|
||||
for (int i=0; i<e->scklNumBlocksPerChannel; i++){
|
||||
for (int i=0; i<e->nActives; i++){
|
||||
volatile uint16_t* activePtr = (volatile uint16_t*)&e->active[i];
|
||||
while (activePtr[0] != 0) sched_yield();
|
||||
}
|
||||
|
@ -106,7 +105,7 @@ static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** wor
|
|||
// Initialize with work elem if provided
|
||||
if (base) memcpy(e, base, sizeof(struct ncclWorkElem));
|
||||
|
||||
for (int i=0; i<base->scklNumBlocksPerChannel; i++){
|
||||
for (int i=0; i<base->nActives; i++){
|
||||
e->active[i] = 1;
|
||||
}
|
||||
e->index = opIndex;
|
||||
|
@ -123,8 +122,8 @@ static ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams*
|
|||
}
|
||||
|
||||
// Set active = 2 for the last operation and add a no-op on empty channels (p2p case).
|
||||
// SCKL: this needs to be set properly as different work items can have different scklNumBlocksPerChannel
|
||||
int scklNumBlocksPerChannel = 1;
|
||||
// SCKL: this loop sets number of active elements according to SCKL aglo. Also total number of threadblocks are calculated here.
|
||||
int totalNBlocks = 0;
|
||||
for (int c=0; c<params->gridDim.x; c++) {
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
if (channel->workCount == 0) {
|
||||
|
@ -136,16 +135,17 @@ static ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams*
|
|||
e->p2p.nThreads = 0;
|
||||
}
|
||||
int channelTailIndex = ((channel->workFifoTail-1)%NCCL_MAX_OPS);
|
||||
scklNumBlocksPerChannel = channel->workFifo[channelTailIndex].elems[0].scklNumBlocksPerChannel;
|
||||
for (int i=0; i<channel->workFifo[channelTailIndex].elems[0].scklNumBlocksPerChannel; i++) {
|
||||
int nActives = channel->workFifo[channelTailIndex].elems[0].nActives;
|
||||
for (int i=0; i<nActives; i++) {
|
||||
channel->workFifo[channelTailIndex].elems[0].active[i] = 2;
|
||||
}
|
||||
totalNBlocks += nActives;
|
||||
}
|
||||
|
||||
// This is the first time SCKL disassociates bids and channels
|
||||
// SCKL for now we are assuming scklNumBlocksPerChannel is the same for each channel.
|
||||
// multiply the number of threadblocks by scklNumBlocksPerChannel
|
||||
params->gridDim.x *= scklNumBlocksPerChannel;
|
||||
// set the gridDim accordingly to the number of active elements. if the algorithm is non-SCKL,
|
||||
// it remains the same. otherwise, it is set to the total # threadblocks necessary for SCKL algorithm
|
||||
// this is the first time sckl disassociates thread block and channels
|
||||
params->gridDim.x = totalNBlocks;
|
||||
|
||||
// Find the first operation, choose the kernel accordingly and pass it
|
||||
// as the first argument.
|
||||
|
@ -155,7 +155,7 @@ static ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams*
|
|||
memcpy(&comm->args, elem, sizeof(struct ncclWorkElem));
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P){
|
||||
for (int i=0; i<elem->scklNumBlocksPerChannel; i++)
|
||||
for (int i=0; i<elem->nActives; i++)
|
||||
elem->active[i] = 0;
|
||||
}
|
||||
|
||||
|
@ -261,16 +261,13 @@ ncclResult_t ncclBarrierEnqueueWait(ncclComm_t comm) {
|
|||
// Also, starting the proxies after the CUDA launch seems to be better for
|
||||
// performance (latency).
|
||||
|
||||
// try to find how many extra threadblocks were allocated for SCKL and adjust gridDim.x
|
||||
int channelTailIndex = ((comm->channels[0].workFifoTail-1)%NCCL_MAX_OPS);
|
||||
int scklNumBlocksPerChannel = comm->channels[0].workFifo[channelTailIndex].elems[0].scklNumBlocksPerChannel;
|
||||
params->gridDim.x /= scklNumBlocksPerChannel;
|
||||
|
||||
uint64_t max = 0ULL;
|
||||
for (int r=0; r<params->gridDim.x; r++) {
|
||||
for (int r=0; r<comm->p2pnChannels; r++) {
|
||||
struct ncclChannel* channel = comm->channels+r;
|
||||
max = std::max(max, channel->workFifoTail);
|
||||
channel->workCount = 0;
|
||||
if (channel->workCount) {
|
||||
max = std::max(max, channel->workFifoTail);
|
||||
channel->workCount = 0;
|
||||
}
|
||||
}
|
||||
for (int r=0; r<comm->p2pnChannels; r++) {
|
||||
struct ncclChannel* channel = comm->channels+r;
|
||||
|
@ -341,6 +338,16 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
|||
if (info->protocol == NCCL_PROTO_SIMPLE) nt += WARP_SIZE; // Extra warp for sync
|
||||
if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE;
|
||||
info->nChannels = nc;
|
||||
// SCKL needs comm->scklAlgo.nChannels. if there are more channels, extra ones replicate SCKL algorithm
|
||||
if (info->algorithm == NCCL_ALGO_SCKL){
|
||||
info->nChannels = ROUNDUP(nc,comm->scklAlgo.nChannels);
|
||||
if (info->nChannels > comm->nChannels)
|
||||
info->nChannels -= comm->scklAlgo.nChannels;
|
||||
if (info->nChannels > comm->nChannels || info->nChannels < comm->scklAlgo.nChannels){
|
||||
WARN("SCKL algo should have at least %d channels but ended up with %d channels.", comm->scklAlgo.nChannels, comm->nChannels);
|
||||
return ncclInternalError;
|
||||
}
|
||||
}
|
||||
info->nThreads = nt;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
@ -381,8 +388,9 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
|
|||
case ncclPatternRingTwice:
|
||||
info->nstepsPerLoop = 2*(info->comm->nRanks-1); info->nchunksPerLoop = info->comm->nRanks; break;
|
||||
case ncclPatternSckl:
|
||||
// SCKL needs a specific number of steps per loop for each connection. it is set properly in ncclProxySaveColl
|
||||
info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->nRanks * info->comm->scklAlgo.nChunks; break;
|
||||
// SCKL needs a specific number of steps per loop for each channel/connection. it is set properly in ncclProxySaveColl
|
||||
// n chunks per loop identifies how many chunks from the input buffer is processed in each iteration.
|
||||
info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->scklAlgo.nchunksPerLoop; break;
|
||||
default:
|
||||
WARN("Unknown pattern %d", info->pattern);
|
||||
return ncclInternalError;
|
||||
|
@ -450,7 +458,8 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
|
|||
if (info->protocol == NCCL_PROTO_LL) chunkEffectiveSize /= 2;
|
||||
if (info->protocol == NCCL_PROTO_LL128) chunkEffectiveSize = (chunkSize / NCCL_LL128_LINEELEMS) * NCCL_LL128_DATAELEMS;
|
||||
//if (info->comm->rank == 0) printf("Coll %d, size %ld -> %dx%d, chunkSize %d (algo %d proto%d)\n", info->coll, info->nBytes, info->nChannels, info->nThreads, chunkSize, info->algorithm, info->protocol);
|
||||
proxyArgs->nLoops = (int)(DIVUP(info->nBytes, (((size_t)(info->nChannels))*info->nchunksPerLoop*chunkEffectiveSize)));
|
||||
// sckl might use multiple channels per loop. therefore, the division by info->comm->scklAlgo.nChannels is necessary if the algo is SCKL.
|
||||
proxyArgs->nLoops = (int)(DIVUP(info->nBytes, ((((size_t)(info->nChannels))*info->nchunksPerLoop*chunkEffectiveSize)/(size_t) (info->algorithm == NCCL_ALGO_SCKL ? info->comm->scklAlgo.nChannels : 1))));
|
||||
// nstepsPerloop for sckl is incorrect and will be adjusted in ncclProxySaveColl
|
||||
proxyArgs->nsteps = info->nstepsPerLoop * proxyArgs->nLoops * chunkSteps;
|
||||
proxyArgs->sliceSteps = sliceSteps;
|
||||
|
@ -507,11 +516,10 @@ ncclResult_t ncclSaveKernel(struct ncclInfo* info) {
|
|||
|
||||
int nChannels = work.coll.nChannels;
|
||||
int nSubChannels = (info->pattern == ncclPatternCollTreeUp || info->pattern == ncclPatternCollTreeDown) ? 2 : 1;
|
||||
|
||||
for (int bid=0; bid<nChannels*nSubChannels; bid++) {
|
||||
int channelId = info->comm->myParams->gridDim.x % info->comm->nChannels;
|
||||
struct ncclChannel* channel = info->comm->channels+channelId;
|
||||
work.scklNumBlocksPerChannel = (info->algorithm == NCCL_ALGO_SCKL) ? info->comm->scklAlgo.nBlocks : 1;
|
||||
work.nActives = (info->algorithm == NCCL_ALGO_SCKL) ? info->comm->scklAlgo.scklChannels[channelId % info->comm->scklAlgo.nChannels].nBlocksForChannel : 1;
|
||||
// Proxy
|
||||
proxyArgs.channel = channel;
|
||||
// Adjust pattern for CollNet based on channel index
|
||||
|
@ -658,7 +666,7 @@ ncclResult_t ncclSaveP2pKernel(struct ncclInfo* info) {
|
|||
info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1);
|
||||
info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads);
|
||||
// sckl does not generate p2p kernels.
|
||||
w->elems[0].scklNumBlocksPerChannel = 1;
|
||||
w->elems[0].isScklAlgorithm = 0;
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
|
|
@ -622,9 +622,9 @@ ncclResult_t scklGetAlgoFromXMLAndSetComm(struct ncclComm* comm) {
|
|||
memset(scklAlgo, 0, sizeof(struct scklAlgorithm));
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "algo", &topNode));
|
||||
int nchunks;
|
||||
NCCLCHECK(xmlGetAttrInt(topNode, "nchunks", &nchunks));
|
||||
scklAlgo->nChunks = nchunks;
|
||||
int nchunksPerLoop;
|
||||
NCCLCHECK(xmlGetAttrInt(topNode, "nchunksperloop", &nchunksPerLoop));
|
||||
scklAlgo->nchunksPerLoop = nchunksPerLoop;
|
||||
for (int s=0; s<topNode->nSubs; s++) {
|
||||
struct ncclXmlNode* node = topNode->subs[s];
|
||||
if (strcmp(node->name, "gpu") == 0){
|
||||
|
@ -635,24 +635,30 @@ ncclResult_t scklGetAlgoFromXMLAndSetComm(struct ncclComm* comm) {
|
|||
for (int t=0; t<node->nSubs; t++) {
|
||||
struct ncclXmlNode* threadblockNode = node->subs[t];
|
||||
if (strcmp(threadblockNode->name, "threadblock") == 0){
|
||||
int rbid, peer, channelId;
|
||||
int bid, peer, channelId;
|
||||
const char* type;
|
||||
NCCLCHECK(xmlGetAttrInt(threadblockNode, "rbid", &rbid));
|
||||
NCCLCHECK(xmlGetAttrInt(threadblockNode, "bid", &bid));
|
||||
NCCLCHECK(xmlGetAttrInt(threadblockNode, "peer", &peer));
|
||||
NCCLCHECK(xmlGetAttrStr(threadblockNode, "type", &type));
|
||||
NCCLCHECK(xmlGetAttrInt(threadblockNode, "chan", &channelId));
|
||||
if (rbid >= SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL){
|
||||
WARN("Too many thread blocks are requested. Max thread blocks: %d, requested: %d", SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL, rbid+1);
|
||||
return ncclInternalError;
|
||||
}
|
||||
if (rbid < 0){
|
||||
WARN("rbid must be positive. rbid: %d", rbid);
|
||||
return ncclInternalError;
|
||||
if (bid < 0){
|
||||
WARN("bid must be positive. bid: %d", bid);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
scklAlgo->nBlocks = std::max(comm->scklAlgo.nBlocks, rbid+1);
|
||||
struct scklThreadBlock* sTB = &scklAlgo->scklTB[rbid];
|
||||
scklAlgo->nBlocks = std::max(scklAlgo->nBlocks, bid+1);
|
||||
if (bid >= MAXCHANNELS*SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL){
|
||||
WARN("Too many thread blocks are requested. Max thread blocks: %d", SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL*MAXCHANNELS);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
struct scklThreadBlock* sTB = &scklAlgo->scklTB[bid];
|
||||
sTB->nsteps = 0;
|
||||
sTB->peer = peer;
|
||||
if (channelId < 0 || channelId > MAXCHANNELS){
|
||||
WARN("ChannelId needs to be between 0 and %d and it was %d", MAXCHANNELS, channelId);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
sTB->channelId = channelId;
|
||||
scklAlgo->nChannels = std::max(scklAlgo->nChannels, channelId+1);
|
||||
if (strcmp(type, "send") == 0){
|
||||
sTB->type = SCKL_SEND;
|
||||
} else if (strcmp(type, "recv") == 0) {
|
||||
|
@ -686,7 +692,7 @@ ncclResult_t scklGetAlgoFromXMLAndSetComm(struct ncclComm* comm) {
|
|||
NCCLCHECK(xmlGetAttrStr(stepNode, "buffer", &buffer));
|
||||
sTB->nsteps = std::max(sTB->nsteps, (uint8_t)(s+1));
|
||||
sTB->transfers[s].offset = offset;
|
||||
sTB->transfers[s].dependentRbid = depend_bid;
|
||||
sTB->transfers[s].dependentBid = depend_bid;
|
||||
sTB->transfers[s].dependentStep = depend_step;
|
||||
if (strcmp(buffer, "input") == 0){
|
||||
sTB->transfers[s].buffer = SCKL_INPUT_BUFFER;
|
||||
|
@ -694,20 +700,27 @@ ncclResult_t scklGetAlgoFromXMLAndSetComm(struct ncclComm* comm) {
|
|||
sTB->transfers[s].buffer = SCKL_OUTPUT_BUFFER;
|
||||
} else {
|
||||
WARN("type of buffer is not supported: %s", buffer);
|
||||
return ncclInternalError;
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
ntransfers++;
|
||||
}
|
||||
}
|
||||
// setting the summary of the sckl aglorithm
|
||||
scklChannelInfo* scklChannel = &scklAlgo->scklChannels[sTB->channelId];
|
||||
sTB->rid = scklChannel->nsendPeers + scklChannel->nrecvPeers;
|
||||
if (sTB->type == SCKL_SEND){
|
||||
scklAlgo->sendPeers[scklAlgo->nsendPeers] = peer;
|
||||
scklAlgo->nchunksForSendPeer[scklAlgo->nsendPeers] = ntransfers;
|
||||
scklAlgo->nsendPeers++;
|
||||
scklChannel->sendPeers[scklChannel->nsendPeers] = peer;
|
||||
scklChannel->nchunksForSendPeer[scklChannel->nsendPeers] = ntransfers;
|
||||
scklChannel->nsendPeers++;
|
||||
} else if (sTB->type == SCKL_RECV){
|
||||
scklAlgo->recvPeers[scklAlgo->nrecvPeers] = peer;
|
||||
scklAlgo->nchunksForRecvPeer[scklAlgo->nrecvPeers] = ntransfers;
|
||||
scklAlgo->nrecvPeers++;
|
||||
scklChannel->recvPeers[scklChannel->nrecvPeers] = peer;
|
||||
scklChannel->nchunksForRecvPeer[scklChannel->nrecvPeers] = ntransfers;
|
||||
scklChannel->nrecvPeers++;
|
||||
}
|
||||
scklChannel->nBlocksForChannel = std::max(scklChannel->nBlocksForChannel, sTB->rid+1);
|
||||
if (scklChannel->nBlocksForChannel > SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL){
|
||||
WARN("Too many sends/recv per channel. Max allowed %d", SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@
|
|||
#define DECL3(func, redop, type) \
|
||||
DECL4(func, RING, redop, type) \
|
||||
DECL4(func, TREE, redop, type) \
|
||||
DECL4(func, COLLNET, redop, type) \
|
||||
DECL4(func, SCKL, redop, type)
|
||||
DECL4(func, SCKL, redop, type) \
|
||||
DECL4(func, COLLNET, redop, type)
|
||||
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
|
|
|
@ -126,7 +126,7 @@ struct ncclRing {
|
|||
struct scklTransfer {
|
||||
int16_t offset;
|
||||
uint8_t buffer; // follow SCKL_THIS_INPUT/SCKL_THIS_OUTPUT macros
|
||||
int8_t dependentRbid; // -1 if not dependent on any threadblock
|
||||
int8_t dependentBid; // -1 if not dependent on any threadblock
|
||||
int8_t dependentStep;
|
||||
};
|
||||
|
||||
|
@ -137,27 +137,34 @@ struct scklThreadBlock {
|
|||
uint8_t peer;
|
||||
uint8_t type; // follow SCKL_SEND and SCKL_RECV macros
|
||||
uint8_t nsteps;
|
||||
uint8_t channelId; // not going to be used for this version. just setting it up for the next version
|
||||
uint8_t channelId; // associated channel
|
||||
uint8_t rid; // relative id of this thread block to the channel
|
||||
// step is used to index into this array. transfers[step] is the addr to transfer.
|
||||
struct scklTransfer transfers[SCKL_MAX_NUM_STEPS];
|
||||
};
|
||||
|
||||
// gpuId is the one that is in comm->rank
|
||||
struct scklAlgorithm {
|
||||
// number of chunks per gpu
|
||||
int nChunks;
|
||||
// number of threadblocks
|
||||
int nBlocks;
|
||||
// rbid is used as an index into this array
|
||||
struct scklThreadBlock scklTB[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
// these two arrays can be inferred from scklTB. they are created to use NCCL API easily
|
||||
struct scklChannelInfo {
|
||||
int sendPeers[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nchunksForSendPeer[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nsendPeers;
|
||||
int recvPeers[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nchunksForRecvPeer[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nrecvPeers;
|
||||
int nBlocksForChannel;
|
||||
};
|
||||
|
||||
// gpuId is the one that is in comm->rank
|
||||
struct scklAlgorithm {
|
||||
// max(#chunks in input, #chunks in output)
|
||||
int nchunksPerLoop;
|
||||
// total number of threadblocks needed by sckl algorithm
|
||||
int nBlocks;
|
||||
// bid is used as an index into this array
|
||||
struct scklThreadBlock scklTB[MAXCHANNELS*SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
// number of channels needed by sckl algorithm
|
||||
int nChannels;
|
||||
// the arrays in this struct can be inferred from scklTB. they are created to use NCCL API easily
|
||||
struct scklChannelInfo scklChannels[MAXCHANNELS];
|
||||
};
|
||||
|
||||
#define NCCL_MAX_TREE_ARITY 3
|
||||
|
@ -188,7 +195,8 @@ struct ncclWorkElem {
|
|||
uint16_t index;
|
||||
// in SCKL algorithms, ncclWorkElem.active element from workFifo is replicated for for all other thread blocks
|
||||
uint16_t active[SCKL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
uint16_t scklNumBlocksPerChannel;
|
||||
uint8_t isScklAlgorithm; // right now, 0 indicates not a sckl algorithm and 1 indicates it is. In future versions, this will be the index into arrays of scklAlgorithms.
|
||||
uint8_t nActives; // if it is a sckl algorithm, it must be set to associated channel number of thread blocks. if not a sckl algorithm, it is 1.
|
||||
|
||||
const void * sendbuff;
|
||||
void * recvbuff;
|
||||
|
|
|
@ -38,7 +38,7 @@ std::chrono::high_resolution_clock::time_point ncclEpoch;
|
|||
#endif
|
||||
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "AllToAll" };
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNet", "SCKL" };
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "SCKL", "CollNet" };
|
||||
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
|
||||
|
||||
NCCL_PARAM(GroupCudaStream, "GROUP_CUDA_STREAM", NCCL_GROUP_CUDA_STREAM);
|
||||
|
@ -865,7 +865,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
|||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
if (comm->nRanks == 1) continue;
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, comm->scklAlgo.nrecvPeers, comm->scklAlgo.recvPeers, comm->scklAlgo.nsendPeers, comm->scklAlgo.sendPeers), ret, affinity_restore);
|
||||
struct scklChannelInfo* scklChannel = &comm->scklAlgo.scklChannels[c % comm->scklAlgo.nChannels];
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, scklChannel->nrecvPeers, scklChannel->recvPeers, scklChannel->nsendPeers, scklChannel->sendPeers), ret, affinity_restore);
|
||||
}
|
||||
// It appears that graph is not really needed for P2pSetup. The only place that actually uses it is in ncclTopoGetNetDev which has a bypass for when it is set to NULL.
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, NULL), ret, affinity_restore);
|
||||
|
|
14
src/proxy.cc
14
src/proxy.cc
|
@ -220,14 +220,16 @@ ncclResult_t ncclProxySaveColl(struct ncclProxyArgs* args, int pattern, int root
|
|||
NCCLCHECK(SaveProxy(proxyRecv, tree->up, args));
|
||||
}
|
||||
if (pattern == ncclPatternSckl){
|
||||
int relativeChannelId = args->channel->id % scklAlgo->nChannels;
|
||||
scklChannelInfo* scklChannel = &scklAlgo->scklChannels[relativeChannelId];
|
||||
// nsteps is adjusted here for SCKL algo
|
||||
for (int i=0; i<scklAlgo->nrecvPeers; i++){
|
||||
args->nsteps = scklAlgo->nchunksForRecvPeer[i] * args->nLoops * args->chunkSteps;
|
||||
NCCLCHECK(SaveProxy(proxyRecv, scklAlgo->recvPeers[i], args));
|
||||
for (int i=0; i<scklChannel->nrecvPeers; i++){
|
||||
args->nsteps = scklChannel->nchunksForRecvPeer[i] * args->nLoops * args->chunkSteps;
|
||||
NCCLCHECK(SaveProxy(proxyRecv, scklChannel->recvPeers[i], args));
|
||||
}
|
||||
for (int i=0; i<scklAlgo->nsendPeers; i++){
|
||||
args->nsteps = scklAlgo->nchunksForSendPeer[i] * args->nLoops * args->chunkSteps;
|
||||
NCCLCHECK(SaveProxy(proxySend, scklAlgo->sendPeers[i], args));
|
||||
for (int i=0; i<scklChannel->nsendPeers; i++){
|
||||
args->nsteps = scklChannel->nchunksForSendPeer[i] * args->nLoops * args->chunkSteps;
|
||||
NCCLCHECK(SaveProxy(proxySend, scklChannel->sendPeers[i], args));
|
||||
}
|
||||
|
||||
if (args->connector && args->connector->conn.shared != 0){
|
||||
|
|
Загрузка…
Ссылка в новой задаче