This commit is contained in:
Guolin Ke 2016-04-22 10:49:18 +08:00
Родитель c81bd1c176
Коммит 05b80d28ea
5 изменённых файлов: 56 добавлений и 34 удалений

Просмотреть файл

@ -44,12 +44,13 @@ public:
// 3. < 0 net error
virtual size_t Recv(MessagePtr* msg) = 0;
// Non-blocking send raw data to rank
virtual void SendTo(int rank, const char* buf, int len) = 0;
// wait for last SendTo success
virtual bool WaitLastSend() = 0;
// Blocking receive raw data from rank
virtual void RecvFrom(int rank, char* buf, int len) = 0;
// Blocking, send raw data to rank
virtual void SendTo(int rank, const char* buf, int len) const = 0;
// Blocking, receive raw data from rank
virtual void RecvFrom(int rank, char* buf, int len) const = 0;
// Blocking, send and recv at same time
virtual void SendRecv(int send_rank, const char* send_buf, int send_len,
int recv_rank, char* recv_buf, int recv_len) const = 0;
virtual int thread_level_support() = 0;
};

Просмотреть файл

@ -86,7 +86,7 @@ public:
* \brief Initial
* \param linkers, the low-level communication methods
*/
void Init(NetInterface* linkers);
void Init(const NetInterface* linkers);
~AllreduceEngine();
/*! \brief Get rank of this machine */
@ -152,7 +152,7 @@ private:
/*! \brief Rank of local machine */
int rank_;
/*! \brief The network interface, provide send/recv functions */
NetInterface* linkers_;
const NetInterface* linkers_;
/*! \brief Bruck map for all gather algorithm*/
BruckMap bruck_map_;
/*! \brief Recursive halving map for reduce scatter */

Просмотреть файл

@ -202,20 +202,17 @@ public:
return RecvAndDeserialize(status.MPI_SOURCE, count, msg);
}
void SendTo(int rank, const char* buf, int len) override {
void SendTo(int rank, const char* buf, int len) const override {
if (len <= 0) {
return;
}
MV_MPI_CALL(MPI_Isend(buf, len, MPI_BYTE, rank, MPI_ANY_TAG, MPI_COMM_WORLD, &last_send_request_));
}
bool WaitLastSend() override {
MPI_Request send_request;
MPI_Status status;
MV_MPI_CALL(MPI_Wait(&last_send_request_, &status));
return true;
MV_MPI_CALL(MPI_Isend(buf, len, MPI_BYTE, rank, MPI_ANY_TAG, MPI_COMM_WORLD, &send_request));
MV_MPI_CALL(MPI_Wait(&send_request, &status));
}
void RecvFrom(int rank, char* buf, int len) override {
void RecvFrom(int rank, char* buf, int len) const override {
MPI_Status status;
int read_cnt = 0;
while (read_cnt < len) {
@ -226,6 +223,24 @@ public:
}
}
void SendRecv(int send_rank, const char* send_data, int send_len,
int recv_rank, char* recv_data, int recv_len) const {
MPI_Request send_request;
// send first, non-blocking
MV_MPI_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &send_request));
// then receive, blocking
MPI_Status status;
int read_cnt = 0;
while (read_cnt < recv_len) {
MV_MPI_CALL(MPI_Recv(recv_data + read_cnt, recv_len - read_cnt, MPI_BYTE, recv_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status));
int cur_cnt;
MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
read_cnt += cur_cnt;
}
// wait for send complete
MV_MPI_CALL(MPI_Wait(&send_request, &status));
}
size_t SerializeAndSend(MessagePtr& msg, MPIMsgHandle* msg_handle) {
CHECK_NOTNULL(msg_handle);
@ -361,7 +376,6 @@ private:
size_t send_size_;
char* recv_buffer_;
size_t recv_size_;
MPI_Request last_send_request_;
};
}

Просмотреть файл

@ -6,6 +6,7 @@
#include "multiverso/net.h"
#include <limits>
#include <thread>
#include "multiverso/message.h"
#include "multiverso/util/log.h"
@ -169,7 +170,7 @@ public:
}
void SendTo(int rank, const char* buf, int len) override {
void SendTo(int rank, const char* buf, int len) const override {
int send_size = 0;
while (send_size < len) {
int cur_size = zmq_send(senders_[rank], buf + send_size, len - send_size, 0);
@ -178,12 +179,7 @@ public:
}
}
bool WaitLastSend() override {
// not need to wait in ZMQ
return true;
}
void RecvFrom(int, char* buf, int len) override {
void RecvFrom(int, char* buf, int len) const override {
// note: rank is not used here
int recv_size = 0;
while (recv_size < len) {
@ -193,6 +189,20 @@ public:
}
}
void SendRecv(int send_rank, const char* send_buf, int send_len,
int recv_rank, char* recv_buf, int recv_len) const override {
// send first
std::thread send_worker(
[this, send_rank, send_buf, send_len] {
SendTo(SendTo, send_buf, send_len);
}
);
// then recv
RecvFrom(recv_rank, recv_buf, recv_len);
// wait for send complete
send_worker.join();
}
int thread_level_support() override {
return NetThreadLevel::THREAD_MULTIPLE;
}

Просмотреть файл

@ -2,7 +2,6 @@
#include <algorithm>
#include "multiverso/net/allreduce_engine.h"
#include "multiverso/net/net_allreduce.h"
namespace multiverso {
@ -11,7 +10,7 @@ AllreduceEngine::AllreduceEngine()
}
void AllreduceEngine::Init(NetInterface* linkers) {
void AllreduceEngine::Init(const NetInterface* linkers) {
linkers_ = linkers;
rank_ = linkers_->rank();
num_machines_ = linkers_->size();
@ -93,21 +92,20 @@ void AllreduceEngine::Allgather(char* input, int all_size, int* block_start, int
write_ptr += block_len[rank_];
int accumulated_block = 1;
for (int i = 0; i < bruck_map_.k; ++i) {
//send
int cur_block_size = (1 << i) < num_machines_ - accumulated_block ? (1 << i) : num_machines_ - accumulated_block;
int target = bruck_map_.out_ranks[i];
int send_len = 0;
for (int j = 0; j < cur_block_size; ++j) {
send_len += block_len[(rank_ + j) % num_machines_];
}
linkers_->SendTo(target, output, send_len);
//rec
int incoming = bruck_map_.in_ranks[i];
int need_recv_cnt = 0;
for (int j = 0; j < cur_block_size; ++j) {
need_recv_cnt += block_len[(rank_ + accumulated_block + j) % num_machines_];
}
linkers_->RecvFrom(incoming, output + write_ptr, need_recv_cnt);
linkers_->SendRecv(target, output, send_len, incoming, output + write_ptr, need_recv_cnt);
write_ptr += need_recv_cnt;
accumulated_block += cur_block_size;
}
@ -139,18 +137,17 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size,
int target = recursive_halving_map_.ranks[i];
int send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i];
//send
int send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += block_len[send_block_start + j];
}
linkers_->SendTo(target, input + block_start[send_block_start], send_size);
//receive
int need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += block_len[recv_block_start + j];
}
linkers_->RecvFrom(target, output, need_recv_cnt);
linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
//reduce
reducer(output, input + block_start[recv_block_start], need_recv_cnt);
}