Revert "reverting commit of check-point for testing"

This reverts commit a6bd4606321afed88621e6c6e0e3d92dd13da3c1.
This commit is contained in:
Fei Gao 2016-03-10 15:17:36 -08:00
Родитель 740f93a8db
Коммит 97f3953bc0
13 изменённых файлов: 19 добавлений и 207 удалений

Просмотреть файл

@ -366,10 +366,9 @@ void TestMatrix(int argc, char* argv[]){
//test data_vec
std::vector<int*> data_rows = { &data[0], &data[num_col], &data[5 * num_col], &data[10*num_col] };
std::vector<int*> delta_rows = { &delta[0], &delta[num_col], &delta[5 * num_col], &delta[10 * num_col] };
worker_table->Add(v, delta_rows, num_col);
worker_table->Get(v, data_rows, num_col);
MV_Barrier();
worker_table->Add(v, delta_rows, num_col);
worker_table->Get(v, data_rows, num_col);
MV_Barrier();
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
@ -383,62 +382,21 @@ void TestMatrix(int argc, char* argv[]){
MV_ShutDown();
}
void TestCheckPoint(int argc, char* argv[], bool restore){
Log::Info("Test CheckPoint\n");
MV_Init(&argc, argv, All, restore);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
//if restore = true, will restore server data and return the next iter number of last dump file
//else do nothing and return 0
int begin_iter = MV_RestoreTable("//5FTGDB2/tableData/serverTable_");
MV_Barrier();//won't dump data without parameters
std::vector<int> delta(size);
for (int i = 0; i < size; ++i)
delta[i] = i;
int * data = new int[size];
Log::Debug("rank %d start from iteration %d\n", MV_Rank(), begin_iter);
for (int i = begin_iter; i < 50; ++i){
worker_table->Add(delta.data(), size);
MV_Barrier(i); //dump table data with iteration i each k iterations
}
worker_table->Get(data, size);
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%d ", data[i * num_col + j]);
printf("\n");
}
MV_ShutDown();
}
void TestComm(int argc, char* argv[]) {
}
int main(int argc, char* argv[]) {
Log::ResetLogLevel(LogLevel::Debug);
if (argc == 2) {
if (argc == 2) {
if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv);
else if (strcmp(argv[1], "array") == 0) TestArray(argc, argv);
else if (strcmp(argv[1], "net") == 0) TestNet(argc, argv);
else if (strcmp(argv[1], "ip") == 0) TestIP();
else if (strcmp(argv[1], "momentum") == 0) TestMomentum(argc, argv);
else if (strcmp(argv[1], "threads") == 0) TestMultipleThread(argc, argv);
else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
else if (strcmp(argv[1], "momentum") == 0) TestMomentum(argc, argv);
else if (strcmp(argv[1], "threads") == 0) TestMultipleThread(argc, argv);
else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else CHECK(false);
}
// argc == 4 is for zeromq test, with two extra arguments: machinefile, port

Просмотреть файл

@ -1,8 +1,6 @@
#ifndef MULTIVERSO_INCLUDE_MULTIVERSO_H_
#define MULTIVERSO_INCLUDE_MULTIVERSO_H_
#include <string>
namespace multiverso {
enum Role {
@ -14,10 +12,9 @@ enum Role {
void MV_Init(int* argc = nullptr,
char* argv[] = nullptr,
int role = All,
bool restart = false);
int role = All);
void MV_Barrier(int iter = -1);
void MV_Barrier();
void MV_ShutDown(bool finalize_mpi = true);
@ -33,8 +30,6 @@ int MV_ServerId();
int MV_WorkerIdToRank(int worker_id);
int MV_ServerIdToRank(int server_id);
int MV_RestoreTable(const std::string& dump_file_path);
// --- Net API -------------------------------------------------------------- //
// NOTE(feiga): these API is only used for specific situation.
// Init Multiverso Net with the provided endpoint. Multiverso Net will bind

Просмотреть файл

@ -12,15 +12,9 @@ class Server : public Actor {
public:
Server();
int RegisterTable(ServerTable* table);
//dump server data to file
void DumpTable(const int& epoch);
//restore data from file and return next iteration number
int RestoreTable(const std::string& file_path);
void SetDumpFilePath(const std::string& dump_file_path);
private:
void ProcessGet(MessagePtr& msg);
void ProcessAdd(MessagePtr& msg);
std::string dump_file_path_;
// contains the parameter data structure and related handle method
std::vector<ServerTable*> store_;
};

Просмотреть файл

@ -138,29 +138,6 @@ public:
result->push_back(value);
}
void DumpTable(std::ofstream& os){
os << decay_momentum_rate_first_ << ' ';
os << decay_momentum_rate_second_ << ' ';
os << stepsize_ << ' ';
for (int i = 0; i < storage_[i].size(); ++i)
os << storage_[i] << ' ';
for (int i = 0; i < smooth_gradient_first_.size(); ++i)
os << smooth_gradient_first_[i] << ' ';
for (int i = 0; i < smooth_gradient_second_.size(); ++i)
os << smooth_gradient_second_[i] << ' ';
}
void RecoverTable(std::ifstream& in){
in >> decay_momentum_rate_first_;
in >> decay_momentum_rate_second_;
in >> stepsize_;
for (int i = 0; i < storage_[i].size(); ++i)
in >> storage_[i];
for (int i = 0; i < smooth_gradient_first_.size(); ++i)
in >> smooth_gradient_first_[i];
for (int i = 0; i < smooth_gradient_second_.size(); ++i)
in >> smooth_gradient_second_[i];
}
private:
int server_id_;
float decay_momentum_rate_first_;

Просмотреть файл

@ -118,15 +118,6 @@ public:
result->push_back(value);
}
void DumpTable(std::ofstream& os){
for (int i = 0; i < storage_.size(); ++i)
os << storage_[i] << ' ';
}
void RecoverTable(std::ifstream& in){
for (int i = 0; i < storage_.size(); ++i)
in >> storage_[i];
}
private:
int server_id_;
// T* storage_;

Просмотреть файл

@ -99,23 +99,6 @@ public:
table_[keys.As<Key>(i)] += vals.As<Val>(i);
}
}
void DumpTable(std::ofstream& os){
os << table_.size() << ' ';
for (auto& i : table_){
os << i.first << ' ' << i.second << ' ';
}
}
void RecoverTable(std::ifstream& in){
int count;
Key k;
Val v;
in >> count;
for (int i = 0; i < count; ++i){//may get wrong when Key or Val is char?
in >> k >> v;
table_[k] = v;
}
}
private:
std::unordered_map<Key, Val> table_;
};

