diff --git a/include/multiverso/multiverso.h b/include/multiverso/multiverso.h index 3b18006..4f96f61 100644 --- a/include/multiverso/multiverso.h +++ b/include/multiverso/multiverso.h @@ -58,6 +58,8 @@ int MV_NetBind(int rank, char* endpoint); // \return 0 SUCCESS // \return -1 FAIL int MV_NetConnect(int* rank, char* endpoint[], int size); +void MV_NetClose(const char* endpoint); +void MV_NetFinalize(); } // namespace multiverso diff --git a/include/multiverso/net.h b/include/multiverso/net.h index 7bd4d13..dfdaec1 100644 --- a/include/multiverso/net.h +++ b/include/multiverso/net.h @@ -26,6 +26,8 @@ public: // Connect with other endpoints virtual int Connect(int* rank, char* endpoints[], int size) = 0; + virtual void Close(const char* endpoint) = 0; + virtual bool active() const = 0; virtual std::string name() const = 0; diff --git a/include/multiverso/net/mpi_net.h b/include/multiverso/net/mpi_net.h index c799981..219012c 100644 --- a/include/multiverso/net/mpi_net.h +++ b/include/multiverso/net/mpi_net.h @@ -139,6 +139,10 @@ public: return -1; } + void Close() override { + Log::Fatal("Shouldn't call this in MPI Net\n"); + } + bool active() const { return inited_ != 0; } int rank() const override { return rank_; } int size() const override { return size_; } diff --git a/include/multiverso/net/zmq_net.h b/include/multiverso/net/zmq_net.h index 28c6876..22ab8fb 100644 --- a/include/multiverso/net/zmq_net.h +++ b/include/multiverso/net/zmq_net.h @@ -43,24 +43,28 @@ public: for (auto ip : machine_lists_) { if (local_ip.find(ip) != local_ip.end()) { // my rank rank_ = static_cast(senders_.size()); - senders_.push_back(nullptr); - receiver_ = zmq_socket(context_, ZMQ_DEALER); - int rc = zmq_bind(receiver_, - ("tcp://" + ip + ":" + std::to_string(port)).c_str()); + Entity sender; sender.endpoint = ""; sender.socket = nullptr; + senders_.push_back(sender); + receiver_.socket = zmq_socket(context_, ZMQ_DEALER); + receiver_.endpoint = ip + ":" + std::to_string(port); + int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str()); + endpoint_to_socket_[receiver_.endpoint] = receiver_.socket; CHECK(rc == 0); int linger = 0; - CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0); + CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0); } else { - void* sender = zmq_socket(context_, ZMQ_DEALER); - int rc = zmq_connect(sender, - ("tcp://" + ip + ":" + std::to_string(port)).c_str()); + Entity sender; + sender.socket = zmq_socket(context_, ZMQ_DEALER); + sender.endpoint = ip + ":" + std::to_string(port); + int rc = zmq_connect(sender.socket, ("tcp://" + sender.endpoint).c_str()); + endpoint_to_socket_[sender.endpoint] = sender.socket; CHECK(rc == 0); senders_.push_back(sender); int linger = 0; - CHECK(zmq_setsockopt(sender, ZMQ_LINGER, &linger, sizeof(linger)) == 0); + CHECK(zmq_setsockopt(sender.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0); } } - CHECK_NOTNULL(receiver_); + CHECK_NOTNULL(receiver_.socket); active_ = true; Log::Info("%s net util inited, rank = %d, size = %d\n", name().c_str(), rank(), size()); @@ -71,11 +75,13 @@ public: std::string ip_port(endpoint); if (context_ == nullptr) { context_ = zmq_ctx_new(); } CHECK_NOTNULL(context_); - receiver_ = zmq_socket(context_, ZMQ_DEALER); - int rc = zmq_bind(receiver_, ("tcp://" + ip_port).c_str()); + receiver_.socket = zmq_socket(context_, ZMQ_DEALER); + receiver_.endpoint = ip_port; + int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str()); + endpoint_to_socket_[receiver_.endpoint] = receiver_.socket; if (rc == 0) { int linger = 0; - CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0); + CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0); return 0; } else { @@ -87,17 +93,23 @@ public: // Connect with other endpoints virtual int Connect(int* ranks, char* endpoints[], int size) override { - CHECK_NOTNULL(receiver_); + CHECK_NOTNULL(receiver_.socket); CHECK_NOTNULL(context_); size_ = size; senders_.resize(size_); for (int i = 0; i < size; ++i) { int rank = ranks[i]; - if (rank == rank_) continue; std::string ip_port(endpoints[i]); - senders_[rank] = zmq_socket(context_, ZMQ_DEALER); + // if (rank == rank_) continue; + if (ip_port == receiver_.endpoint) { + rank_ = rank; + continue; + } + senders_[rank].socket = zmq_socket(context_, ZMQ_DEALER); + senders_[rank].endpoint = ip_port; + endpoint_to_socket_[senders_[rank].endpoint] = senders_[rank].socket; // NOTE(feiga): set linger to 0, otherwise will hang - int rc = zmq_connect(senders_[rank], ("tcp://" + ip_port).c_str()); + int rc = zmq_connect(senders_[rank].socket, ("tcp://" + senders_[rank].endpoint).c_str()); if (rc != 0) { Log::Error("Failed to connect the socket for sender, rank = %d, " "ip:port = %s\n", rank, endpoints[i]); @@ -108,24 +120,47 @@ public: return 0; } - void Finalize() override { - active_ = false; - for (int i = 0; i < senders_.size(); ++i) { - if (i != rank_) { - int linger = 0; - CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0); - int rc = zmq_close(senders_[i]); - CHECK(rc == 0); + void Close(const char* endpoint) override { + std::string str_endpoint(endpoint); + auto it = endpoint_to_socket_.find(str_endpoint); + if (it != endpoint_to_socket_.end()) { + Log::Info("Close endpoint %s\n", it->first.c_str()); + CHECK(zmq_close(it->second) == 0); + endpoint_to_socket_.erase(it); + if (endpoint_to_socket_.empty()) { + // Term the context when the last endpoint is closed properly + zmq_ctx_term(context_); + Log::Info("ZMQ Finalize sucessfully\n"); } } - int linger = 0; - CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0); - int rc = zmq_close(receiver_); - CHECK(rc == 0); + } - CHECK(zmq_ctx_shutdown(context_)==0); - // zmq_ctx_term(context_); - Log::Info("zmq finalize: close context\n"); + void Finalize() override { + active_ = false; + for (auto entity : senders_) { + if (entity.socket != nullptr) { + zmq_disconnect(entity.socket, entity.endpoint.c_str()); + Close(entity.endpoint.c_str()); + } + } + zmq_unbind(receiver_.socket, receiver_.endpoint.c_str()); + Close(receiver_.endpoint.c_str()); + //for (int i = 0; i < senders_.size(); ++i) { + // if (i != rank_) { + // int linger = 0; + // CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0); + // int rc = zmq_close(senders_[i]); + // CHECK(rc == 0); + // } + //} + //int linger = 0; + //CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0); + //int rc = zmq_close(receiver_); + //CHECK(rc == 0); + + //CHECK(zmq_ctx_shutdown(context_)==0); + //zmq_ctx_term(context_); + //Log::Info("zmq finalize: close context\n"); } bool active() const override { return active_; } @@ -136,7 +171,7 @@ public: size_t Send(MessagePtr& msg) override { size_t size = 0; int dst = msg->dst(); - void* socket = senders_[dst]; + void* socket = senders_[dst].socket; CHECK_NOTNULL(socket); int send_size; send_size = zmq_send(socket, msg->header(), @@ -169,25 +204,25 @@ public: MessagePtr& msg = *msg_ptr; msg->data().clear(); CHECK(msg.get()); - recv_size = zmq_recv(receiver_, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT); + recv_size = zmq_recv(receiver_.socket, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT); if (recv_size < 0) { return -1; } CHECK(Message::kHeaderSize == recv_size); size += recv_size; - zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size); + zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size); while (more) { - recv_size = zmq_recv(receiver_, &blob_size, sizeof(size_t), 0); + recv_size = zmq_recv(receiver_.socket, &blob_size, sizeof(size_t), 0); CHECK(recv_size == sizeof(size_t)); size += recv_size; - zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size); + zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size); CHECK(more); Blob blob(blob_size); - recv_size = zmq_recv(receiver_, blob.data(), blob.size(), 0); + recv_size = zmq_recv(receiver_.socket, blob.data(), blob.size(), 0); CHECK(recv_size == blob_size); size += recv_size; msg->Push(blob); - zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size); + zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size); } return size; } @@ -196,7 +231,7 @@ public: void SendTo(int rank, char* buf, int len) const override { int send_size = 0; while (send_size < len) { - int cur_size = zmq_send(senders_[rank], buf + send_size, len - send_size, 0); + int cur_size = zmq_send(senders_[rank].socket, buf + send_size, len - send_size, 0); if (cur_size < 0) { Log::Error("socket send error %d", cur_size); } send_size += cur_size; } @@ -206,7 +241,7 @@ public: // note: rank is not used here int recv_size = 0; while (recv_size < len) { - int cur_size = zmq_recv(receiver_, buf + recv_size, len - recv_size, 0); + int cur_size = zmq_recv(receiver_.socket, buf + recv_size, len - recv_size, 0); if (cur_size < 0) { Log::Error("socket receive error %d", cur_size); } recv_size += cur_size; } @@ -250,11 +285,22 @@ protected: bool active_; void* context_; - void* receiver_; - std::vector senders_; + + struct Entity { + std::string endpoint; + void* socket; + }; + + Entity receiver_; + std::vector senders_; + + // void* receiver_; + // std::vector senders_; + int rank_; int size_; std::vector machine_lists_; + std::unordered_map endpoint_to_socket_; }; } // namespace multiverso diff --git a/src/multiverso.cpp b/src/multiverso.cpp index b1b1b48..94d5037 100644 --- a/src/multiverso.cpp +++ b/src/multiverso.cpp @@ -57,6 +57,14 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) { return NetInterface::Get()->Connect(ranks, endpoints, size); } +void MV_NetClose(const char* endpoint) { + NetInterface::Get()->Close(endpoint); +} + +void MV_NetFinalize() { + NetInterface::Get()->Finalize(); +} + template void MV_Aggregate(char*, int); template void MV_Aggregate(int*, int); template void MV_Aggregate(float*, int);