diff --git a/next/Test/main.cpp b/next/Test/main.cpp index 26fe380..59c938b 100644 --- a/next/Test/main.cpp +++ b/next/Test/main.cpp @@ -222,6 +222,9 @@ void TestNet(int argc, char* argv[]) { msg->Push(Blob(hi1, 13)); msg->Push(Blob(hi2, 11)); msg->Push(Blob(hi3, 18)); + for (int i = 0; i < msg->size(); ++i) { + Log::Info("In Send: %s\n", msg->data()[i].data()); + }; while (net->Send(msg) == 0) ; Log::Info("rank 0 send\n"); } diff --git a/next/include/multiverso/net/mpi_net.h b/next/include/multiverso/net/mpi_net.h index f112bb8..6ee2606 100644 --- a/next/include/multiverso/net/mpi_net.h +++ b/next/include/multiverso/net/mpi_net.h @@ -40,7 +40,7 @@ public: size_t size() const { return size_; } void Wait() { - CHECK_NOTNULL(msg_.get()); + // CHECK_NOTNULL(msg_.get()); int count = static_cast(handles_.size()); MPI_Status* status = new MPI_Status[count]; MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status)); @@ -48,7 +48,7 @@ public: } int Test() { - CHECK_NOTNULL(msg_.get()); + // CHECK_NOTNULL(msg_.get()); int count = static_cast(handles_.size()); MPI_Status* status = new MPI_Status[count]; int flag; @@ -121,6 +121,29 @@ public: // return size; //} + //size_t Send(MessagePtr& msg) override { + // if (msg.get()) { send_queue_.Push(msg); } + // + // if (last_handle_.get() != nullptr && !last_handle_->Test()) { + // // Last msg is still on the air + // return 0; + // } + + // // send over, free the last msg + // last_handle_.reset(); + + // // if there is more msg to send + // if (send_queue_.Empty()) return 0; + // + // // Send a front msg of send queue + // last_handle_.reset(new MPIMsgHandle()); + // MessagePtr sending_msg; + // CHECK(send_queue_.TryPop(sending_msg)); + // last_handle_->set_msg(sending_msg); + // size_t size = SendAsync(last_handle_->msg(), last_handle_.get()); + // return size; + //} + size_t Send(MessagePtr& msg) override { if (msg.get()) { send_queue_.Push(msg); } @@ -139,21 +162,88 @@ public: last_handle_.reset(new MPIMsgHandle()); MessagePtr sending_msg; CHECK(send_queue_.TryPop(sending_msg)); - last_handle_->set_msg(sending_msg); - size_t size = SendAsync(last_handle_->msg(), last_handle_.get()); + + size_t size = SerializeAndSend(sending_msg, last_handle_.get()); return size; } + //size_t Recv(MessagePtr* msg) override { + // MPI_Status status; + // int flag; + // // non-blocking probe whether message comes + // MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status)); + // int count; + // MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count)); + // if (!flag) return 0; + // CHECK(count == Message::kHeaderSize); + // return RecvMsgFrom(status.MPI_SOURCE, msg); + //} + size_t Recv(MessagePtr* msg) override { MPI_Status status; int flag; // non-blocking probe whether message comes MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status)); + if (!flag) return 0; int count; MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count)); - if (!flag) return 0; - CHECK(count == Message::kHeaderSize); - return RecvMsgFrom(status.MPI_SOURCE, msg); + if (count > recv_size_) { + recv_buffer_ = (char*)realloc(recv_buffer_, count); + recv_size_ = count; + } + // CHECK(count == Message::kHeaderSize); + return RecvAndDeserialize(status.MPI_SOURCE, count, msg); + } + + size_t SerializeAndSend(MessagePtr& msg, MPIMsgHandle* msg_handle) { + + CHECK_NOTNULL(msg_handle); + size_t size = sizeof(int) + Message::kHeaderSize; + for (auto& data : msg->data()) size += sizeof(int) + data.size(); + if (size > send_size_) { + send_buffer_ = (char*)realloc(send_buffer_, size); + send_size_ = size; + } + memcpy(send_buffer_, msg->header(), Message::kHeaderSize); + char* p = send_buffer_ + Message::kHeaderSize; + for (auto& data : msg->data()) { + int s = data.size(); + memcpy(p, &s, sizeof(int)); + p += sizeof(int); + memcpy(p, data.data(), s); + p += s; + } + int over = -1; + memcpy(p, &over, sizeof(int)); + + MPI_Request handle; + MV_MPI_CALL(MPI_Isend(send_buffer_, size, MPI_BYTE, msg->dst(), 0, MPI_COMM_WORLD, &handle)); + msg_handle->add_handle(handle); + return size; + } + + size_t RecvAndDeserialize(int src, int count, MessagePtr* msg_ptr) { + if (!msg_ptr->get()) msg_ptr->reset(new Message()); + MessagePtr& msg = *msg_ptr; + msg->data().clear(); + MPI_Status status; + MV_MPI_CALL(MPI_Recv(recv_buffer_, count, + MPI_BYTE, src, 0, MPI_COMM_WORLD, &status)); + char* p = recv_buffer_; + int s; + memcpy(msg->header(), p, Message::kHeaderSize); + p += Message::kHeaderSize; + memcpy(&s, p, sizeof(int)); + p += sizeof(int); + while (s != -1) { + Blob data(s); + memcpy(data.data(), p, data.size()); + msg->Push(data); + p += data.size(); + memcpy(&s, p, sizeof(int)); + p += sizeof(int); + } + return count; } int thread_level_support() override { @@ -226,6 +316,10 @@ private: // std::queue msg_handles_; std::unique_ptr last_handle_; MtQueue send_queue_; + char* send_buffer_; + int send_size_; + char* recv_buffer_; + int recv_size_; }; } diff --git a/next/include/multiverso/table/matrix_table.h b/next/include/multiverso/table/matrix_table.h index 0dbd814..c4a7881 100644 --- a/next/include/multiverso/table/matrix_table.h +++ b/next/include/multiverso/table/matrix_table.h @@ -68,7 +68,8 @@ public: void Add(int row_id, T* data, size_t size) { if (row_id >= 0) CHECK(size == num_col_); - Blob ids_blob(&row_id, sizeof(int)); + int row = row_id; + Blob ids_blob(&row, sizeof(int)); Blob data_blob(data, size * sizeof(T)); WorkerTable::Add(ids_blob, data_blob); Log::Debug("worker %d adding rows\n", MV_Rank()); @@ -195,7 +196,7 @@ public: CHECK(server_id_ != -1); int size = num_row / MV_NumServers(); - row_offset_ = size * MV_Rank(); // Zoo::Get()->rank(); + row_offset_ = size * server_id_; // Zoo::Get()->rank(); if (server_id_ == MV_NumServers() - 1){ size = num_row - row_offset_; }