This commit is contained in:
liming-vie 2016-05-03 19:33:57 +08:00
Родитель a0e2d7c04c
Коммит c7937dbcb7
6 изменённых файлов: 253 добавлений и 156 удалений

Просмотреть файл

@ -276,13 +276,12 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
MatrixWorkerTable<int>* worker_table =
static_cast<MatrixWorkerTable<int>*>((new MatrixTableHelper<int>(num_row, num_col))->CreateTable());
//MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
//MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
//if restore = true, will restore server data and return the next iter number of last dump file
//else do nothing and return 0
if (worker_table == nullptr) {
//no worker in this node
}
//if restore = true, will restore server data and return the next iter number of last dump file
//else do nothing and return 0
// int begin_iter = MV_LoadTable("serverTable_");
MV_Barrier();//won't dump data without parameters
@ -319,7 +318,6 @@ void TestAllreduce(int argc, char* argv[]) {
MV_ShutDown();
}
template<typename WT, typename ST>
void TestmatrixPerformance(int argc, char* argv[],
std::function<std::shared_ptr<WT>(int num_row, int num_col)>CreateWorkerTable,

Просмотреть файл

@ -92,29 +92,6 @@ protected:
integer_t num_col_;
};
////new implementation
//template<typename T>
//class MatrixTableFactory : public TableFactory {
//public:
// /*
// * args[0] : num_row
// * args[1] : num_col
// */
// MatrixTableFactory(const std::vector<void*>&args) {
// CHECK(args.size() == 2);
// num_row_ = *(int*)args[0];
// num_col_ = *(int*)args[1];
// }
//protected:
// WorkerTable* CreateWorkerTable() override{
// return new MatrixWorkerTable<T>(num_row_, num_col_);
// }
// ServerTable* CreateServerTable() override{
// return new MatrixServerTable<T>(num_row_, num_col_);
// }
// int num_row_;
// int num_col_;
//};
}
#endif // MULTIVERSO_MATRIX_TABLE_H_

Просмотреть файл

