[CCLBackend] Using parallel memcpy for inference_all_reduce (#4404)

* use parallel version of memcpy

* include max buf size to 16MB per rank

* support any input buffer size

* fix format error
This commit is contained in:
Ma, Guokai 2023-10-03 21:38:20 +08:00 коммит произвёл GitHub
Родитель 1760627eb9
Коммит 9a55291452
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 48 добавлений и 34 удалений

Просмотреть файл

@ -499,6 +499,17 @@ void all_reduce_caching(torch::Tensor& data,
.wait());
}
static void parallel_memcpy(void* to, void* from, size_t n_bytes)
__attribute__((target("avx512bw")));
static void parallel_memcpy(void* to, void* from, size_t n_bytes)
{
#pragma omp parallel for
for (int i = 0; i < n_bytes; i += VECTOR_LENGTH_IN_BYTES) {
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
}
}
void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
@ -517,8 +528,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
default: data_type_fallback = true;
}
if (data_size > MAX_BUF_SIZE || data_type_fallback ||
(data_size % VECTOR_LENGTH_IN_BYTES) != 0 || !all_ranks_local_p) {
if (data_type_fallback || (data_size % VECTOR_LENGTH_IN_BYTES) != 0 || !all_ranks_local_p) {
// fallback to oneccl allreduce
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
@ -530,42 +540,46 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
return;
}
auto data_ptr = data.data_ptr();
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data.data_ptr()) + offset);
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
size_t chunk_el = chunk_size / (data_size / numel);
memcpy(workspace[world_rank].buffer, data_ptr, data_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__copy_in_done;
parallel_memcpy(workspace[world_rank].buffer, data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__copy_in_done;
if (world_rank == 0) {
// compute allreduce result on rank 0
for (int i = 1; i < world_size; i++) {
// wait until the other rank copy the buffer
wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done);
if (world_rank == 0) {
// compute allreduce result on rank 0
for (int i = 1; i < world_size; i++) {
// wait until the other rank copy the buffer
wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done);
}
reduce_all_buffers(workspace, chunk_el, data.scalar_type(), world_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__reduce_done;
parallel_memcpy(data_ptr, workspace[0].buffer, chunk_size);
}
reduce_all_buffers(workspace, numel, data.scalar_type(), world_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__reduce_done;
memcpy(data_ptr, workspace[0].buffer, data_size);
}
if (world_rank != 0) {
wait_buffer_state_until(0, coll_allreduce_naive__reduce_done);
memcpy(data_ptr, workspace[0].buffer, data_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__copy_out_done;
}
if (world_rank == 0) {
for (int i = 1; i < world_size; i++) {
wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done);
if (world_rank != 0) {
wait_buffer_state_until(0, coll_allreduce_naive__reduce_done);
parallel_memcpy(data_ptr, workspace[0].buffer, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_allreduce_naive__copy_out_done;
}
if (world_rank == 0) {
for (int i = 1; i < world_size; i++) {
wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done);
}
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_begin;
}
if (world_rank != 0) {
// if rank 0 spin too fast it could be in state 1 of next allreduce
// in this case wait_buffer_state_until(0, 0) may cause deadlock
// what we are certain is when rank 0 finishes the state won't be 2
wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done);
workspace[world_rank].state = coll_begin;
}
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = coll_begin;
}
if (world_rank != 0) {
// if rank 0 spin too fast it could be in state 1 of next allreduce
// in this case wait_buffer_state_until(0, 0) may cause deadlock
// what we are certain is when rank 0 finishes the state won't be 2
wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done);
workspace[world_rank].state = coll_begin;
}
}