Просмотреть файл

@ -1,4 +1,4 @@
#ifndef MULTIVERSO_MATRIX_TABLE_H_
#ifndef MULTIVERSO_MATRIX_TABLE_H_
#define MULTIVERSO_MATRIX_TABLE_H_
#include "multiverso/multiverso.h"
@ -262,19 +262,6 @@ public:
}
}
void DumpTable(std::ofstream& os) override{
char c = '\t';
for (int i = 0; i < storage_.size(); ++i){
os << storage_[i] << c;
}
}
void RecoverTable(std::ifstream& in) override{
for (int i = 0; i < storage_.size(); ++i){
in >> storage_[i];
}
}
private:
int server_id_;
int num_col_;

Просмотреть файл

@ -133,21 +133,6 @@ public:
result->push_back(value);
}
void DumpTable(std::ofstream& os){
os << smooth_momentum_ << ' ';
for (int i = 0; i < storage_.size(); ++i)
os << storage_[i] << ' ';
for (int i = 0; i < smooth_gradient_.size(); ++i)
os << smooth_gradient_[i] << ' ';
}
void RecoverTable(std::ifstream& in){
in >> smooth_momentum_;
for (int i = 0; i < storage_.size(); ++i)
in >> storage_[i];
for (int i = 0; i < smooth_gradient_.size(); ++i)
in >> smooth_gradient_[i];
}
private:
int server_id_;
float smooth_momentum_;

Просмотреть файл

@ -52,8 +52,6 @@ public:
virtual void ProcessAdd(const std::vector<Blob>& data) = 0;
virtual void ProcessGet(const std::vector<Blob>& data,
std::vector<Blob>* result) = 0;
virtual void DumpTable(std::ofstream& os) = 0;
virtual void RecoverTable(std::ifstream& in) = 0;
const std::string name() const { return std::string(typeid(this).name());};

Просмотреть файл

