This commit is contained in:
feiga 2016-06-30 16:19:08 +08:00
Родитель 4bda014f0b
Коммит 798353f799
5 изменённых файлов: 105 добавлений и 43 удалений

Просмотреть файл

@ -58,6 +58,8 @@ int MV_NetBind(int rank, char* endpoint);
// \return 0 SUCCESS
// \return -1 FAIL
int MV_NetConnect(int* rank, char* endpoint[], int size);
void MV_NetClose(const char* endpoint);
void MV_NetFinalize();
} // namespace multiverso

Просмотреть файл

@ -26,6 +26,8 @@ public:
// Connect with other endpoints
virtual int Connect(int* rank, char* endpoints[], int size) = 0;
virtual void Close(const char* endpoint) = 0;
virtual bool active() const = 0;
virtual std::string name() const = 0;

Просмотреть файл

@ -139,6 +139,10 @@ public:
return -1;
}
void Close() override {
Log::Fatal("Shouldn't call this in MPI Net\n");
}
bool active() const { return inited_ != 0; }
int rank() const override { return rank_; }
int size() const override { return size_; }

Просмотреть файл

@ -43,24 +43,28 @@ public:
for (auto ip : machine_lists_) {
if (local_ip.find(ip) != local_ip.end()) { // my rank
rank_ = static_cast<int>(senders_.size());
senders_.push_back(nullptr);
receiver_ = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_bind(receiver_,
("tcp://" + ip + ":" + std::to_string(port)).c_str());
Entity sender; sender.endpoint = ""; sender.socket = nullptr;
senders_.push_back(sender);
receiver_.socket = zmq_socket(context_, ZMQ_DEALER);
receiver_.endpoint = ip + ":" + std::to_string(port);
int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str());
endpoint_to_socket_[receiver_.endpoint] = receiver_.socket;
CHECK(rc == 0);
int linger = 0;
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
} else {
void* sender = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_connect(sender,
("tcp://" + ip + ":" + std::to_string(port)).c_str());
Entity sender;
sender.socket = zmq_socket(context_, ZMQ_DEALER);
sender.endpoint = ip + ":" + std::to_string(port);
int rc = zmq_connect(sender.socket, ("tcp://" + sender.endpoint).c_str());
endpoint_to_socket_[sender.endpoint] = sender.socket;
CHECK(rc == 0);
senders_.push_back(sender);
int linger = 0;
CHECK(zmq_setsockopt(sender, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
CHECK(zmq_setsockopt(sender.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
}
}
CHECK_NOTNULL(receiver_);
CHECK_NOTNULL(receiver_.socket);
active_ = true;
Log::Info("%s net util inited, rank = %d, size = %d\n",
name().c_str(), rank(), size());
@ -71,11 +75,13 @@ public:
std::string ip_port(endpoint);
if (context_ == nullptr) { context_ = zmq_ctx_new(); }
CHECK_NOTNULL(context_);
receiver_ = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_bind(receiver_, ("tcp://" + ip_port).c_str());
receiver_.socket = zmq_socket(context_, ZMQ_DEALER);
receiver_.endpoint = ip_port;
int rc = zmq_bind(receiver_.socket, ("tcp://" + receiver_.endpoint).c_str());
endpoint_to_socket_[receiver_.endpoint] = receiver_.socket;
if (rc == 0) {
int linger = 0;
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
CHECK(zmq_setsockopt(receiver_.socket, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
return 0;
}
else {
@ -87,17 +93,23 @@ public:
// Connect with other endpoints
virtual int Connect(int* ranks, char* endpoints[], int size) override {
CHECK_NOTNULL(receiver_);
CHECK_NOTNULL(receiver_.socket);
CHECK_NOTNULL(context_);
size_ = size;
senders_.resize(size_);
for (int i = 0; i < size; ++i) {
int rank = ranks[i];
if (rank == rank_) continue;
std::string ip_port(endpoints[i]);
senders_[rank] = zmq_socket(context_, ZMQ_DEALER);
// if (rank == rank_) continue;
if (ip_port == receiver_.endpoint) {
rank_ = rank;
continue;
}
senders_[rank].socket = zmq_socket(context_, ZMQ_DEALER);
senders_[rank].endpoint = ip_port;
endpoint_to_socket_[senders_[rank].endpoint] = senders_[rank].socket;
// NOTE(feiga): set linger to 0, otherwise will hang
int rc = zmq_connect(senders_[rank], ("tcp://" + ip_port).c_str());
int rc = zmq_connect(senders_[rank].socket, ("tcp://" + senders_[rank].endpoint).c_str());
if (rc != 0) {
Log::Error("Failed to connect the socket for sender, rank = %d, "
"ip:port = %s\n", rank, endpoints[i]);
@ -108,24 +120,47 @@ public:
return 0;
}
void Finalize() override {
active_ = false;
for (int i = 0; i < senders_.size(); ++i) {
if (i != rank_) {
int linger = 0;
CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0);
int rc = zmq_close(senders_[i]);
CHECK(rc == 0);
void Close(const char* endpoint) override {
std::string str_endpoint(endpoint);
auto it = endpoint_to_socket_.find(str_endpoint);
if (it != endpoint_to_socket_.end()) {
Log::Info("Close endpoint %s\n", it->first.c_str());
CHECK(zmq_close(it->second) == 0);
endpoint_to_socket_.erase(it);
if (endpoint_to_socket_.empty()) {
// Term the context when the last endpoint is closed properly
zmq_ctx_term(context_);
Log::Info("ZMQ Finalize sucessfully\n");
}
}
int linger = 0;
CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
int rc = zmq_close(receiver_);
CHECK(rc == 0);
}
CHECK(zmq_ctx_shutdown(context_)==0);
// zmq_ctx_term(context_);
Log::Info("zmq finalize: close context\n");
void Finalize() override {
active_ = false;
for (auto entity : senders_) {
if (entity.socket != nullptr) {
zmq_disconnect(entity.socket, entity.endpoint.c_str());
Close(entity.endpoint.c_str());
}
}
zmq_unbind(receiver_.socket, receiver_.endpoint.c_str());
Close(receiver_.endpoint.c_str());
//for (int i = 0; i < senders_.size(); ++i) {
// if (i != rank_) {
// int linger = 0;
// CHECK(zmq_setsockopt(senders_[i], ZMQ_LINGER, &linger, sizeof(linger)) == 0);
// int rc = zmq_close(senders_[i]);
// CHECK(rc == 0);
// }
//}
//int linger = 0;
//CHECK(zmq_setsockopt(receiver_, ZMQ_LINGER, &linger, sizeof(linger)) == 0);
//int rc = zmq_close(receiver_);
//CHECK(rc == 0);
//CHECK(zmq_ctx_shutdown(context_)==0);
//zmq_ctx_term(context_);
//Log::Info("zmq finalize: close context\n");
}
bool active() const override { return active_; }
@ -136,7 +171,7 @@ public:
size_t Send(MessagePtr& msg) override {
size_t size = 0;
int dst = msg->dst();
void* socket = senders_[dst];
void* socket = senders_[dst].socket;
CHECK_NOTNULL(socket);
int send_size;
send_size = zmq_send(socket, msg->header(),
@ -169,25 +204,25 @@ public:
MessagePtr& msg = *msg_ptr;
msg->data().clear();
CHECK(msg.get());
recv_size = zmq_recv(receiver_, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT);
recv_size = zmq_recv(receiver_.socket, msg->header(), Message::kHeaderSize, ZMQ_DONTWAIT);
if (recv_size < 0) { return -1; }
CHECK(Message::kHeaderSize == recv_size);
size += recv_size;
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
while (more) {
recv_size = zmq_recv(receiver_, &blob_size, sizeof(size_t), 0);
recv_size = zmq_recv(receiver_.socket, &blob_size, sizeof(size_t), 0);
CHECK(recv_size == sizeof(size_t));
size += recv_size;
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
CHECK(more);
Blob blob(blob_size);
recv_size = zmq_recv(receiver_, blob.data(), blob.size(), 0);
recv_size = zmq_recv(receiver_.socket, blob.data(), blob.size(), 0);
CHECK(recv_size == blob_size);
size += recv_size;
msg->Push(blob);
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_.socket, ZMQ_RCVMORE, &more, &more_size);
}
return size;
}
@ -196,7 +231,7 @@ public:
void SendTo(int rank, char* buf, int len) const override {
int send_size = 0;
while (send_size < len) {
int cur_size = zmq_send(senders_[rank], buf + send_size, len - send_size, 0);
int cur_size = zmq_send(senders_[rank].socket, buf + send_size, len - send_size, 0);
if (cur_size < 0) { Log::Error("socket send error %d", cur_size); }
send_size += cur_size;
}
@ -206,7 +241,7 @@ public:
// note: rank is not used here
int recv_size = 0;
while (recv_size < len) {
int cur_size = zmq_recv(receiver_, buf + recv_size, len - recv_size, 0);
int cur_size = zmq_recv(receiver_.socket, buf + recv_size, len - recv_size, 0);
if (cur_size < 0) { Log::Error("socket receive error %d", cur_size); }
recv_size += cur_size;
}
@ -250,11 +285,22 @@ protected:
bool active_;
void* context_;
void* receiver_;
std::vector<void*> senders_;
struct Entity {
std::string endpoint;
void* socket;
};
Entity receiver_;
std::vector<Entity> senders_;
// void* receiver_;
// std::vector<void*> senders_;
int rank_;
int size_;
std::vector<std::string> machine_lists_;
std::unordered_map<std::string, void*> endpoint_to_socket_;
};
} // namespace multiverso

Просмотреть файл

@ -57,6 +57,14 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
void MV_NetClose(const char* endpoint) {
NetInterface::Get()->Close(endpoint);
}
void MV_NetFinalize() {
NetInterface::Get()->Finalize();
}
template void MV_Aggregate<char>(char*, int);
template void MV_Aggregate<int>(int*, int);
template void MV_Aggregate<float>(float*, int);