diff --git a/Test/Test.vcxproj b/Test/Test.vcxproj
index 01ed1e3..81a2309 100644
--- a/Test/Test.vcxproj
+++ b/Test/Test.vcxproj
@@ -83,7 +83,7 @@
true
true
true
- IMultiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);
+ Multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);
D:\multiverso-next\x64\release
diff --git a/Test/main.cpp b/Test/main.cpp
index 2f2aafa..adf1f9e 100644
--- a/Test/main.cpp
+++ b/Test/main.cpp
@@ -10,6 +10,7 @@
#include
#include
#include
+#include
#include
#include
@@ -112,7 +113,6 @@ void TestArray(int argc, char* argv[]) {
std::cout << data[i] << " "; std::cout << std::endl;
MV_Barrier();
- if (iter % 100 == 0) MV_Dashboard();
}
MV_ShutDown();
}
@@ -157,52 +157,52 @@ void TestArray(int argc, char* argv[]) {
#define ARRAY_SIZE 4683776
void TestMultipleThread(int argc, char* argv[])
{
- Log::Info("Test Multiple threads \n");
- std::mt19937_64 eng{ std::random_device{}() };
- std::uniform_int_distribution<> dist{ 5, 10000 };
- std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
- //Log::ResetLogLevel(LogLevel::Debug);
- MV_Init(&argc, argv);
+ Log::Info("Test Multiple threads \n");
+ std::mt19937_64 eng{ std::random_device{}() };
+ std::uniform_int_distribution<> dist{ 5, 10000 };
+ std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
+ //Log::ResetLogLevel(LogLevel::Debug);
+ MV_Init(&argc, argv);
- ArrayWorker* shared_array = new ArrayWorker(ARRAY_SIZE);
- ArrayServer* server_array = new ArrayServer(ARRAY_SIZE);
- std::thread* m_prefetchThread = nullptr;
- MV_Barrier();
- Log::Info("Create tables OK\n");
+ ArrayWorker* shared_array = new ArrayWorker(ARRAY_SIZE);
+ ArrayServer* server_array = new ArrayServer(ARRAY_SIZE);
+ std::thread* m_prefetchThread = nullptr;
+ MV_Barrier();
+ Log::Info("Create tables OK\n");
- std::vector delta(ARRAY_SIZE);
- while (true){
- if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
- {
- m_prefetchThread->join();
- delete m_prefetchThread;
- m_prefetchThread = nullptr;
- }
+ std::vector delta(ARRAY_SIZE);
+ while (true){
+ if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
+ {
+ m_prefetchThread->join();
+ delete m_prefetchThread;
+ m_prefetchThread = nullptr;
+ }
- std::fill(delta.begin(), delta.end(), 0);
- for (int i = 0; i < ARRAY_SIZE; ++i)
- {
- std::mt19937_64 eng{ std::random_device{}() };
- std::uniform_real_distribution dist{ -1, 1 };
- delta[i] = dist(eng);
- }
- m_prefetchThread = new std::thread([&](){
-
- //std::mt19937_64 eng{ std::random_device{}() };
- //std::uniform_int_distribution<> dist{ 50, 500 };
- //std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
- shared_array->Add(delta.data(), ARRAY_SIZE);
- shared_array->Get(delta.data(), ARRAY_SIZE);
- Log::Info("Rank %d Get OK\n", MV_Rank());
- for (int i = 0; i < 10; ++i)
- std::cout << delta[i] << " "; std::cout << std::endl;
- });
+ std::fill(delta.begin(), delta.end(), 0);
+ for (int i = 0; i < ARRAY_SIZE; ++i)
+ {
+ std::mt19937_64 eng{ std::random_device{}() };
+ std::uniform_real_distribution dist{ -1, 1 };
+ delta[i] = dist(eng);
+ }
+ m_prefetchThread = new std::thread([&](){
- //shared_array->Get(data, 10);
- MV_Barrier();
+ //std::mt19937_64 eng{ std::random_device{}() };
+ //std::uniform_int_distribution<> dist{ 50, 500 };
+ //std::this_thread::sleep_for(std::chrono::milliseconds{ dist(eng) });
+ shared_array->Add(delta.data(), ARRAY_SIZE);
+ shared_array->Get(delta.data(), ARRAY_SIZE);
+ Log::Info("Rank %d Get OK\n", MV_Rank());
+ for (int i = 0; i < 10; ++i)
+ std::cout << delta[i] << " "; std::cout << std::endl;
+ });
- }
- MV_ShutDown();
+ //shared_array->Get(data, 10);
+ MV_Barrier();
+
+ }
+ MV_ShutDown();
}
@@ -230,7 +230,7 @@ void TestNet(int argc, char* argv[]) {
for (int i = 0; i < msg->size(); ++i) {
Log::Info("In Send: %s\n", msg->data()[i].data());
};
- while (net->Send(msg) == 0) ;
+ while (net->Send(msg) == 0);
Log::Info("rank 0 send\n");
}
@@ -248,7 +248,8 @@ void TestNet(int argc, char* argv[]) {
Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data());
};
}
- } else {// other rank
+ }
+ else {// other rank
MessagePtr msg(new Message());// = std::make_unique();
while (net->Recv(&msg) == 0) {
// Log::Info("recv return 0\n");
@@ -267,7 +268,7 @@ void TestNet(int argc, char* argv[]) {
msg->Push(Blob(hi1, 13));
msg->Push(Blob(hi2, 11));
msg->Push(Blob(hi3, 18));
- while (net->Send(msg) == 0) ;
+ while (net->Send(msg) == 0);
Log::Info("rank %d send\n", net->rank());
}
// while (!net->Test()) {
@@ -283,7 +284,7 @@ void TestIP() {
for (auto ip : ip_list) Log::Info("%s\n", ip.c_str());
}
-void TestNoNet(int argc, char* argv[]) {
+void TestNoNet(int argc, char* argv[]) {
int provided;
MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided);
@@ -329,90 +330,90 @@ void TestNoNet(int argc, char* argv[]) {
}
void TestMatrix(int argc, char* argv[]){
- Log::Info("Test Matrix\n");
+ Log::Info("Test Matrix\n");
- MV_Init(&argc, argv);
+ MV_Init(&argc, argv);
- int num_row = 11, num_col = 10;
- int size = num_row * num_col;
+ int num_row = 11, num_col = 10;
+ int size = num_row * num_col;
// MatrixWorkerTable* worker_table =
- // static_cast*>(MV_CreateTable("matrix", { &num_row, &num_col })); //new implementation
+ // static_cast*>(MV_CreateTable("matrix", { &num_row, &num_col })); //new implementation
// static_cast*>((new MatrixTableHelper(num_row, num_col))->CreateTable()); //older one
//if (worker_table == nullptr){ //should have more if statement to avoid nullptr in using worker_table
// Log::Debug("rank %d has no worker\n", MV_Rank());
// }
- MatrixWorkerTable* worker_table = new MatrixWorkerTable(num_row, num_col);
- MatrixServerTable* server_table = new MatrixServerTable(num_row, num_col);
- std::thread* m_prefetchThread = nullptr;
- MV_Barrier();
+ MatrixWorkerTable* worker_table = new MatrixWorkerTable(num_row, num_col);
+ MatrixServerTable* server_table = new MatrixServerTable(num_row, num_col);
+ std::thread* m_prefetchThread = nullptr;
+ MV_Barrier();
- while (true)
- {
- if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
- {
- m_prefetchThread->join();
- delete m_prefetchThread;
- m_prefetchThread = nullptr;
- }
- std::vector v = { 0, 1, 5, 10 };
+ while (true)
+ {
+ if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
+ {
+ m_prefetchThread->join();
+ delete m_prefetchThread;
+ m_prefetchThread = nullptr;
+ }
+ std::vector v = { 0, 1, 5, 10 };
- // test data
- std::vector delta(size);
- for (int i = 0; i < size; ++i)
- delta[i] = i;
+ // test data
+ std::vector delta(size);
+ for (int i = 0; i < size; ++i)
+ delta[i] = i;
- float * data = new float[size];
- m_prefetchThread = new std::thread([&](){
+ float * data = new float[size];
+ m_prefetchThread = new std::thread([&](){
UpdateOption option;
- worker_table->Add(delta.data(), size, &option); //add all
+ worker_table->Add(delta.data(), size, &option); //add all
- worker_table->Get(data, size); //get all
- printf("----------------------------\n");
- for (int i = 0; i < num_row; ++i){
- printf("rank %d, row %d: ", MV_Rank(), i);
- for (int j = 0; j < num_col; ++j)
- printf("%.2f ", data[i * num_col + j]);
- printf("\n");
- };
- });
+ worker_table->Get(data, size); //get all
+ printf("----------------------------\n");
+ for (int i = 0; i < num_row; ++i){
+ printf("rank %d, row %d: ", MV_Rank(), i);
+ for (int j = 0; j < num_col; ++j)
+ printf("%.2f ", data[i * num_col + j]);
+ printf("\n");
+ };
+ });
- MV_Barrier();
- if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
- {
- m_prefetchThread->join();
- delete m_prefetchThread;
- m_prefetchThread = nullptr;
- }
- //test data_vec
- std::vector data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] };
- std::vector delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
+ MV_Barrier();
+ if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
+ {
+ m_prefetchThread->join();
+ delete m_prefetchThread;
+ m_prefetchThread = nullptr;
+ }
+ //test data_vec
+ std::vector data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10 * num_col] };
+ std::vector delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
UpdateOption option;
- worker_table->Add(v, delta_rows, num_col, &option);
- worker_table->Get(v, data_rows, num_col);
- MV_Barrier();
+ worker_table->Add(v, delta_rows, num_col, &option);
+ worker_table->Get(v, data_rows, num_col);
+ MV_Barrier();
- printf("----------------------------\n");
- for (int i = 0; i < num_row; ++i){
- printf("rank %d, row %d: ", MV_Rank(), i);
- for (int j = 0; j < num_col; ++j)
- printf("%.2f ", data[i * num_col + j]);
- printf("\n");
- }
- MV_Barrier();
+ printf("----------------------------\n");
+ for (int i = 0; i < num_row; ++i){
+ printf("rank %d, row %d: ", MV_Rank(), i);
+ for (int j = 0; j < num_col; ++j)
+ printf("%.2f ", data[i * num_col + j]);
+ printf("\n");
+ }
+ MV_Barrier();
- }
- MV_ShutDown();
+ }
+ MV_ShutDown();
}
// NOTE(feiga): this doesn't work now since I roll back some implementation
void TestCheckPoint(int argc, char* argv[], bool restore){
Log::Info("Test CheckPoint\n");
- MV_Init(&argc, argv, 3 /*, restore */);
+ MV_Init(&argc, argv);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
@@ -453,19 +454,25 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
MV_ShutDown();
}
-void TestComm(int argc, char* argv[]) {
-
+void TestAllreduce(int argc, char* argv[]) {
+ multiverso::SetCMDFlag("ps_role", std::string("none"));
+ MV_Init(&argc, argv);
+ int a = 1;
+ MV_Aggregate(&a, 1);
+ std::cout << "a = " << a << std::endl;
+ MV_ShutDown();
}
+
int main(int argc, char* argv[]) {
Log::ResetLogLevel(LogLevel::Debug);
if (argc == 1){
- multiverso::MV_Init();
- ::testing::InitGoogleTest(&argc, argv);
- auto res = RUN_ALL_TESTS();
- multiverso::MV_ShutDown();
- return res;
+ multiverso::MV_Init();
+ ::testing::InitGoogleTest(&argc, argv);
+ auto res = RUN_ALL_TESTS();
+ multiverso::MV_ShutDown();
+ return res;
}
else {
if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv);
@@ -477,6 +484,7 @@ int main(int argc, char* argv[]) {
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
+ else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv);
else CHECK(false);
}
return 0;
diff --git a/include/multiverso/multiverso.h b/include/multiverso/multiverso.h
index 01a261f..510d2d3 100644
--- a/include/multiverso/multiverso.h
+++ b/include/multiverso/multiverso.h
@@ -5,13 +5,11 @@
namespace multiverso {
-void MV_Init(int* argc = nullptr,
- char* argv[] = nullptr,
- int role = 3);
+void MV_Init(int* argc = nullptr, char* argv[] = nullptr);
void MV_Barrier();
-void MV_ShutDown(bool finalize_mpi = true);
+void MV_ShutDown(bool finalize_net = true);
int MV_Rank();
int MV_Size();
@@ -25,9 +23,9 @@ int MV_ServerId();
int MV_WorkerIdToRank(int worker_id);
int MV_ServerIdToRank(int server_id);
-// Show the dashboard information about the monitored excuation time
-// used for profile
-void MV_Dashboard();
+// inplace sum by allreduce
+template
+void MV_Aggregate(ElemType* data, int size);
// --- Net API -------------------------------------------------------------- //
// NOTE(feiga): these API is only used for specific situation.
diff --git a/include/multiverso/net.h b/include/multiverso/net.h
index 150b321..a907ce2 100644
--- a/include/multiverso/net.h
+++ b/include/multiverso/net.h
@@ -32,6 +32,8 @@ public:
virtual int size() const = 0;
virtual int rank() const = 0;
+ // virtual void Allreduce(void* data, size_t count, int type, int type_size);
+
// \return 1. > 0 sended size
// 2. = 0 not sended
// 3. < 0 net error
@@ -45,6 +47,12 @@ public:
virtual int thread_level_support() = 0;
};
+namespace net {
+// inplace allreduce
+template
+void Allreduce(Typename* data, size_t elem_count);
+}
+
} // namespace multiverso
#endif // MULTIVERSO_NET_NET_H_
diff --git a/include/multiverso/net/mpi_net.h b/include/multiverso/net/mpi_net.h
index 2e58dc2..fd49181 100644
--- a/include/multiverso/net/mpi_net.h
+++ b/include/multiverso/net/mpi_net.h
@@ -70,6 +70,7 @@ public:
MV_MPI_CALL(MPI_Initialized(&inited_));
if (!inited_) {
MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_));
+ MV_MPI_CALL(MPI_Initialized(&inited_));
}
MV_MPI_CALL(MPI_Query_thread(&thread_provided_));
if (thread_provided_ < MPI_THREAD_SERIALIZED) {
@@ -105,6 +106,9 @@ public:
int size() const override { return size_; }
std::string name() const override { return "MPI"; }
+ template
+ static void Allreduce(ElemType* data, size_t elem_count, int op = MPI_SUM);
+
//size_t Send(MessagePtr& msg) override {
// while (!msg_handles_.empty()) {
// MPIMsgHandle* prev = msg_handles_.front();
diff --git a/include/multiverso/node.h b/include/multiverso/node.h
index a11cd42..9f3155d 100644
--- a/include/multiverso/node.h
+++ b/include/multiverso/node.h
@@ -4,15 +4,16 @@
namespace multiverso {
enum Role {
+ NONE = 0,
WORKER = 1,
- SERVER = 2
+ SERVER = 2,
+ ALL = 3
};
struct Node {
int rank;
// role can be 0, 1, 2, 3
- // 00 means neither worker nor server, should be controllor, so at most
- // one node could use this value
+ // 00 means neither worker nor server
// 01 means worker
// 10 means server
// 11 means both server and worker, default value
diff --git a/include/multiverso/server.h b/include/multiverso/server.h
index fed650a..7c83832 100644
--- a/include/multiverso/server.h
+++ b/include/multiverso/server.h
@@ -9,21 +9,17 @@
namespace multiverso {
class ServerTable;
+class Synchronizer;
class Server : public Actor {
public:
Server();
int RegisterTable(ServerTable* table);
- // store server data to file
- void StoreTable(int epoch);
- // load data from file and return next iteration number
- int LoadTable(const std::string& file_path);
- void SetTableFilePath(const std::string& table_file_path);
private:
void ProcessGet(MessagePtr& msg);
void ProcessAdd(MessagePtr& msg);
- std::string table_file_path_;
// contains the parameter data structure and related handle method
+ // Synchronizer* sync_;
std::vector store_;
};
diff --git a/include/multiverso/zoo.h b/include/multiverso/zoo.h
index bc1e565..f6a76c9 100644
--- a/include/multiverso/zoo.h
+++ b/include/multiverso/zoo.h
@@ -24,7 +24,7 @@ public:
static Zoo* Get() { static Zoo zoo; return &zoo; }
// Start all actors
- void Start(int* argc, char** argv, int role);
+ void Start(int* argc, char** argv);
// Stop all actors
void Stop(bool finalize_net);
diff --git a/src/Multiverso.vcxproj b/src/Multiverso.vcxproj
index 0610988..890303b 100644
--- a/src/Multiverso.vcxproj
+++ b/src/Multiverso.vcxproj
@@ -195,7 +195,6 @@
-
@@ -221,6 +220,7 @@
+
diff --git a/src/Multiverso.vcxproj.filters b/src/Multiverso.vcxproj.filters
index d929ffc..75165dd 100644
--- a/src/Multiverso.vcxproj.filters
+++ b/src/Multiverso.vcxproj.filters
@@ -94,9 +94,6 @@
updater
-
- updater
-
table
@@ -215,5 +212,8 @@
io
+
+ net
+
\ No newline at end of file
diff --git a/src/multiverso.cpp b/src/multiverso.cpp
index d8e8b98..208ba9a 100644
--- a/src/multiverso.cpp
+++ b/src/multiverso.cpp
@@ -6,8 +6,8 @@
namespace multiverso {
-void MV_Init(int* argc, char* argv[], int role) {
- Zoo::Get()->Start(argc, argv, role);
+void MV_Init(int* argc, char* argv[]) {
+ Zoo::Get()->Start(argc, argv);
}
void MV_ShutDown(bool finalize_net) {
@@ -42,8 +42,9 @@ int MV_ServerIdToRank(int server_id) {
return Zoo::Get()->server_id_to_rank(server_id);
}
-void MV_Dashboard() {
- Dashboard::Display();
+template
+void MV_Aggregate(ElemType* data, int size) {
+ net::Allreduce(data, size);
}
int MV_NetBind(int rank, char* endpoint) {
@@ -54,4 +55,9 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
+template void MV_Aggregate(char*, int);
+template void MV_Aggregate(int*, int);
+template void MV_Aggregate(float*, int);
+template void MV_Aggregate(double*, int);
+
} // namespace multiverso
diff --git a/src/net.cpp b/src/net.cpp
index 6127cba..24e8164 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -23,4 +23,23 @@ NetInterface* NetInterface::Get() {
#endif
}
+namespace net {
+template
+void Allreduce(Typename* data, size_t elem_count) {
+#ifdef MULTIVERSO_USE_MPI
+ CHECK(NetInterface::Get()->active());
+ MPINetWrapper::Allreduce(data, elem_count);
+#else
+
+#endif
+}
+
+template void Allreduce(char*, size_t);
+template void Allreduce(int*, size_t);
+template void Allreduce(float*, size_t);
+template void Allreduce(double*, size_t);
+
+} // namespace net
+
+
} // namespace multiverso
diff --git a/src/net/mpi_net.cpp b/src/net/mpi_net.cpp
new file mode 100644
index 0000000..1083a04
--- /dev/null
+++ b/src/net/mpi_net.cpp
@@ -0,0 +1,23 @@
+#include "multiverso/net/mpi_net.h"
+
+namespace multiverso {
+
+namespace {
+MPI_Datatype GetDataType(char*) { return MPI_CHAR; }
+MPI_Datatype GetDataType(int*) { return MPI_INT; }
+MPI_Datatype GetDataType(float*) { return MPI_FLOAT; }
+MPI_Datatype GetDataType(double*) { return MPI_DOUBLE; }
+}
+
+template
+void MPINetWrapper::Allreduce(ElemType* data, size_t elem_count, int op) {
+ MPI_Allreduce(MPI_IN_PLACE, data, (int)elem_count,
+ GetDataType(data), op, MPI_COMM_WORLD);
+}
+
+template void MPINetWrapper::Allreduce(char*, size_t, int);
+template void MPINetWrapper::Allreduce(int*, size_t, int);
+template void MPINetWrapper::Allreduce(float*, size_t, int);
+template void MPINetWrapper::Allreduce(double*, size_t, int);
+
+} // namespace multiverso
\ No newline at end of file
diff --git a/src/server.cpp b/src/server.cpp
index 00a76a9..b083718 100644
--- a/src/server.cpp
+++ b/src/server.cpp
@@ -44,51 +44,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
MONITOR_END(SERVER_PROCESS_ADD)
}
-void Server::SetTableFilePath(const std::string& table_file_path) {
- int id = Zoo::Get()->server_rank();
- std::string server_id_str = (id == 0 ? "0" : "");
- while (id > 0) {
- server_id_str = static_cast((id % 10) + '0') + server_id_str;
- id /= 10;
- }
- table_file_path_ = table_file_path + server_id_str;
-}
-
-void Server::StoreTable(int epoch) {
- Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
- FileOpenMode::Write);
- stream->Write(&epoch, sizeof(int));
- for (int i = 0; i < store_.size(); ++i) {
- store_[i]->Store(stream);
- }
- delete stream;
-}
-
-int Server::LoadTable(const std::string& file_path) {
- Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
- FileOpenMode::Read);
- if (!stream->Good()) {
- Log::Error("Rank %d open file %s error in Server::LoadTable\n",
- Zoo::Get()->rank(), file_path.c_str());
- delete stream;
- return 0; // open file error, may not exist
- }
-
- int iter;
- size_t readsize = stream->Read(&iter, sizeof(int));
- if (readsize == 0) {
- Log::Error("Rank %d read file %s no data in Server::LoadTable\n",
- Zoo::Get()->rank(), file_path.c_str());
- delete stream;
- return 0; // no store data
- }
-
- for (int i = 0; i < store_.size(); ++i) {
- store_[i]->Load(stream);
- }
-
- delete stream;
- return iter + 1; // the next iteration number
-}
-
} // namespace multiverso
diff --git a/src/updater/updater.cpp b/src/updater/updater.cpp
index c9d3a00..98d70da 100644
--- a/src/updater/updater.cpp
+++ b/src/updater/updater.cpp
@@ -2,7 +2,6 @@
#include "multiverso/updater/adagrad_updater.h"
#include "multiverso/updater/momentum_updater.h"
-#include "multiverso/updater/second_order_gradient_updater.h"
#include "multiverso/updater/sgd_updater.h"
#include "multiverso/util/configure.h"
#include "multiverso/util/log.h"
@@ -35,7 +34,6 @@ Updater* Updater::GetUpdater(size_t size) {
if (type == "sgd") return new SGDUpdater(size);
if (type == "adagrad") return new AdaGradUpdater(size);
if (type == "momentum_sgd") return new MomentumUpdater(size);
- if (type == "second_order_sgd") return new SecondOrderUpdater(size);
// Default: simple updater
return new Updater();
}
diff --git a/src/zoo.cpp b/src/zoo.cpp
index ad77b0f..8e7c1c5 100644
--- a/src/zoo.cpp
+++ b/src/zoo.cpp
@@ -20,39 +20,59 @@ Zoo::Zoo() {}
Zoo::~Zoo() {}
-void Zoo::Start(int* argc, char** argv, int role) {
+MV_DEFINE_string(ps_role, "default", "none / worker / server / default");
+MV_DEFINE_bool(ma, "false", "model average, will not start server if true");
+
+namespace {
+int ParsePSRole(const std::string& ps_role) {
+ if (ps_role == "none") return Role::NONE;
+ if (ps_role == "worker") return Role::WORKER;
+ if (ps_role == "server") return Role::SERVER;
+ if (ps_role == "default") return Role::ALL;
+ return -1;
+}
+} // namespace
+
+void Zoo::Start(int* argc, char** argv) {
Log::Debug("Zoo started\n");
- CHECK(role >= 0 && role <= 3);
ParseCMDFlags(argc, argv);
// Init the network
net_util_ = NetInterface::Get();
net_util_->Init(argc, argv);
- nodes_.resize(size());
- nodes_[rank()].rank = rank();
- nodes_[rank()].role = role;
- mailbox_.reset(new MtQueue);
- // NOTE(feiga): the start order is non-trivial, communicator should be last.
- if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); }
- if (node::is_server(role)) { Actor* server = new Server(); server->Start(); }
- if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); }
- Actor* communicator = new Communicator();
- communicator->Start();
+ if (!MV_CONFIG_ma) {
+ int role = ParsePSRole(MV_CONFIG_ps_role);
+ CHECK(role != -1);
- // activate the system
- RegisterNode();
- Log::Info("Rank %d: Zoo start sucessfully\n", rank());
+ nodes_.resize(size());
+ nodes_[rank()].rank = rank();
+ nodes_[rank()].role = role;
+ mailbox_.reset(new MtQueue);
+
+ // NOTE(feiga): the start order is non-trivial, communicator should be last.
+ if (rank() == 0) { Actor* controler = new Controller(); controler->Start(); }
+ if (node::is_server(role)) { Actor* server = new Server(); server->Start(); }
+ if (node::is_worker(role)) { Actor* worker = new Worker(); worker->Start(); }
+ Actor* communicator = new Communicator();
+ communicator->Start();
+
+ // activate the system
+ RegisterNode();
+ Log::Info("Rank %d: Zoo start sucessfully\n", rank());
+ }
}
void Zoo::Stop(bool finalize_net) {
// Stop the system
- Barrier();
+ if (!MV_CONFIG_ma) {
+ Barrier();
- Dashboard::Display();
+ Dashboard::Display();
- // Stop all actors
- for (auto actor : zoo_) { actor.second->Stop(); }
+ // Stop all actors
+ for (auto actor : zoo_) { actor.second->Stop(); }
+ }
// Stop the network
if (finalize_net) net_util_->Finalize();
}