diff --git a/Source/Multiverso/include/multiverso/multiverso.h b/Source/Multiverso/include/multiverso/multiverso.h index 145b49d96..4586221ee 100644 --- a/Source/Multiverso/include/multiverso/multiverso.h +++ b/Source/Multiverso/include/multiverso/multiverso.h @@ -5,15 +5,15 @@ namespace multiverso { enum Role { - kNull = 0, - kWorker = 1, - kServer = 2, - kAll = 3 + Null = 0, + Worker = 1, + Server = 2, + All = 3 }; void MV_Init(int* argc = nullptr, char* argv[] = nullptr, - int role = kAll); + int role = All); void MV_Barrier(); @@ -24,17 +24,6 @@ int MV_Size(); int MV_Worker_Id(); int MV_Server_Id(); - -// will deprecate the following function name -void MultiversoInit(int* argc = nullptr, - char* argv[] = nullptr, - int role = kAll); - -void MultiversoBarrier(); - -void MultiversoShutDown(bool finalize_mpi = true); - -int MultiversoRank(); } #endif // MULTIVERSO_INCLUDE_MULTIVERSO_H_ \ No newline at end of file diff --git a/Source/Multiverso/include/multiverso/net/mpi_net.h b/Source/Multiverso/include/multiverso/net/mpi_net.h index efd49672d..43c74a855 100644 --- a/Source/Multiverso/include/multiverso/net/mpi_net.h +++ b/Source/Multiverso/include/multiverso/net/mpi_net.h @@ -1,4 +1,4 @@ -#ifndef MULTIVERSO_NET_MPI_NET_H_ +#ifndef MULTIVERSO_NET_MPI_NET_H_ #define MULTIVERSO_NET_MPI_NET_H_ #ifdef MULTIVERSO_USE_MPI @@ -22,6 +22,8 @@ namespace multiverso { +#define MV_MPI_CALL(mpi_return) CHECK((mpi_return) == MPI_SUCCESS) + class MPINetWrapper : public NetInterface { public: MPINetWrapper() : more_(std::numeric_limits::max()) {} @@ -41,18 +43,18 @@ public: CHECK_NOTNULL(msg_.get()); int count = static_cast(handles_.size()); MPI_Status* status = new MPI_Status[count]; - MPI_Waitall(count, handles_.data(), status); + MV_MPI_CALL(MPI_Waitall(count, handles_.data(), status)); delete[] status; } - bool Test() { + int Test() { CHECK_NOTNULL(msg_.get()); int count = static_cast(handles_.size()); MPI_Status* status = new MPI_Status[count]; int flag; - MPI_Testall(count, handles_.data(), &flag, status); + MV_MPI_CALL(MPI_Testall(count, handles_.data(), &flag, status)); delete[] status; - return flag != 0; + return flag; } private: std::vector handles_; @@ -62,11 +64,11 @@ public: void Init(int* argc, char** argv) override { // MPI_Init(argc, &argv); - MPI_Initialized(&inited_); + MV_MPI_CALL(MPI_Initialized(&inited_)); if (!inited_) { - MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_); + MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_)); } - MPI_Query_thread(&thread_provided_); + MV_MPI_CALL(MPI_Query_thread(&thread_provided_)); if (thread_provided_ < MPI_THREAD_SERIALIZED) { Log::Fatal("At least MPI_THREAD_SERIALIZED supported is needed by multiverso.\n"); } @@ -78,7 +80,7 @@ public: } MPI_Comm_rank(MPI_COMM_WORLD, &rank_); MPI_Comm_size(MPI_COMM_WORLD, &size_); - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(MPI_COMM_WORLD); Log::Debug("%s net util inited, rank = %d, size = %d\n", name().c_str(), rank(), size()); } @@ -135,9 +137,12 @@ public: MPI_Status status; int flag; // non-blocking probe whether message comes - MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status); + MV_MPI_CALL(MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status)); + int count; + MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count)); if (!flag) return 0; - return RecvMsg(msg); + CHECK(count == Message::kHeaderSize); + return RecvMsgFrom(status.MPI_SOURCE, msg); } int thread_level_support() override { @@ -152,44 +157,44 @@ private: CHECK_NOTNULL(msg_handle); size_t size = Message::kHeaderSize; MPI_Request handle; - MPI_Isend(msg->header(), Message::kHeaderSize, MPI_BYTE, - msg->dst(), 0, MPI_COMM_WORLD, &handle); + CHECK_NOTNULL(msg->header()); + MV_MPI_CALL(MPI_Isend(msg->header(), Message::kHeaderSize, MPI_BYTE, + msg->dst(), 0, MPI_COMM_WORLD, &handle)); msg_handle->add_handle(handle); // Send multiple msg for (auto& blob : msg->data()) { CHECK_NOTNULL(blob.data()); - MPI_Isend(blob.data(), static_cast(blob.size()), MPI_BYTE, msg->dst(), - 0, MPI_COMM_WORLD, &handle); + MV_MPI_CALL(MPI_Isend(blob.data(), static_cast(blob.size()), + MPI_BYTE, msg->dst(), + 0, MPI_COMM_WORLD, &handle)); size += blob.size(); msg_handle->add_handle(handle); } // Send an extra over tag indicating the finish of this Message - MPI_Isend(&more_, sizeof(char), MPI_BYTE, msg->dst(), - 0, MPI_COMM_WORLD, &handle); + MV_MPI_CALL(MPI_Isend(&more_, sizeof(char), MPI_BYTE, msg->dst(), + 0, MPI_COMM_WORLD, &handle)); // Log::Debug("MPI-Net: rank %d send msg size = %d\n", rank(), size+4); msg_handle->add_handle(handle); return size + sizeof(char); } - size_t RecvMsg(MessagePtr* msg_ptr) { + size_t RecvMsgFrom(int source, MessagePtr* msg_ptr) { if (!msg_ptr->get()) msg_ptr->reset(new Message()); MessagePtr& msg = *msg_ptr; msg->data().clear(); MPI_Status status; - MPI_Recv(msg->header(), Message::kHeaderSize, - MPI_BYTE, MPI_ANY_SOURCE, - 0, MPI_COMM_WORLD, &status); + CHECK_NOTNULL(msg->header()); + MV_MPI_CALL(MPI_Recv(msg->header(), Message::kHeaderSize, + MPI_BYTE, source, 0, MPI_COMM_WORLD, &status)); size_t size = Message::kHeaderSize; - int i = 0; - int num_probe = 0; while (true) { int count; - CHECK(MPI_SUCCESS == MPI_Probe(msg->src(), 0, MPI_COMM_WORLD, &status)); - MPI_Get_count(&status, MPI_BYTE, &count); + MV_MPI_CALL(MPI_Probe(source, 0, MPI_COMM_WORLD, &status)); + MV_MPI_CALL(MPI_Get_count(&status, MPI_BYTE, &count)); Blob blob(count); // We only receive from msg->src() until we recv the overtag msg - MPI_Recv(blob.data(), count, MPI_BYTE, msg->src(), - 0, MPI_COMM_WORLD, &status); + MV_MPI_CALL(MPI_Recv(blob.data(), count, MPI_BYTE, source, + 0, MPI_COMM_WORLD, &status)); size += count; if (count == sizeof(char)) { if (blob.As() == more_) break; @@ -197,7 +202,6 @@ private: } msg->Push(blob); } - // Log::Debug("MPI-Net: rank %d end recv from src %d, size = %d\n", rank(), msg->src(), size); return size; } diff --git a/Source/Multiverso/include/multiverso/net/zmq_net.h b/Source/Multiverso/include/multiverso/net/zmq_net.h index b9dfea8b9..f013b1b0f 100644 --- a/Source/Multiverso/include/multiverso/net/zmq_net.h +++ b/Source/Multiverso/include/multiverso/net/zmq_net.h @@ -1,6 +1,7 @@ #ifndef MULTIVERSO_NET_ZMQ_NET_H_ #define MULTIVERSO_NET_ZMQ_NET_H_ +// #define MULTIVERSO_USE_ZMQ #ifdef MULTIVERSO_USE_ZMQ #include "multiverso/net.h" diff --git a/Source/Multiverso/include/multiverso/util/mt_queue.h b/Source/Multiverso/include/multiverso/util/mt_queue.h index af95453f7..3d43f870f 100644 --- a/Source/Multiverso/include/multiverso/util/mt_queue.h +++ b/Source/Multiverso/include/multiverso/util/mt_queue.h @@ -95,7 +95,6 @@ bool MtQueue::Pop(T& result) // empty_condition_.wait(lock, // [this]{ return !buffer_.empty() || exit_; }); while (buffer_.empty() && !exit_) { - // while (true) { empty_condition_.wait(lock); } if (buffer_.empty()) return false; diff --git a/Source/Multiverso/include/multiverso/zoo.h b/Source/Multiverso/include/multiverso/zoo.h index 777b35cef..d403fe695 100644 --- a/Source/Multiverso/include/multiverso/zoo.h +++ b/Source/Multiverso/include/multiverso/zoo.h @@ -36,6 +36,9 @@ public: int size() const; // TODO(to change) + int worker_rank() const; + int server_rank() const; + int num_workers() const { return num_workers_; } int num_servers() const { return num_servers_; } diff --git a/Source/Multiverso/x64/release/libmultiverso.a b/Source/Multiverso/x64/release/libmultiverso.a index a4f89996a..62698b8a2 100644 Binary files a/Source/Multiverso/x64/release/libmultiverso.a and b/Source/Multiverso/x64/release/libmultiverso.a differ