add interface to make C# wrapper works

This commit is contained in:
feiga 2016-03-03 22:03:56 +08:00
Родитель fcca31a17f
Коммит 5c8e644f82
8 изменённых файлов: 127 добавлений и 30 удалений

Просмотреть файл

@ -108,7 +108,7 @@
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>MULTIVERSO_USE_MPI;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>MULTIVERSO_USE_ZMQ;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@ -126,7 +126,7 @@
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>MULTIVERSO_USE_MPI;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>MULTIVERSO_USE_ZMQ;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>

Просмотреть файл

@ -22,8 +22,31 @@ void MV_ShutDown(bool finalize_mpi = true);
int MV_Rank();
int MV_Size();
int MV_Num_Workers();
int MV_Num_Servers();
int MV_Worker_Id();
int MV_Server_Id();
}
// --- Net API -------------------------------------------------------------- //
// NOTE(feiga): these API is only used for specific situation.
// Init Multiverso Net with the provided endpoint. Multiverso Net will bind
// the provided endpoint and use this endpoint to listen and recv message
// \param rank the rank of this MV process
// \param endpoint endpoint with format ip:port, e.g., 127.0.0.1:9999
// \return 0 SUCCESS
// \return -1 FAIL
int MV_Net_Bind(int rank, char* endpoint);
// Connect Multiverso Net with other processes in the system. Multiverso Net
// will connect these endpoints and send msgs
// \param ranks array of rank
// \param endpoints endpoints for each rank
// \param size size of the array
// \return 0 SUCCESS
// \return -1 FAIL
int MV_Net_Connect(int* rank, char* endpoint[], int size);
} // namespace multiverso
#endif // MULTIVERSO_INCLUDE_MULTIVERSO_H_

Просмотреть файл

