zmq issue
This commit is contained in:
Родитель
4bda014f0b
Коммит
798353f799
|
@ -58,6 +58,8 @@ int MV_NetBind(int rank, char* endpoint);
|
|||
// \return 0 SUCCESS
|
||||
// \return -1 FAIL
|
||||
int MV_NetConnect(int* rank, char* endpoint[], int size);
|
||||
void MV_NetClose(const char* endpoint);
|
||||
void MV_NetFinalize();
|
||||
|
||||
} // namespace multiverso
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@ public:
|
|||
// Connect with other endpoints
|
||||
virtual int Connect(int* rank, char* endpoints[], int size) = 0;
|
||||
|
||||
virtual void Close(const char* endpoint) = 0;
|
||||
|
||||
virtual bool active() const = 0;
|
||||
|
||||
virtual std::string name() const = 0;
|
||||
|
|
|
@ -139,6 +139,10 @@ public:
|
|||
return -1;
|
||||
}
|
||||
|
||||
void Close() override {
|
||||
Log::Fatal("Shouldn't call this in MPI Net\n");
|
||||
}
|
||||
|
||||
bool active() const { return inited_ != 0; }
|
||||
int rank() const override { return rank_; }
|
||||
int size() const override { return size_; }
|
||||
|
|
|
@ -43,24 +43,28 @@ public:
|
|||
for (auto ip : machine_lists_) {
|
||||
if (local_ip.find(ip) != local_ip.end()) { // my rank
|
||||
rank_ = static_cast<int>(senders_.size());
|
||||
senders_.push_back(nullptr);
|
||||
receiver_ = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_bind(receiver_,
|
||||
("tcp://" + ip + ":" + std::to_string(port)).c_str());
|
||||
Entity sender; sender.endpoint = ""; sender.socket = nullptr;
|
||||
senders_.push_back(sender);
|
||||
receiver_.socket = zmq_socket(context_, ZMQ_DEALER);
|
||||
receiver_.endpoint = ip + ":" + std::to_string(port);
|
||||
int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str());
|
||||
endpoint_to_socket_[receiver_.endpoint] = receiver_.socket;
|
||||
CHECK(rc == 0);
|
||||
int linger = 0;
|
||||
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
} else {
|
||||
void* sender = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_connect(sender,
|
||||
("tcp://" + ip + ":" + std::to_string(port)).c_str());
|
||||
Entity sender;
|
||||
sender.socket = zmq_socket(context_, ZMQ_DEALER);
|
||||
sender.endpoint = ip + ":" + std::to_string(port);
|
||||
int rc = zmq_connect(sender.socket, ("tcp://" + sender.endpoint).c_str());
|
||||
endpoint_to_socket_[sender.endpoint] = sender.socket;
|
||||
CHECK(rc == 0);
|
||||
senders_.push_back(sender);
|
||||
int linger = 0;
|
||||
CHECK(zmq_setsockopt(sender, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
CHECK(zmq_setsockopt(sender.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
}
|
||||
}
|
||||
CHECK_NOTNULL(receiver_);
|
||||
CHECK_NOTNULL(receiver_.socket);
|
||||
active_ = true;
|
||||
Log::Info("%s net util inited, rank = %d, size = %d\n",
|
||||
name().c_str(), rank(), size());
|
||||
|
@ -71,11 +75,13 @@ public:
|
|||
std::string ip_port(endpoint);
|
||||
if (context_ == nullptr) { context_ = zmq_ctx_new(); }
|
||||
CHECK_NOTNULL(context_);
|
||||
receiver_ = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_bind(receiver_, ("tcp://" + ip_port).c_str());
|
||||
receiver_.socket = zmq_socket(context_, ZMQ_DEALER);
|
||||
receiver_.endpoint = ip_port;
|
||||
int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str());
|
||||
endpoint_to_socket_[receiver_.endpoint] = receiver_.socket;
|
||||
if (rc == 0) {
|
||||
int linger = 0;
|
||||
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
@ -87,17 +93,23 @@ public:
|
|||
|
||||
// Connect with other endpoints
|
||||
virtual int Connect(int* ranks, char* endpoints[], int size) override {
|
||||
CHECK_NOTNULL(receiver_);
|
||||
CHECK_NOTNULL(receiver_.socket);
|
||||
CHECK_NOTNULL(context_);
|
||||
size_ = size;
|
||||
senders_.resize(size_);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
int rank = ranks[i];
|
||||
if (rank == rank_) continue;
|
||||
std::string ip_port(endpoints[i]);
|
||||
senders_[rank] = zmq_socket(context_, ZMQ_DEALER);
|
||||
// if (rank == rank_) continue;
|
||||
if (ip_port == receiver_.endpoint) {
|
||||
rank_ = rank;
|
||||
continue;
|
||||
}
|
||||
senders_[rank].socket = zmq_socket(context_, ZMQ_DEALER);
|
||||
senders_[rank].endpoint = ip_port;
|
||||
endpoint_to_socket_[senders_[rank].endpoint] = senders_[rank].socket;
|
||||
// NOTE(feiga): set linger to 0, otherwise will hang
|
||||
int rc = zmq_connect(senders_[rank], ("tcp://" + ip_port).c_str());
|
||||
int rc = zmq_connect(senders_[rank].socket, ("tcp://" + senders_[rank].endpoint).c_str());
|
||||
if (rc != 0) {
|
||||
Log::Error("Failed to connect the socket for sender, rank = %d, "
|
||||
"ip:port = %s\n", rank, endpoints[i]);
|
||||
|
@ -108,24 +120,47 @@ public:
|
|||
return 0;
|
||||
}
|
||||
|
||||
void Finalize() override {
|
||||
active_ = false;
|
||||
for (int i = 0; i < senders_.size(); ++i) {
|
||||
if (i != rank_) {
|
||||
int linger = 0;
|
||||
CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
int rc = zmq_close(senders_[i]);
|
||||
CHECK(rc == 0);
|
||||
void Close(const char* endpoint) override {
|
||||
std::string str_endpoint(endpoint);
|
||||
auto it = endpoint_to_socket_.find(str_endpoint);
|
||||
if (it != endpoint_to_socket_.end()) {
|
||||
Log::Info("Close endpoint %s\n", it->first.c_str());
|
||||
CHECK(zmq_close(it->second) == 0);
|
||||
endpoint_to_socket_.erase(it);
|
||||
if (endpoint_to_socket_.empty()) {
|
||||
// Term the context when the last endpoint is closed properly
|
||||
zmq_ctx_term(context_);
|
||||
Log::Info("ZMQ Finalize sucessfully\n");
|
||||
}
|
||||
}
|
||||
int linger = 0;
|
||||
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
int rc = zmq_close(receiver_);
|
||||
CHECK(rc == 0);
|
||||
}
|
||||
|
||||
CHECK(zmq_ctx_shutdown(context_)==0);
|
||||
// zmq_ctx_term(context_);
|
||||
Log::Info("zmq finalize: close context\n");
|
||||
void Finalize() override {
|
||||
active_ = false;
|
||||
for (auto entity : senders_) {
|
||||
if (entity.socket != nullptr) {
|
||||
zmq_disconnect(entity.socket, entity.endpoint.c_str());
|
||||
Close(entity.endpoint.c_str());
|
||||
}
|
||||
}
|
||||
zmq_unbind(receiver_.socket, receiver_.endpoint.c_str());
|
||||
Close(receiver_.endpoint.c_str());
|
||||
//for (int i = 0; i < senders_.size(); ++i) {
|
||||
// if (i != rank_) {
|
||||
// int linger = 0;
|
||||
// CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
// int rc = zmq_close(senders_[i]);
|
||||
// CHECK(rc == 0);
|
||||
// }
|
||||
//}
|
||||
//int linger = 0;
|
||||
//CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
|
||||
//int rc = zmq_close(receiver_);
|
||||
//CHECK(rc == 0);
|
||||
|
||||
//CHECK(zmq_ctx_shutdown(context_)==0);
|
||||
//zmq_ctx_term(context_);
|
||||
//Log::Info("zmq finalize: close context\n");
|
||||
}
|
||||
|
||||
bool active() const override { return active_; }
|
||||
|
@ -136,7 +171,7 @@ public:
|
|||
size_t Send(MessagePtr& msg) override {
|
||||
size_t size = 0;
|
||||
int dst = msg->dst();
|
||||
void* socket = senders_[dst];
|
||||
void* socket = senders_[dst].socket;
|
||||
CHECK_NOTNULL(socket);
|
||||
int send_size;
|
||||
send_size = zmq_send(socket, msg->header(),
|
||||
|
@ -169,25 +204,25 @@ public:
|
|||
MessagePtr& msg = *msg_ptr;
|
||||
msg->data().clear();
|
||||
CHECK(msg.get());
|
||||
recv_size = zmq_recv(receiver_, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT);
|
||||
recv_size = zmq_recv(receiver_.socket, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT);
|
||||
if (recv_size < 0) { return -1; }
|
||||
CHECK(Message::kHeaderSize == recv_size);
|
||||
|
||||
size += recv_size;
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
|
||||
|
||||
while (more) {
|
||||
recv_size = zmq_recv(receiver_, &blob_size, sizeof(size_t), 0);
|
||||
recv_size = zmq_recv(receiver_.socket, &blob_size, sizeof(size_t), 0);
|
||||
CHECK(recv_size == sizeof(size_t));
|
||||
size += recv_size;
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
|
||||
CHECK(more);
|
||||
Blob blob(blob_size);
|
||||
recv_size = zmq_recv(receiver_, blob.data(), blob.size(), 0);
|
||||
recv_size = zmq_recv(receiver_.socket, blob.data(), blob.size(), 0);
|
||||
CHECK(recv_size == blob_size);
|
||||
size += recv_size;
|
||||
msg->Push(blob);
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
@ -196,7 +231,7 @@ public:
|
|||
void SendTo(int rank, char* buf, int len) const override {
|
||||
int send_size = 0;
|
||||
while (send_size < len) {
|
||||
int cur_size = zmq_send(senders_[rank], buf + send_size, len - send_size, 0);
|
||||
int cur_size = zmq_send(senders_[rank].socket, buf + send_size, len - send_size, 0);
|
||||
if (cur_size < 0) { Log::Error("socket send error %d", cur_size); }
|
||||
send_size += cur_size;
|
||||
}
|
||||
|
@ -206,7 +241,7 @@ public:
|
|||
// note: rank is not used here
|
||||
int recv_size = 0;
|
||||
while (recv_size < len) {
|
||||
int cur_size = zmq_recv(receiver_, buf + recv_size, len - recv_size, 0);
|
||||
int cur_size = zmq_recv(receiver_.socket, buf + recv_size, len - recv_size, 0);
|
||||
if (cur_size < 0) { Log::Error("socket receive error %d", cur_size); }
|
||||
recv_size += cur_size;
|
||||
}
|
||||
|
@ -250,11 +285,22 @@ protected:
|
|||
|
||||
bool active_;
|
||||
void* context_;
|
||||
void* receiver_;
|
||||
std::vector<void*> senders_;
|
||||
|
||||
struct Entity {
|
||||
std::string endpoint;
|
||||
void* socket;
|
||||
};
|
||||
|
||||
Entity receiver_;
|
||||
std::vector<Entity> senders_;
|
||||
|
||||
// void* receiver_;
|
||||
// std::vector<void*> senders_;
|
||||
|
||||
int rank_;
|
||||
int size_;
|
||||
std::vector<std::string> machine_lists_;
|
||||
std::unordered_map<std::string, void*> endpoint_to_socket_;
|
||||
};
|
||||
} // namespace multiverso
|
||||
|
||||
|
|
|
@ -57,6 +57,14 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
|
|||
return NetInterface::Get()->Connect(ranks, endpoints, size);
|
||||
}
|
||||
|
||||
void MV_NetClose(const char* endpoint) {
|
||||
NetInterface::Get()->Close(endpoint);
|
||||
}
|
||||
|
||||
void MV_NetFinalize() {
|
||||
NetInterface::Get()->Finalize();
|
||||
}
|
||||
|
||||
template void MV_Aggregate<char>(char*, int);
|
||||
template void MV_Aggregate<int>(int*, int);
|
||||
template void MV_Aggregate<float>(float*, int);
|
||||
|
|
Загрузка…
Ссылка в новой задаче