[CPU] SHM based allreduce improvement for small message size (#5571)

On CPU server, when running SHM based allreduce for small messages, the
performance is pretty much dominated by synchronization latency. These
latency includes the following two situations:
1. Wait for status change from other ranks.
2. Use `#pragma omp parallel for` to accelerator memory bandwidth bound
operations such as `parallel_memcpy` or `reduce`.

Each synchronization add a little time to allreduce latency. In current
implementation, for small messages, 5 syncs on rank 0 are needed. This
includes: 1) copy-in; 2) wait for other ranks done copy; 3) reduce; 4)
copy-out; 5) wait for other ranks finish copy-out

We redesign the algorithm for small message allreduce (called
`symmetric_naive_allreduce`) to have only three syncs, each rank do
exactly the same steps: 1) copy-in; 2) wait for other ranks done copy;
3) reduce to output buffer directly. We use double buffer so we can skip
the last wait and go directly to next call using another buffer. We have
a carefully designed state check to avoid using global barrier among
ranks.

Test shows for message size < 1MB, allreduce latency will reduce 30% to
50%. This is especially helpful for tensor parallel decoding with small
batch size, where the tensor size is usually a few 10s of KBytes.

|message size(bytes)|new method latency(us)|old method latency(us)|
|---|---|---| 
|             2      |    13.34|20.39
|             4      |    13.44|19.57
|             8      |    13.70|19.76
|            16     |     13.27|20.43
|            32     |     13.42|19.75
|            64     |     13.38|19.80
|           128    |      13.70|19.44
|           256    |      13.99|20.33
|           512    |      13.91|20.28
|          1024   |       15.00|22.86
|          2048   |       15.82|20.93
|          4096   |       16.00|21.08
|          8192   |       16.31|21.50
|         16384  |        16.27|22.95
|         32768  |        16.13|25.17
|         65536  |        18.92|25.90
|        131072 |         21.12|27.42
|        262144 |         23.09|32.36
|        524288 |         32.78|42.80

Because the new method would compute same reduce value on all ranks.
Caution needs to be taken to ensure the result is identical on all
ranks. We use the test in the link
https://github.com/delock/ds_allreduce_bench/blob/main/ds_comm_bench.py#L70
to ensure the implementation is correct.
https://github.com/delock/ds_allreduce_bench/blob/main/validate.sh is a
test script for better coverage.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com>
This commit is contained in:
Ma, Guokai 2024-06-13 05:00:20 +08:00 коммит произвёл GitHub
Родитель dfcade2414
Коммит eda5075b88
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 279 добавлений и 265 удалений

Просмотреть файл

