From 9327d76850a06b55f75431d24851918c12d99ae4 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Fri, 9 Apr 2021 00:59:42 -0700 Subject: [PATCH] Extract shared parts of interpreter code Simple, LL and LL128 specify the differing parts with wrappers for the *Primitives structs. A SCKLFunction class template contains most of the interpreter logic. --- src/collectives/device/sckl_interpreter.h | 325 +++++++++------------- 1 file changed, 134 insertions(+), 191 deletions(-) diff --git a/src/collectives/device/sckl_interpreter.h b/src/collectives/device/sckl_interpreter.h index bb2913d..15f8cfa 100644 --- a/src/collectives/device/sckl_interpreter.h +++ b/src/collectives/device/sckl_interpreter.h @@ -14,14 +14,13 @@ #define COMPUTE_FLAG(__WORKINDEX__,__GRIDOFFSET_ITER__,__STEP__) \ SCKL_MAX_ITER*SCKL_MAX_NUM_STEPS*__WORKINDEX__ + (__GRIDOFFSET_ITER__ * SCKL_MAX_NUM_STEPS + __STEP__) -template -class SCKLFunctionSimple { +template +class SCKLFunction { public: __device__ void run(struct ncclWorkElem* args) { struct ncclDevComm* comm = args->comm; struct scklAlgorithm* scklAlgo = &comm->scklAlgo; const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; const int sync_tid = args->nThreads-1; // last thread is most likely not doing anthing and used for sckl cross thread synchronization const int bid = blockIdx.x; const int scklNBlocks = scklAlgo->nBlocks; @@ -31,32 +30,29 @@ class SCKLFunctionSimple { struct scklThreadBlock* scklTB = &scklAlgo->scklTB[rscklbid]; const int channelId = scklIndex * scklAlgo->nChannels + scklTB->channelId; struct ncclChannel* channel = comm->channels+channelId; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * SCKL_CHUNKSTEPS; - const int nranks = comm->nRanks; - const int nchunksPerLoopPerRank = scklAlgo->nchunksPerLoop/nranks; - const ssize_t loopSize = (ssize_t)chunkSize*nScklInstnaces; - const ssize_t size = args->coll.count; - const ssize_t sizePerScklChunk = size/nchunksPerLoopPerRank; - // sckl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags - // this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions. - const int workIndex = args->index+1; - volatile struct scklFlag* scklFlags = comm->scklFlags; + // Compute pointers T * thisInput = (T*)args->sendbuff; T * thisOutput = (T*)args->recvbuff; int recvPeer = scklTB->recvpeer; int sendPeer = scklTB->sendpeer; - ncclPrimitives - prims(tid, nthreads, &recvPeer, &sendPeer, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); + PRIMS_WRAPPER prims{args, tid, &recvPeer, &sendPeer, thisOutput, channel}; + + const int nranks = comm->nRanks; + const int nchunksPerLoopPerRank = scklAlgo->nchunksPerLoop/nranks; + const ssize_t loopSize = (ssize_t)prims.chunkSize*nScklInstnaces; + const ssize_t size = args->coll.count; + const ssize_t sizePerScklChunk = size/nchunksPerLoopPerRank; + // sckl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags + // this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions. + const int workIndex = args->index+1; + volatile struct scklFlag* scklFlags = comm->scklFlags; + for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerScklChunk; gridOffset += loopSize, iter++) { - int realChunkSize = min(chunkSize, DIVUP(sizePerScklChunk-gridOffset,nScklInstnaces)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); - ssize_t chunkOffset = gridOffset + scklIndex*realChunkSize; + size_t chunkOffset = prims.initIter(sizePerScklChunk, gridOffset, nScklInstnaces, scklIndex); ssize_t srcoffset, dstoffset; T* srcPointer, * dstPointer; - int nelem = min(realChunkSize, sizePerScklChunk-chunkOffset); for (int i = 0; i < scklTB->nsteps; i++){ struct scklTransfer* sckltran = &scklTB->transfers[i]; if (sckltran->type == SCKL_NO_OP) continue; @@ -77,180 +73,13 @@ class SCKLFunctionSimple { dstoffset = chunkOffset + (ssize_t) sckltran->dstoffset * sizePerScklChunk; switch (sckltran->type) { case SCKL_SEND: - prims.directSend(srcPointer + srcoffset, dstoffset, nelem); + prims.send(srcPointer + srcoffset, dstoffset); break; case SCKL_RECV: - prims.directRecv(dstPointer + dstoffset, dstoffset, nelem); + prims.recv(dstPointer + dstoffset, dstoffset); break; case SCKL_RECV_COPY_SEND: - prims.directRecvCopySend(dstPointer + dstoffset, dstoffset, nelem); - break; - default: - return; - } - - if (tid == sync_tid){ - __threadfence(); - uint64_t curFlag = COMPUTE_FLAG(workIndex, iter, i); - scklFlags[bid].flag = curFlag; - } - } - } - } -}; - -#include "prims_ll128.h" -template -class SCKLFunctionLL128 { - public: - __device__ void run(struct ncclWorkElem* args) { - struct ncclDevComm* comm = args->comm; - struct scklAlgorithm* scklAlgo = &comm->scklAlgo; - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int sync_tid = args->nThreads-1; // last thread is most likely not doing anthing and used for sckl cross thread synchronization - const int bid = blockIdx.x; - const int scklNBlocks = scklAlgo->nBlocks; - const int rscklbid = bid % scklNBlocks; // bid within a sckl algo - const int scklIndex = bid / scklNBlocks; // which instance of sckl algo - const int nScklInstnaces = gridDim.x / scklAlgo->nBlocks; // number of sckl aglos - struct scklThreadBlock* scklTB = &scklAlgo->scklTB[rscklbid]; - const int channelId = scklIndex * scklAlgo->nChannels + scklTB->channelId; - struct ncclChannel* channel = comm->channels+channelId; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; - const int nranks = comm->nRanks; - const int nchunksPerLoopPerRank = scklAlgo->nchunksPerLoop/nranks; - const ssize_t loopSize = (ssize_t)chunkSize*nScklInstnaces; - const ssize_t size = args->coll.count; - const ssize_t sizePerScklChunk = size/nchunksPerLoopPerRank; - // sckl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags - // this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions. - const int workIndex = args->index+1; - volatile struct scklFlag* scklFlags = comm->scklFlags; - // Compute pointers - T * thisInput = (T*)args->sendbuff; - T * thisOutput = (T*)args->recvbuff; - int recvPeer = scklTB->recvpeer; - int sendPeer = scklTB->sendpeer; - - ncclLL128Primitives prims(tid, nthreads, &recvPeer, &sendPeer, stepSize, channel, comm); - for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerScklChunk; gridOffset += loopSize, iter++) { - chunkSize = min(chunkSize, DIVUP(sizePerScklChunk-gridOffset,nScklInstnaces*minChunkSize)*minChunkSize); - ssize_t chunkOffset = gridOffset + scklIndex*chunkSize; - ssize_t srcoffset, dstoffset; - T* srcPointer, * dstPointer; - int nelem = min(chunkSize, sizePerScklChunk-chunkOffset); - for (int i = 0; i < scklTB->nsteps; i++){ - struct scklTransfer* sckltran = &scklTB->transfers[i]; - if (sckltran->type == SCKL_NO_OP) continue; - // first wait if there is a dependence - int8_t dependentBid = sckltran->dependentBid + scklIndex * scklNBlocks; - int8_t dependentStep = sckltran->dependentStep; - if (sckltran->dependentBid >= 0){ - if (tid == sync_tid){ - uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep); - while ((scklFlags + dependentBid)->flag < goalFlag){}; - } - __syncthreads(); - } - - srcPointer = (sckltran->srcbuffer == SCKL_INPUT_BUFFER) ? thisInput : thisOutput; - srcoffset = chunkOffset + (ssize_t) sckltran->srcoffset * sizePerScklChunk; - dstPointer = (sckltran->dstbuffer == SCKL_INPUT_BUFFER) ? thisInput : thisOutput; - dstoffset = chunkOffset + (ssize_t) sckltran->dstoffset * sizePerScklChunk; - switch (sckltran->type) { - case SCKL_SEND: - prims.send(srcPointer + srcoffset, nelem); - break; - case SCKL_RECV: - prims.recv(dstPointer + dstoffset, nelem); - break; - case SCKL_RECV_COPY_SEND: - prims.recvCopySend(dstPointer + dstoffset, nelem); - break; - default: - return; - } - if (tid == sync_tid){ - __threadfence(); - uint64_t curFlag = COMPUTE_FLAG(workIndex, iter, i); - scklFlags[bid].flag = curFlag; - } - } - } - } -}; - - -template -class SCKLFunctionLL { - public: - __device__ void run(struct ncclWorkElem* args) { - struct ncclDevComm* comm = args->comm; - struct scklAlgorithm* scklAlgo = &comm->scklAlgo; - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int sync_tid = args->nThreads-1; // last thread is most likely not doing anthing and used for sckl cross thread synchronization - const int bid = blockIdx.x; - const int scklNBlocks = scklAlgo->nBlocks; - const int rscklbid = bid % scklNBlocks; // bid within a sckl algo - const int scklIndex = bid / scklNBlocks; // which instance of sckl algo - const int nScklInstnaces = gridDim.x / scklAlgo->nBlocks; // number of sckl aglos - struct scklThreadBlock* scklTB = &scklAlgo->scklTB[rscklbid]; - const int channelId = scklIndex * scklAlgo->nChannels + scklTB->channelId; - struct ncclChannel* channel = comm->channels+channelId; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const int nranks = comm->nRanks; - const int nchunksPerLoopPerRank = scklAlgo->nchunksPerLoop/nranks; - const ssize_t loopSize = (ssize_t)chunkSize*nScklInstnaces; - const ssize_t size = args->coll.count; - const ssize_t sizePerScklChunk = size/nchunksPerLoopPerRank; - // sckl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags - // this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions. - const int workIndex = args->index+1; - volatile struct scklFlag* scklFlags = comm->scklFlags; - // Compute pointers - T * thisInput = (T*)args->sendbuff; - T * thisOutput = (T*)args->recvbuff; - int recvPeer = scklTB->recvpeer; - int sendPeer = scklTB->sendpeer; - - ncclLLPrimitives prims(tid, nthreads, &recvPeer, &sendPeer, stepLines, channel, comm); - for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerScklChunk; gridOffset += loopSize, iter++) { - ssize_t chunkOffset = gridOffset + scklIndex*chunkSize; - ssize_t srcoffset, dstoffset; - T* srcPointer, * dstPointer; - int nelem = min(chunkSize, sizePerScklChunk-chunkOffset); - for (int i = 0; i < scklTB->nsteps; i++){ - struct scklTransfer* sckltran = &scklTB->transfers[i]; - if (sckltran->type == SCKL_NO_OP) continue; - // first wait if there is a dependence - int8_t dependentBid = sckltran->dependentBid + scklIndex * scklNBlocks; - int8_t dependentStep = sckltran->dependentStep; - if (sckltran->dependentBid >= 0){ - if (tid == sync_tid){ - uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep); - while ((scklFlags + dependentBid)->flag < goalFlag){}; - } - __syncthreads(); - } - - srcPointer = (sckltran->srcbuffer == SCKL_INPUT_BUFFER) ? thisInput : thisOutput; - srcoffset = chunkOffset + (ssize_t) sckltran->srcoffset * sizePerScklChunk; - dstPointer = (sckltran->dstbuffer == SCKL_INPUT_BUFFER) ? thisInput : thisOutput; - dstoffset = chunkOffset + (ssize_t) sckltran->dstoffset * sizePerScklChunk; - switch (sckltran->type) { - case SCKL_SEND: - prims.send(srcPointer + srcoffset, nelem); - break; - case SCKL_RECV: - prims.recv(dstPointer + dstoffset, nelem); - break; - case SCKL_RECV_COPY_SEND: - prims.recvCopySend(dstPointer + dstoffset, nelem); + prims.recvCopySend(dstPointer + dstoffset, dstoffset); break; default: return; @@ -263,4 +92,118 @@ class SCKLFunctionLL { } } } -}; \ No newline at end of file +}; + +template +struct SimpleWrapper { + const int nthreads; + const int stepSize; + const int chunkSize; + ncclPrimitives prims; + + int nelem; + + __device__ SimpleWrapper(struct ncclWorkElem* args, int tid, int* recvPeer, int* sendPeer, T * thisOutput, struct ncclChannel* channel) + : nthreads(args->nThreads-WARP_SIZE), + stepSize(args->comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS)), + chunkSize(stepSize * SCKL_CHUNKSTEPS), + prims(tid, nthreads, recvPeer, sendPeer, thisOutput, stepSize, channel, args->comm, ncclShmem->ptrs, 0) {} + + __device__ size_t initIter(ssize_t sizePerScklChunk, ssize_t gridOffset, int nScklInstnaces, int scklIndex) { + int realChunkSize = min(chunkSize, DIVUP(sizePerScklChunk-gridOffset,nScklInstnaces)); + ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t chunkOffset = gridOffset + scklIndex*realChunkSize; + nelem = min(realChunkSize, sizePerScklChunk-chunkOffset); + return chunkOffset; + } + + __device__ void send(T * chunkPointer, ssize_t dstoffset) { + prims.directSend(chunkPointer, dstoffset, nelem); + } + + __device__ void recv(T * chunkPointer, ssize_t dstoffset) { + prims.directRecv(chunkPointer, dstoffset, nelem); + } + + __device__ void recvCopySend(T * chunkPointer, ssize_t dstoffset) { + prims.directRecvCopySend(chunkPointer, dstoffset, nelem); + } +}; + +template +class SCKLFunctionSimple : public SCKLFunction> {}; + +#include "prims_ll128.h" +template +struct LL128Wrapper { + const int stepSize; + ssize_t chunkSize; + const ssize_t minChunkSize; + ncclLL128Primitives prims; + + int nelem; + + __device__ LL128Wrapper(struct ncclWorkElem* args, int tid, int* recvPeer, int* sendPeer, T * thisOutput, struct ncclChannel* channel) + : stepSize(args->comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS)), + chunkSize(stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T))), + minChunkSize((NCCL_LL128_SHMEM_ELEMS_PER_THREAD*args->nThreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2), + prims(tid, args->nThreads, recvPeer, sendPeer, stepSize, channel, args->comm) {} + + __device__ size_t initIter(ssize_t sizePerScklChunk, ssize_t gridOffset, int nScklInstnaces, int scklIndex) { + chunkSize = min(chunkSize, DIVUP(sizePerScklChunk-gridOffset,nScklInstnaces*minChunkSize)*minChunkSize); + ssize_t chunkOffset = gridOffset + scklIndex*chunkSize; + nelem = min(chunkSize, sizePerScklChunk-chunkOffset); + return chunkOffset; + } + + __device__ void send(T * chunkPointer, ssize_t dstoffset) { + prims.send(chunkPointer, nelem); + } + + __device__ void recv(T * chunkPointer, ssize_t dstoffset) { + prims.recv(chunkPointer, nelem); + } + + __device__ void recvCopySend(T * chunkPointer, ssize_t dstoffset) { + prims.recvCopySend(chunkPointer, nelem); + } +}; + +template +class SCKLFunctionLL128 : public SCKLFunction> {}; + +template +struct LLWrapper { + const int stepLines; + const ssize_t chunkSize; + ncclLLPrimitives prims; + + int nelem; + + __device__ LLWrapper(struct ncclWorkElem* args, int tid, int* recvPeer, int* sendPeer, T * thisOutput, struct ncclChannel* channel) + : stepLines(args->comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS)), + chunkSize(stepLines * sizeof(uint64_t) / sizeof(T)), + prims(tid, args->nThreads, recvPeer, sendPeer, stepLines, channel, args->comm) {} + + __device__ size_t initIter(ssize_t sizePerScklChunk, ssize_t gridOffset, int nScklInstnaces, int scklIndex) { + ssize_t chunkOffset = gridOffset + scklIndex*chunkSize; + nelem = min(chunkSize, sizePerScklChunk-chunkOffset); + return chunkOffset; + } + + __device__ void send(T * chunkPointer, ssize_t dstoffset) { + prims.send(chunkPointer, nelem); + } + + __device__ void recv(T * chunkPointer, ssize_t dstoffset) { + prims.recv(chunkPointer, nelem); + } + + __device__ void recvCopySend(T * chunkPointer, ssize_t dstoffset) { + prims.recvCopySend(chunkPointer, nelem); + } +}; + +template +class SCKLFunctionLL : public SCKLFunction> {}; +