add interface for model average

This commit is contained in:
feiga 2016-04-17 19:56:58 +08:00
Родитель 2af610bf25
Коммит 468948ea44
16 изменённых файлов: 239 добавлений и 205 удалений

Просмотреть файл

@ -83,7 +83,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>IMultiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
<AdditionalDependencies>Multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
<AdditionalLibraryDirectories>D:\multiverso-next\x64\release</AdditionalLibraryDirectories>
</Link>
<ProjectReference>

Просмотреть файл

@ -10,6 +10,7 @@
#include <multiverso/net.h>
#include <multiverso/util/log.h>
#include <multiverso/util/net_util.h>
#include <multiverso/util/configure.h>
#include <multiverso/table/array_table.h>
#include <multiverso/table/kv_table.h>
@ -112,7 +113,6 @@ void TestArray(int argc, char* argv[]) {
std::cout << data[i] << " "; std::cout << std::endl;
MV_Barrier();
if (iter % 100 == 0) MV_Dashboard();
}
MV_ShutDown();
}
@ -157,52 +157,52 @@ void TestArray(int argc, char* argv[]) {
#define ARRAY_SIZE 4683776
void TestMultipleThread(int argc, char* argv[])
{
Log::Info("Test Multiple threads \n");
std::mt19937_64 eng{ std::random_device{}() };
std::uniform_int_distribution<> dist{ 5, 10000 };
std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
//Log::ResetLogLevel(LogLevel::Debug);
MV_Init(&argc, argv);
Log::Info("Test Multiple threads \n");
std::mt19937_64 eng{ std::random_device{}() };
std::uniform_int_distribution<> dist{ 5, 10000 };
std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
//Log::ResetLogLevel(LogLevel::Debug);
MV_Init(&argc, argv);
ArrayWorker<float>* shared_array = new ArrayWorker<float>(ARRAY_SIZE);
ArrayServer<float>* server_array = new ArrayServer<float>(ARRAY_SIZE);
std::thread* m_prefetchThread = nullptr;
MV_Barrier();
Log::Info("Create tables OK\n");
ArrayWorker<float>* shared_array = new ArrayWorker<float>(ARRAY_SIZE);
ArrayServer<float>* server_array = new ArrayServer<float>(ARRAY_SIZE);
std::thread* m_prefetchThread = nullptr;
MV_Barrier();
Log::Info("Create tables OK\n");
std::vector<float> delta(ARRAY_SIZE);
while (true){
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
std::vector<float> delta(ARRAY_SIZE);
while (true){
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
std::fill(delta.begin(), delta.end(), 0);
for (int i = 0; i < ARRAY_SIZE; ++i)
{
std::mt19937_64 eng{ std::random_device{}() };
std::uniform_real_distribution<float> dist{ -1, 1 };
delta[i] = dist(eng);
}
m_prefetchThread = new std::thread([&](){
//std::mt19937_64 eng{ std::random_device{}() };
//std::uniform_int_distribution<> dist{ 50, 500 };
//std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
shared_array->Add(delta.data(), ARRAY_SIZE);
shared_array->Get(delta.data(), ARRAY_SIZE);
Log::Info("Rank %d Get OK\n", MV_Rank());
for (int i = 0; i < 10; ++i)
std::cout << delta[i] << " "; std::cout << std::endl;
});
std::fill(delta.begin(), delta.end(), 0);
for (int i = 0; i < ARRAY_SIZE; ++i)
{
std::mt19937_64 eng{ std::random_device{}() };
std::uniform_real_distribution<float> dist{ -1, 1 };
delta[i] = dist(eng);
}
m_prefetchThread = new std::thread([&](){
//shared_array->Get(data, 10);
MV_Barrier();
//std::mt19937_64 eng{ std::random_device{}() };
//std::uniform_int_distribution<> dist{ 50, 500 };
//std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
shared_array->Add(delta.data(), ARRAY_SIZE);
shared_array->Get(delta.data(), ARRAY_SIZE);
Log::Info("Rank %d Get OK\n", MV_Rank());
for (int i = 0; i < 10; ++i)
std::cout << delta[i] << " "; std::cout << std::endl;
});
}
MV_ShutDown();
//shared_array->Get(data, 10);
MV_Barrier();
}
MV_ShutDown();
}
@ -230,7 +230,7 @@ void TestNet(int argc, char* argv[]) {
for (int i = 0; i < msg->size(); ++i) {
Log::Info("In Send: %s\n", msg->data()[i].data());
};
while (net->Send(msg) == 0) ;
while (net->Send(msg) == 0);
Log::Info("rank 0 send\n");
}
@ -248,7 +248,8 @@ void TestNet(int argc, char* argv[]) {
Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data());
};
}
} else {// other rank
}
else {// other rank
MessagePtr msg(new Message());// = std::make_unique<Message>();
while (net->Recv(&msg) == 0) {
// Log::Info("recv return 0\n");
@ -267,7 +268,7 @@ void TestNet(int argc, char* argv[]) {
msg->Push(Blob(hi1, 13));
msg->Push(Blob(hi2, 11));
msg->Push(Blob(hi3, 18));
while (net->Send(msg) == 0) ;
while (net->Send(msg) == 0);
Log::Info("rank %d send\n", net->rank());
}
// while (!net->Test()) {
@ -283,7 +284,7 @@ void TestIP() {
for (auto ip : ip_list) Log::Info("%s\n", ip.c_str());
}
void TestNoNet(int argc, char* argv[]) {
void TestNoNet(int argc, char* argv[]) {
int provided;
MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided);
@ -329,90 +330,90 @@ void TestNoNet(int argc, char* argv[]) {
}
void TestMatrix(int argc, char* argv[]){
Log::Info("Test Matrix\n");
Log::Info("Test Matrix\n");
MV_Init(&argc, argv);
MV_Init(&argc, argv);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
int num_row = 11, num_col = 10;
int size = num_row * num_col;
// MatrixWorkerTable<int>* worker_table =
// static_cast<MatrixWorkerTable<int>*>(MV_CreateTable<int>("matrix", { &num_row, &num_col })); //new implementation
// static_cast<MatrixWorkerTable<int>*>(MV_CreateTable<int>("matrix", { &num_row, &num_col })); //new implementation
// static_cast<MatrixWorkerTable<int>*>((new MatrixTableHelper<int>(num_row, num_col))->CreateTable()); //older one
//if (worker_table == nullptr){ //should have more if statement to avoid nullptr in using worker_table
// Log::Debug("rank %d has no worker\n", MV_Rank());
// }
MatrixWorkerTable<float>* worker_table = new MatrixWorkerTable<float>(num_row, num_col);
MatrixServerTable<float>* server_table = new MatrixServerTable<float>(num_row, num_col);
std::thread* m_prefetchThread = nullptr;
MV_Barrier();
MatrixWorkerTable<float>* worker_table = new MatrixWorkerTable<float>(num_row, num_col);
MatrixServerTable<float>* server_table = new MatrixServerTable<float>(num_row, num_col);
std::thread* m_prefetchThread = nullptr;
MV_Barrier();
while (true)
{
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
std::vector<int> v = { 0, 1, 5, 10 };
while (true)
{
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
std::vector<int> v = { 0, 1, 5, 10 };
// test data
std::vector<float> delta(size);
for (int i = 0; i < size; ++i)
delta[i] = i;
// test data
std::vector<float> delta(size);
for (int i = 0; i < size; ++i)
delta[i] = i;
float * data = new float[size];
m_prefetchThread = new std::thread([&](){
float * data = new float[size];
m_prefetchThread = new std::thread([&](){
UpdateOption option;
worker_table->Add(delta.data(), size, &option); //add all
worker_table->Add(delta.data(), size, &option); //add all
worker_table->Get(data, size); //get all
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%.2f ", data[i * num_col + j]);
printf("\n");
};
});
worker_table->Get(data, size); //get all
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%.2f ", data[i * num_col + j]);
printf("\n");
};
});
MV_Barrier();
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
//test data_vec
std::vector<float*> data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] };
std::vector<float*> delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
MV_Barrier();
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
{
m_prefetchThread->join();
delete m_prefetchThread;
m_prefetchThread = nullptr;
}
//test data_vec
std::vector<float*> data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] };
std::vector<float*> delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
UpdateOption option;
worker_table->Add(v, delta_rows, num_col, &option);
worker_table->Get(v, data_rows, num_col);
MV_Barrier();
worker_table->Add(v, delta_rows, num_col, &option);
worker_table->Get(v, data_rows, num_col);
MV_Barrier();
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%.2f ", data[i * num_col + j]);
printf("\n");
}
MV_Barrier();
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%.2f ", data[i * num_col + j]);
printf("\n");
}
MV_Barrier();
}
MV_ShutDown();
}
MV_ShutDown();
}
// NOTE(feiga): this doesn't work now since I roll back some implementation
void TestCheckPoint(int argc, char* argv[], bool restore){
Log::Info("Test CheckPoint\n");
MV_Init(&argc, argv, 3 /*, restore */);
MV_Init(&argc, argv);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
@ -453,19 +454,25 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
MV_ShutDown();
}
void TestComm(int argc, char* argv[]) {
void TestAllreduce(int argc, char* argv[]) {
multiverso::SetCMDFlag("ps_role", std::string("none"));
MV_Init(&argc, argv);
int a = 1;
MV_Aggregate(&a, 1);
std::cout << "a = " << a << std::endl;
MV_ShutDown();
}
int main(int argc, char* argv[]) {
Log::ResetLogLevel(LogLevel::Debug);
if (argc == 1){
multiverso::MV_Init();
::testing::InitGoogleTest(&argc, argv);
auto res = RUN_ALL_TESTS();
multiverso::MV_ShutDown();
return res;
multiverso::MV_Init();
::testing::InitGoogleTest(&argc, argv);
auto res = RUN_ALL_TESTS();
multiverso::MV_ShutDown();
return res;
}
else {
if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv);
@ -477,6 +484,7 @@ int main(int argc, char* argv[]) {
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv);
else CHECK(false);
}
return 0;

Просмотреть файл

@ -5,13 +5,11 @@
namespace multiverso {
void MV_Init(int* argc = nullptr,
char* argv[] = nullptr,
int role = 3);
void MV_Init(int* argc = nullptr, char* argv[] = nullptr);
void MV_Barrier();
void MV_ShutDown(bool finalize_mpi = true);
void MV_ShutDown(bool finalize_net = true);
int MV_Rank();
int MV_Size();
@ -25,9 +23,9 @@ int MV_ServerId();
int MV_WorkerIdToRank(int worker_id);
int MV_ServerIdToRank(int server_id);
// Show the dashboard information about the monitored excuation time
// used for profile
void MV_Dashboard();
// inplace sum by allreduce
template <typename ElemType>
void MV_Aggregate(ElemType* data, int size);
// --- Net API -------------------------------------------------------------- //
// NOTE(feiga): these API is only used for specific situation.

Просмотреть файл

@ -32,6 +32,8 @@ public:
virtual int size() const = 0;
virtual int rank() const = 0;
// virtual void Allreduce(void* data, size_t count, int type, int type_size);
// \return 1. > 0 sended size
// 2. = 0 not sended
// 3. < 0 net error
@ -45,6 +47,12 @@ public:
virtual int thread_level_support() = 0;
};
namespace net {
// inplace allreduce
template <typename Typename>
void Allreduce(Typename* data, size_t elem_count);
}
} // namespace multiverso
#endif // MULTIVERSO_NET_NET_H_