@ -21,9 +21,13 @@
// states for collectives // states for collectives
enum coll_state { enum coll_state {
coll_begin = 0, coll_begin = 0,
coll_allreduce_naive__copy_in_done, // this state is for rank != 0 coll_allreduce_naive__copy_in_done,
coll_allreduce_naive__reduce_done, // this state is for rank == 0 coll_allreduce_naive__reduce_done,
coll_allreduce_naive__copy_out_done, // this state is for rank != 0 // alternative state when allreduce is working on alternative buffer
// of the double buffer.
coll_alt1_allreduce_naive__copy_in_done,
coll_alt2_allreduce_naive__copy_in_done,
coll_alt1_allreduce_naive__reduce_done,
}; };
// SHM building blocks // SHM building blocks
@ -71,6 +75,8 @@ void shared_close(SharedData* data)
} }
} }
static int world_size;
// SHM based allreduce helper functions // SHM based allreduce helper functions
// buffer that holds shm name // buffer that holds shm name
#define NAME_BUF_SIZE 1000 #define NAME_BUF_SIZE 1000
@ -78,64 +84,37 @@ void shared_close(SharedData* data)
#define NAIVE_ALLREDUCE_THRESHOLD 1048576 #define NAIVE_ALLREDUCE_THRESHOLD 1048576
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" #define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
struct allreduce_workspace { struct allreduce_workspace {
enum coll_state state; enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
sem_t mutex; // idx=1 -- state for distributed_naive_all_reduce
sem_t turnstile1; // double buffer to avoid syncing between rounds
sem_t turnstile2; // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for symmetric_naive_all_reduce
int counter; // after that : buffer for distributed_naive_all_reduce
char buffer[MAX_BUF_SIZE]; char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
}; };
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
struct allreduce_workspace** workspace; struct allreduce_workspace** workspace;
void wait_buffer_state_until(int index, enum coll_state state) // buffer for small messages, double buffer
{ char** symmetric_buffer[2];
volatile enum coll_state* state_ptr = &(workspace[index]->state); // buffer for large messages, double buffer
char** distributed_buffer[2];
while (*state_ptr != state) void wait_buffer_state_until_2(int index,
; enum coll_state state0,
} enum coll_state state1,
int state_group)
void wait_buffer_state_until_range(int index, enum coll_state start, int size)
{ {
volatile enum coll_state* state_ptr = &(workspace[index]->state); volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
enum coll_state end = (enum coll_state)(start + size);
while (1) { while (1) {
volatile enum coll_state cur_state = *state_ptr; volatile enum coll_state cur_state = *state_ptr;
if (cur_state >= start and cur_state < end) break; if (cur_state == state0 || cur_state == state1) break;
} }
} }
void wait_buffer_state_until_not(int index, enum coll_state state)
{
volatile enum coll_state* state_ptr = &(workspace[index]->state);
while (*state_ptr == state)
;
}
void barrier_wait(int root_idx, int num_ranks)
{
// Phase 1: Wait for all threads to enter the barrier
auto shared = workspace[root_idx];
sem_wait(&shared->mutex);
shared->counter++;
if (shared->counter == num_ranks) {
for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile1); }
}
sem_post(&shared->mutex);
sem_wait(&shared->turnstile1);
// Phase 2: Wait for all threads to exit the barrier
sem_wait(&shared->mutex);
shared->counter--;
if (shared->counter == 0) {
for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile2); }
}
sem_post(&shared->mutex);
sem_wait(&shared->turnstile2);
}
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); __m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_bf16_to_fp32(const __m256i src) inline __m512 cvt_bf16_to_fp32(const __m256i src)
{ {
@ -167,122 +146,52 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src)
void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out) void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
__attribute__((target("avx512bw"))); __attribute__((target("avx512bw")));
void reduce_bf16_buffers(int start_elements, void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
int num_elements,
int num_buffers,
int to_buffer_idx,
struct allreduce_workspace** workspace)
__attribute__((target("avx512bw"))); __attribute__((target("avx512bw")));
void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out) void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
__attribute__((target("avx512bw"))); __attribute__((target("avx512bw")));
void reduce_fp32_buffers(int start_elements, void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
int num_elements,
int num_buffers,
int to_buffer_idx,
struct allreduce_workspace** workspace)
__attribute__((target("avx512bw"))); __attribute__((target("avx512bw")));
// N_REDUCE_LIMIT is the number of buffers that can be reduced together in one shot. void reduce_all_buffers(int start_elements,
// Compared with do N-1 2-reduces which needs 2*(N-1) read and N-1 write,
// N-reduce only needs N read and 1 write, this saves 2/3 memory bandwidth.
// When increase N_REDUCE_LIMIT to a bigger number, do the following steps
// 1. Extend REPEAT_<X> macros list down below
// 2. Extend switch cases which call "REPEAT(X, ...)" down below
#define N_REDUCE_LIMIT 16
void reduce_all_buffers(struct allreduce_workspace** workspace,
int start_elements,
int num_elements, int num_elements,
c10::ScalarType scalar_type, c10::ScalarType scalar_type,
int num_buffers, int to_buffer_idx,
int to_buffer_idx) char* to_buffer,
char** buffers)
{ {
switch (scalar_type) { switch (scalar_type) {
case c10::ScalarType::BFloat16: case c10::ScalarType::BFloat16:
if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { if (world_size == 2) {
reduce_bf16_buffers( // add the other buffer to to_buffer
start_elements, num_elements, num_buffers, to_buffer_idx, workspace); reduce_2_bf16_buffers_iio(num_elements,
buffers[1 - to_buffer_idx] + start_elements * 2,
to_buffer + start_elements * 2,
to_buffer + start_elements * 2);
} else { } else {
for (int i = 0; i < num_buffers; i++) { reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
if (i == to_buffer_idx) continue;
reduce_2_bf16_buffers_iio(
num_elements,
workspace[i]->buffer + start_elements * 2,
workspace[to_buffer_idx]->buffer + start_elements * 2,
workspace[to_buffer_idx]->buffer + start_elements * 2);
}
} }
break; break;
case c10::ScalarType::Float: case c10::ScalarType::Float:
if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { if (world_size == 2) {
reduce_fp32_buffers( reduce_2_fp32_buffers_iio(num_elements,
start_elements, num_elements, num_buffers, to_buffer_idx, workspace); buffers[1 - to_buffer_idx] + start_elements * 4,
to_buffer + start_elements * 4,
to_buffer + start_elements * 4);
} else { } else {
for (int i = 0; i < num_buffers; i++) { assert(world_size > 2);
if (i == to_buffer_idx) continue; reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
reduce_2_fp32_buffers_iio(
num_elements,
workspace[i]->buffer + start_elements * 4,
workspace[to_buffer_idx]->buffer + start_elements * 4,
workspace[to_buffer_idx]->buffer + start_elements * 4);
}
} }
break; break;
default: assert(!"Should not get here"); default: assert(!"Should not get here");
} }
} }
#define REPEAT(N, x) REPEAT_##N(x)
#define REPEAT_1(x) x(1)
#define REPEAT_2(x) \
REPEAT_1(x); \
x(2)
#define REPEAT_3(x) \
REPEAT_2(x); \
x(3)
#define REPEAT_4(x) \
REPEAT_3(x); \
x(4)
#define REPEAT_5(x) \
REPEAT_4(x); \
x(5)
#define REPEAT_6(x) \
REPEAT_5(x); \
x(6)
#define REPEAT_7(x) \
REPEAT_6(x); \
x(7)
#define REPEAT_8(x) \
REPEAT_7(x); \
x(8)
#define REPEAT_9(x) \
REPEAT_8(x); \
x(9)
#define REPEAT_10(x) \
REPEAT_9(x); \
x(10)
#define REPEAT_11(x) \
REPEAT_10(x); \
x(11)
#define REPEAT_12(x) \
REPEAT_11(x); \
x(12)
#define REPEAT_13(x) \
REPEAT_12(x); \
x(13)
#define REPEAT_14(x) \
REPEAT_13(x); \
x(14)
#define REPEAT_15(x) \
REPEAT_14(x); \
x(15)
#define CVT_ADD_BF16(x) \ #define CVT_ADD_BF16(x) \
do { \ do { \
auto in##x##_val = \ auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[x]->buffer + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \ inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0) } while (0)
@ -292,11 +201,7 @@ void reduce_all_buffers(struct allreduce_workspace** workspace,
// whether this number needs to be changed // whether this number needs to be changed
#define VECTOR_LENGTH_IN_BYTES 32 #define VECTOR_LENGTH_IN_BYTES 32
void reduce_bf16_buffers(int start_elements, void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
int num_elements,
int num_buffers,
int to_buffer_idx,
struct allreduce_workspace** workspace)
{ {
const int element_size = 2; const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
@ -307,34 +212,40 @@ void reduce_bf16_buffers(int start_elements,
#pragma omp parallel for #pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) { i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0]->buffer + i))); auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (num_buffers) { switch (world_size) {
case 16: REPEAT(15, CVT_ADD_BF16); break; case 16: CVT_ADD_BF16(15);
case 15: REPEAT(14, CVT_ADD_BF16); break; case 15: CVT_ADD_BF16(14);
case 14: REPEAT(13, CVT_ADD_BF16); break; case 14: CVT_ADD_BF16(13);
case 13: REPEAT(12, CVT_ADD_BF16); break; case 13: CVT_ADD_BF16(12);
case 12: REPEAT(11, CVT_ADD_BF16); break; case 12: CVT_ADD_BF16(11);
case 11: REPEAT(10, CVT_ADD_BF16); break; case 11: CVT_ADD_BF16(10);
case 10: REPEAT(9, CVT_ADD_BF16); break; case 10: CVT_ADD_BF16(9);
case 9: REPEAT(8, CVT_ADD_BF16); break; case 9: CVT_ADD_BF16(8);
case 8: REPEAT(7, CVT_ADD_BF16); break; case 8: CVT_ADD_BF16(7);
case 7: REPEAT(6, CVT_ADD_BF16); break; case 7: CVT_ADD_BF16(6);
case 6: REPEAT(5, CVT_ADD_BF16); break; case 6: CVT_ADD_BF16(5);
case 5: REPEAT(4, CVT_ADD_BF16); break; case 5: CVT_ADD_BF16(4);
case 4: REPEAT(3, CVT_ADD_BF16); break; case 4: CVT_ADD_BF16(3);
case 3: REPEAT(2, CVT_ADD_BF16); break; case 3:
default: assert(!"Should not get here."); CVT_ADD_BF16(2);
CVT_ADD_BF16(1);
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
} }
_mm256_storeu_si256((__m256i*)(workspace[to_buffer_idx]->buffer + i), }
cvt_fp32_to_bf16(inout_val)); _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
} }
// process remaining part // process remaining part
int i = (start_elements + main_elements) * element_size; int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) { while (remain_elements > 0) {
float val = 0.0f; float val = 0.0f;
for (int j = 0; j < num_buffers; j++) { val += *(at::BFloat16*)(workspace[j]->buffer + i); } for (int j = 0; j < world_size; j++) { val += *(at::BFloat16*)(buffers[j] + i); }
*(at::BFloat16*)(workspace[to_buffer_idx]->buffer + i) = val; *(at::BFloat16*)(to_buffer + i) = val;
remain_elements--; remain_elements--;
i += element_size; i += element_size;
} }
@ -369,15 +280,11 @@ void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out
#define CVT_ADD_F32(x) \ #define CVT_ADD_F32(x) \
do { \ do { \
auto in##x##_val = _mm256_loadu_ps((float*)(workspace[x]->buffer + i)); \ auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
inout_val = _mm256_add_ps(inout_val, in##x##_val); \ inout_val = _mm256_add_ps(inout_val, in##x##_val); \
} while (0) } while (0)
void reduce_fp32_buffers(int start_elements, void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
int num_elements,
int num_buffers,
int to_buffer_idx,
struct allreduce_workspace** workspace)
{ {
const int element_size = 4; const int element_size = 4;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
@ -388,33 +295,40 @@ void reduce_fp32_buffers(int start_elements,
#pragma omp parallel for #pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) { i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = _mm256_loadu_ps((float*)(workspace[0]->buffer + i)); auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
switch (num_buffers) { switch (world_size) {
case 16: REPEAT(15, CVT_ADD_F32); break; case 16: CVT_ADD_F32(15);
case 15: REPEAT(14, CVT_ADD_F32); break; case 15: CVT_ADD_F32(14);
case 14: REPEAT(13, CVT_ADD_F32); break; case 14: CVT_ADD_F32(13);
case 13: REPEAT(12, CVT_ADD_F32); break; case 13: CVT_ADD_F32(12);
case 12: REPEAT(11, CVT_ADD_F32); break; case 12: CVT_ADD_F32(11);
case 11: REPEAT(10, CVT_ADD_F32); break; case 11: CVT_ADD_F32(10);
case 10: REPEAT(9, CVT_ADD_F32); break; case 10: CVT_ADD_F32(9);
case 9: REPEAT(8, CVT_ADD_F32); break; case 9: CVT_ADD_F32(8);
case 8: REPEAT(7, CVT_ADD_F32); break; case 8: CVT_ADD_F32(7);
case 7: REPEAT(6, CVT_ADD_F32); break; case 7: CVT_ADD_F32(6);
case 6: REPEAT(5, CVT_ADD_F32); break; case 6: CVT_ADD_F32(5);
case 5: REPEAT(4, CVT_ADD_F32); break; case 5: CVT_ADD_F32(4);
case 4: REPEAT(3, CVT_ADD_F32); break; case 4: CVT_ADD_F32(3);
case 3: REPEAT(2, CVT_ADD_F32); break; case 3:
default: assert(!"Should not get here."); CVT_ADD_F32(2);
CVT_ADD_F32(1);
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
inout_val = _mm256_add_ps(inout_val, in_val);
} }
_mm256_storeu_ps((float*)(workspace[to_buffer_idx]->buffer + i), inout_val); }
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
} }
// process remaining part // process remaining part
int i = (start_elements + main_elements) * element_size; int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) { while (remain_elements > 0) {
float val = 0.0f; float val = 0.0f;
for (int j = 0; j < num_buffers; j++) { val += *(float*)(workspace[j]->buffer + i); } for (int j = 0; j < world_size; j++) { val += *(float*)(buffers[j] + i); }
*(float*)(workspace[to_buffer_idx]->buffer + i) = val; *(float*)(to_buffer + i) = val;
remain_elements--; remain_elements--;
i += element_size; i += element_size;
} }
@ -448,7 +362,6 @@ void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out
} }
static bool is_initialized = 0; static bool is_initialized = 0;
static int world_size;
static int world_rank; static int world_rank;
void shm_initialize(int size, int rank, char* addr_string, char* port_string) void shm_initialize(int size, int rank, char* addr_string, char* port_string)
@ -477,10 +390,15 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string)
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank);
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace_buf->state = coll_begin; workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
workspace_buf->states[1] = coll_begin;
// create the workspace pointer list // create the workspace pointer list
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*)); workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
// map shm of all ranks // map shm of all ranks
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
@ -494,11 +412,11 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string)
workspace[i] = workspace_buf_other; workspace[i] = workspace_buf_other;
} else { } else {
workspace[i] = workspace_buf; workspace[i] = workspace_buf;
workspace_buf->counter = 0;
sem_init(&workspace_buf->mutex, 1, 1);
sem_init(&workspace_buf->turnstile1, 1, 0);
sem_init(&workspace_buf->turnstile2, 1, 0);
} }
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
} }
} }
@ -539,46 +457,122 @@ size_t slice_el_start(size_t chunk_el, int slice_idx)
return slice_size * slice_idx; return slice_size * slice_idx;
} }
void naive_all_reduce(char* data_ptr, /*
Symmetrical naive all_reduce
step 0: before enter the function ith times, state is copy(i-1)
step 1: each rank copy data from input (data_ptr) to SHM buffer[i]
step 2: set own state to copy(i)
step 3: wait each other rank's state equal or later than copy(i)
step 4: reduce across SHM buffer(ith) directly into output (data_ptr)
*/
void symmetric_naive_all_reduce(char* data_ptr,
c10::ScalarType scalar_type, c10::ScalarType scalar_type,
size_t chunk_size, size_t chunk_size,
size_t chunk_el) size_t chunk_el)
{ {
parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); #ifdef DO_PROFILE
std::atomic_thread_fence(std::memory_order_release); static double total_t1_t0 = 0.0;
workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; static double total_t2_t1 = 0.0;
static double total_t3_t2 = 0.0;
static int count = -16; // warmup
auto t0 = std::chrono::system_clock::now();
#endif
if (world_rank == 0) { /*
// compute allreduce result on rank 0 We can't have infinite number of buffers and states. 2 sets of buffer
for (int i = 1; i < world_size; i++) { and 3 sets of states is just enough. Consider current rank is in step 3,
with it's own state set to copy(i), the other rank will them have the
following situations:
------------------------------------------------
my state | can I proceed? | the other rank state
================================================
| N | copy(i-1)
|----------------|---------------------
copy(i) | Y | copy(i)
|----------------|---------------------
| Y | copy(i+1)
------------------------------------------------
* When I have state as copy(i), the other rank cannot have state
copy(i-2) or before. In that case I'll be in state copy(i-1) and cannot
proceed to copy(i).
* The other rank cannot have state copy(i+2) or beyond because my
state is still copy(i), copy(i+1) is as far as the other rank could go.
* From a rank's POV, all the other ranks can be divided into three sets:
- Lagging ranks: ranks that are still working on previous iteration
- Syncing ranks: ranks that are working on current iteration
- Leading ranks: ranks that are working on next iteration
* We can have 3 sets of states, one set for syncing ranks; one set for
lagging ranks; one set of leading ranks. With 3 sets of states, we can
distinguish between lagging and leading ranks.
* Note from any rank's POV, leading ranks and lagging ranks does not
appear at the same time. Either all other ranks are syncing or
lagging, or all other ranks are syncing or leading. Otherwise leading
and lagging ranks will be 2 iterations apart and this should not happen.
* So we have 2 sets of buffers, one buffer is used by current iter;
one buffer used by either lagging ranks or leading ranks.
*/
const int state_group = 0;
static int current_buffer = 0;
static int state_idx = 0;
enum coll_state copy_current, copy_next;
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
copy_next = coll_alt2_allreduce_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allreduce_naive__copy_in_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default: assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
#ifdef DO_PROFILE
auto t1 = std::chrono::system_clock::now();
#endif
for (int i = 0; i < world_size; i++) {
// wait until the other rank copy the buffer // wait until the other rank copy the buffer
wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done); if (i != world_rank) { wait_buffer_state_until_2(i, copy_current, copy_next, state_group); }
} }
reduce_all_buffers(workspace, 0, chunk_el, scalar_type, world_size, 0); #ifdef DO_PROFILE
std::atomic_thread_fence(std::memory_order_release); auto t2 = std::chrono::system_clock::now();
workspace[world_rank]->state = coll_allreduce_naive__reduce_done; #endif
parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size);
// each rank reduce the buffer independently so therre is no need for synchronization afterward
reduce_all_buffers(
0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
// switch buffer
current_buffer = 1 - current_buffer;
#ifdef DO_PROFILE
auto t3 = std::chrono::system_clock::now();
count++;
if (count > 0) {
total_t1_t0 += std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count();
total_t2_t1 += std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
total_t3_t2 += std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count();
if (world_rank == 0 && count == 1000) {
printf("symmetric_naive_all_reduce time breakdown:\n");
printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count);
printf("\twait for copy: %.2f\n", total_t2_t1 / count);
printf("\treduce: %.2f\n", total_t3_t2 / count);
} }
if (world_rank != 0) {
wait_buffer_state_until(0, coll_allreduce_naive__reduce_done);
parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->state = coll_allreduce_naive__copy_out_done;
}
if (world_rank == 0) {
for (int i = 1; i < world_size; i++) {
wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done);
}
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->state = coll_begin;
}
if (world_rank != 0) {
// if rank 0 spin too fast it could be in state 1 of next allreduce
// in this case wait_buffer_state_until(0, 0) may cause deadlock
// what we are certain is when rank 0 finishes the state won't be 2
wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done);
workspace[world_rank]->state = coll_begin;
} }
#endif
} }
// naive allreduce distributed, each rank do naive reduce on its slice // naive allreduce distributed, each rank do naive reduce on its slice
@ -597,10 +591,33 @@ void distributed_naive_reduce(char* data_ptr,
auto t0 = std::chrono::system_clock::now(); auto t0 = std::chrono::system_clock::now();
#endif #endif
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
enum coll_state copy_current, copy_next, reduce_current;
// similar to symmetric_naive_allreduce, but here we only need two sets of
// states, because distributed naive reduce has two barriers in the algorithm
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
reduce_current = coll_allreduce_naive__reduce_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
reduce_current = coll_alt1_allreduce_naive__reduce_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default: assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 2;
int data_size = chunk_size / chunk_el; int data_size = chunk_size / chunk_el;
parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release); std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; workspace[world_rank]->states[state_group] = copy_current;
#ifdef DO_PROFILE #ifdef DO_PROFILE
auto t1 = std::chrono::system_clock::now(); auto t1 = std::chrono::system_clock::now();
@ -608,7 +625,8 @@ void distributed_naive_reduce(char* data_ptr,
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer // wait until all the other ranks copy the buffer
wait_buffer_state_until_range(i, coll_allreduce_naive__copy_in_done, 2); if (i != world_rank)
wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
} }
#ifdef DO_PROFILE #ifdef DO_PROFILE
@ -616,40 +634,36 @@ void distributed_naive_reduce(char* data_ptr,
#endif #endif
// reduce scatter // reduce scatter
reduce_all_buffers(workspace, reduce_all_buffers(slice_el_start(chunk_el, world_rank),
slice_el_start(chunk_el, world_rank),
slice_size(chunk_el, world_rank), slice_size(chunk_el, world_rank),
scalar_type, scalar_type,
world_size, world_rank,
world_rank); distributed_buffer[current_buffer][world_rank],
distributed_buffer[current_buffer]);
std::atomic_thread_fence(std::memory_order_release); std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->state = coll_allreduce_naive__reduce_done; workspace[world_rank]->states[state_group] = reduce_current;
#ifdef DO_PROFILE #ifdef DO_PROFILE
auto t3 = std::chrono::system_clock::now(); auto t3 = std::chrono::system_clock::now();
#endif #endif
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
int rank = (i + world_rank) % world_size; // wait until all the other ranks reduce the buffer
// wait until the other rank reduce the buffer if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
wait_buffer_state_until_range(rank, coll_allreduce_naive__reduce_done, 2);
parallel_memcpy(slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(workspace[rank]->buffer, chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
} }
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->state = coll_allreduce_naive__copy_out_done;
#ifdef DO_PROFILE
auto t4 = std::chrono::system_clock::now(); auto t4 = std::chrono::system_clock::now();
#endif
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
wait_buffer_state_until_not(i, coll_allreduce_naive__reduce_done); int rank = (i + world_rank) % world_size;
parallel_memcpy(
slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(
distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
} }
std::atomic_thread_fence(std::memory_order_release); current_buffer = 1 - current_buffer;
workspace[world_rank]->state = coll_begin;
#ifdef DO_PROFILE #ifdef DO_PROFILE
auto t5 = std::chrono::system_clock::now(); auto t5 = std::chrono::system_clock::now();
@ -665,8 +679,8 @@ void distributed_naive_reduce(char* data_ptr,
printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count);
printf("\twait for copy: %.2f\n", total_t2_t1 / count); printf("\twait for copy: %.2f\n", total_t2_t1 / count);
printf("\treduce: %.2f\n", total_t3_t2 / count); printf("\treduce: %.2f\n", total_t3_t2 / count);
printf("\tcopy buffer to output: %.2f\n", total_t4_t3 / count); printf("\twait for reduce finish: %.2f\n", total_t4_t3 / count);
printf("\twait finish: %.2f\n", total_t5_t4 / count); printf("\tcopy out: %.2f\n", total_t5_t4 / count);
} }
} }
#endif #endif
@ -679,7 +693,7 @@ void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size)
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
size_t chunk_el = chunk_size / (data_size / numel); size_t chunk_el = chunk_size / (data_size / numel);
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD)
naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
else else
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
} }