MPI thread multiple and serialized mode failed. Now using MPI in a single thread
This commit is contained in:
Родитель
9d1711da31
Коммит
aee379d826
|
@ -15,8 +15,8 @@ Global
|
|||
GlobalSection(ProjectConfigurationPlatforms) = postSolution
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.debug|x64.ActiveCfg = debug|x64
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.debug|x64.Build.0 = debug|x64
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.release|x64.ActiveCfg = debug|x64
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.release|x64.Build.0 = debug|x64
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.release|x64.ActiveCfg = release|x64
|
||||
{F2DD7153-E4FB-4CA8-9B9E-CB6AB025CCBA}.release|x64.Build.0 = release|x64
|
||||
{546681D6-495C-4AEE-BBC2-3CAEC86B5137}.debug|x64.ActiveCfg = debug|x64
|
||||
{546681D6-495C-4AEE-BBC2-3CAEC86B5137}.release|x64.ActiveCfg = release|x64
|
||||
{546681D6-495C-4AEE-BBC2-3CAEC86B5137}.release|x64.Build.0 = release|x64
|
||||
|
|
|
@ -74,27 +74,27 @@ void TestArray() {
|
|||
MultiversoBarrier();
|
||||
Log::Info("Create tables OK\n");
|
||||
|
||||
for (int i = 0; i < 100000; ++i) {
|
||||
// std::vector<float>& vec = shared_array->raw();
|
||||
|
||||
// shared_array->Get();
|
||||
float data[10];
|
||||
shared_array->Get(data, 10);
|
||||
float data[10];
|
||||
|
||||
Log::Info("Get OK\n");
|
||||
std::vector<float> delta(10);
|
||||
for (int i = 0; i < 10; ++i)
|
||||
delta[i] = static_cast<float>(i);
|
||||
|
||||
for (int i = 0; i < 10; ++i) std::cout << data[i] << " "; std::cout << std::endl;
|
||||
shared_array->Add(delta.data(), 10);
|
||||
|
||||
std::vector<float> delta(10);
|
||||
for (int i = 0; i < 10; ++i) delta[i] = static_cast<float>(i);
|
||||
Log::Info("Rank %d Add OK\n", MultiversoRank());
|
||||
|
||||
shared_array->Add(delta.data(), 10);
|
||||
|
||||
Log::Info("Add OK\n");
|
||||
|
||||
shared_array->Get(data, 10);
|
||||
|
||||
for (int i = 0; i < 10; ++i) std::cout << data[i] << " "; std::cout << std::endl;
|
||||
shared_array->Get(data, 10);
|
||||
Log::Info("Rank %d Get OK\n", MultiversoRank());
|
||||
for (int i = 0; i < 10; ++i)
|
||||
std::cout << data[i] << " "; std::cout << std::endl;
|
||||
MultiversoBarrier();
|
||||
|
||||
}
|
||||
MultiversoShutDown();
|
||||
}
|
||||
|
||||
|
@ -122,6 +122,7 @@ void TestNet() {
|
|||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// Log::ResetLogLevel(LogLevel::Debug);
|
||||
if (argc == 2) {
|
||||
if (strcmp(argv[1], "kv") == 0) TestKV();
|
||||
else if (strcmp(argv[1], "array") == 0) TestArray();
|
||||
|
|
|
@ -17,6 +17,7 @@ void MultiversoBarrier();
|
|||
|
||||
void MultiversoShutDown(bool finalize_mpi = true);
|
||||
|
||||
int MultiversoRank();
|
||||
}
|
||||
|
||||
#endif // MULTIVERSO_INCLUDE_MULTIVERSO_H_
|
|
@ -3,6 +3,7 @@
|
|||
#ifndef MULTIVERSO_MT_QUEUE_H_
|
||||
#define MULTIVERSO_MT_QUEUE_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <queue>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
|
@ -16,7 +17,7 @@ template<typename T>
|
|||
class MtQueue {
|
||||
public:
|
||||
/*! \brief Constructor */
|
||||
MtQueue() : exit_(false) {}
|
||||
MtQueue() { exit_.store(false); }
|
||||
|
||||
/*!
|
||||
* \brief Push an element into the queue. the function is based on
|
||||
|
@ -60,13 +61,16 @@ public:
|
|||
/*! \brief Exit queue, awake all threads blocked by the queue */
|
||||
void Exit();
|
||||
|
||||
bool Alive();
|
||||
|
||||
private:
|
||||
/*! the underlying container of queue */
|
||||
std::queue<T> buffer_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable empty_condition_;
|
||||
/*! whether the queue is still work */
|
||||
bool exit_;
|
||||
std::atomic_bool exit_;
|
||||
// bool exit_;
|
||||
|
||||
// No copying allowed
|
||||
MtQueue(const MtQueue&);
|
||||
|
@ -126,9 +130,14 @@ bool MtQueue<T>::Empty() const {
|
|||
template<typename T>
|
||||
void MtQueue<T>::Exit() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
exit_ = true;
|
||||
exit_.store(true);
|
||||
empty_condition_.notify_all();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool MtQueue<T>::Alive() {
|
||||
return exit_ == false;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MULTIVERSO_MT_QUEUE_H_
|
|
@ -42,7 +42,7 @@ void Actor::Main() {
|
|||
}
|
||||
|
||||
void Actor::DeliverTo(const std::string& dst_name, MessagePtr& msg) {
|
||||
Log::Debug("actors delivering msg (type = %d) from (rank = %d, %s) to (rank = %d, %s).\n", msg->type(), msg->src(), name().c_str(), msg->dst(), dst_name.c_str());
|
||||
// Log::Debug("actors delivering msg (type = %d) from (rank = %d, %s) to (rank = %d, %s).\n", msg->type(), msg->src(), name().c_str(), msg->dst(), dst_name.c_str());
|
||||
Zoo::Get()->Deliver(dst_name, msg);
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "multiverso/zoo.h"
|
||||
#include "multiverso/net.h"
|
||||
#include "multiverso/util/log.h"
|
||||
#include "multiverso/util/mt_queue.h"
|
||||
|
||||
namespace multiverso {
|
||||
|
||||
|
@ -33,12 +34,21 @@ Communicator::Communicator() : Actor(actor::kCommunicator) {
|
|||
|
||||
void Communicator::Main() {
|
||||
// TODO(feiga): join the thread, make sure it exit properly
|
||||
recv_thread_.reset(new std::thread(&Communicator::Communicate, this));
|
||||
Actor::Main();
|
||||
//recv_thread_.reset(new std::thread(&Communicator::Communicate, this));
|
||||
//Actor::Main();
|
||||
MessagePtr msg;
|
||||
while (mailbox_->Alive()) {
|
||||
while (mailbox_->TryPop(msg)) {
|
||||
ProcessMessage(msg);
|
||||
};
|
||||
size_t size = net_util_->Recv(&msg);
|
||||
if (size > 0) LocalForward(msg);
|
||||
}
|
||||
}
|
||||
|
||||
void Communicator::ProcessMessage(MessagePtr& msg) {
|
||||
if (msg->dst() != net_util_->rank()) {
|
||||
Log::Debug("Send a msg from %d to %d, type = %d\n", msg->src(), msg->dst(), msg->type());
|
||||
net_util_->Send(msg);
|
||||
return;
|
||||
}
|
||||
|
@ -51,6 +61,7 @@ void Communicator::Communicate() {
|
|||
size_t size = net_util_->Recv(&msg);
|
||||
if (size > 0) {
|
||||
// a message received
|
||||
Log::Debug("Recv a msg from %d to %d, size = %d, type = %d\n", msg->src(), msg->dst(), msg->size(), msg->type());
|
||||
CHECK(msg->dst() == Zoo::Get()->rank());
|
||||
LocalForward(msg);
|
||||
}
|
||||
|
|
|
@ -16,4 +16,8 @@ void MultiversoBarrier() {
|
|||
Zoo::Get()->Barrier();
|
||||
}
|
||||
|
||||
int MultiversoRank() {
|
||||
return Zoo::Get()->rank();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -25,13 +25,14 @@ namespace multiverso {
|
|||
#ifdef MULTIVERSO_USE_MPI
|
||||
class MPINetWrapper : public NetInterface {
|
||||
public:
|
||||
MPINetWrapper() : more_(std::numeric_limits<int>::max()) {}
|
||||
MPINetWrapper() : more_(std::numeric_limits<char>::max()) {}
|
||||
|
||||
void Init(int* argc, char** argv) override {
|
||||
// MPI_Init(argc, &argv);
|
||||
MPI_Initialized(&inited_);
|
||||
if (!inited_) {
|
||||
MPI_Init_thread(argc, &argv, MPI_THREAD_MULTIPLE, &thread_provided_);
|
||||
MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_);
|
||||
// MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_);
|
||||
// CHECK(thread_provided_ == MPI_THREAD_MULTIPLE);
|
||||
}
|
||||
MPI_Query_thread(&thread_provided_);
|
||||
|
@ -60,20 +61,16 @@ public:
|
|||
}
|
||||
|
||||
size_t Recv(MessagePtr* msg) override {
|
||||
MPI_Status status;
|
||||
int flag;
|
||||
// non-blocking probe whether message comes
|
||||
MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status);
|
||||
if (!flag) return 0;
|
||||
if (thread_provided_ == MPI_THREAD_SERIALIZED) {
|
||||
MPI_Status status;
|
||||
int flag;
|
||||
// non-blocking probe whether message comes
|
||||
MPI_Iprobe(MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &flag, &status);
|
||||
if (flag) {
|
||||
// a message come
|
||||
// block receive with lock guard
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return RecvMsg(msg);
|
||||
} else {
|
||||
// no message comes
|
||||
return 0;
|
||||
}
|
||||
// a message come
|
||||
// block receive with lock guard
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return RecvMsg(msg);
|
||||
} else if (thread_provided_ == MPI_THREAD_MULTIPLE) {
|
||||
return RecvMsg(msg);
|
||||
} else {
|
||||
|
@ -94,39 +91,53 @@ public:
|
|||
size += blob.size();
|
||||
}
|
||||
// Send an extra over tag indicating the finish of this Message
|
||||
MPI_Send(&more_, sizeof(int), MPI_BYTE, msg->dst(),
|
||||
MPI_Send(&more_, sizeof(char), MPI_BYTE, msg->dst(),
|
||||
0, MPI_COMM_WORLD);
|
||||
Log::Debug("MPI-Net: rank %d send msg size = %d\n", rank(), size+4);
|
||||
return size + sizeof(int);
|
||||
// Log::Debug("MPI-Net: rank %d send msg size = %d\n", rank(), size+4);
|
||||
return size + sizeof(char);
|
||||
}
|
||||
|
||||
size_t RecvMsg(MessagePtr* msg_ptr) {
|
||||
if (!msg_ptr->get()) msg_ptr->reset(new Message());
|
||||
// Receiving a Message from multiple recv
|
||||
Log::Debug("MPI-Net: rank %d started recv msg\n", rank());
|
||||
// Log::Debug("MPI-Net: rank %d started recv msg\n", rank());
|
||||
MessagePtr& msg = *msg_ptr;
|
||||
msg->data().clear();
|
||||
MPI_Status status;
|
||||
MPI_Recv(msg->header(), Message::kHeaderSize,
|
||||
MPI_BYTE, MPI_ANY_SOURCE,
|
||||
0, MPI_COMM_WORLD, &status);
|
||||
size_t size = Message::kHeaderSize;
|
||||
int i = 0;
|
||||
int flag;
|
||||
int num_probe = 0;
|
||||
while (true) {
|
||||
int count;
|
||||
MPI_Probe(msg->src(), 0, MPI_COMM_WORLD, &status);
|
||||
CHECK(MPI_SUCCESS == MPI_Probe(msg->src(), 0, MPI_COMM_WORLD, &status));
|
||||
//CHECK(MPI_SUCCESS == MPI_Iprobe(msg->src(), 0, MPI_COMM_WORLD, &flag, &status));
|
||||
//if (!flag) {
|
||||
// if (num_probe > 100) Log::Debug(" VLOG(RECV), Iprobe failed too much time \n", ++num_probe);
|
||||
// continue;
|
||||
//}
|
||||
MPI_Get_count(&status, MPI_BYTE, &count);
|
||||
Blob blob(count);
|
||||
// We only receive from msg->src() until we recv the overtag msg
|
||||
MPI_Recv(blob.data(), count, MPI_BYTE, msg->src(),
|
||||
0, MPI_COMM_WORLD, &status);
|
||||
size += count;
|
||||
if (count == sizeof(int) && blob.As<int>() == more_) break;
|
||||
if (count == sizeof(char)) {
|
||||
if (blob.As<char>() == more_) break;
|
||||
CHECK(1+1 != 2);
|
||||
}
|
||||
msg->Push(blob);
|
||||
// Log::Debug(" VLOG(RECV): i = %d\n", ++i);
|
||||
}
|
||||
Log::Debug("MPI-Net: rank %d end recv from src %d, size = %d\n", rank(), msg->src(), size);
|
||||
// Log::Debug("MPI-Net: rank %d end recv from src %d, size = %d\n", rank(), msg->src(), size);
|
||||
return size;
|
||||
}
|
||||
|
||||
private:
|
||||
const int more_;
|
||||
const char more_;
|
||||
std::mutex mutex_;
|
||||
int thread_provided_;
|
||||
int inited_;
|
||||
|
|
|
@ -94,11 +94,11 @@ void Zoo::Barrier() {
|
|||
msg->set_type(MsgType::Control_Barrier);
|
||||
Deliver(actor::kCommunicator, msg);
|
||||
|
||||
Log::Debug("rank %d requested barrier.\n", rank());
|
||||
// Log::Debug("rank %d requested barrier.\n", rank());
|
||||
// wait for reply
|
||||
mailbox_->Pop(msg);
|
||||
CHECK(msg->type() == MsgType::Control_Reply_Barrier);
|
||||
Log::Debug("rank %d reached barrier\n", rank());
|
||||
// Log::Debug("rank %d reached barrier\n", rank());
|
||||
}
|
||||
|
||||
int Zoo::RegisterTable(WorkerTable* worker_table) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче