change mpi send and recv to debug

This commit is contained in:
feiga 2016-03-09 03:31:22 +08:00
Родитель 45d8b53053
Коммит f2e452b24b
3 изменённых файлов: 107 добавлений и 9 удалений

Просмотреть файл

@ -222,6 +222,9 @@ void TestNet(int argc, char* argv[]) {
msg->Push(Blob(hi1, 13)); msg->Push(Blob(hi1, 13));
msg->Push(Blob(hi2, 11)); msg->Push(Blob(hi2, 11));
msg->Push(Blob(hi3, 18)); msg->Push(Blob(hi3, 18));
for (int i = 0; i < msg->size(); ++i) {
Log::Info("In Send: %s\n", msg->data()[i].data());
};
while (net->Send(msg) == 0) ; while (net->Send(msg) == 0) ;
Log::Info("rank 0 send\n"); Log::Info("rank 0 send\n");
} }

Просмотреть файл

@ -40,7 +40,7 @@ public:
size_t size() const { return size_; } size_t size() const { return size_; }
void Wait() { void Wait() {
CHECK_NOTNULL(msg_.get()); // CHECK_NOTNULL(msg_.get());
int count = static_cast<int>(handles_.size()); int count = static_cast<int>(handles_.size());
MPI_Status* status = new MPI_Status[count]; MPI_Status* status = new MPI_Status[count];
MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status)); MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status));
@ -48,7 +48,7 @@ public:
} }
int Test() { int Test() {
CHECK_NOTNULL(msg_.get()); // CHECK_NOTNULL(msg_.get());
int count = static_cast<int>(handles_.size()); int count = static_cast<int>(handles_.size());
MPI_Status* status = new MPI_Status[count]; MPI_Status* status = new MPI_Status[count];
int flag; int flag;
@ -121,6 +121,29 @@ public:
// return size; // return size;
//} //}
//size_t Send(MessagePtr& msg) override {
// if (msg.get()) { send_queue_.Push(msg); }
//
// if (last_handle_.get() != nullptr && !last_handle_->Test()) {
// // Last msg is still on the air
// return 0;
// }
// // send over, free the last msg
// last_handle_.reset();
// // if there is more msg to send
// if (send_queue_.Empty()) return 0;
//
// // Send a front msg of send queue
// last_handle_.reset(new MPIMsgHandle());
// MessagePtr sending_msg;
// CHECK(send_queue_.TryPop(sending_msg));
// last_handle_->set_msg(sending_msg);
// size_t size = SendAsync(last_handle_->msg(), last_handle_.get());
// return size;
//}
size_t Send(MessagePtr& msg) override { size_t Send(MessagePtr& msg) override {
if (msg.get()) { send_queue_.Push(msg); } if (msg.get()) { send_queue_.Push(msg); }
@ -139,21 +162,88 @@ public:
last_handle_.reset(new MPIMsgHandle()); last_handle_.reset(new MPIMsgHandle());
MessagePtr sending_msg; MessagePtr sending_msg;
CHECK(send_queue_.TryPop(sending_msg)); CHECK(send_queue_.TryPop(sending_msg));
last_handle_->set_msg(sending_msg);
size_t size = SendAsync(last_handle_->msg(), last_handle_.get()); size_t size = SerializeAndSend(sending_msg, last_handle_.get());
return size; return size;
} }
//size_t Recv(MessagePtr* msg) override {
// MPI_Status status;
// int flag;
// // non-blocking probe whether message comes
// MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status));
// int count;
// MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count));
// if (!flag) return 0;
// CHECK(count == Message::kHeaderSize);
// return RecvMsgFrom(status.MPI_SOURCE, msg);
//}
size_t Recv(MessagePtr* msg) override { size_t Recv(MessagePtr* msg) override {
MPI_Status status; MPI_Status status;
int flag; int flag;
// non-blocking probe whether message comes // non-blocking probe whether message comes
MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status)); MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status));
if (!flag) return 0;
int count; int count;
MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count)); MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count));
if (!flag) return 0; if (count > recv_size_) {
CHECK(count == Message::kHeaderSize); recv_buffer_ = (char*)realloc(recv_buffer_, count);
return RecvMsgFrom(status.MPI_SOURCE, msg); recv_size_ = count;
}
// CHECK(count == Message::kHeaderSize);
return RecvAndDeserialize(status.MPI_SOURCE, count, msg);
}
size_t SerializeAndSend(MessagePtr& msg, MPIMsgHandle* msg_handle) {
CHECK_NOTNULL(msg_handle);
size_t size = sizeof(int) + Message::kHeaderSize;
for (auto& data : msg->data()) size += sizeof(int) + data.size();
if (size > send_size_) {
send_buffer_ = (char*)realloc(send_buffer_, size);
send_size_ = size;
}
memcpy(send_buffer_, msg->header(), Message::kHeaderSize);
char* p = send_buffer_ + Message::kHeaderSize;
for (auto& data : msg->data()) {
int s = data.size();
memcpy(p, &s, sizeof(int));
p += sizeof(int);
memcpy(p, data.data(), s);
p += s;
}
int over = -1;
memcpy(p, &over, sizeof(int));
MPI_Request handle;
MV_MPI_CALL(MPI_Isend(send_buffer_, size, MPI_BYTE, msg->dst(), 0, MPI_COMM_WORLD, &handle));
msg_handle->add_handle(handle);
return size;
}
size_t RecvAndDeserialize(int src, int count, MessagePtr* msg_ptr) {
if (!msg_ptr->get()) msg_ptr->reset(new Message());
MessagePtr& msg = *msg_ptr;
msg->data().clear();
MPI_Status status;
MV_MPI_CALL(MPI_Recv(recv_buffer_, count,
MPI_BYTE, src, 0, MPI_COMM_WORLD, &status));
char* p = recv_buffer_;
int s;
memcpy(msg->header(), p, Message::kHeaderSize);
p += Message::kHeaderSize;
memcpy(&s, p, sizeof(int));
p += sizeof(int);
while (s != -1) {
Blob data(s);
memcpy(data.data(), p, data.size());
msg->Push(data);
p += data.size();
memcpy(&s, p, sizeof(int));
p += sizeof(int);
}
return count;
} }
int thread_level_support() override { int thread_level_support() override {
@ -226,6 +316,10 @@ private:
// std::queue<MPIMsgHandle *> msg_handles_; // std::queue<MPIMsgHandle *> msg_handles_;
std::unique_ptr<MPIMsgHandle> last_handle_; std::unique_ptr<MPIMsgHandle> last_handle_;
MtQueue<MessagePtr> send_queue_; MtQueue<MessagePtr> send_queue_;
char* send_buffer_;
int send_size_;
char* recv_buffer_;
int recv_size_;
}; };
} }

Просмотреть файл

@ -68,7 +68,8 @@ public:
void Add(int row_id, T* data, size_t size) { void Add(int row_id, T* data, size_t size) {
if (row_id >= 0) CHECK(size == num_col_); if (row_id >= 0) CHECK(size == num_col_);
Blob ids_blob(&row_id, sizeof(int)); int row = row_id;
Blob ids_blob(&row, sizeof(int));
Blob data_blob(data, size * sizeof(T)); Blob data_blob(data, size * sizeof(T));
WorkerTable::Add(ids_blob, data_blob); WorkerTable::Add(ids_blob, data_blob);
Log::Debug("worker %d adding rows\n", MV_Rank()); Log::Debug("worker %d adding rows\n", MV_Rank());
@ -195,7 +196,7 @@ public:
CHECK(server_id_ != -1); CHECK(server_id_ != -1);
int size = num_row / MV_NumServers(); int size = num_row / MV_NumServers();
row_offset_ = size * MV_Rank(); // Zoo::Get()->rank(); row_offset_ = size * server_id_; // Zoo::Get()->rank();
if (server_id_ == MV_NumServers() - 1){ if (server_id_ == MV_NumServers() - 1){
size = num_row - row_offset_; size = num_row - row_offset_;
} }