clear logic for raw send/recv
This commit is contained in:
Родитель
c81bd1c176
Коммит
05b80d28ea
|
@ -44,12 +44,13 @@ public:
|
|||
// 3. < 0 net error
|
||||
virtual size_t Recv(MessagePtr* msg) = 0;
|
||||
|
||||
// Non-blocking send raw data to rank
|
||||
virtual void SendTo(int rank, const char* buf, int len) = 0;
|
||||
// wait for last SendTo success
|
||||
virtual bool WaitLastSend() = 0;
|
||||
// Blocking receive raw data from rank
|
||||
virtual void RecvFrom(int rank, char* buf, int len) = 0;
|
||||
// Blocking, send raw data to rank
|
||||
virtual void SendTo(int rank, const char* buf, int len) const = 0;
|
||||
// Blocking, receive raw data from rank
|
||||
virtual void RecvFrom(int rank, char* buf, int len) const = 0;
|
||||
// Blocking, send and recv at same time
|
||||
virtual void SendRecv(int send_rank, const char* send_buf, int send_len,
|
||||
int recv_rank, char* recv_buf, int recv_len) const = 0;
|
||||
|
||||
virtual int thread_level_support() = 0;
|
||||
};
|
||||
|
|
|
@ -86,7 +86,7 @@ public:
|
|||
* \brief Initial
|
||||
* \param linkers, the low-level communication methods
|
||||
*/
|
||||
void Init(NetInterface* linkers);
|
||||
void Init(const NetInterface* linkers);
|
||||
|
||||
~AllreduceEngine();
|
||||
/*! \brief Get rank of this machine */
|
||||
|
@ -152,7 +152,7 @@ private:
|
|||
/*! \brief Rank of local machine */
|
||||
int rank_;
|
||||
/*! \brief The network interface, provide send/recv functions */
|
||||
NetInterface* linkers_;
|
||||
const NetInterface* linkers_;
|
||||
/*! \brief Bruck map for all gather algorithm*/
|
||||
BruckMap bruck_map_;
|
||||
/*! \brief Recursive halving map for reduce scatter */
|
||||
|
|
|
@ -202,20 +202,17 @@ public:
|
|||
return RecvAndDeserialize(status.MPI_SOURCE, count, msg);
|
||||
}
|
||||
|
||||
void SendTo(int rank, const char* buf, int len) override {
|
||||
void SendTo(int rank, const char* buf, int len) const override {
|
||||
if (len <= 0) {
|
||||
return;
|
||||
}
|
||||
MV_MPI_CALL(MPI_Isend(buf, len, MPI_BYTE, rank, MPI_ANY_TAG, MPI_COMM_WORLD, &last_send_request_));
|
||||
}
|
||||
|
||||
bool WaitLastSend() override {
|
||||
MPI_Request send_request;
|
||||
MPI_Status status;
|
||||
MV_MPI_CALL(MPI_Wait(&last_send_request_, &status));
|
||||
return true;
|
||||
MV_MPI_CALL(MPI_Isend(buf, len, MPI_BYTE, rank, MPI_ANY_TAG, MPI_COMM_WORLD, &send_request));
|
||||
MV_MPI_CALL(MPI_Wait(&send_request, &status));
|
||||
}
|
||||
|
||||
void RecvFrom(int rank, char* buf, int len) override {
|
||||
void RecvFrom(int rank, char* buf, int len) const override {
|
||||
MPI_Status status;
|
||||
int read_cnt = 0;
|
||||
while (read_cnt < len) {
|
||||
|
@ -226,6 +223,24 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void SendRecv(int send_rank, const char* send_data, int send_len,
|
||||
int recv_rank, char* recv_data, int recv_len) const {
|
||||
MPI_Request send_request;
|
||||
// send first, non-blocking
|
||||
MV_MPI_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &send_request));
|
||||
// then receive, blocking
|
||||
MPI_Status status;
|
||||
int read_cnt = 0;
|
||||
while (read_cnt < recv_len) {
|
||||
MV_MPI_CALL(MPI_Recv(recv_data + read_cnt, recv_len - read_cnt, MPI_BYTE, recv_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status));
|
||||
int cur_cnt;
|
||||
MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
|
||||
read_cnt += cur_cnt;
|
||||
}
|
||||
// wait for send complete
|
||||
MV_MPI_CALL(MPI_Wait(&send_request, &status));
|
||||
}
|
||||
|
||||
size_t SerializeAndSend(MessagePtr& msg, MPIMsgHandle* msg_handle) {
|
||||
|
||||
CHECK_NOTNULL(msg_handle);
|
||||
|
@ -361,7 +376,6 @@ private:
|
|||
size_t send_size_;
|
||||
char* recv_buffer_;
|
||||
size_t recv_size_;
|
||||
MPI_Request last_send_request_;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "multiverso/net.h"
|
||||
|
||||
#include <limits>
|
||||
#include <thread>
|
||||
|
||||
#include "multiverso/message.h"
|
||||
#include "multiverso/util/log.h"
|
||||
|
@ -169,7 +170,7 @@ public:
|
|||
}
|
||||
|
||||
|
||||
void SendTo(int rank, const char* buf, int len) override {
|
||||
void SendTo(int rank, const char* buf, int len) const override {
|
||||
int send_size = 0;
|
||||
while (send_size < len) {
|
||||
int cur_size = zmq_send(senders_[rank], buf + send_size, len - send_size, 0);
|
||||
|
@ -178,12 +179,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
bool WaitLastSend() override {
|
||||
// not need to wait in ZMQ
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecvFrom(int, char* buf, int len) override {
|
||||
void RecvFrom(int, char* buf, int len) const override {
|
||||
// note: rank is not used here
|
||||
int recv_size = 0;
|
||||
while (recv_size < len) {
|
||||
|
@ -193,6 +189,20 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void SendRecv(int send_rank, const char* send_buf, int send_len,
|
||||
int recv_rank, char* recv_buf, int recv_len) const override {
|
||||
// send first
|
||||
std::thread send_worker(
|
||||
[this, send_rank, send_buf, send_len] {
|
||||
SendTo(SendTo, send_buf, send_len);
|
||||
}
|
||||
);
|
||||
// then recv
|
||||
RecvFrom(recv_rank, recv_buf, recv_len);
|
||||
// wait for send complete
|
||||
send_worker.join();
|
||||
}
|
||||
|
||||
int thread_level_support() override {
|
||||
return NetThreadLevel::THREAD_MULTIPLE;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "multiverso/net/allreduce_engine.h"
|
||||
#include "multiverso/net/net_allreduce.h"
|
||||
|
||||
namespace multiverso {
|
||||
|
||||
|
@ -11,7 +10,7 @@ AllreduceEngine::AllreduceEngine()
|
|||
|
||||
}
|
||||
|
||||
void AllreduceEngine::Init(NetInterface* linkers) {
|
||||
void AllreduceEngine::Init(const NetInterface* linkers) {
|
||||
linkers_ = linkers;
|
||||
rank_ = linkers_->rank();
|
||||
num_machines_ = linkers_->size();
|
||||
|
@ -93,21 +92,20 @@ void AllreduceEngine::Allgather(char* input, int all_size, int* block_start, int
|
|||
write_ptr += block_len[rank_];
|
||||
int accumulated_block = 1;
|
||||
for (int i = 0; i < bruck_map_.k; ++i) {
|
||||
//send
|
||||
int cur_block_size = (1 << i) < num_machines_ - accumulated_block ? (1 << i) : num_machines_ - accumulated_block;
|
||||
int target = bruck_map_.out_ranks[i];
|
||||
int send_len = 0;
|
||||
for (int j = 0; j < cur_block_size; ++j) {
|
||||
send_len += block_len[(rank_ + j) % num_machines_];
|
||||
}
|
||||
linkers_->SendTo(target, output, send_len);
|
||||
//rec
|
||||
|
||||
int incoming = bruck_map_.in_ranks[i];
|
||||
int need_recv_cnt = 0;
|
||||
for (int j = 0; j < cur_block_size; ++j) {
|
||||
need_recv_cnt += block_len[(rank_ + accumulated_block + j) % num_machines_];
|
||||
}
|
||||
linkers_->RecvFrom(incoming, output + write_ptr, need_recv_cnt);
|
||||
|
||||
linkers_->SendRecv(target, output, send_len, incoming, output + write_ptr, need_recv_cnt);
|
||||
write_ptr += need_recv_cnt;
|
||||
accumulated_block += cur_block_size;
|
||||
}
|
||||
|
@ -139,18 +137,17 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size,
|
|||
int target = recursive_halving_map_.ranks[i];
|
||||
int send_block_start = recursive_halving_map_.send_block_start[i];
|
||||
int recv_block_start = recursive_halving_map_.recv_block_start[i];
|
||||
//send
|
||||
int send_size = 0;
|
||||
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
|
||||
send_size += block_len[send_block_start + j];
|
||||
}
|
||||
linkers_->SendTo(target, input + block_start[send_block_start], send_size);
|
||||
//receive
|
||||
|
||||
int need_recv_cnt = 0;
|
||||
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
|
||||
need_recv_cnt += block_len[recv_block_start + j];
|
||||
}
|
||||
linkers_->RecvFrom(target, output, need_recv_cnt);
|
||||
|
||||
linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
|
||||
//reduce
|
||||
reducer(output, input + block_start[recv_block_start], need_recv_cnt);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче