add interface to make C# wrapper works
This commit is contained in:
Родитель
fcca31a17f
Коммит
5c8e644f82
|
@ -108,7 +108,7 @@
|
|||
</PrecompiledHeader>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>MULTIVERSO_USE_MPI;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>MULTIVERSO_USE_ZMQ;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
|
@ -126,7 +126,7 @@
|
|||
<Optimization>MaxSpeed</Optimization>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>MULTIVERSO_USE_MPI;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>MULTIVERSO_USE_ZMQ;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
|
|
|
@ -22,8 +22,31 @@ void MV_ShutDown(bool finalize_mpi = true);
|
|||
int MV_Rank();
|
||||
int MV_Size();
|
||||
|
||||
int MV_Num_Workers();
|
||||
int MV_Num_Servers();
|
||||
|
||||
int MV_Worker_Id();
|
||||
int MV_Server_Id();
|
||||
}
|
||||
|
||||
// --- Net API -------------------------------------------------------------- //
|
||||
// NOTE(feiga): these API is only used for specific situation.
|
||||
// Init Multiverso Net with the provided endpoint. Multiverso Net will bind
|
||||
// the provided endpoint and use this endpoint to listen and recv message
|
||||
// \param rank the rank of this MV process
|
||||
// \param endpoint endpoint with format ip:port, e.g., 127.0.0.1:9999
|
||||
// \return 0 SUCCESS
|
||||
// \return -1 FAIL
|
||||
int MV_Net_Bind(int rank, char* endpoint);
|
||||
|
||||
// Connect Multiverso Net with other processes in the system. Multiverso Net
|
||||
// will connect these endpoints and send msgs
|
||||
// \param ranks array of rank
|
||||
// \param endpoints endpoints for each rank
|
||||
// \param size size of the array
|
||||
// \return 0 SUCCESS
|
||||
// \return -1 FAIL
|
||||
int MV_Net_Connect(int* rank, char* endpoint[], int size);
|
||||
|
||||
} // namespace multiverso
|
||||
|
||||
#endif // MULTIVERSO_INCLUDE_MULTIVERSO_H_
|
|
@ -16,7 +16,15 @@ enum NetThreadLevel {
|
|||
class NetInterface {
|
||||
public:
|
||||
static NetInterface* Get();
|
||||
|
||||
virtual void Init(int* argc = nullptr, char** argv = nullptr) = 0;
|
||||
|
||||
// Bind with a specific endpoint
|
||||
virtual int Bind(int rank, char* endpoint) = 0;
|
||||
|
||||
// Connect with other endpoints
|
||||
virtual int Connect(int* rank, char* endpoints[], int size) = 0;
|
||||
|
||||
virtual void Finalize() = 0;
|
||||
|
||||
virtual std::string name() const = 0;
|
||||
|
|
|
@ -87,6 +87,14 @@ public:
|
|||
|
||||
void Finalize() override { MPI_Finalize(); }
|
||||
|
||||
int Bind(int rank, char* endpoint) override {
|
||||
Log::Fatal("Shouldn't call this in MPI Net\n");
|
||||
}
|
||||
|
||||
int Connect(int* ranks, char* endpoints[], int size) override {
|
||||
Log::Fatal("Shouldn't call this in MPI Net\n");
|
||||
}
|
||||
|
||||
int rank() const override { return rank_; }
|
||||
int size() const override { return size_; }
|
||||
std::string name() const override { return "MPI"; }
|
||||
|
|
|
@ -21,6 +21,7 @@ public:
|
|||
// argv[2]: port used
|
||||
void Init(int* argc, char** argv) override {
|
||||
// get machine file
|
||||
if (inited_) return;
|
||||
CHECK(*argc > 2);
|
||||
std::vector<std::string> machine_lists;
|
||||
ParseMachineFile(argv[1], &machine_lists);
|
||||
|
@ -36,28 +37,65 @@ public:
|
|||
|
||||
for (auto ip : machine_lists) {
|
||||
if (local_ip.find(ip) != local_ip.end()) { // my rank
|
||||
rank_ = static_cast<int>(requester_.size());
|
||||
requester_.push_back(nullptr);
|
||||
responder_ = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_bind(responder_,
|
||||
rank_ = static_cast<int>(senders_.size());
|
||||
senders_.push_back(nullptr);
|
||||
receiver_ = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_bind(receiver_,
|
||||
("tcp://" + ip + ":" + std::to_string(port)).c_str());
|
||||
CHECK(rc == 0);
|
||||
} else {
|
||||
void* requester = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_connect(requester,
|
||||
void* senders = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_connect(senders,
|
||||
("tcp://" + ip + ":" + std::to_string(port)).c_str());
|
||||
CHECK(rc == 0);
|
||||
requester_.push_back(requester);
|
||||
senders_.push_back(senders);
|
||||
}
|
||||
}
|
||||
CHECK_NOTNULL(responder_);
|
||||
CHECK_NOTNULL(receiver_);
|
||||
inited_ = true;
|
||||
Log::Info("%s net util inited, rank = %d, size = %d\n",
|
||||
name().c_str(), rank(), size());
|
||||
}
|
||||
|
||||
virtual int Bind(int rank, char* endpoint) override {
|
||||
rank_ = rank;
|
||||
std::string ip_port(endpoint);
|
||||
if (context_ == nullptr) { context_ = zmq_ctx_new(); }
|
||||
CHECK_NOTNULL(context_);
|
||||
receiver_ = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_bind(receiver_, ("tcp://" + ip_port).c_str());
|
||||
if (rc == 0) return 0;
|
||||
else {
|
||||
Log::Error("Failed to bind the socket for receiver, ip:port = %s\n",
|
||||
endpoint);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Connect with other endpoints
|
||||
virtual int Connect(int* ranks, char* endpoints[], int size) override {
|
||||
CHECK_NOTNULL(receiver_);
|
||||
CHECK_NOTNULL(context_);
|
||||
size_ = size + 1;
|
||||
senders_.resize(size_);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
int rank = ranks[i];
|
||||
std::string ip_port(endpoints[i]);
|
||||
senders_[rank] = zmq_socket(context_, ZMQ_DEALER);
|
||||
int rc = zmq_connect(senders_[rank], ("tcp://" + ip_port).c_str());
|
||||
if (rc != 0) {
|
||||
Log::Error("Failed to connect the socket for sender, rank = %d, "
|
||||
"ip:port = %s\n", rank, endpoints[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
inited_ = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Finalize() override {
|
||||
zmq_close(responder_);
|
||||
for (auto& p : requester_) if (p) zmq_close(p);
|
||||
zmq_close(receiver_);
|
||||
for (auto& p : senders_) if (p) zmq_close(p);
|
||||
zmq_ctx_destroy(context_);
|
||||
}
|
||||
|
||||
|
@ -68,7 +106,7 @@ public:
|
|||
size_t Send(MessagePtr& msg) override {
|
||||
size_t size = 0;
|
||||
int dst = msg->dst();
|
||||
void* socket = requester_[dst];
|
||||
void* socket = senders_[dst];
|
||||
CHECK_NOTNULL(socket);
|
||||
int send_size;
|
||||
send_size = zmq_send(socket, msg->header(),
|
||||
|
@ -101,25 +139,25 @@ public:
|
|||
MessagePtr& msg = *msg_ptr;
|
||||
msg->data().clear();
|
||||
CHECK(msg.get());
|
||||
recv_size = zmq_recv(responder_, msg->header(), Message::kHeaderSize, 0);
|
||||
recv_size = zmq_recv(receiver_, msg->header(), Message::kHeaderSize, 0);
|
||||
if (recv_size < 0) { return -1; }
|
||||
CHECK(Message::kHeaderSize == recv_size);
|
||||
|
||||
size += recv_size;
|
||||
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
|
||||
while (more) {
|
||||
recv_size = zmq_recv(responder_, &blob_size, sizeof(size_t), 0);
|
||||
recv_size = zmq_recv(receiver_, &blob_size, sizeof(size_t), 0);
|
||||
CHECK(recv_size == sizeof(size_t));
|
||||
size += recv_size;
|
||||
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
CHECK(more);
|
||||
Blob blob(blob_size);
|
||||
recv_size = zmq_recv(responder_, blob.data(), blob.size(), 0);
|
||||
recv_size = zmq_recv(receiver_, blob.data(), blob.size(), 0);
|
||||
CHECK(recv_size == blob_size);
|
||||
size += recv_size;
|
||||
msg->Push(blob);
|
||||
zmq_getsockopt(responder_, ZMQ_RCVMORE, &more, &more_size);
|
||||
zmq_getsockopt(receiver_, ZMQ_RCVMORE, &more, &more_size);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
@ -151,10 +189,10 @@ private:
|
|||
fclose(file);
|
||||
}
|
||||
|
||||
|
||||
bool inited_;
|
||||
void* context_;
|
||||
void* responder_;
|
||||
std::vector<void*> requester_;
|
||||
void* receiver_;
|
||||
std::vector<void*> senders_;
|
||||
int rank_;
|
||||
int size_;
|
||||
};
|
||||
|
|
|
@ -36,8 +36,8 @@ public:
|
|||
int size() const;
|
||||
|
||||
// TODO(to change)
|
||||
int worker_rank() const;
|
||||
int server_rank() const;
|
||||
int worker_rank() const { return nodes_[rank()].worker_id; }
|
||||
int server_rank() const { return nodes_[rank()].server_id; }
|
||||
|
||||
int num_workers() const { return num_workers_; }
|
||||
int num_servers() const { return num_servers_; }
|
||||
|
|
|
@ -70,7 +70,6 @@ void Communicator::ProcessMessage(MessagePtr& msg) {
|
|||
if (msg->dst() != net_util_->rank()) {
|
||||
// Log::Debug("Send a msg from %d to %d, type = %d\n", msg->src(), msg->dst(), msg->type());
|
||||
net_util_->Send(msg);
|
||||
CHECK(msg.get() == nullptr)
|
||||
return;
|
||||
}
|
||||
LocalForward(msg);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "multiverso/multiverso.h"
|
||||
|
||||
#include "multiverso/net.h"
|
||||
#include "multiverso/zoo.h"
|
||||
|
||||
namespace multiverso {
|
||||
|
@ -12,12 +13,32 @@ void MV_ShutDown(bool finalize_net) {
|
|||
Zoo::Get()->Stop(finalize_net);
|
||||
}
|
||||
|
||||
void MV_Barrier() {
|
||||
Zoo::Get()->Barrier();
|
||||
void MV_Barrier() { Zoo::Get()->Barrier(); }
|
||||
|
||||
int MV_Rank() { return Zoo::Get()->rank(); }
|
||||
|
||||
int MV_Size() { return Zoo::Get()->size(); }
|
||||
|
||||
int MV_Worker_Id() {
|
||||
return Zoo::Get()->worker_rank();
|
||||
}
|
||||
int MV_Server_Id() {
|
||||
return Zoo::Get()->server_rank();
|
||||
}
|
||||
|
||||
int MV_Rank() {
|
||||
return Zoo::Get()->rank();
|
||||
int MV_Num_Workers() {
|
||||
return Zoo::Get()->num_workers();
|
||||
}
|
||||
int MV_Num_Servers() {
|
||||
return Zoo::Get()->num_servers();
|
||||
}
|
||||
|
||||
int MV_Net_Bind(int rank, char* endpoint) {
|
||||
return NetInterface::Get()->Bind(rank, endpoint);
|
||||
}
|
||||
|
||||
int MV_Net_Connect(int* ranks, char* endpoints[], int size) {
|
||||
return NetInterface::Get()->Connect(ranks, endpoints, size);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче