change mpi send and recv to debug
This commit is contained in:
Родитель
45d8b53053
Коммит
f2e452b24b
|
@ -222,6 +222,9 @@ void TestNet(int argc, char* argv[]) {
|
||||||
msg->Push(Blob(hi1, 13));
|
msg->Push(Blob(hi1, 13));
|
||||||
msg->Push(Blob(hi2, 11));
|
msg->Push(Blob(hi2, 11));
|
||||||
msg->Push(Blob(hi3, 18));
|
msg->Push(Blob(hi3, 18));
|
||||||
|
for (int i = 0; i < msg->size(); ++i) {
|
||||||
|
Log::Info("In Send: %s\n", msg->data()[i].data());
|
||||||
|
};
|
||||||
while (net->Send(msg) == 0) ;
|
while (net->Send(msg) == 0) ;
|
||||||
Log::Info("rank 0 send\n");
|
Log::Info("rank 0 send\n");
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ public:
|
||||||
size_t size() const { return size_; }
|
size_t size() const { return size_; }
|
||||||
|
|
||||||
void Wait() {
|
void Wait() {
|
||||||
CHECK_NOTNULL(msg_.get());
|
// CHECK_NOTNULL(msg_.get());
|
||||||
int count = static_cast<int>(handles_.size());
|
int count = static_cast<int>(handles_.size());
|
||||||
MPI_Status* status = new MPI_Status[count];
|
MPI_Status* status = new MPI_Status[count];
|
||||||
MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status));
|
MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status));
|
||||||
|
@ -48,7 +48,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
int Test() {
|
int Test() {
|
||||||
CHECK_NOTNULL(msg_.get());
|
// CHECK_NOTNULL(msg_.get());
|
||||||
int count = static_cast<int>(handles_.size());
|
int count = static_cast<int>(handles_.size());
|
||||||
MPI_Status* status = new MPI_Status[count];
|
MPI_Status* status = new MPI_Status[count];
|
||||||
int flag;
|
int flag;
|
||||||
|
@ -121,6 +121,29 @@ public:
|
||||||
// return size;
|
// return size;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
//size_t Send(MessagePtr& msg) override {
|
||||||
|
// if (msg.get()) { send_queue_.Push(msg); }
|
||||||
|
//
|
||||||
|
// if (last_handle_.get() != nullptr && !last_handle_->Test()) {
|
||||||
|
// // Last msg is still on the air
|
||||||
|
// return 0;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // send over, free the last msg
|
||||||
|
// last_handle_.reset();
|
||||||
|
|
||||||
|
// // if there is more msg to send
|
||||||
|
// if (send_queue_.Empty()) return 0;
|
||||||
|
//
|
||||||
|
// // Send a front msg of send queue
|
||||||
|
// last_handle_.reset(new MPIMsgHandle());
|
||||||
|
// MessagePtr sending_msg;
|
||||||
|
// CHECK(send_queue_.TryPop(sending_msg));
|
||||||
|
// last_handle_->set_msg(sending_msg);
|
||||||
|
// size_t size = SendAsync(last_handle_->msg(), last_handle_.get());
|
||||||
|
// return size;
|
||||||
|
//}
|
||||||
|
|
||||||
size_t Send(MessagePtr& msg) override {
|
size_t Send(MessagePtr& msg) override {
|
||||||
if (msg.get()) { send_queue_.Push(msg); }
|
if (msg.get()) { send_queue_.Push(msg); }
|
||||||
|
|
||||||
|
@ -139,21 +162,88 @@ public:
|
||||||
last_handle_.reset(new MPIMsgHandle());
|
last_handle_.reset(new MPIMsgHandle());
|
||||||
MessagePtr sending_msg;
|
MessagePtr sending_msg;
|
||||||
CHECK(send_queue_.TryPop(sending_msg));
|
CHECK(send_queue_.TryPop(sending_msg));
|
||||||
last_handle_->set_msg(sending_msg);
|
|
||||||
size_t size = SendAsync(last_handle_->msg(), last_handle_.get());
|
size_t size = SerializeAndSend(sending_msg, last_handle_.get());
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//size_t Recv(MessagePtr* msg) override {
|
||||||
|
// MPI_Status status;
|
||||||
|
// int flag;
|
||||||
|
// // non-blocking probe whether message comes
|
||||||
|
// MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status));
|
||||||
|
// int count;
|
||||||
|
// MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count));
|
||||||
|
// if (!flag) return 0;
|
||||||
|
// CHECK(count == Message::kHeaderSize);
|
||||||
|
// return RecvMsgFrom(status.MPI_SOURCE, msg);
|
||||||
|
//}
|
||||||
|
|
||||||
size_t Recv(MessagePtr* msg) override {
|
size_t Recv(MessagePtr* msg) override {
|
||||||
MPI_Status status;
|
MPI_Status status;
|
||||||
int flag;
|
int flag;
|
||||||
// non-blocking probe whether message comes
|
// non-blocking probe whether message comes
|
||||||
MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status));
|
MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status));
|
||||||
|
if (!flag) return 0;
|
||||||
int count;
|
int count;
|
||||||
MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count));
|
MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count));
|
||||||
if (!flag) return 0;
|
if (count > recv_size_) {
|
||||||
CHECK(count == Message::kHeaderSize);
|
recv_buffer_ = (char*)realloc(recv_buffer_, count);
|
||||||
return RecvMsgFrom(status.MPI_SOURCE, msg);
|
recv_size_ = count;
|
||||||
|
}
|
||||||
|
// CHECK(count == Message::kHeaderSize);
|
||||||
|
return RecvAndDeserialize(status.MPI_SOURCE, count, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t SerializeAndSend(MessagePtr& msg, MPIMsgHandle* msg_handle) {
|
||||||
|
|
||||||
|
CHECK_NOTNULL(msg_handle);
|
||||||
|
size_t size = sizeof(int) + Message::kHeaderSize;
|
||||||
|
for (auto& data : msg->data()) size += sizeof(int) + data.size();
|
||||||
|
if (size > send_size_) {
|
||||||
|
send_buffer_ = (char*)realloc(send_buffer_, size);
|
||||||
|
send_size_ = size;
|
||||||
|
}
|
||||||
|
memcpy(send_buffer_, msg->header(), Message::kHeaderSize);
|
||||||
|
char* p = send_buffer_ + Message::kHeaderSize;
|
||||||
|
for (auto& data : msg->data()) {
|
||||||
|
int s = data.size();
|
||||||
|
memcpy(p, &s, sizeof(int));
|
||||||
|
p += sizeof(int);
|
||||||
|
memcpy(p, data.data(), s);
|
||||||
|
p += s;
|
||||||
|
}
|
||||||
|
int over = -1;
|
||||||
|
memcpy(p, &over, sizeof(int));
|
||||||
|
|
||||||
|
MPI_Request handle;
|
||||||
|
MV_MPI_CALL(MPI_Isend(send_buffer_, size, MPI_BYTE, msg->dst(), 0, MPI_COMM_WORLD, &handle));
|
||||||
|
msg_handle->add_handle(handle);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t RecvAndDeserialize(int src, int count, MessagePtr* msg_ptr) {
|
||||||
|
if (!msg_ptr->get()) msg_ptr->reset(new Message());
|
||||||
|
MessagePtr& msg = *msg_ptr;
|
||||||
|
msg->data().clear();
|
||||||
|
MPI_Status status;
|
||||||
|
MV_MPI_CALL(MPI_Recv(recv_buffer_, count,
|
||||||
|
MPI_BYTE, src, 0, MPI_COMM_WORLD, &status));
|
||||||
|
char* p = recv_buffer_;
|
||||||
|
int s;
|
||||||
|
memcpy(msg->header(), p, Message::kHeaderSize);
|
||||||
|
p += Message::kHeaderSize;
|
||||||
|
memcpy(&s, p, sizeof(int));
|
||||||
|
p += sizeof(int);
|
||||||
|
while (s != -1) {
|
||||||
|
Blob data(s);
|
||||||
|
memcpy(data.data(), p, data.size());
|
||||||
|
msg->Push(data);
|
||||||
|
p += data.size();
|
||||||
|
memcpy(&s, p, sizeof(int));
|
||||||
|
p += sizeof(int);
|
||||||
|
}
|
||||||
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
int thread_level_support() override {
|
int thread_level_support() override {
|
||||||
|
@ -226,6 +316,10 @@ private:
|
||||||
// std::queue<MPIMsgHandle *> msg_handles_;
|
// std::queue<MPIMsgHandle *> msg_handles_;
|
||||||
std::unique_ptr<MPIMsgHandle> last_handle_;
|
std::unique_ptr<MPIMsgHandle> last_handle_;
|
||||||
MtQueue<MessagePtr> send_queue_;
|
MtQueue<MessagePtr> send_queue_;
|
||||||
|
char* send_buffer_;
|
||||||
|
int send_size_;
|
||||||
|
char* recv_buffer_;
|
||||||
|
int recv_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,8 @@ public:
|
||||||
|
|
||||||
void Add(int row_id, T* data, size_t size) {
|
void Add(int row_id, T* data, size_t size) {
|
||||||
if (row_id >= 0) CHECK(size == num_col_);
|
if (row_id >= 0) CHECK(size == num_col_);
|
||||||
Blob ids_blob(&row_id, sizeof(int));
|
int row = row_id;
|
||||||
|
Blob ids_blob(&row, sizeof(int));
|
||||||
Blob data_blob(data, size * sizeof(T));
|
Blob data_blob(data, size * sizeof(T));
|
||||||
WorkerTable::Add(ids_blob, data_blob);
|
WorkerTable::Add(ids_blob, data_blob);
|
||||||
Log::Debug("worker %d adding rows\n", MV_Rank());
|
Log::Debug("worker %d adding rows\n", MV_Rank());
|
||||||
|
@ -195,7 +196,7 @@ public:
|
||||||
CHECK(server_id_ != -1);
|
CHECK(server_id_ != -1);
|
||||||
|
|
||||||
int size = num_row / MV_NumServers();
|
int size = num_row / MV_NumServers();
|
||||||
row_offset_ = size * MV_Rank(); // Zoo::Get()->rank();
|
row_offset_ = size * server_id_; // Zoo::Get()->rank();
|
||||||
if (server_id_ == MV_NumServers() - 1){
|
if (server_id_ == MV_NumServers() - 1){
|
||||||
size = num_row - row_offset_;
|
size = num_row - row_offset_;
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче