add interface for model average
This commit is contained in:
Родитель
2af610bf25
Коммит
468948ea44
|
@ -83,7 +83,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>IMultiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
|
||||
<AdditionalDependencies>Multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>D:\multiverso-next\x64\release</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
<ProjectReference>
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <multiverso/net.h>
|
||||
#include <multiverso/util/log.h>
|
||||
#include <multiverso/util/net_util.h>
|
||||
#include <multiverso/util/configure.h>
|
||||
|
||||
#include <multiverso/table/array_table.h>
|
||||
#include <multiverso/table/kv_table.h>
|
||||
|
@ -112,7 +113,6 @@ void TestArray(int argc, char* argv[]) {
|
|||
std::cout << data[i] << " "; std::cout << std::endl;
|
||||
MV_Barrier();
|
||||
|
||||
if (iter % 100 == 0) MV_Dashboard();
|
||||
}
|
||||
MV_ShutDown();
|
||||
}
|
||||
|
@ -248,7 +248,8 @@ void TestNet(int argc, char* argv[]) {
|
|||
Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data());
|
||||
};
|
||||
}
|
||||
} else {// other rank
|
||||
}
|
||||
else {// other rank
|
||||
MessagePtr msg(new Message());// = std::make_unique<Message>();
|
||||
while (net->Recv(&msg) == 0) {
|
||||
// Log::Info("recv return 0\n");
|
||||
|
@ -412,7 +413,7 @@ void TestMatrix(int argc, char* argv[]){
|
|||
void TestCheckPoint(int argc, char* argv[], bool restore){
|
||||
Log::Info("Test CheckPoint\n");
|
||||
|
||||
MV_Init(&argc, argv, 3 /*, restore */);
|
||||
MV_Init(&argc, argv);
|
||||
|
||||
int num_row = 11, num_col = 10;
|
||||
int size = num_row * num_col;
|
||||
|
@ -453,10 +454,16 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
|
|||
MV_ShutDown();
|
||||
}
|
||||
|
||||
void TestComm(int argc, char* argv[]) {
|
||||
|
||||
void TestAllreduce(int argc, char* argv[]) {
|
||||
multiverso::SetCMDFlag("ps_role", std::string("none"));
|
||||
MV_Init(&argc, argv);
|
||||
int a = 1;
|
||||
MV_Aggregate(&a, 1);
|
||||
std::cout << "a = " << a << std::endl;
|
||||
MV_ShutDown();
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
Log::ResetLogLevel(LogLevel::Debug);
|
||||
|
||||
|
@ -477,6 +484,7 @@ int main(int argc, char* argv[]) {
|
|||
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
|
||||
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
|
||||
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
|
||||
else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv);
|
||||
else CHECK(false);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -5,13 +5,11 @@
|
|||
|
||||
namespace multiverso {
|
||||
|
||||
void MV_Init(int* argc = nullptr,
|
||||
char* argv[] = nullptr,
|
||||
int role = 3);
|
||||
void MV_Init(int* argc = nullptr, char* argv[] = nullptr);
|
||||
|
||||
void MV_Barrier();
|
||||
|
||||
void MV_ShutDown(bool finalize_mpi = true);
|
||||
void MV_ShutDown(bool finalize_net = true);
|
||||
|
||||
int MV_Rank();
|
||||
int MV_Size();
|
||||
|
@ -25,9 +23,9 @@ int MV_ServerId();
|
|||
int MV_WorkerIdToRank(int worker_id);
|
||||
int MV_ServerIdToRank(int server_id);
|
||||
|
||||
// Show the dashboard information about the monitored excuation time
|
||||
// used for profile
|
||||
void MV_Dashboard();
|
||||
// inplace sum by allreduce
|
||||
template <typename ElemType>
|
||||
void MV_Aggregate(ElemType* data, int size);
|
||||
|
||||
// --- Net API -------------------------------------------------------------- //
|
||||
// NOTE(feiga): these API is only used for specific situation.
|
||||
|
|
|
@ -32,6 +32,8 @@ public:
|
|||
virtual int size() const = 0;
|
||||
virtual int rank() const = 0;
|
||||
|
||||
// virtual void Allreduce(void* data, size_t count, int type, int type_size);
|
||||
|
||||
// \return 1. > 0 sended size
|
||||
// 2. = 0 not sended
|
||||
// 3. < 0 net error
|
||||
|
@ -45,6 +47,12 @@ public:
|
|||
virtual int thread_level_support() = 0;
|
||||
};
|
||||
|
||||
namespace net {
|
||||
// inplace allreduce
|
||||
template <typename Typename>
|
||||
void Allreduce(Typename* data, size_t elem_count);
|
||||
}
|
||||
|
||||
} // namespace multiverso
|
||||
|
||||
#endif // MULTIVERSO_NET_NET_H_
|
||||
|
|
|
@ -70,6 +70,7 @@ public:
|
|||
MV_MPI_CALL(MPI_Initialized(&inited_));
|
||||
if (!inited_) {
|
||||
MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_));
|
||||
MV_MPI_CALL(MPI_Initialized(&inited_));
|
||||
}
|
||||
MV_MPI_CALL(MPI_Query_thread(&thread_provided_));
|
||||
if (thread_provided_ < MPI_THREAD_SERIALIZED) {
|
||||
|
@ -105,6 +106,9 @@ public:
|
|||
int size() const override { return size_; }
|
||||
std::string name() const override { return "MPI"; }
|
||||
|
||||
template <typename ElemType>
|
||||
static void Allreduce(ElemType* data, size_t elem_count, int op = MPI_SUM);
|
||||
|
||||
//size_t Send(MessagePtr& msg) override {
|
||||
// while (!msg_handles_.empty()) {
|
||||
// MPIMsgHandle* prev = msg_handles_.front();
|
||||
|
|
|
@ -4,15 +4,16 @@
|
|||
namespace multiverso {
|
||||
|
||||
enum Role {
|
||||
NONE = 0,
|
||||
WORKER = 1,
|
||||
SERVER = 2
|
||||
SERVER = 2,
|
||||
ALL = 3
|
||||
};
|
||||
|
||||
struct Node {
|
||||
int rank;
|
||||
// role can be 0, 1, 2, 3
|
||||
// 00 means neither worker nor server, should be controllor, so at most
|
||||
// one node could use this value
|
||||
// 00 means neither worker nor server
|
||||
// 01 means worker
|
||||
// 10 means server
|
||||
// 11 means both server and worker, default value
|
||||
|
|
|
@ -9,21 +9,17 @@
|
|||
namespace multiverso {
|
||||
|
||||
class ServerTable;
|
||||
class Synchronizer;
|
||||
|
||||
class Server : public Actor {
|
||||
public:
|
||||
Server();
|
||||
int RegisterTable(ServerTable* table);
|
||||
// store server data to file
|
||||
void StoreTable(int epoch);
|
||||
// load data from file and return next iteration number
|
||||
int LoadTable(const std::string& file_path);
|
||||
void SetTableFilePath(const std::string& table_file_path);
|
||||
private:
|
||||
void ProcessGet(MessagePtr& msg);
|
||||
void ProcessAdd(MessagePtr& msg);
|
||||
std::string table_file_path_;
|
||||
// contains the parameter data structure and related handle method
|
||||
// Synchronizer* sync_;
|
||||
std::vector<ServerTable*> store_;
|
||||
};
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ public:
|
|||
static Zoo* Get() { static Zoo zoo; return &zoo; }
|
||||
|
||||
// Start all actors
|
||||
void Start(int* argc, char** argv, int role);
|
||||
void Start(int* argc, char** argv);
|
||||
// Stop all actors
|
||||
void Stop(bool finalize_net);
|
||||
|
||||
|
|
|
@ -195,7 +195,6 @@
|
|||
<ClInclude Include="..\include\multiverso\table_interface.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\sgd_updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\momentum_updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\util\configure.h" />
|
||||
|
@ -221,6 +220,7 @@
|
|||
<ClCompile Include="net.cpp" />
|
||||
<ClCompile Include="net\allreduce_engine.cpp" />
|
||||
<ClCompile Include="net\allreduce_topo.cpp" />
|
||||
<ClCompile Include="net\mpi_net.cpp" />
|
||||
<ClCompile Include="node.cpp" />
|
||||
<ClCompile Include="server.cpp" />
|
||||
<ClCompile Include="table.cpp" />
|
||||
|
|
|
@ -94,9 +94,6 @@
|
|||
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h">
|
||||
<Filter>updater</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h">
|
||||
<Filter>updater</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\include\multiverso\table\sparse_matrix_table.h">
|
||||
<Filter>table</Filter>
|
||||
</ClInclude>
|
||||
|
@ -215,5 +212,8 @@
|
|||
<ClCompile Include="io\local_stream.cpp">
|
||||
<Filter>io</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="net\mpi_net.cpp">
|
||||
<Filter>net</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -6,8 +6,8 @@
|
|||
|
||||
namespace multiverso {
|
||||
|
||||
void MV_Init(int* argc, char* argv[], int role) {
|
||||
Zoo::Get()->Start(argc, argv, role);
|
||||
void MV_Init(int* argc, char* argv[]) {
|
||||
Zoo::Get()->Start(argc, argv);
|
||||
}
|
||||
|
||||
void MV_ShutDown(bool finalize_net) {
|
||||
|
@ -42,8 +42,9 @@ int MV_ServerIdToRank(int server_id) {
|
|||
return Zoo::Get()->server_id_to_rank(server_id);
|
||||
}
|
||||
|
||||
void MV_Dashboard() {
|
||||
Dashboard::Display();
|
||||
template <typename ElemType>
|
||||
void MV_Aggregate(ElemType* data, int size) {
|
||||
net::Allreduce(data, size);
|
||||
}
|
||||
|
||||
int MV_NetBind(int rank, char* endpoint) {
|
||||
|
@ -54,4 +55,9 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
|
|||
return NetInterface::Get()->Connect(ranks, endpoints, size);
|
||||
}
|
||||
|
||||
template void MV_Aggregate<char>(char*, int);
|
||||
template void MV_Aggregate<int>(int*, int);
|
||||
template void MV_Aggregate<float>(float*, int);
|
||||
template void MV_Aggregate<double>(double*, int);
|
||||
|
||||
} // namespace multiverso
|
||||
|
|
19
src/net.cpp
19
src/net.cpp
|
@ -23,4 +23,23 @@ NetInterface* NetInterface::Get() {
|
|||
#endif
|
||||
}
|
||||
|
||||
namespace net {
|
||||
template <typename Typename>
|
||||
void Allreduce(Typename* data, size_t elem_count) {
|
||||
#ifdef MULTIVERSO_USE_MPI
|
||||
CHECK(NetInterface::Get()->active());
|
||||
MPINetWrapper::Allreduce(data, elem_count);
|
||||
#else
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
template void Allreduce<char>(char*, size_t);
|
||||
template void Allreduce<int>(int*, size_t);
|
||||
template void Allreduce<float>(float*, size_t);
|
||||
template void Allreduce<double>(double*, size_t);
|
||||
|
||||
} // namespace net
|
||||
|
||||
|
||||
} // namespace multiverso
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#include "multiverso/net/mpi_net.h"
|
||||
|
||||
namespace multiverso {
|
||||
|
||||
namespace {
|
||||
MPI_Datatype GetDataType(char*) { return MPI_CHAR; }
|
||||
MPI_Datatype GetDataType(int*) { return MPI_INT; }
|
||||
MPI_Datatype GetDataType(float*) { return MPI_FLOAT; }
|
||||
MPI_Datatype GetDataType(double*) { return MPI_DOUBLE; }
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void MPINetWrapper::Allreduce(ElemType* data, size_t elem_count, int op) {
|
||||
MPI_Allreduce(MPI_IN_PLACE, data, (int)elem_count,
|
||||
GetDataType(data), op, MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
template void MPINetWrapper::Allreduce<char>(char*, size_t, int);
|
||||
template void MPINetWrapper::Allreduce<int>(int*, size_t, int);
|
||||
template void MPINetWrapper::Allreduce<float>(float*, size_t, int);
|
||||
template void MPINetWrapper::Allreduce<double>(double*, size_t, int);
|
||||
|
||||
} // namespace multiverso
|
|
@ -44,51 +44,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
|
|||
MONITOR_END(SERVER_PROCESS_ADD)
|
||||
}
|
||||
|
||||
void Server::SetTableFilePath(const std::string& table_file_path) {
|
||||
int id = Zoo::Get()->server_rank();
|
||||
std::string server_id_str = (id == 0 ? "0" : "");
|
||||
while (id > 0) {
|
||||
server_id_str = static_cast<char>((id % 10) + '0') + server_id_str;
|
||||
id /= 10;
|
||||
}
|
||||
table_file_path_ = table_file_path + server_id_str;
|
||||
}
|
||||
|
||||
void Server::StoreTable(int epoch) {
|
||||
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
|
||||
FileOpenMode::Write);
|
||||
stream->Write(&epoch, sizeof(int));
|
||||
for (int i = 0; i < store_.size(); ++i) {
|
||||
store_[i]->Store(stream);
|
||||
}
|
||||
delete stream;
|
||||
}
|
||||
|
||||
int Server::LoadTable(const std::string& file_path) {
|
||||
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
|
||||
FileOpenMode::Read);
|
||||
if (!stream->Good()) {
|
||||
Log::Error("Rank %d open file %s error in Server::LoadTable\n",
|
||||
Zoo::Get()->rank(), file_path.c_str());
|
||||
delete stream;
|
||||
return 0; // open file error, may not exist
|
||||
}
|
||||
|
||||
int iter;
|
||||
size_t readsize = stream->Read(&iter, sizeof(int));
|
||||
if (readsize == 0) {
|
||||
Log::Error("Rank %d read file %s no data in Server::LoadTable\n",
|
||||
Zoo::Get()->rank(), file_path.c_str());
|
||||
delete stream;
|
||||
return 0; // no store data
|
||||
}
|
||||
|
||||
for (int i = 0; i < store_.size(); ++i) {
|
||||
store_[i]->Load(stream);
|
||||
}
|
||||
|
||||
delete stream;
|
||||
return iter + 1; // the next iteration number
|
||||
}
|
||||
|
||||
} // namespace multiverso
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "multiverso/updater/adagrad_updater.h"
|
||||
#include "multiverso/updater/momentum_updater.h"
|
||||
#include "multiverso/updater/second_order_gradient_updater.h"
|
||||
#include "multiverso/updater/sgd_updater.h"
|
||||
#include "multiverso/util/configure.h"
|
||||
#include "multiverso/util/log.h"
|
||||
|
@ -35,7 +34,6 @@ Updater<T>* Updater<T>::GetUpdater(size_t size) {
|
|||
if (type == "sgd") return new SGDUpdater<T>(size);
|
||||
if (type == "adagrad") return new AdaGradUpdater<T>(size);
|
||||
if (type == "momentum_sgd") return new MomentumUpdater<T>(size);
|
||||
if (type == "second_order_sgd") return new SecondOrderUpdater<T>(size);
|
||||
// Default: simple updater
|
||||
return new Updater<T>();
|
||||
}
|
||||
|
|
24
src/zoo.cpp
24
src/zoo.cpp
|
@ -20,14 +20,31 @@ Zoo::Zoo() {}
|
|||
|
||||
Zoo::~Zoo() {}
|
||||
|
||||
void Zoo::Start(int* argc, char** argv, int role) {
|
||||
MV_DEFINE_string(ps_role, "default", "none / worker / server / default");
|
||||
MV_DEFINE_bool(ma, "false", "model average, will not start server if true");
|
||||
|
||||
namespace {
|
||||
int ParsePSRole(const std::string& ps_role) {
|
||||
if (ps_role == "none") return Role::NONE;
|
||||
if (ps_role == "worker") return Role::WORKER;
|
||||
if (ps_role == "server") return Role::SERVER;
|
||||
if (ps_role == "default") return Role::ALL;
|
||||
return -1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Zoo::Start(int* argc, char** argv) {
|
||||
Log::Debug("Zoo started\n");
|
||||
CHECK(role >= 0 && role <= 3);
|
||||
ParseCMDFlags(argc, argv);
|
||||
|
||||
// Init the network
|
||||
net_util_ = NetInterface::Get();
|
||||
net_util_->Init(argc, argv);
|
||||
|
||||
if (!MV_CONFIG_ma) {
|
||||
int role = ParsePSRole(MV_CONFIG_ps_role);
|
||||
CHECK(role != -1);
|
||||
|
||||
nodes_.resize(size());
|
||||
nodes_[rank()].rank = rank();
|
||||
nodes_[rank()].role = role;
|
||||
|
@ -44,15 +61,18 @@ void Zoo::Start(int* argc, char** argv, int role) {
|
|||
RegisterNode();
|
||||
Log::Info("Rank %d: Zoo start sucessfully\n", rank());
|
||||
}
|
||||
}
|
||||
|
||||
void Zoo::Stop(bool finalize_net) {
|
||||
// Stop the system
|
||||
if (!MV_CONFIG_ma) {
|
||||
Barrier();
|
||||
|
||||
Dashboard::Display();
|
||||
|
||||
// Stop all actors
|
||||
for (auto actor : zoo_) { actor.second->Stop(); }
|
||||
}
|
||||
// Stop the network
|
||||
if (finalize_net) net_util_->Finalize();
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче