add interface for model average

This commit is contained in:
feiga 2016-04-17 19:56:58 +08:00
Родитель 2af610bf25
Коммит 468948ea44
16 изменённых файлов: 239 добавлений и 205 удалений

Просмотреть файл

@ -83,7 +83,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>IMultiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
<AdditionalDependencies>Multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);</AdditionalDependencies>
<AdditionalLibraryDirectories>D:\multiverso-next\x64\release</AdditionalLibraryDirectories>
</Link>
<ProjectReference>

Просмотреть файл

@ -10,6 +10,7 @@
#include <multiverso/net.h>
#include <multiverso/util/log.h>
#include <multiverso/util/net_util.h>
#include <multiverso/util/configure.h>
#include <multiverso/table/array_table.h>
#include <multiverso/table/kv_table.h>
@ -112,7 +113,6 @@ void TestArray(int argc, char* argv[]) {
std::cout << data[i] << " "; std::cout << std::endl;
MV_Barrier();
if (iter % 100 == 0) MV_Dashboard();
}
MV_ShutDown();
}
@ -248,7 +248,8 @@ void TestNet(int argc, char* argv[]) {
Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data());
};
}
} else {// other rank
}
else {// other rank
MessagePtr msg(new Message());// = std::make_unique<Message>();
while (net->Recv(&msg) == 0) {
// Log::Info("recv return 0\n");
@ -412,7 +413,7 @@ void TestMatrix(int argc, char* argv[]){
void TestCheckPoint(int argc, char* argv[], bool restore){
Log::Info("Test CheckPoint\n");
MV_Init(&argc, argv, 3 /*, restore */);
MV_Init(&argc, argv);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
@ -453,10 +454,16 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
MV_ShutDown();
}
void TestComm(int argc, char* argv[]) {
void TestAllreduce(int argc, char* argv[]) {
multiverso::SetCMDFlag("ps_role", std::string("none"));
MV_Init(&argc, argv);
int a = 1;
MV_Aggregate(&a, 1);
std::cout << "a = " << a << std::endl;
MV_ShutDown();
}
int main(int argc, char* argv[]) {
Log::ResetLogLevel(LogLevel::Debug);
@ -477,6 +484,7 @@ int main(int argc, char* argv[]) {
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv);
else CHECK(false);
}
return 0;

Просмотреть файл

@ -5,13 +5,11 @@
namespace multiverso {
void MV_Init(int* argc = nullptr,
char* argv[] = nullptr,
int role = 3);
void MV_Init(int* argc = nullptr, char* argv[] = nullptr);
void MV_Barrier();
void MV_ShutDown(bool finalize_mpi = true);
void MV_ShutDown(bool finalize_net = true);
int MV_Rank();
int MV_Size();
@ -25,9 +23,9 @@ int MV_ServerId();
int MV_WorkerIdToRank(int worker_id);
int MV_ServerIdToRank(int server_id);
// Show the dashboard information about the monitored excuation time
// used for profile
void MV_Dashboard();
// inplace sum by allreduce
template <typename ElemType>
void MV_Aggregate(ElemType* data, int size);
// --- Net API -------------------------------------------------------------- //
// NOTE(feiga): these API is only used for specific situation.

Просмотреть файл

@ -32,6 +32,8 @@ public:
virtual int size() const = 0;
virtual int rank() const = 0;
// virtual void Allreduce(void* data, size_t count, int type, int type_size);
// \return 1. > 0 sended size
// 2. = 0 not sended
// 3. < 0 net error
@ -45,6 +47,12 @@ public:
virtual int thread_level_support() = 0;
};
namespace net {
// inplace allreduce
template <typename Typename>
void Allreduce(Typename* data, size_t elem_count);
}
} // namespace multiverso
#endif // MULTIVERSO_NET_NET_H_

Просмотреть файл

@ -70,6 +70,7 @@ public:
MV_MPI_CALL(MPI_Initialized(&inited_));
if (!inited_) {
MV_MPI_CALL(MPI_Init_thread(argc, &argv, MPI_THREAD_SERIALIZED, &thread_provided_));
MV_MPI_CALL(MPI_Initialized(&inited_));
}
MV_MPI_CALL(MPI_Query_thread(&thread_provided_));
if (thread_provided_ < MPI_THREAD_SERIALIZED) {
@ -105,6 +106,9 @@ public:
int size() const override { return size_; }
std::string name() const override { return "MPI"; }
template <typename ElemType>
static void Allreduce(ElemType* data, size_t elem_count, int op = MPI_SUM);
//size_t Send(MessagePtr& msg) override {
// while (!msg_handles_.empty()) {
// MPIMsgHandle* prev = msg_handles_.front();

Просмотреть файл

@ -4,15 +4,16 @@
namespace multiverso {
enum Role {
NONE = 0,
WORKER = 1,
SERVER = 2
SERVER = 2,
ALL = 3
};
struct Node {
int rank;
// role can be 0, 1, 2, 3
// 00 means neither worker nor server, should be controllor, so at most
// one node could use this value
// 00 means neither worker nor server
// 01 means worker
// 10 means server
// 11 means both server and worker, default value

Просмотреть файл

@ -9,21 +9,17 @@
namespace multiverso {
class ServerTable;
class Synchronizer;
class Server : public Actor {
public:
Server();
int RegisterTable(ServerTable* table);
// store server data to file
void StoreTable(int epoch);
// load data from file and return next iteration number
int LoadTable(const std::string& file_path);
void SetTableFilePath(const std::string& table_file_path);
private:
void ProcessGet(MessagePtr& msg);
void ProcessAdd(MessagePtr& msg);
std::string table_file_path_;
// contains the parameter data structure and related handle method
// Synchronizer* sync_;
std::vector<ServerTable*> store_;
};

Просмотреть файл

@ -24,7 +24,7 @@ public:
static Zoo* Get() { static Zoo zoo; return &zoo; }
// Start all actors
void Start(int* argc, char** argv, int role);
void Start(int* argc, char** argv);
// Stop all actors
void Stop(bool finalize_net);

Просмотреть файл

@ -195,7 +195,6 @@
<ClInclude Include="..\include\multiverso\table_interface.h" />
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h" />
<ClInclude Include="..\include\multiverso\updater\sgd_updater.h" />
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h" />
<ClInclude Include="..\include\multiverso\updater\momentum_updater.h" />
<ClInclude Include="..\include\multiverso\updater\updater.h" />
<ClInclude Include="..\include\multiverso\util\configure.h" />
@ -221,6 +220,7 @@
<ClCompile Include="net.cpp" />
<ClCompile Include="net\allreduce_engine.cpp" />
<ClCompile Include="net\allreduce_topo.cpp" />
<ClCompile Include="net\mpi_net.cpp" />
<ClCompile Include="node.cpp" />
<ClCompile Include="server.cpp" />
<ClCompile Include="table.cpp" />

Просмотреть файл

@ -94,9 +94,6 @@
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h">
<Filter>updater</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\updater\second_order_gradient_updater.h">
<Filter>updater</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\table\sparse_matrix_table.h">
<Filter>table</Filter>
</ClInclude>
@ -215,5 +212,8 @@
<ClCompile Include="io\local_stream.cpp">
<Filter>io</Filter>
</ClCompile>
<ClCompile Include="net\mpi_net.cpp">
<Filter>net</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -6,8 +6,8 @@
namespace multiverso {
void MV_Init(int* argc, char* argv[], int role) {
Zoo::Get()->Start(argc, argv, role);
void MV_Init(int* argc, char* argv[]) {
Zoo::Get()->Start(argc, argv);
}
void MV_ShutDown(bool finalize_net) {
@ -42,8 +42,9 @@ int MV_ServerIdToRank(int server_id) {
return Zoo::Get()->server_id_to_rank(server_id);
}
void MV_Dashboard() {
Dashboard::Display();
template <typename ElemType>
void MV_Aggregate(ElemType* data, int size) {
net::Allreduce(data, size);
}
int MV_NetBind(int rank, char* endpoint) {
@ -54,4 +55,9 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
template void MV_Aggregate<char>(char*, int);
template void MV_Aggregate<int>(int*, int);
template void MV_Aggregate<float>(float*, int);
template void MV_Aggregate<double>(double*, int);
} // namespace multiverso

Просмотреть файл

@ -23,4 +23,23 @@ NetInterface* NetInterface::Get() {
#endif
}
namespace net {
template <typename Typename>
void Allreduce(Typename* data, size_t elem_count) {
#ifdef MULTIVERSO_USE_MPI
CHECK(NetInterface::Get()->active());
MPINetWrapper::Allreduce(data, elem_count);
#else
#endif
}
template void Allreduce<char>(char*, size_t);
template void Allreduce<int>(int*, size_t);
template void Allreduce<float>(float*, size_t);
template void Allreduce<double>(double*, size_t);
} // namespace net
} // namespace multiverso

23
src/net/mpi_net.cpp Normal file
Просмотреть файл

@ -0,0 +1,23 @@
#include "multiverso/net/mpi_net.h"
namespace multiverso {
namespace {
MPI_Datatype GetDataType(char*) { return MPI_CHAR; }
MPI_Datatype GetDataType(int*) { return MPI_INT; }
MPI_Datatype GetDataType(float*) { return MPI_FLOAT; }
MPI_Datatype GetDataType(double*) { return MPI_DOUBLE; }
}
template <typename ElemType>
void MPINetWrapper::Allreduce(ElemType* data, size_t elem_count, int op) {
MPI_Allreduce(MPI_IN_PLACE, data, (int)elem_count,
GetDataType(data), op, MPI_COMM_WORLD);
}
template void MPINetWrapper::Allreduce<char>(char*, size_t, int);
template void MPINetWrapper::Allreduce<int>(int*, size_t, int);
template void MPINetWrapper::Allreduce<float>(float*, size_t, int);
template void MPINetWrapper::Allreduce<double>(double*, size_t, int);
} // namespace multiverso

Просмотреть файл

@ -44,51 +44,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
MONITOR_END(SERVER_PROCESS_ADD)
}
void Server::SetTableFilePath(const std::string& table_file_path) {
int id = Zoo::Get()->server_rank();
std::string server_id_str = (id == 0 ? "0" : "");
while (id > 0) {
server_id_str = static_cast<char>((id % 10) + '0') + server_id_str;
id /= 10;
}
table_file_path_ = table_file_path + server_id_str;
}
void Server::StoreTable(int epoch) {
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
FileOpenMode::Write);
stream->Write(&epoch, sizeof(int));
for (int i = 0; i < store_.size(); ++i) {
store_[i]->Store(stream);
}
delete stream;
}
int Server::LoadTable(const std::string& file_path) {
Stream* stream = StreamFactory::GetStream(URI(table_file_path_),
FileOpenMode::Read);
if (!stream->Good()) {
Log::Error("Rank %d open file %s error in Server::LoadTable\n",
Zoo::Get()->rank(), file_path.c_str());
delete stream;
return 0; // open file error, may not exist
}
int iter;
size_t readsize = stream->Read(&iter, sizeof(int));
if (readsize == 0) {
Log::Error("Rank %d read file %s no data in Server::LoadTable\n",
Zoo::Get()->rank(), file_path.c_str());
delete stream;
return 0; // no store data
}
for (int i = 0; i < store_.size(); ++i) {
store_[i]->Load(stream);
}
delete stream;
return iter + 1; // the next iteration number
}
} // namespace multiverso

Просмотреть файл

@ -2,7 +2,6 @@
#include "multiverso/updater/adagrad_updater.h"
#include "multiverso/updater/momentum_updater.h"
#include "multiverso/updater/second_order_gradient_updater.h"
#include "multiverso/updater/sgd_updater.h"
#include "multiverso/util/configure.h"
#include "multiverso/util/log.h"
@ -35,7 +34,6 @@ Updater<T>* Updater<T>::GetUpdater(size_t size) {
if (type == "sgd") return new SGDUpdater<T>(size);
if (type == "adagrad") return new AdaGradUpdater<T>(size);
if (type == "momentum_sgd") return new MomentumUpdater<T>(size);
if (type == "second_order_sgd") return new SecondOrderUpdater<T>(size);
// Default: simple updater
return new Updater<T>();
}

Просмотреть файл

@ -20,14 +20,31 @@ Zoo::Zoo() {}
Zoo::~Zoo() {}
void Zoo::Start(int* argc, char** argv, int role) {
MV_DEFINE_string(ps_role, "default", "none / worker / server / default");
MV_DEFINE_bool(ma, "false", "model average, will not start server if true");
namespace {
int ParsePSRole(const std::string& ps_role) {
if (ps_role == "none") return Role::NONE;
if (ps_role == "worker") return Role::WORKER;
if (ps_role == "server") return Role::SERVER;
if (ps_role == "default") return Role::ALL;
return -1;
}
} // namespace
void Zoo::Start(int* argc, char** argv) {
Log::Debug("Zoo started\n");
CHECK(role >= 0 && role <= 3);
ParseCMDFlags(argc, argv);
// Init the network
net_util_ = NetInterface::Get();
net_util_->Init(argc, argv);
if (!MV_CONFIG_ma) {
int role = ParsePSRole(MV_CONFIG_ps_role);
CHECK(role != -1);
nodes_.resize(size());
nodes_[rank()].rank = rank();
nodes_[rank()].role = role;
@ -44,15 +61,18 @@ void Zoo::Start(int* argc, char** argv, int role) {
RegisterNode();
Log::Info("Rank %d: Zoo start sucessfully\n", rank());
}
}
void Zoo::Stop(bool finalize_net) {
// Stop the system
if (!MV_CONFIG_ma) {
Barrier();
Dashboard::Display();
// Stop all actors
for (auto actor : zoo_) { actor.second->Stop(); }
}
// Stop the network
if (finalize_net) net_util_->Finalize();
}