@ -23,11 +23,11 @@ public:
static Zoo* Get() { static Zoo zoo; return &zoo; };
// Start all actors
void Start(int* argc, char** argv, int role, bool restart);
void Start(int* argc, char** argv, int role);
// Stop all actors
void Stop(bool finalize_net);
void Barrier(const int& iter = -1);
void Barrier();
void SendTo(const std::string& name, MessagePtr&);
void Receive(MessagePtr& msg);
@ -58,8 +58,6 @@ public:
CHECK(zoo_[name] == nullptr);
zoo_[name] = actor;
}
int RestoreTable(const std::string& dump_file_path);
private:
// private constructor
Zoo();
@ -77,9 +75,6 @@ private:
int num_workers_;
int num_servers_;
bool restart_;
int dump_each_k_;
};
}

Просмотреть файл

@ -5,15 +5,15 @@
namespace multiverso {
void MV_Init(int* argc, char* argv[], int role, bool restart) {
Zoo::Get()->Start(argc, argv, role, restart);
void MV_Init(int* argc, char* argv[], int role) {
Zoo::Get()->Start(argc, argv, role);
}
void MV_ShutDown(bool finalize_net) {
Zoo::Get()->Stop(finalize_net);
}
void MV_Barrier(int iter) { Zoo::Get()->Barrier(iter); }
void MV_Barrier() { Zoo::Get()->Barrier(); }
int MV_Rank() { return Zoo::Get()->rank(); }
@ -49,7 +49,4 @@ int MV_NetConnect(int* ranks, char* endpoints[], int size) {
return NetInterface::Get()->Connect(ranks, endpoints, size);
}
int MV_RestoreTable(const std::string& dump_file_path){
return Zoo::Get()->RestoreTable(dump_file_path);
}
}

Просмотреть файл

@ -35,35 +35,4 @@ void Server::ProcessAdd(MessagePtr& msg) {
SendTo(actor::kCommunicator, reply);
}
void Server::SetDumpFilePath(const std::string& dump_file_path){
int id = Zoo::Get()->server_rank();
std::string server_id_str = (id == 0 ? "0" : "");
while (id > 0){
server_id_str = (char)((id % 10) + '0') + server_id_str;
id /= 10;
}
dump_file_path_ = dump_file_path + server_id_str;
}
void Server::DumpTable(const int& epoch){
std::ofstream os(dump_file_path_, std::ios::out);
char c = '\n';
os << epoch << c;
for (int i = 0; i < store_.size(); ++i){
store_[i]->DumpTable(os);
os << c;
}
os.close();
}
int Server::RestoreTable(const std::string& file_path){
std::ifstream in(dump_file_path_, std::ios::in);
int iter;
in >> iter;
for (int i = 0; i < store_.size(); ++i){
store_[i]->RecoverTable(in);
}
in.close();
return iter + 1; //the next iteration number
}
}

Просмотреть файл

@ -11,17 +11,13 @@
namespace multiverso {
Zoo::Zoo() {
dump_each_k_ = 5;
restart_ = false;
}
Zoo::Zoo() {}
Zoo::~Zoo() {}
void Zoo::Start(int* argc, char** argv, int role, bool restart) {
void Zoo::Start(int* argc, char** argv, int role) {
Log::Debug("Zoo started\n");
CHECK(role >= 0 && role <= 3);
restart_ = restart;
// Init the network
net_util_ = NetInterface::Get();
net_util_->Init(argc, argv);
@ -93,7 +89,7 @@ void Zoo::RegisterNode() {
}
}
void Zoo::Barrier(const int& iter) {
void Zoo::Barrier() {
MessagePtr msg(new Message());
msg->set_src(rank());
msg->set_dst(0); // rank 0 acts as the controller master.
@ -106,11 +102,6 @@ void Zoo::Barrier(const int& iter) {
mailbox_->Pop(msg);
CHECK(msg->type() == MsgType::Control_Reply_Barrier);
Log::Debug("rank %d reached barrier\n", rank());
if (iter >= 0 && iter % dump_each_k_ == 0){
static_cast<Server*>(zoo_[actor::kServer])->DumpTable(iter);
}
}
int Zoo::RegisterTable(WorkerTable* worker_table) {
@ -123,12 +114,4 @@ int Zoo::RegisterTable(ServerTable* server_table) {
->RegisterTable(server_table);
}
int Zoo::RestoreTable(const std::string& dump_file_path){
auto server = static_cast<Server*>(zoo_[actor::kServer]);
server->SetDumpFilePath(dump_file_path);
if (restart_){
return server->RestoreTable(dump_file_path);
}
return 0;
}
}