@ -1,113 +1,168 @@
#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
#ifndef MULTIVERSO_TABLE_FACTORY_H_
#define MULTIVERSO_TABLE_FACTORY_H_
#include "multiverso/table_interface.h"
#include <functional>
#include <string>
#include <map>
#include "multiverso/util/log.h"
#include "multiverso/table/array_table.h"
namespace multiverso {
enum EleType {
kInt = 0, kFloat, kDouble
};
class WorkerTable;
typedef WorkerTable* (*worker_table_creater_t)(void**table_args);
typedef ServerTable* (*server_table_creater_t)(void**table_args);
class TableFactory {
public:
virtual WorkerTable* CreateTable(const std::string& args) = 0;
static TableFactory* GetFactory(const std::string& type);
static WorkerTable* CreateTable(
EleType ele_type,
std::string& type,
void**table_args);
static WorkerTable* CreateTable(
EleType ele_type1,
EleType ele_type2,
std::string& type,
void**table_args);
static void RegisterTable(
std::string& type,
worker_table_creater_t wt,
server_table_creater_t st);
private:
static WorkerTable* CreateTable(
std::string& type,
void**table_args);
TableFactory() = default;
static std::unordered_map<
std::string,
std::pair<worker_table_creater_t,
server_table_creater_t> > table_creaters_;
};
void MV_CreateTable(const std::string& type, const std::string& args, void** out) {
TableFactory* tf = TableFactory::GetFactory(type);
*out = static_cast<void*>(tf->CreateTable(args));
}
class ArrayTableFactory : public TableFactory {
public:
WorkerTable* CreateTable(const std::string& args) override {
// new ArrayServer()
namespace table_factory {
struct TableRegister {
TableRegister(std::string type, worker_table_creater_t wt,
server_table_creater_t st) {
TableFactory::RegisterTable(type, wt, st);
}
}
// TODO(feiga): Refine
};
} // namespace table_factory
// TODO(feiga): provide better table creator method
// Abstract Factory to create server and worker
//class TableFactory {
// // static TableFactory* GetTableFactory();
// virtual WorkerTable* CreateWorker() = 0;
// virtual ServerTable* CreateServer() = 0;
// static TableFactory* fatory_;
//};
} // namespace multiverso
//namespace table {
#endif // MULTIVERSO_TABLE_FACTORY_H_
//}
//class TableBuilder {
//public:
// TableBuilder& SetArribute(const std::string& name, const std::string& val);
// WorkerTable* WorkerTableBuild();
// ServerTable* ServerTableBuild();
//private:
// std::string Get(const std::string& name) const;
// std::unordered_map<std::string, std::string> params_;
//};
//class Context {
//public:
// Context& SetArribute(const std::string& name, const std::string& val);
//
// int get_int(const std::string& name) {
// CHECK(params_.find(name) != params_.end());
// return atoi(params_[name].c_str());
// }
//#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
//#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
//
//private:
// std::string get(const std::string& name) const;
// std::map<std::string, std::string> params_;
//};
//#include <functional>
//#include <string>
//#include <map>
//
//#include "multiverso/util/log.h"
//#include "multiverso/table/array_table.h"
//
//namespace multiverso {
//
//class WorkerTable;
//
//class TableRegistry {
//class TableFactory {
//public:
// typedef WorkerTable* (*Creator)(const Context&);
// typedef std::map<std::string, Creator> Registry;
// static TableRegistry* Global();
//
// static void AddCreator(const std::string& type, Creator creator) {
// Registry& r = registry();
// r[type] = creator;
// }
//
// static Registry& registry() {
// static Registry instance;
// return instance;
// }
//
// static WorkerTable* CreateTable(const std::string& type, const Context& context) {
// Registry& r = registry();
// return r[type](context);
//
// }
//
//private:
// TableRegistry() {}
// virtual WorkerTable* CreateTable(const std::string& args) = 0;
// static TableFactory* GetFactory(const std::string& type);
//};
//
//class TableRegisterer {
//public:
// TableRegisterer(const std::string& type,
// WorkerTable* (*creator)(const Context&)) {
// TableRegistry::AddCreator(type, creator);
// }
//};
//
//#define REGISTER_TABLE_CREATOR(type, creator) \
// static TableRegisterer(type, creator) g_creator_##type(#type, creator);
//
}
#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
//void MV_CreateTable(const std::string& type, const std::string& args, void** out) {
// TableFactory* tf = TableFactory::GetFactory(type);
// *out = static_cast<void*>(tf->CreateTable(args));
//}
//
//class ArrayTableFactory : public TableFactory {
//public:
// WorkerTable* CreateTable(const std::string& args) override {
// // new ArrayServer()
// }
//}
//// TODO(feiga): Refine
//
//// TODO(feiga): provide better table creator method
//// Abstract Factory to create server and worker
////class TableFactory {
//// // static TableFactory* GetTableFactory();
//// virtual WorkerTable* CreateWorker() = 0;
//// virtual ServerTable* CreateServer() = 0;
//// static TableFactory* fatory_;
////};
//
////namespace table {
//
////}
//
////class TableBuilder {
////public:
//// TableBuilder& SetArribute(const std::string& name, const std::string& val);
//// WorkerTable* WorkerTableBuild();
//// ServerTable* ServerTableBuild();
////private:
//// std::string Get(const std::string& name) const;
//// std::unordered_map<std::string, std::string> params_;
////};
//
////class Context {
////public:
//// Context& SetArribute(const std::string& name, const std::string& val);
////
//// int get_int(const std::string& name) {
//// CHECK(params_.find(name) != params_.end());
//// return atoi(params_[name].c_str());
//// }
////
////private:
//// std::string get(const std::string& name) const;
//// std::map<std::string, std::string> params_;
////};
////
////class WorkerTable;
////
////class TableRegistry {
////public:
//// typedef WorkerTable* (*Creator)(const Context&);
//// typedef std::map<std::string, Creator> Registry;
//// static TableRegistry* Global();
////
//// static void AddCreator(const std::string& type, Creator creator) {
//// Registry& r = registry();
//// r[type] = creator;
//// }
////
//// static Registry& registry() {
//// static Registry instance;
//// return instance;
//// }
////
//// static WorkerTable* CreateTable(const std::string& type, const Context& context) {
//// Registry& r = registry();
//// return r[type](context);
////
//// }
////
////private:
//// TableRegistry() {}
////};
////
////class TableRegisterer {
////public:
//// TableRegisterer(const std::string& type,
//// WorkerTable* (*creator)(const Context&)) {
//// TableRegistry::AddCreator(type, creator);
//// }
////};
////
////#define REGISTER_TABLE_CREATOR(type, creator) \
//// static TableRegisterer(type, creator) g_creator_##type(#type, creator);
////
//}
//
//#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_

Просмотреть файл

@ -70,23 +70,7 @@ public:
virtual void ProcessGet(const std::vector<Blob>& data,
std::vector<Blob>* result) = 0;
};
// TODO(feiga): provide better table creator method
// Abstract Factory to create server and worker
// my new implementation
class TableFactory {
public:
template<typename Key, typename Val = void>
static WorkerTable* CreateTable(const std::string& table_type,
const std::vector<void*>& table_args,
const std::string& dump_file_path = "");
virtual ~TableFactory() {}
protected:
virtual WorkerTable* CreateWorkerTable() = 0;
virtual ServerTable* CreateServerTable() = 0;
};
// older one
class TableHelper {
public:
TableHelper() {}
@ -98,27 +82,6 @@ protected:
virtual ServerTable* CreateServerTable() = 0;
};
// template<typename T>
// class MatrixTableFactory;
// template function should be defined in the same file with declaration
// template<typename Key, typename Val>
// WorkerTable* TableFactory::CreateTable(const std::string& table_type,
// const std::vector<void*>& table_args, const std::string& dump_file_path) {
// bool worker = (MV_WorkerId() >= 0);
// bool server = (MV_ServerId() >= 0);
// TableFactory* factory;
// if (table_type == "matrix") {
// factory = new MatrixTableFactory<Key>(table_args);
// }
// else if (table_type == "array") {
// }
// else CHECK(false);
//
// if (server) factory->CreateServerTable();
// if (worker) return factory->CreateWorkerTable();
// return nullptr;
// }
} // namespace multiverso
#endif // MULTIVERSO_TABLE_INTERFACE_H_