@ -16,7 +16,15 @@ enum NetThreadLevel {
class NetInterface {
public:
static NetInterface* Get();
virtual void Init(int* argc = nullptr, char** argv = nullptr) = 0;
// Bind with a specific endpoint
virtual int Bind(int rank, char* endpoint) = 0;
// Connect with other endpoints
virtual int Connect(int* rank, char* endpoints[], int size) = 0;
virtual void Finalize() = 0;
virtual std::string name() const = 0;

Просмотреть файл

@ -87,6 +87,14 @@ public:
void Finalize() override { MPI_Finalize(); }
int Bind(int rank, char* endpoint) override {
Log::Fatal("Shouldn't call this in MPI Net\n");
}
int Connect(int* ranks, char* endpoints[], int size) override {
Log::Fatal("Shouldn't call this in MPI Net\n");
}
int rank() const override { return rank_; }
int size() const override { return size_; }
std::string name() const override { return "MPI"; }

Просмотреть файл

@ -21,6 +21,7 @@ public:
// argv[2]: port used
void Init(int* argc, char** argv) override {
// get machine file
if (inited_) return;
CHECK(*argc > 2);
std::vector<std::string> machine_lists;
ParseMachineFile(argv[1], &machine_lists);
@ -36,28 +37,65 @@ public:
for (auto ip : machine_lists) {
if (local_ip.find(ip) != local_ip.end()) { // my rank
rank_ = static_cast<int>(requester_.size());
requester_.push_back(nullptr);
responder_ = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_bind(responder_,
rank_ = static_cast<int>(senders_.size());
senders_.push_back(nullptr);
receiver_ = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_bind(receiver_,
("tcp://" + ip + ":" + std::to_string(port)).c_str());
CHECK(rc == 0);
} else {
void* requester = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_connect(requester,
void* senders = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_connect(senders,
("tcp://" + ip + ":" + std::to_string(port)).c_str());
CHECK(rc == 0);
requester_.push_back(requester);
senders_.push_back(senders);
}
}
CHECK_NOTNULL(responder_);
CHECK_NOTNULL(receiver_);
inited_ = true;
Log::Info("%s net util inited, rank = %d, size = %d\n",
name().c_str(), rank(), size());
}
virtual int Bind(int rank, char* endpoint) override {
rank_ = rank;
std::string ip_port(endpoint);
if (context_ == nullptr) { context_ = zmq_ctx_new(); }
CHECK_NOTNULL(context_);
receiver_ = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_bind(receiver_, ("tcp://" + ip_port).c_str());
if (rc == 0) return 0;
else {
Log::Error("Failed to bind the socket for receiver, ip:port = %s\n",
endpoint);
return -1;
}
}
// Connect with other endpoints
virtual int Connect(int* ranks, char* endpoints[], int size) override {
CHECK_NOTNULL(receiver_);
CHECK_NOTNULL(context_);
size_ = size + 1;
senders_.resize(size_);
for (int i = 0; i < size; ++i) {
int rank = ranks[i];
std::string ip_port(endpoints[i]);
senders_[rank] = zmq_socket(context_, ZMQ_DEALER);
int rc = zmq_connect(senders_[rank], ("tcp://" + ip_port).c_str());
if (rc != 0) {
Log::Error("Failed to connect the socket for sender, rank = %d, "
"ip:port = %s\n", rank, endpoints[i]);
return -1;
}
}
inited_ = true;
return 0;
}
void Finalize() override {
zmq_close(responder_);
for (auto& p : requester_) if (p) zmq_close(p);
zmq_close(receiver_);
for (auto& p : senders_) if (p) zmq_close(p);
zmq_ctx_destroy(context_);
}
@ -68,7 +106,7 @@ public:
size_t Send(MessagePtr& msg) override {
size_t size = 0;
int dst = msg->dst();
void* socket = requester_[dst];
void* socket = senders_[dst];
CHECK_NOTNULL(socket);
int send_size;
send_size = zmq_send(socket, msg->header(),
@ -101,25 +139,25 @@ public:
MessagePtr& msg = *msg_ptr;
msg->data().clear();
CHECK(msg.get());
recv_size = zmq_recv(responder_, msg->header(), Message::kHeaderSize, 0);
recv_size = zmq_recv(receiver_, msg->header(), Message::kHeaderSize, 0);
if (recv_size < 0) { return -1; }
CHECK(Message::kHeaderSize == recv_size);
size += recv_size;
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
while (more) {
recv_size = zmq_recv(responder_, &blob_size, sizeof(size_t), 0);
recv_size = zmq_recv(receiver_, &blob_size, sizeof(size_t), 0);
CHECK(recv_size == sizeof(size_t));
size += recv_size;
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
CHECK(more);
Blob blob(blob_size);
recv_size = zmq_recv(responder_, blob.data(), blob.size(), 0);
recv_size = zmq_recv(receiver_, blob.data(), blob.size(), 0);
CHECK(recv_size == blob_size);
size += recv_size;
msg->Push(blob);
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
}
return size;
}
@ -151,10 +189,10 @@ private:
fclose(file);
}
bool inited_;
void* context_;
void* responder_;
std::vector<void*> requester_;
void* receiver_;
std::vector<void*> senders_;
int rank_;
int size_;
};

Просмотреть файл

@ -36,8 +36,8 @@ public:
int size() const;
// TODO(to change)
int worker_rank() const;
int server_rank() const;
int worker_rank() const { return nodes_[rank()].worker_id; }
int server_rank() const { return nodes_[rank()].server_id; }
int num_workers() const { return num_workers_; }
int num_servers() const { return num_servers_; }

Просмотреть файл

@ -70,7 +70,6 @@ void Communicator::ProcessMessage(MessagePtr& msg) {
if (msg->dst() != net_util_->rank()) {
// Log::Debug("Send a msg from %d to %d, type = %d\n", msg->src(), msg->dst(), msg->type());
net_util_->Send(msg);
CHECK(msg.get() == nullptr)
return;
}
LocalForward(msg);

Просмотреть файл

@ -1,5 +1,6 @@
#include "multiverso/multiverso.h"
#include "multiverso/net.h"
#include "multiverso/zoo.h"
namespace multiverso {
@ -12,12 +13,32 @@ void MV_ShutDown(bool finalize_net) {
Zoo::Get()->Stop(finalize_net);
}
void MV_Barrier() {
Zoo::Get()->Barrier();
void MV_Barrier() { Zoo::Get()->Barrier(); }
int MV_Rank() { return Zoo::Get()->rank(); }
int MV_Size() { return Zoo::Get()->size(); }
int MV_Worker_Id() {
return Zoo::Get()->worker_rank();
}
int MV_Server_Id() {
return Zoo::Get()->server_rank();
}
int MV_Rank() {
return Zoo::Get()->rank();
int MV_Num_Workers() {
return Zoo::Get()->num_workers();
}
int MV_Num_Servers() {
return Zoo::Get()->num_servers();
}
int MV_Net_Bind(int rank, char* endpoint) {
return NetInterface::Get()->Bind(rank, endpoint);
}
int MV_Net_Connect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
}