Просмотреть файл

@ -70,6 +70,7 @@ public:
MV_MPI_CALL(MPI_Initialized(&inited_));
if (!inited_) {
MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_));
MV_MPI_CALL(MPI_Initialized(&inited_));
}
MV_MPI_CALL(MPI_Query_thread(&thread_provided_));
if (thread_provided_ < MPI_THREAD_SERIALIZED) {
@ -105,6 +106,9 @@ public:
int size() const override { return size_; }
std::string name() const override { return "MPI"; }
template <typename ElemType>
static void Allreduce(ElemType* data, size_t elem_count, int op = MPI_SUM);
//size_t Send(MessagePtr& msg) override {
// while (!msg_handles_.empty()) {
// MPIMsgHandle* prev = msg_handles_.front();

Просмотреть файл

@ -4,15 +4,16 @@
namespace multiverso {
enum Role {
NONE = 0,
WORKER = 1,
SERVER = 2
SERVER = 2,
ALL = 3
};
struct Node {
int rank;
// role can be 0, 1, 2, 3
// 00 means neither worker nor server, should be controllor, so at most
// one node could use this value
// 00 means neither worker nor server
// 01 means worker
// 10 means server
// 11 means both server and worker, default value

Просмотреть файл

@ -9,21 +9,17 @@
namespace multiverso {
class ServerTable;
class Synchronizer;
class Server : public Actor {
public:
Server();
int RegisterTable(ServerTable* table);
// store server data to file
void StoreTable(int epoch);
// load data from file and return next iteration number
int LoadTable(const std::string& file_path);
void SetTableFilePath(const std::string& table_file_path);
private:
void ProcessGet(MessagePtr& msg);
void ProcessAdd(MessagePtr& msg);
std::string table_file_path_;
// contains the parameter data structure and related handle method
// Synchronizer* sync_;
std::vector<ServerTable*> store_;
};

Просмотреть файл

@ -24,7 +24,7 @@ public:
static Zoo* Get() { static Zoo zoo; return &zoo; }
// Start all actors
void Start(int* argc, char** argv, int role);
void Start(int* argc, char** argv);
// Stop all actors
void Stop(bool finalize_net);

Просмотреть файл

@ -195,7 +195,6 @@
<ClInclude Include="..\include\multiverso\table_interface.h" />
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h" />
<ClInclude Include="..\include\multiverso\updater\sgd_updater.h" />
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h" />
<ClInclude Include="..\include\multiverso\updater\momentum_updater.h" />
<ClInclude Include="..\include\multiverso\updater\updater.h" />
<ClInclude Include="..\include\multiverso\util\configure.h" />
@ -221,6 +220,7 @@
<ClCompile Include="net.cpp" />
<ClCompile Include="net\allreduce_engine.cpp" />
<ClCompile Include="net\allreduce_topo.cpp" />
<ClCompile Include="net\mpi_net.cpp" />
<ClCompile Include="node.cpp" />
<ClCompile Include="server.cpp" />
<ClCompile Include="table.cpp" />

Просмотреть файл

@ -94,9 +94,6 @@
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h">
<Filter>updater</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h">
<Filter>updater</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\table\sparse_matrix_table.h">
<Filter>table</Filter>
</ClInclude>
@ -215,5 +212,8 @@
<ClCompile Include="io\local_stream.cpp">
<Filter>io</Filter>
</ClCompile>
<ClCompile Include="net\mpi_net.cpp">
<Filter>net</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -6,8 +6,8 @@
namespace multiverso {
void MV_Init(int* argc, char* argv[], int role) {
Zoo::Get()->Start(argc, argv, role);
void MV_Init(int* argc, char* argv[]) {
Zoo::Get()->Start(argc, argv);
}
void MV_ShutDown(bool finalize_net) {
@ -42,8 +42,9 @@ int MV_ServerIdToRank(int server_id) {
return Zoo::Get()->server_id_to_rank(server_id);
}
void MV_Dashboard() {
Dashboard::Display();
template <typename ElemType>
void MV_Aggregate(ElemType* data, int size) {
net::Allreduce(data, size);
}
int MV_NetBind(int rank, char* endpoint) {
@ -54,4 +55,9 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
template void MV_Aggregate<char>(char*, int);
template void MV_Aggregate<int>(int*, int);
template void MV_Aggregate<float>(float*, int);
template void MV_Aggregate<double>(double*, int);
} // namespace multiverso

Просмотреть файл

@ -23,4 +23,23 @@ NetInterface* NetInterface::Get() {
#endif
}
namespace net {
template <typename Typename>
void Allreduce(Typename* data, size_t elem_count) {
#ifdef MULTIVERSO_USE_MPI
CHECK(NetInterface::Get()->active());
MPINetWrapper::Allreduce(data, elem_count);
#else
#endif
}
template void Allreduce<char>(char*, size_t);
template void Allreduce<int>(int*, size_t);
template void Allreduce<float>(float*, size_t);
template void Allreduce<double>(double*, size_t);
} // namespace net
} // namespace multiverso

23
src/net/mpi_net.cpp Normal file
Просмотреть файл

@ -0,0 +1,23 @@
#include "multiverso/net/mpi_net.h"
namespace multiverso {
namespace {
MPI_Datatype GetDataType(char*) { return MPI_CHAR; }
MPI_Datatype GetDataType(int*) { return MPI_INT; }
MPI_Datatype GetDataType(float*) { return MPI_FLOAT; }
MPI_Datatype GetDataType(double*) { return MPI_DOUBLE; }
}
template <typename ElemType>
void MPINetWrapper::Allreduce(ElemType* data, size_t elem_count, int op) {
MPI_Allreduce(MPI_IN_PLACE, data, (int)elem_count,
GetDataType(data), op, MPI_COMM_WORLD);
}
template void MPINetWrapper::Allreduce<char>(char*, size_t, int);
template void MPINetWrapper::Allreduce<int>(int*, size_t, int);
template void MPINetWrapper::Allreduce<float>(float*, size_t, int);
template void MPINetWrapper::Allreduce<double>(double*, size_t, int);
} // namespace multiverso

Просмотреть файл

@ -44,51 +44,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
MONITOR_END(SERVER_PROCESS_ADD)
}
void Server::SetTableFilePath(const std::string& table_file_path) {
int id = Zoo::Get()->server_rank();
std::string server_id_str = (id == 0 ? "0" : "");
while (id > 0) {
server_id_str = static_cast<char>((id % 10) + '0') + server_id_str;
id /= 10;
}
table_file_path_ = table_file_path + server_id_str;
}
void Server::StoreTable(int epoch) {
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
FileOpenMode::Write);
stream->Write(&epoch, sizeof(int));
for (int i = 0; i < store_.size(); ++i) {
store_[i]->Store(stream);
}
delete stream;
}
int Server::LoadTable(const std::string& file_path) {
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
FileOpenMode::Read);
if (!stream->Good()) {
Log::Error("Rank %d open file %s error in Server::LoadTable\n",
Zoo::Get()->rank(), file_path.c_str());
delete stream;
return 0; // open file error, may not exist
}
int iter;
size_t readsize = stream->Read(&iter, sizeof(int));
if (readsize == 0) {
Log::Error("Rank %d read file %s no data in Server::LoadTable\n",
Zoo::Get()->rank(), file_path.c_str());
delete stream;
return 0; // no store data
}
for (int i = 0; i < store_.size(); ++i) {
store_[i]->Load(stream);
}
delete stream;
return iter + 1; // the next iteration number
}
} // namespace multiverso

Просмотреть файл

@ -2,7 +2,6 @@
#include "multiverso/updater/adagrad_updater.h"
#include "multiverso/updater/momentum_updater.h"
#include "multiverso/updater/second_order_gradient_updater.h"
#include "multiverso/updater/sgd_updater.h"
#include "multiverso/util/configure.h"
#include "multiverso/util/log.h"
@ -35,7 +34,6 @@ Updater<T>* Updater<T>::GetUpdater(size_t size) {
if (type == "sgd") return new SGDUpdater<T>(size);
if (type == "adagrad") return new AdaGradUpdater<T>(size);
if (type == "momentum_sgd") return new MomentumUpdater<T>(size);
if (type == "second_order_sgd") return new SecondOrderUpdater<T>(size);
// Default: simple updater
return new Updater<T>();
}

Просмотреть файл

@ -20,39 +20,59 @@ Zoo::Zoo() {}
Zoo::~Zoo() {}
void Zoo::Start(int* argc, char** argv, int role) {
MV_DEFINE_string(ps_role, "default", "none / worker / server / default");
MV_DEFINE_bool(ma, "false", "model average, will not start server if true");
namespace {
int ParsePSRole(const std::string& ps_role) {
if (ps_role == "none") return Role::NONE;
if (ps_role == "worker") return Role::WORKER;
if (ps_role == "server") return Role::SERVER;
if (ps_role == "default") return Role::ALL;
return -1;
}
} // namespace
void Zoo::Start(int* argc, char** argv) {
Log::Debug("Zoo started\n");
CHECK(role >= 0 && role <= 3);
ParseCMDFlags(argc, argv);
// Init the network
net_util_ = NetInterface::Get();
net_util_->Init(argc, argv);
nodes_.resize(size());
nodes_[rank()].rank = rank();
nodes_[rank()].role = role;
mailbox_.reset(new MtQueue<MessagePtr>);
// NOTE(feiga): the start order is non-trivial, communicator should be last.
if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); }
if (node::is_server(role)) { Actor* server = new Server(); server->Start(); }
if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); }
Actor* communicator = new Communicator();
communicator->Start();
if (!MV_CONFIG_ma) {
int role = ParsePSRole(MV_CONFIG_ps_role);
CHECK(role != -1);
// activate the system
RegisterNode();
Log::Info("Rank %d: Zoo start sucessfully\n", rank());
nodes_.resize(size());
nodes_[rank()].rank = rank();
nodes_[rank()].role = role;
mailbox_.reset(new MtQueue<MessagePtr>);
// NOTE(feiga): the start order is non-trivial, communicator should be last.
if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); }
if (node::is_server(role)) { Actor* server = new Server(); server->Start(); }
if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); }
Actor* communicator = new Communicator();
communicator->Start();
// activate the system
RegisterNode();
Log::Info("Rank %d: Zoo start sucessfully\n", rank());
}
}
void Zoo::Stop(bool finalize_net) {
// Stop the system
Barrier();
if (!MV_CONFIG_ma) {
Barrier();
Dashboard::Display();
Dashboard::Display();
// Stop all actors
for (auto actor : zoo_) { actor.second->Stop(); }
// Stop all actors
for (auto actor : zoo_) { actor.second->Stop(); }
}
// Stop the network
if (finalize_net) net_util_->Finalize();
}