Revert "reverting commit of check-point for testing"
This reverts commit a6bd4606321afed88621e6c6e0e3d92dd13da3c1.
This commit is contained in:
Родитель
740f93a8db
Коммит
97f3953bc0
|
@ -366,10 +366,9 @@ void TestMatrix(int argc, char* argv[]){
|
|||
//test data_vec
|
||||
std::vector<int*> data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10*num_col] };
|
||||
std::vector<int*> delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
|
||||
|
||||
worker_table->Add(v, delta_rows, num_col);
|
||||
worker_table->Get(v, data_rows, num_col);
|
||||
MV_Barrier();
|
||||
worker_table->Add(v, delta_rows, num_col);
|
||||
worker_table->Get(v, data_rows, num_col);
|
||||
MV_Barrier();
|
||||
|
||||
printf("----------------------------\n");
|
||||
for (int i = 0; i < num_row; ++i){
|
||||
|
@ -383,62 +382,21 @@ void TestMatrix(int argc, char* argv[]){
|
|||
MV_ShutDown();
|
||||
}
|
||||
|
||||
void TestCheckPoint(int argc, char* argv[], bool restore){
|
||||
Log::Info("Test CheckPoint\n");
|
||||
|
||||
MV_Init(&argc, argv, All, restore);
|
||||
|
||||
int num_row = 11, num_col = 10;
|
||||
int size = num_row * num_col;
|
||||
|
||||
MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
|
||||
MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
|
||||
//if restore = true, will restore server data and return the next iter number of last dump file
|
||||
//else do nothing and return 0
|
||||
int begin_iter = MV_RestoreTable("//5FTGDB2/tableData/serverTable_");
|
||||
MV_Barrier();//won't dump data without parameters
|
||||
|
||||
std::vector<int> delta(size);
|
||||
for (int i = 0; i < size; ++i)
|
||||
delta[i] = i;
|
||||
int * data = new int[size];
|
||||
|
||||
Log::Debug("rank %d start from iteration %d\n", MV_Rank(), begin_iter);
|
||||
|
||||
for (int i = begin_iter; i < 50; ++i){
|
||||
worker_table->Add(delta.data(), size);
|
||||
MV_Barrier(i); //dump table data with iteration i each k iterations
|
||||
}
|
||||
worker_table->Get(data, size);
|
||||
|
||||
printf("----------------------------\n");
|
||||
for (int i = 0; i < num_row; ++i){
|
||||
printf("rank %d, row %d: ", MV_Rank(), i);
|
||||
for (int j = 0; j < num_col; ++j)
|
||||
printf("%d ", data[i * num_col + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
MV_ShutDown();
|
||||
}
|
||||
|
||||
void TestComm(int argc, char* argv[]) {
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
Log::ResetLogLevel(LogLevel::Debug);
|
||||
if (argc == 2) {
|
||||
if (argc == 2) {
|
||||
if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv);
|
||||
else if (strcmp(argv[1], "array") == 0) TestArray(argc, argv);
|
||||
else if (strcmp(argv[1], "net") == 0) TestNet(argc, argv);
|
||||
else if (strcmp(argv[1], "ip") == 0) TestIP();
|
||||
else if (strcmp(argv[1], "momentum") == 0) TestMomentum(argc, argv);
|
||||
else if (strcmp(argv[1], "threads") == 0) TestMultipleThread(argc, argv);
|
||||
else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
|
||||
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
|
||||
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
|
||||
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
|
||||
else if (strcmp(argv[1], "momentum") == 0) TestMomentum(argc, argv);
|
||||
else if (strcmp(argv[1], "threads") == 0) TestMultipleThread(argc, argv);
|
||||
else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
|
||||
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
|
||||
else CHECK(false);
|
||||
}
|
||||
// argc == 4 is for zeromq test, with two extra arguments: machinefile, port
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#ifndef MULTIVERSO_INCLUDE_MULTIVERSO_H_
|
||||
#define MULTIVERSO_INCLUDE_MULTIVERSO_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace multiverso {
|
||||
|
||||
enum Role {
|
||||
|
@ -14,10 +12,9 @@ enum Role {
|
|||
|
||||
void MV_Init(int* argc = nullptr,
|
||||
char* argv[] = nullptr,
|
||||
int role = All,
|
||||
bool restart = false);
|
||||
int role = All);
|
||||
|
||||
void MV_Barrier(int iter = -1);
|
||||
void MV_Barrier();
|
||||
|
||||
void MV_ShutDown(bool finalize_mpi = true);
|
||||
|
||||
|
@ -33,8 +30,6 @@ int MV_ServerId();
|
|||
int MV_WorkerIdToRank(int worker_id);
|
||||
int MV_ServerIdToRank(int server_id);
|
||||
|
||||
int MV_RestoreTable(const std::string& dump_file_path);
|
||||
|
||||
// --- Net API -------------------------------------------------------------- //
|
||||
// NOTE(feiga): these API is only used for specific situation.
|
||||
// Init Multiverso Net with the provided endpoint. Multiverso Net will bind
|
||||
|
|
|
@ -12,15 +12,9 @@ class Server : public Actor {
|
|||
public:
|
||||
Server();
|
||||
int RegisterTable(ServerTable* table);
|
||||
//dump server data to file
|
||||
void DumpTable(const int& epoch);
|
||||
//restore data from file and return next iteration number
|
||||
int RestoreTable(const std::string& file_path);
|
||||
void SetDumpFilePath(const std::string& dump_file_path);
|
||||
private:
|
||||
void ProcessGet(MessagePtr& msg);
|
||||
void ProcessAdd(MessagePtr& msg);
|
||||
std::string dump_file_path_;
|
||||
// contains the parameter data structure and related handle method
|
||||
std::vector<ServerTable*> store_;
|
||||
};
|
||||
|
|
|
@ -138,29 +138,6 @@ public:
|
|||
result->push_back(value);
|
||||
}
|
||||
|
||||
void DumpTable(std::ofstream& os){
|
||||
os << decay_momentum_rate_first_ << ' ';
|
||||
os << decay_momentum_rate_second_ << ' ';
|
||||
os << stepsize_ << ' ';
|
||||
for (int i = 0; i < storage_[i].size(); ++i)
|
||||
os << storage_[i] << ' ';
|
||||
for (int i = 0; i < smooth_gradient_first_.size(); ++i)
|
||||
os << smooth_gradient_first_[i] << ' ';
|
||||
for (int i = 0; i < smooth_gradient_second_.size(); ++i)
|
||||
os << smooth_gradient_second_[i] << ' ';
|
||||
}
|
||||
void RecoverTable(std::ifstream& in){
|
||||
in >> decay_momentum_rate_first_;
|
||||
in >> decay_momentum_rate_second_;
|
||||
in >> stepsize_;
|
||||
for (int i = 0; i < storage_[i].size(); ++i)
|
||||
in >> storage_[i];
|
||||
for (int i = 0; i < smooth_gradient_first_.size(); ++i)
|
||||
in >> smooth_gradient_first_[i];
|
||||
for (int i = 0; i < smooth_gradient_second_.size(); ++i)
|
||||
in >> smooth_gradient_second_[i];
|
||||
}
|
||||
|
||||
private:
|
||||
int server_id_;
|
||||
float decay_momentum_rate_first_;
|
||||
|
|
|
@ -118,15 +118,6 @@ public:
|
|||
result->push_back(value);
|
||||
}
|
||||
|
||||
void DumpTable(std::ofstream& os){
|
||||
for (int i = 0; i < storage_.size(); ++i)
|
||||
os << storage_[i] << ' ';
|
||||
}
|
||||
void RecoverTable(std::ifstream& in){
|
||||
for (int i = 0; i < storage_.size(); ++i)
|
||||
in >> storage_[i];
|
||||
}
|
||||
|
||||
private:
|
||||
int server_id_;
|
||||
// T* storage_;
|
||||
|
|
|
@ -99,23 +99,6 @@ public:
|
|||
table_[keys.As<Key>(i)] += vals.As<Val>(i);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpTable(std::ofstream& os){
|
||||
os << table_.size() << ' ';
|
||||
for (auto& i : table_){
|
||||
os << i.first << ' ' << i.second << ' ';
|
||||
}
|
||||
}
|
||||
void RecoverTable(std::ifstream& in){
|
||||
int count;
|
||||
Key k;
|
||||
Val v;
|
||||
in >> count;
|
||||
for (int i = 0; i < count; ++i){//may get wrong when Key or Val is char?
|
||||
in >> k >> v;
|
||||
table_[k] = v;
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::unordered_map<Key, Val> table_;
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#ifndef MULTIVERSO_MATRIX_TABLE_H_
|
||||
#ifndef MULTIVERSO_MATRIX_TABLE_H_
|
||||
#define MULTIVERSO_MATRIX_TABLE_H_
|
||||
|
||||
#include "multiverso/multiverso.h"
|
||||
|
@ -262,19 +262,6 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void DumpTable(std::ofstream& os) override{
|
||||
char c = '\t';
|
||||
for (int i = 0; i < storage_.size(); ++i){
|
||||
os << storage_[i] << c;
|
||||
}
|
||||
}
|
||||
|
||||
void RecoverTable(std::ifstream& in) override{
|
||||
for (int i = 0; i < storage_.size(); ++i){
|
||||
in >> storage_[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int server_id_;
|
||||
int num_col_;
|
||||
|
|
|
@ -133,21 +133,6 @@ public:
|
|||
result->push_back(value);
|
||||
}
|
||||
|
||||
void DumpTable(std::ofstream& os){
|
||||
os << smooth_momentum_ << ' ';
|
||||
for (int i = 0; i < storage_.size(); ++i)
|
||||
os << storage_[i] << ' ';
|
||||
for (int i = 0; i < smooth_gradient_.size(); ++i)
|
||||
os << smooth_gradient_[i] << ' ';
|
||||
}
|
||||
void RecoverTable(std::ifstream& in){
|
||||
in >> smooth_momentum_;
|
||||
for (int i = 0; i < storage_.size(); ++i)
|
||||
in >> storage_[i];
|
||||
for (int i = 0; i < smooth_gradient_.size(); ++i)
|
||||
in >> smooth_gradient_[i];
|
||||
}
|
||||
|
||||
private:
|
||||
int server_id_;
|
||||
float smooth_momentum_;
|
||||
|
|
|
@ -52,8 +52,6 @@ public:
|
|||
virtual void ProcessAdd(const std::vector<Blob>& data) = 0;
|
||||
virtual void ProcessGet(const std::vector<Blob>& data,
|
||||
std::vector<Blob>* result) = 0;
|
||||
virtual void DumpTable(std::ofstream& os) = 0;
|
||||
virtual void RecoverTable(std::ifstream& in) = 0;
|
||||
|
||||
const std::string name() const { return std::string(typeid(this).name());};
|
||||
|
||||
|
|
|
@ -23,11 +23,11 @@ public:
|
|||
static Zoo* Get() { static Zoo zoo; return &zoo; };
|
||||
|
||||
// Start all actors
|
||||
void Start(int* argc, char** argv, int role, bool restart);
|
||||
void Start(int* argc, char** argv, int role);
|
||||
// Stop all actors
|
||||
void Stop(bool finalize_net);
|
||||
|
||||
void Barrier(const int& iter = -1);
|
||||
void Barrier();
|
||||
|
||||
void SendTo(const std::string& name, MessagePtr&);
|
||||
void Receive(MessagePtr& msg);
|
||||
|
@ -58,8 +58,6 @@ public:
|
|||
CHECK(zoo_[name] == nullptr);
|
||||
zoo_[name] = actor;
|
||||
}
|
||||
|
||||
int RestoreTable(const std::string& dump_file_path);
|
||||
private:
|
||||
// private constructor
|
||||
Zoo();
|
||||
|
@ -77,9 +75,6 @@ private:
|
|||
|
||||
int num_workers_;
|
||||
int num_servers_;
|
||||
|
||||
bool restart_;
|
||||
int dump_each_k_;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
|
||||
namespace multiverso {
|
||||
|
||||
void MV_Init(int* argc, char* argv[], int role, bool restart) {
|
||||
Zoo::Get()->Start(argc, argv, role, restart);
|
||||
void MV_Init(int* argc, char* argv[], int role) {
|
||||
Zoo::Get()->Start(argc, argv, role);
|
||||
}
|
||||
|
||||
void MV_ShutDown(bool finalize_net) {
|
||||
Zoo::Get()->Stop(finalize_net);
|
||||
}
|
||||
|
||||
void MV_Barrier(int iter) { Zoo::Get()->Barrier(iter); }
|
||||
void MV_Barrier() { Zoo::Get()->Barrier(); }
|
||||
|
||||
int MV_Rank() { return Zoo::Get()->rank(); }
|
||||
|
||||
|
@ -49,7 +49,4 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
|
|||
return NetInterface::Get()->Connect(ranks, endpoints, size);
|
||||
}
|
||||
|
||||
int MV_RestoreTable(const std::string& dump_file_path){
|
||||
return Zoo::Get()->RestoreTable(dump_file_path);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,35 +35,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
|
|||
SendTo(actor::kCommunicator, reply);
|
||||
}
|
||||
|
||||
void Server::SetDumpFilePath(const std::string& dump_file_path){
|
||||
int id = Zoo::Get()->server_rank();
|
||||
std::string server_id_str = (id == 0 ? "0" : "");
|
||||
while (id > 0){
|
||||
server_id_str = (char)((id % 10) + '0') + server_id_str;
|
||||
id /= 10;
|
||||
}
|
||||
dump_file_path_ = dump_file_path + server_id_str;
|
||||
}
|
||||
|
||||
void Server::DumpTable(const int& epoch){
|
||||
std::ofstream os(dump_file_path_, std::ios::out);
|
||||
char c = '\n';
|
||||
os << epoch << c;
|
||||
for (int i = 0; i < store_.size(); ++i){
|
||||
store_[i]->DumpTable(os);
|
||||
os << c;
|
||||
}
|
||||
os.close();
|
||||
}
|
||||
|
||||
int Server::RestoreTable(const std::string& file_path){
|
||||
std::ifstream in(dump_file_path_, std::ios::in);
|
||||
int iter;
|
||||
in >> iter;
|
||||
for (int i = 0; i < store_.size(); ++i){
|
||||
store_[i]->RecoverTable(in);
|
||||
}
|
||||
in.close();
|
||||
return iter + 1; //the next iteration number
|
||||
}
|
||||
}
|
|
@ -11,17 +11,13 @@
|
|||
|
||||
namespace multiverso {
|
||||
|
||||
Zoo::Zoo() {
|
||||
dump_each_k_ = 5;
|
||||
restart_ = false;
|
||||
}
|
||||
Zoo::Zoo() {}
|
||||
|
||||
Zoo::~Zoo() {}
|
||||
|
||||
void Zoo::Start(int* argc, char** argv, int role, bool restart) {
|
||||
void Zoo::Start(int* argc, char** argv, int role) {
|
||||
Log::Debug("Zoo started\n");
|
||||
CHECK(role >= 0 && role <= 3);
|
||||
restart_ = restart;
|
||||
// Init the network
|
||||
net_util_ = NetInterface::Get();
|
||||
net_util_->Init(argc, argv);
|
||||
|
@ -93,7 +89,7 @@ void Zoo::RegisterNode() {
|
|||
}
|
||||
}
|
||||
|
||||
void Zoo::Barrier(const int& iter) {
|
||||
void Zoo::Barrier() {
|
||||
MessagePtr msg(new Message());
|
||||
msg->set_src(rank());
|
||||
msg->set_dst(0); // rank 0 acts as the controller master.
|
||||
|
@ -106,11 +102,6 @@ void Zoo::Barrier(const int& iter) {
|
|||
mailbox_->Pop(msg);
|
||||
CHECK(msg->type() == MsgType::Control_Reply_Barrier);
|
||||
Log::Debug("rank %d reached barrier\n", rank());
|
||||
|
||||
if (iter >= 0 && iter % dump_each_k_ == 0){
|
||||
static_cast<Server*>(zoo_[actor::kServer])->DumpTable(iter);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int Zoo::RegisterTable(WorkerTable* worker_table) {
|
||||
|
@ -123,12 +114,4 @@ int Zoo::RegisterTable(ServerTable* server_table) {
|
|||
->RegisterTable(server_table);
|
||||
}
|
||||
|
||||
int Zoo::RestoreTable(const std::string& dump_file_path){
|
||||
auto server = static_cast<Server*>(zoo_[actor::kServer]);
|
||||
server->SetDumpFilePath(dump_file_path);
|
||||
if (restart_){
|
||||
return server->RestoreTable(dump_file_path);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче