diff --git a/Test/Test.vcxproj b/Test/Test.vcxproj index 01ed1e3..81a2309 100644 --- a/Test/Test.vcxproj +++ b/Test/Test.vcxproj @@ -83,7 +83,7 @@ true true true - IMultiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies); + Multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies); D:\multiverso-next\x64\release diff --git a/Test/main.cpp b/Test/main.cpp index 2f2aafa..adf1f9e 100644 --- a/Test/main.cpp +++ b/Test/main.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -112,7 +113,6 @@ void TestArray(int argc, char* argv[]) { std::cout << data[i] << " "; std::cout << std::endl; MV_Barrier(); - if (iter % 100 == 0) MV_Dashboard(); } MV_ShutDown(); } @@ -157,52 +157,52 @@ void TestArray(int argc, char* argv[]) { #define ARRAY_SIZE 4683776 void TestMultipleThread(int argc, char* argv[]) { - Log::Info("Test Multiple threads \n"); - std::mt19937_64 eng{ std::random_device{}() }; - std::uniform_int_distribution<> dist{ 5, 10000 }; - std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) }); - //Log::ResetLogLevel(LogLevel::Debug); - MV_Init(&argc, argv); + Log::Info("Test Multiple threads \n"); + std::mt19937_64 eng{ std::random_device{}() }; + std::uniform_int_distribution<> dist{ 5, 10000 }; + std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) }); + //Log::ResetLogLevel(LogLevel::Debug); + MV_Init(&argc, argv); - ArrayWorker* shared_array = new ArrayWorker(ARRAY_SIZE); - ArrayServer* server_array = new ArrayServer(ARRAY_SIZE); - std::thread* m_prefetchThread = nullptr; - MV_Barrier(); - Log::Info("Create tables OK\n"); + ArrayWorker* shared_array = new ArrayWorker(ARRAY_SIZE); + ArrayServer* server_array = new ArrayServer(ARRAY_SIZE); + std::thread* m_prefetchThread = nullptr; + MV_Barrier(); + Log::Info("Create tables OK\n"); - std::vector delta(ARRAY_SIZE); - while (true){ - if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) - { - m_prefetchThread->join(); - delete m_prefetchThread; - m_prefetchThread = nullptr; - } + std::vector delta(ARRAY_SIZE); + while (true){ + if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) + { + m_prefetchThread->join(); + delete m_prefetchThread; + m_prefetchThread = nullptr; + } - std::fill(delta.begin(), delta.end(), 0); - for (int i = 0; i < ARRAY_SIZE; ++i) - { - std::mt19937_64 eng{ std::random_device{}() }; - std::uniform_real_distribution dist{ -1, 1 }; - delta[i] = dist(eng); - } - m_prefetchThread = new std::thread([&](){ - - //std::mt19937_64 eng{ std::random_device{}() }; - //std::uniform_int_distribution<> dist{ 50, 500 }; - //std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) }); - shared_array->Add(delta.data(), ARRAY_SIZE); - shared_array->Get(delta.data(), ARRAY_SIZE); - Log::Info("Rank %d Get OK\n", MV_Rank()); - for (int i = 0; i < 10; ++i) - std::cout << delta[i] << " "; std::cout << std::endl; - }); + std::fill(delta.begin(), delta.end(), 0); + for (int i = 0; i < ARRAY_SIZE; ++i) + { + std::mt19937_64 eng{ std::random_device{}() }; + std::uniform_real_distribution dist{ -1, 1 }; + delta[i] = dist(eng); + } + m_prefetchThread = new std::thread([&](){ - //shared_array->Get(data, 10); - MV_Barrier(); + //std::mt19937_64 eng{ std::random_device{}() }; + //std::uniform_int_distribution<> dist{ 50, 500 }; + //std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) }); + shared_array->Add(delta.data(), ARRAY_SIZE); + shared_array->Get(delta.data(), ARRAY_SIZE); + Log::Info("Rank %d Get OK\n", MV_Rank()); + for (int i = 0; i < 10; ++i) + std::cout << delta[i] << " "; std::cout << std::endl; + }); - } - MV_ShutDown(); + //shared_array->Get(data, 10); + MV_Barrier(); + + } + MV_ShutDown(); } @@ -230,7 +230,7 @@ void TestNet(int argc, char* argv[]) { for (int i = 0; i < msg->size(); ++i) { Log::Info("In Send: %s\n", msg->data()[i].data()); }; - while (net->Send(msg) == 0) ; + while (net->Send(msg) == 0); Log::Info("rank 0 send\n"); } @@ -248,7 +248,8 @@ void TestNet(int argc, char* argv[]) { Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data()); }; } - } else {// other rank + } + else {// other rank MessagePtr msg(new Message());// = std::make_unique(); while (net->Recv(&msg) == 0) { // Log::Info("recv return 0\n"); @@ -267,7 +268,7 @@ void TestNet(int argc, char* argv[]) { msg->Push(Blob(hi1, 13)); msg->Push(Blob(hi2, 11)); msg->Push(Blob(hi3, 18)); - while (net->Send(msg) == 0) ; + while (net->Send(msg) == 0); Log::Info("rank %d send\n", net->rank()); } // while (!net->Test()) { @@ -283,7 +284,7 @@ void TestIP() { for (auto ip : ip_list) Log::Info("%s\n", ip.c_str()); } -void TestNoNet(int argc, char* argv[]) { +void TestNoNet(int argc, char* argv[]) { int provided; MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); @@ -329,90 +330,90 @@ void TestNoNet(int argc, char* argv[]) { } void TestMatrix(int argc, char* argv[]){ - Log::Info("Test Matrix\n"); + Log::Info("Test Matrix\n"); - MV_Init(&argc, argv); + MV_Init(&argc, argv); - int num_row = 11, num_col = 10; - int size = num_row * num_col; + int num_row = 11, num_col = 10; + int size = num_row * num_col; // MatrixWorkerTable* worker_table = - // static_cast*>(MV_CreateTable("matrix", { &num_row, &num_col })); //new implementation + // static_cast*>(MV_CreateTable("matrix", { &num_row, &num_col })); //new implementation // static_cast*>((new MatrixTableHelper(num_row, num_col))->CreateTable()); //older one //if (worker_table == nullptr){ //should have more if statement to avoid nullptr in using worker_table // Log::Debug("rank %d has no worker\n", MV_Rank()); // } - MatrixWorkerTable* worker_table = new MatrixWorkerTable(num_row, num_col); - MatrixServerTable* server_table = new MatrixServerTable(num_row, num_col); - std::thread* m_prefetchThread = nullptr; - MV_Barrier(); + MatrixWorkerTable* worker_table = new MatrixWorkerTable(num_row, num_col); + MatrixServerTable* server_table = new MatrixServerTable(num_row, num_col); + std::thread* m_prefetchThread = nullptr; + MV_Barrier(); - while (true) - { - if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) - { - m_prefetchThread->join(); - delete m_prefetchThread; - m_prefetchThread = nullptr; - } - std::vector v = { 0, 1, 5, 10 }; + while (true) + { + if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) + { + m_prefetchThread->join(); + delete m_prefetchThread; + m_prefetchThread = nullptr; + } + std::vector v = { 0, 1, 5, 10 }; - // test data - std::vector delta(size); - for (int i = 0; i < size; ++i) - delta[i] = i; + // test data + std::vector delta(size); + for (int i = 0; i < size; ++i) + delta[i] = i; - float * data = new float[size]; - m_prefetchThread = new std::thread([&](){ + float * data = new float[size]; + m_prefetchThread = new std::thread([&](){ UpdateOption option; - worker_table->Add(delta.data(), size, &option); //add all + worker_table->Add(delta.data(), size, &option); //add all - worker_table->Get(data, size); //get all - printf("----------------------------\n"); - for (int i = 0; i < num_row; ++i){ - printf("rank %d, row %d: ", MV_Rank(), i); - for (int j = 0; j < num_col; ++j) - printf("%.2f ", data[i * num_col + j]); - printf("\n"); - }; - }); + worker_table->Get(data, size); //get all + printf("----------------------------\n"); + for (int i = 0; i < num_row; ++i){ + printf("rank %d, row %d: ", MV_Rank(), i); + for (int j = 0; j < num_col; ++j) + printf("%.2f ", data[i * num_col + j]); + printf("\n"); + }; + }); - MV_Barrier(); - if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) - { - m_prefetchThread->join(); - delete m_prefetchThread; - m_prefetchThread = nullptr; - } - //test data_vec - std::vector data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] }; - std::vector delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] }; + MV_Barrier(); + if (m_prefetchThread != nullptr && m_prefetchThread->joinable()) + { + m_prefetchThread->join(); + delete m_prefetchThread; + m_prefetchThread = nullptr; + } + //test data_vec + std::vector data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] }; + std::vector delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] }; UpdateOption option; - worker_table->Add(v, delta_rows, num_col, &option); - worker_table->Get(v, data_rows, num_col); - MV_Barrier(); + worker_table->Add(v, delta_rows, num_col, &option); + worker_table->Get(v, data_rows, num_col); + MV_Barrier(); - printf("----------------------------\n"); - for (int i = 0; i < num_row; ++i){ - printf("rank %d, row %d: ", MV_Rank(), i); - for (int j = 0; j < num_col; ++j) - printf("%.2f ", data[i * num_col + j]); - printf("\n"); - } - MV_Barrier(); + printf("----------------------------\n"); + for (int i = 0; i < num_row; ++i){ + printf("rank %d, row %d: ", MV_Rank(), i); + for (int j = 0; j < num_col; ++j) + printf("%.2f ", data[i * num_col + j]); + printf("\n"); + } + MV_Barrier(); - } - MV_ShutDown(); + } + MV_ShutDown(); } // NOTE(feiga): this doesn't work now since I roll back some implementation void TestCheckPoint(int argc, char* argv[], bool restore){ Log::Info("Test CheckPoint\n"); - MV_Init(&argc, argv, 3 /*, restore */); + MV_Init(&argc, argv); int num_row = 11, num_col = 10; int size = num_row * num_col; @@ -453,19 +454,25 @@ void TestCheckPoint(int argc, char* argv[], bool restore){ MV_ShutDown(); } -void TestComm(int argc, char* argv[]) { - +void TestAllreduce(int argc, char* argv[]) { + multiverso::SetCMDFlag("ps_role", std::string("none")); + MV_Init(&argc, argv); + int a = 1; + MV_Aggregate(&a, 1); + std::cout << "a = " << a << std::endl; + MV_ShutDown(); } + int main(int argc, char* argv[]) { Log::ResetLogLevel(LogLevel::Debug); if (argc == 1){ - multiverso::MV_Init(); - ::testing::InitGoogleTest(&argc, argv); - auto res = RUN_ALL_TESTS(); - multiverso::MV_ShutDown(); - return res; + multiverso::MV_Init(); + ::testing::InitGoogleTest(&argc, argv); + auto res = RUN_ALL_TESTS(); + multiverso::MV_ShutDown(); + return res; } else { if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv); @@ -477,6 +484,7 @@ int main(int argc, char* argv[]) { else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv); else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false); else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true); + else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv); else CHECK(false); } return 0; diff --git a/include/multiverso/multiverso.h b/include/multiverso/multiverso.h index 01a261f..510d2d3 100644 --- a/include/multiverso/multiverso.h +++ b/include/multiverso/multiverso.h @@ -5,13 +5,11 @@ namespace multiverso { -void MV_Init(int* argc = nullptr, - char* argv[] = nullptr, - int role = 3); +void MV_Init(int* argc = nullptr, char* argv[] = nullptr); void MV_Barrier(); -void MV_ShutDown(bool finalize_mpi = true); +void MV_ShutDown(bool finalize_net = true); int MV_Rank(); int MV_Size(); @@ -25,9 +23,9 @@ int MV_ServerId(); int MV_WorkerIdToRank(int worker_id); int MV_ServerIdToRank(int server_id); -// Show the dashboard information about the monitored excuation time -// used for profile -void MV_Dashboard(); +// inplace sum by allreduce +template +void MV_Aggregate(ElemType* data, int size); // --- Net API -------------------------------------------------------------- // // NOTE(feiga): these API is only used for specific situation. diff --git a/include/multiverso/net.h b/include/multiverso/net.h index 150b321..a907ce2 100644 --- a/include/multiverso/net.h +++ b/include/multiverso/net.h @@ -32,6 +32,8 @@ public: virtual int size() const = 0; virtual int rank() const = 0; + // virtual void Allreduce(void* data, size_t count, int type, int type_size); + // \return 1. > 0 sended size // 2. = 0 not sended // 3. < 0 net error @@ -45,6 +47,12 @@ public: virtual int thread_level_support() = 0; }; +namespace net { +// inplace allreduce +template +void Allreduce(Typename* data, size_t elem_count); +} + } // namespace multiverso #endif // MULTIVERSO_NET_NET_H_ diff --git a/include/multiverso/net/mpi_net.h b/include/multiverso/net/mpi_net.h index 2e58dc2..fd49181 100644 --- a/include/multiverso/net/mpi_net.h +++ b/include/multiverso/net/mpi_net.h @@ -70,6 +70,7 @@ public: MV_MPI_CALL(MPI_Initialized(&inited_)); if (!inited_) { MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_)); + MV_MPI_CALL(MPI_Initialized(&inited_)); } MV_MPI_CALL(MPI_Query_thread(&thread_provided_)); if (thread_provided_ < MPI_THREAD_SERIALIZED) { @@ -105,6 +106,9 @@ public: int size() const override { return size_; } std::string name() const override { return "MPI"; } + template + static void Allreduce(ElemType* data, size_t elem_count, int op = MPI_SUM); + //size_t Send(MessagePtr& msg) override { // while (!msg_handles_.empty()) { // MPIMsgHandle* prev = msg_handles_.front(); diff --git a/include/multiverso/node.h b/include/multiverso/node.h index a11cd42..9f3155d 100644 --- a/include/multiverso/node.h +++ b/include/multiverso/node.h @@ -4,15 +4,16 @@ namespace multiverso { enum Role { + NONE = 0, WORKER = 1, - SERVER = 2 + SERVER = 2, + ALL = 3 }; struct Node { int rank; // role can be 0, 1, 2, 3 - // 00 means neither worker nor server, should be controllor, so at most - // one node could use this value + // 00 means neither worker nor server // 01 means worker // 10 means server // 11 means both server and worker, default value diff --git a/include/multiverso/server.h b/include/multiverso/server.h index fed650a..7c83832 100644 --- a/include/multiverso/server.h +++ b/include/multiverso/server.h @@ -9,21 +9,17 @@ namespace multiverso { class ServerTable; +class Synchronizer; class Server : public Actor { public: Server(); int RegisterTable(ServerTable* table); - // store server data to file - void StoreTable(int epoch); - // load data from file and return next iteration number - int LoadTable(const std::string& file_path); - void SetTableFilePath(const std::string& table_file_path); private: void ProcessGet(MessagePtr& msg); void ProcessAdd(MessagePtr& msg); - std::string table_file_path_; // contains the parameter data structure and related handle method + // Synchronizer* sync_; std::vector store_; }; diff --git a/include/multiverso/zoo.h b/include/multiverso/zoo.h index bc1e565..f6a76c9 100644 --- a/include/multiverso/zoo.h +++ b/include/multiverso/zoo.h @@ -24,7 +24,7 @@ public: static Zoo* Get() { static Zoo zoo; return &zoo; } // Start all actors - void Start(int* argc, char** argv, int role); + void Start(int* argc, char** argv); // Stop all actors void Stop(bool finalize_net); diff --git a/src/Multiverso.vcxproj b/src/Multiverso.vcxproj index 0610988..890303b 100644 --- a/src/Multiverso.vcxproj +++ b/src/Multiverso.vcxproj @@ -195,7 +195,6 @@ - @@ -221,6 +220,7 @@ + diff --git a/src/Multiverso.vcxproj.filters b/src/Multiverso.vcxproj.filters index d929ffc..75165dd 100644 --- a/src/Multiverso.vcxproj.filters +++ b/src/Multiverso.vcxproj.filters @@ -94,9 +94,6 @@ updater - - updater - table @@ -215,5 +212,8 @@ io + + net + \ No newline at end of file diff --git a/src/multiverso.cpp b/src/multiverso.cpp index d8e8b98..208ba9a 100644 --- a/src/multiverso.cpp +++ b/src/multiverso.cpp @@ -6,8 +6,8 @@ namespace multiverso { -void MV_Init(int* argc, char* argv[], int role) { - Zoo::Get()->Start(argc, argv, role); +void MV_Init(int* argc, char* argv[]) { + Zoo::Get()->Start(argc, argv); } void MV_ShutDown(bool finalize_net) { @@ -42,8 +42,9 @@ int MV_ServerIdToRank(int server_id) { return Zoo::Get()->server_id_to_rank(server_id); } -void MV_Dashboard() { - Dashboard::Display(); +template +void MV_Aggregate(ElemType* data, int size) { + net::Allreduce(data, size); } int MV_NetBind(int rank, char* endpoint) { @@ -54,4 +55,9 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) { return NetInterface::Get()->Connect(ranks, endpoints, size); } +template void MV_Aggregate(char*, int); +template void MV_Aggregate(int*, int); +template void MV_Aggregate(float*, int); +template void MV_Aggregate(double*, int); + } // namespace multiverso diff --git a/src/net.cpp b/src/net.cpp index 6127cba..24e8164 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -23,4 +23,23 @@ NetInterface* NetInterface::Get() { #endif } +namespace net { +template +void Allreduce(Typename* data, size_t elem_count) { +#ifdef MULTIVERSO_USE_MPI + CHECK(NetInterface::Get()->active()); + MPINetWrapper::Allreduce(data, elem_count); +#else + +#endif +} + +template void Allreduce(char*, size_t); +template void Allreduce(int*, size_t); +template void Allreduce(float*, size_t); +template void Allreduce(double*, size_t); + +} // namespace net + + } // namespace multiverso diff --git a/src/net/mpi_net.cpp b/src/net/mpi_net.cpp new file mode 100644 index 0000000..1083a04 --- /dev/null +++ b/src/net/mpi_net.cpp @@ -0,0 +1,23 @@ +#include "multiverso/net/mpi_net.h" + +namespace multiverso { + +namespace { +MPI_Datatype GetDataType(char*) { return MPI_CHAR; } +MPI_Datatype GetDataType(int*) { return MPI_INT; } +MPI_Datatype GetDataType(float*) { return MPI_FLOAT; } +MPI_Datatype GetDataType(double*) { return MPI_DOUBLE; } +} + +template +void MPINetWrapper::Allreduce(ElemType* data, size_t elem_count, int op) { + MPI_Allreduce(MPI_IN_PLACE, data, (int)elem_count, + GetDataType(data), op, MPI_COMM_WORLD); +} + +template void MPINetWrapper::Allreduce(char*, size_t, int); +template void MPINetWrapper::Allreduce(int*, size_t, int); +template void MPINetWrapper::Allreduce(float*, size_t, int); +template void MPINetWrapper::Allreduce(double*, size_t, int); + +} // namespace multiverso \ No newline at end of file diff --git a/src/server.cpp b/src/server.cpp index 00a76a9..b083718 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -44,51 +44,4 @@ void Server::ProcessAdd(MessagePtr& msg) { MONITOR_END(SERVER_PROCESS_ADD) } -void Server::SetTableFilePath(const std::string& table_file_path) { - int id = Zoo::Get()->server_rank(); - std::string server_id_str = (id == 0 ? "0" : ""); - while (id > 0) { - server_id_str = static_cast((id % 10) + '0') + server_id_str; - id /= 10; - } - table_file_path_ = table_file_path + server_id_str; -} - -void Server::StoreTable(int epoch) { - Stream* stream = StreamFactory::GetStream(URI(table_file_path_), - FileOpenMode::Write); - stream->Write(&epoch, sizeof(int)); - for (int i = 0; i < store_.size(); ++i) { - store_[i]->Store(stream); - } - delete stream; -} - -int Server::LoadTable(const std::string& file_path) { - Stream* stream = StreamFactory::GetStream(URI(table_file_path_), - FileOpenMode::Read); - if (!stream->Good()) { - Log::Error("Rank %d open file %s error in Server::LoadTable\n", - Zoo::Get()->rank(), file_path.c_str()); - delete stream; - return 0; // open file error, may not exist - } - - int iter; - size_t readsize = stream->Read(&iter, sizeof(int)); - if (readsize == 0) { - Log::Error("Rank %d read file %s no data in Server::LoadTable\n", - Zoo::Get()->rank(), file_path.c_str()); - delete stream; - return 0; // no store data - } - - for (int i = 0; i < store_.size(); ++i) { - store_[i]->Load(stream); - } - - delete stream; - return iter + 1; // the next iteration number -} - } // namespace multiverso diff --git a/src/updater/updater.cpp b/src/updater/updater.cpp index c9d3a00..98d70da 100644 --- a/src/updater/updater.cpp +++ b/src/updater/updater.cpp @@ -2,7 +2,6 @@ #include "multiverso/updater/adagrad_updater.h" #include "multiverso/updater/momentum_updater.h" -#include "multiverso/updater/second_order_gradient_updater.h" #include "multiverso/updater/sgd_updater.h" #include "multiverso/util/configure.h" #include "multiverso/util/log.h" @@ -35,7 +34,6 @@ Updater* Updater::GetUpdater(size_t size) { if (type == "sgd") return new SGDUpdater(size); if (type == "adagrad") return new AdaGradUpdater(size); if (type == "momentum_sgd") return new MomentumUpdater(size); - if (type == "second_order_sgd") return new SecondOrderUpdater(size); // Default: simple updater return new Updater(); } diff --git a/src/zoo.cpp b/src/zoo.cpp index ad77b0f..8e7c1c5 100644 --- a/src/zoo.cpp +++ b/src/zoo.cpp @@ -20,39 +20,59 @@ Zoo::Zoo() {} Zoo::~Zoo() {} -void Zoo::Start(int* argc, char** argv, int role) { +MV_DEFINE_string(ps_role, "default", "none / worker / server / default"); +MV_DEFINE_bool(ma, "false", "model average, will not start server if true"); + +namespace { +int ParsePSRole(const std::string& ps_role) { + if (ps_role == "none") return Role::NONE; + if (ps_role == "worker") return Role::WORKER; + if (ps_role == "server") return Role::SERVER; + if (ps_role == "default") return Role::ALL; + return -1; +} +} // namespace + +void Zoo::Start(int* argc, char** argv) { Log::Debug("Zoo started\n"); - CHECK(role >= 0 && role <= 3); ParseCMDFlags(argc, argv); // Init the network net_util_ = NetInterface::Get(); net_util_->Init(argc, argv); - nodes_.resize(size()); - nodes_[rank()].rank = rank(); - nodes_[rank()].role = role; - mailbox_.reset(new MtQueue); - // NOTE(feiga): the start order is non-trivial, communicator should be last. - if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); } - if (node::is_server(role)) { Actor* server = new Server(); server->Start(); } - if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); } - Actor* communicator = new Communicator(); - communicator->Start(); + if (!MV_CONFIG_ma) { + int role = ParsePSRole(MV_CONFIG_ps_role); + CHECK(role != -1); - // activate the system - RegisterNode(); - Log::Info("Rank %d: Zoo start sucessfully\n", rank()); + nodes_.resize(size()); + nodes_[rank()].rank = rank(); + nodes_[rank()].role = role; + mailbox_.reset(new MtQueue); + + // NOTE(feiga): the start order is non-trivial, communicator should be last. + if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); } + if (node::is_server(role)) { Actor* server = new Server(); server->Start(); } + if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); } + Actor* communicator = new Communicator(); + communicator->Start(); + + // activate the system + RegisterNode(); + Log::Info("Rank %d: Zoo start sucessfully\n", rank()); + } } void Zoo::Stop(bool finalize_net) { // Stop the system - Barrier(); + if (!MV_CONFIG_ma) { + Barrier(); - Dashboard::Display(); + Dashboard::Display(); - // Stop all actors - for (auto actor : zoo_) { actor.second->Stop(); } + // Stop all actors + for (auto actor : zoo_) { actor.second->Stop(); } + } // Stop the network if (finalize_net) net_util_->Finalize(); }