Просмотреть файл

@ -191,6 +191,7 @@
<ClInclude Include="..\include\multiverso\table\kv_table.h" />
<ClInclude Include="..\include\multiverso\table\matrix_table.h" />
<ClInclude Include="..\include\multiverso\table\sparse_matrix_table.h" />
<ClInclude Include="..\include\multiverso\table_factory.h" />
<ClInclude Include="..\include\multiverso\table_interface.h" />
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h" />
<ClInclude Include="..\include\multiverso\updater\sgd_updater.h" />
@ -226,6 +227,7 @@
<ClCompile Include="table\array_table.cpp" />
<ClCompile Include="table\matrix_table.cpp" />
<ClCompile Include="table\sparse_matrix_table.cpp" />
<ClCompile Include="table_factory.cpp" />
<ClCompile Include="updater\updater.cpp" />
<ClCompile Include="timer.cpp" />
<ClCompile Include="util\log.cpp" />

102
src/table_factory.cpp Normal file
Просмотреть файл

@ -0,0 +1,102 @@
#include "multiverso/table_factory.h"
#include "multiverso/table/array_table.h"
#include "multiverso/table/matrix_table.h"
namespace multiverso {
std::unordered_map<std::string,
std::pair<worker_table_creater_t, server_table_creater_t> >
TableFactory::table_creaters_;
std::string ele_type_str(EleType ele_type) {
switch (ele_type) {
case kInt:
return "int_";
case kFloat:
return "float_";
case kDouble:
return "kdouble_";
}
return "unknown";
}
WorkerTable* TableFactory::CreateTable(
std::string& type,
void**table_args) {
CHECK(table_creaters_.find(type) != table_creaters_.end());
if (MV_ServerId() >= 0) {
table_creaters_[type].second(table_args);
}
if (MV_WorkerId() >= 0) {
return table_creaters_[type].first(table_args);
}
return nullptr;
}
WorkerTable* TableFactory::CreateTable(
EleType ele_type,
std::string& type,
void**table_args) {
std::string typestr = ele_type_str(ele_type) + type;
return CreateTable(typestr, table_args);
}
WorkerTable* TableFactory::CreateTable(
EleType ele_type1,
EleType ele_type2,
std::string& type,
void**table_args) {
std::string typestr = ele_type_str(ele_type1)
+ ele_type_str(ele_type2) + type;
return CreateTable(typestr, table_args);
}
void TableFactory::RegisterTable(
std::string& type,
worker_table_creater_t wt,
server_table_creater_t st) {
CHECK(table_creaters_.find(type) == table_creaters_.end());
table_creaters_[type] = std::make_pair(wt, st);
}
#define MV_REGISTER_TABLE(type, worker_table_creater, \
server_table_creater) \
namespace table_factory { \
TableRegister type##_table_register(#type, \
worker_table_creater, \
server_table_creater); \
}
#define MV_REGISTER_TABLE_WITH_BASIC_TYPE(table_name, \
worker_table_creater, server_table_creater) \
MV_REGISTER_TABLE(int_##table_name, worker_table_creater<int>, \
server_table_creater<int>); \
MV_REGISTER_TABLE(float_##table_name, worker_table_creater<float>, \
server_table_creater<float>); \
MV_REGISTER_TABLE(double_##table_name, worker_table_creater<double>,\
server_table_creater<double>);
template<typename T>
WorkerTable* create_array_worker(void **args) {
return new ArrayWorker<T>(*(size_t*)(*args));
}
template<typename T>
ServerTable* create_array_server(void **args) {
return new ArrayServer<T>(*(size_t*)(*args));
}
MV_REGISTER_TABLE_WITH_BASIC_TYPE(array, create_array_worker, create_array_server);
template<typename T>
WorkerTable* create_matrix_worker(void **args) {
return new MatrixWorkerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args+1)));
}
template<typename T>
ServerTable* create_matrix_server(void **args) {
return new MatrixServerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args + 1)));
}
MV_REGISTER_TABLE_WITH_BASIC_TYPE(matrix, create_matrix_worker, create_matrix_server);
} // namespace multiverso