New table factory
This commit is contained in:
Родитель
a0e2d7c04c
Коммит
c7937dbcb7
|
@ -276,13 +276,12 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
|
|||
|
||||
MatrixWorkerTable<int>* worker_table =
|
||||
static_cast<MatrixWorkerTable<int>*>((new MatrixTableHelper<int>(num_row, num_col))->CreateTable());
|
||||
//MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
|
||||
//MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
|
||||
//if restore = true, will restore server data and return the next iter number of last dump file
|
||||
//else do nothing and return 0
|
||||
if (worker_table == nullptr) {
|
||||
//no worker in this node
|
||||
}
|
||||
|
||||
//if restore = true, will restore server data and return the next iter number of last dump file
|
||||
//else do nothing and return 0
|
||||
// int begin_iter = MV_LoadTable("serverTable_");
|
||||
MV_Barrier();//won't dump data without parameters
|
||||
|
||||
|
@ -319,7 +318,6 @@ void TestAllreduce(int argc, char* argv[]) {
|
|||
MV_ShutDown();
|
||||
}
|
||||
|
||||
|
||||
template<typename WT, typename ST>
|
||||
void TestmatrixPerformance(int argc, char* argv[],
|
||||
std::function<std::shared_ptr<WT>(int num_row, int num_col)>CreateWorkerTable,
|
||||
|
|
|
@ -92,29 +92,6 @@ protected:
|
|||
integer_t num_col_;
|
||||
};
|
||||
|
||||
////new implementation
|
||||
//template<typename T>
|
||||
//class MatrixTableFactory : public TableFactory {
|
||||
//public:
|
||||
// /*
|
||||
// * args[0] : num_row
|
||||
// * args[1] : num_col
|
||||
// */
|
||||
// MatrixTableFactory(const std::vector<void*>&args) {
|
||||
// CHECK(args.size() == 2);
|
||||
// num_row_ = *(int*)args[0];
|
||||
// num_col_ = *(int*)args[1];
|
||||
// }
|
||||
//protected:
|
||||
// WorkerTable* CreateWorkerTable() override{
|
||||
// return new MatrixWorkerTable<T>(num_row_, num_col_);
|
||||
// }
|
||||
// ServerTable* CreateServerTable() override{
|
||||
// return new MatrixServerTable<T>(num_row_, num_col_);
|
||||
// }
|
||||
// int num_row_;
|
||||
// int num_col_;
|
||||
//};
|
||||
}
|
||||
|
||||
#endif // MULTIVERSO_MATRIX_TABLE_H_
|
||||
|
|
|
@ -1,113 +1,168 @@
|
|||
#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
#ifndef MULTIVERSO_TABLE_FACTORY_H_
|
||||
#define MULTIVERSO_TABLE_FACTORY_H_
|
||||
|
||||
#include "multiverso/table_interface.h"
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "multiverso/util/log.h"
|
||||
#include "multiverso/table/array_table.h"
|
||||
|
||||
namespace multiverso {
|
||||
enum EleType {
|
||||
kInt = 0, kFloat, kDouble
|
||||
};
|
||||
|
||||
class WorkerTable;
|
||||
typedef WorkerTable* (*worker_table_creater_t)(void**table_args);
|
||||
typedef ServerTable* (*server_table_creater_t)(void**table_args);
|
||||
|
||||
class TableFactory {
|
||||
public:
|
||||
virtual WorkerTable* CreateTable(const std::string& args) = 0;
|
||||
static TableFactory* GetFactory(const std::string& type);
|
||||
static WorkerTable* CreateTable(
|
||||
EleType ele_type,
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
static WorkerTable* CreateTable(
|
||||
EleType ele_type1,
|
||||
EleType ele_type2,
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
static void RegisterTable(
|
||||
std::string& type,
|
||||
worker_table_creater_t wt,
|
||||
server_table_creater_t st);
|
||||
private:
|
||||
static WorkerTable* CreateTable(
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
TableFactory() = default;
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::pair<worker_table_creater_t,
|
||||
server_table_creater_t> > table_creaters_;
|
||||
};
|
||||
|
||||
|
||||
void MV_CreateTable(const std::string& type, const std::string& args, void** out) {
|
||||
TableFactory* tf = TableFactory::GetFactory(type);
|
||||
*out = static_cast<void*>(tf->CreateTable(args));
|
||||
}
|
||||
|
||||
class ArrayTableFactory : public TableFactory {
|
||||
public:
|
||||
WorkerTable* CreateTable(const std::string& args) override {
|
||||
// new ArrayServer()
|
||||
namespace table_factory {
|
||||
struct TableRegister {
|
||||
TableRegister(std::string type, worker_table_creater_t wt,
|
||||
server_table_creater_t st) {
|
||||
TableFactory::RegisterTable(type, wt, st);
|
||||
}
|
||||
}
|
||||
// TODO(feiga): Refine
|
||||
};
|
||||
} // namespace table_factory
|
||||
|
||||
// TODO(feiga): provide better table creator method
|
||||
// Abstract Factory to create server and worker
|
||||
//class TableFactory {
|
||||
// // static TableFactory* GetTableFactory();
|
||||
// virtual WorkerTable* CreateWorker() = 0;
|
||||
// virtual ServerTable* CreateServer() = 0;
|
||||
// static TableFactory* fatory_;
|
||||
//};
|
||||
} // namespace multiverso
|
||||
|
||||
//namespace table {
|
||||
#endif // MULTIVERSO_TABLE_FACTORY_H_
|
||||
|
||||
//}
|
||||
|
||||
//class TableBuilder {
|
||||
//public:
|
||||
// TableBuilder& SetArribute(const std::string& name, const std::string& val);
|
||||
// WorkerTable* WorkerTableBuild();
|
||||
// ServerTable* ServerTableBuild();
|
||||
//private:
|
||||
// std::string Get(const std::string& name) const;
|
||||
// std::unordered_map<std::string, std::string> params_;
|
||||
//};
|
||||
|
||||
//class Context {
|
||||
//public:
|
||||
// Context& SetArribute(const std::string& name, const std::string& val);
|
||||
//
|
||||
// int get_int(const std::string& name) {
|
||||
// CHECK(params_.find(name) != params_.end());
|
||||
// return atoi(params_[name].c_str());
|
||||
// }
|
||||
//#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
//#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
//
|
||||
//private:
|
||||
// std::string get(const std::string& name) const;
|
||||
// std::map<std::string, std::string> params_;
|
||||
//};
|
||||
//#include <functional>
|
||||
//#include <string>
|
||||
//#include <map>
|
||||
//
|
||||
//#include "multiverso/util/log.h"
|
||||
//#include "multiverso/table/array_table.h"
|
||||
//
|
||||
//namespace multiverso {
|
||||
//
|
||||
//class WorkerTable;
|
||||
//
|
||||
//class TableRegistry {
|
||||
//class TableFactory {
|
||||
//public:
|
||||
// typedef WorkerTable* (*Creator)(const Context&);
|
||||
// typedef std::map<std::string, Creator> Registry;
|
||||
// static TableRegistry* Global();
|
||||
//
|
||||
// static void AddCreator(const std::string& type, Creator creator) {
|
||||
// Registry& r = registry();
|
||||
// r[type] = creator;
|
||||
// }
|
||||
//
|
||||
// static Registry& registry() {
|
||||
// static Registry instance;
|
||||
// return instance;
|
||||
// }
|
||||
//
|
||||
// static WorkerTable* CreateTable(const std::string& type, const Context& context) {
|
||||
// Registry& r = registry();
|
||||
// return r[type](context);
|
||||
//
|
||||
// }
|
||||
//
|
||||
//private:
|
||||
// TableRegistry() {}
|
||||
// virtual WorkerTable* CreateTable(const std::string& args) = 0;
|
||||
// static TableFactory* GetFactory(const std::string& type);
|
||||
//};
|
||||
//
|
||||
//class TableRegisterer {
|
||||
//public:
|
||||
// TableRegisterer(const std::string& type,
|
||||
// WorkerTable* (*creator)(const Context&)) {
|
||||
// TableRegistry::AddCreator(type, creator);
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//#define REGISTER_TABLE_CREATOR(type, creator) \
|
||||
// static TableRegisterer(type, creator) g_creator_##type(#type, creator);
|
||||
//
|
||||
}
|
||||
|
||||
#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
//void MV_CreateTable(const std::string& type, const std::string& args, void** out) {
|
||||
// TableFactory* tf = TableFactory::GetFactory(type);
|
||||
// *out = static_cast<void*>(tf->CreateTable(args));
|
||||
//}
|
||||
//
|
||||
//class ArrayTableFactory : public TableFactory {
|
||||
//public:
|
||||
// WorkerTable* CreateTable(const std::string& args) override {
|
||||
// // new ArrayServer()
|
||||
// }
|
||||
//}
|
||||
//// TODO(feiga): Refine
|
||||
//
|
||||
//// TODO(feiga): provide better table creator method
|
||||
//// Abstract Factory to create server and worker
|
||||
////class TableFactory {
|
||||
//// // static TableFactory* GetTableFactory();
|
||||
//// virtual WorkerTable* CreateWorker() = 0;
|
||||
//// virtual ServerTable* CreateServer() = 0;
|
||||
//// static TableFactory* fatory_;
|
||||
////};
|
||||
//
|
||||
////namespace table {
|
||||
//
|
||||
////}
|
||||
//
|
||||
////class TableBuilder {
|
||||
////public:
|
||||
//// TableBuilder& SetArribute(const std::string& name, const std::string& val);
|
||||
//// WorkerTable* WorkerTableBuild();
|
||||
//// ServerTable* ServerTableBuild();
|
||||
////private:
|
||||
//// std::string Get(const std::string& name) const;
|
||||
//// std::unordered_map<std::string, std::string> params_;
|
||||
////};
|
||||
//
|
||||
////class Context {
|
||||
////public:
|
||||
//// Context& SetArribute(const std::string& name, const std::string& val);
|
||||
////
|
||||
//// int get_int(const std::string& name) {
|
||||
//// CHECK(params_.find(name) != params_.end());
|
||||
//// return atoi(params_[name].c_str());
|
||||
//// }
|
||||
////
|
||||
////private:
|
||||
//// std::string get(const std::string& name) const;
|
||||
//// std::map<std::string, std::string> params_;
|
||||
////};
|
||||
////
|
||||
////class WorkerTable;
|
||||
////
|
||||
////class TableRegistry {
|
||||
////public:
|
||||
//// typedef WorkerTable* (*Creator)(const Context&);
|
||||
//// typedef std::map<std::string, Creator> Registry;
|
||||
//// static TableRegistry* Global();
|
||||
////
|
||||
//// static void AddCreator(const std::string& type, Creator creator) {
|
||||
//// Registry& r = registry();
|
||||
//// r[type] = creator;
|
||||
//// }
|
||||
////
|
||||
//// static Registry& registry() {
|
||||
//// static Registry instance;
|
||||
//// return instance;
|
||||
//// }
|
||||
////
|
||||
//// static WorkerTable* CreateTable(const std::string& type, const Context& context) {
|
||||
//// Registry& r = registry();
|
||||
//// return r[type](context);
|
||||
////
|
||||
//// }
|
||||
////
|
||||
////private:
|
||||
//// TableRegistry() {}
|
||||
////};
|
||||
////
|
||||
////class TableRegisterer {
|
||||
////public:
|
||||
//// TableRegisterer(const std::string& type,
|
||||
//// WorkerTable* (*creator)(const Context&)) {
|
||||
//// TableRegistry::AddCreator(type, creator);
|
||||
//// }
|
||||
////};
|
||||
////
|
||||
////#define REGISTER_TABLE_CREATOR(type, creator) \
|
||||
//// static TableRegisterer(type, creator) g_creator_##type(#type, creator);
|
||||
////
|
||||
//}
|
||||
//
|
||||
//#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
|
@ -70,23 +70,7 @@ public:
|
|||
virtual void ProcessGet(const std::vector<Blob>& data,
|
||||
std::vector<Blob>* result) = 0;
|
||||
};
|
||||
|
||||
// TODO(feiga): provide better table creator method
|
||||
// Abstract Factory to create server and worker
|
||||
// my new implementation
|
||||
class TableFactory {
|
||||
public:
|
||||
template<typename Key, typename Val = void>
|
||||
static WorkerTable* CreateTable(const std::string& table_type,
|
||||
const std::vector<void*>& table_args,
|
||||
const std::string& dump_file_path = "");
|
||||
virtual ~TableFactory() {}
|
||||
protected:
|
||||
virtual WorkerTable* CreateWorkerTable() = 0;
|
||||
virtual ServerTable* CreateServerTable() = 0;
|
||||
};
|
||||
|
||||
// older one
|
||||
|
||||
class TableHelper {
|
||||
public:
|
||||
TableHelper() {}
|
||||
|
@ -98,27 +82,6 @@ protected:
|
|||
virtual ServerTable* CreateServerTable() = 0;
|
||||
};
|
||||
|
||||
// template<typename T>
|
||||
// class MatrixTableFactory;
|
||||
// template function should be defined in the same file with declaration
|
||||
// template<typename Key, typename Val>
|
||||
// WorkerTable* TableFactory::CreateTable(const std::string& table_type,
|
||||
// const std::vector<void*>& table_args, const std::string& dump_file_path) {
|
||||
// bool worker = (MV_WorkerId() >= 0);
|
||||
// bool server = (MV_ServerId() >= 0);
|
||||
// TableFactory* factory;
|
||||
// if (table_type == "matrix") {
|
||||
// factory = new MatrixTableFactory<Key>(table_args);
|
||||
// }
|
||||
// else if (table_type == "array") {
|
||||
// }
|
||||
// else CHECK(false);
|
||||
//
|
||||
// if (server) factory->CreateServerTable();
|
||||
// if (worker) return factory->CreateWorkerTable();
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
} // namespace multiverso
|
||||
|
||||
#endif // MULTIVERSO_TABLE_INTERFACE_H_
|
||||
|
|
|
@ -191,6 +191,7 @@
|
|||
<ClInclude Include="..\include\multiverso\table\kv_table.h" />
|
||||
<ClInclude Include="..\include\multiverso\table\matrix_table.h" />
|
||||
<ClInclude Include="..\include\multiverso\table\sparse_matrix_table.h" />
|
||||
<ClInclude Include="..\include\multiverso\table_factory.h" />
|
||||
<ClInclude Include="..\include\multiverso\table_interface.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\adagrad_updater.h" />
|
||||
<ClInclude Include="..\include\multiverso\updater\sgd_updater.h" />
|
||||
|
@ -226,6 +227,7 @@
|
|||
<ClCompile Include="table\array_table.cpp" />
|
||||
<ClCompile Include="table\matrix_table.cpp" />
|
||||
<ClCompile Include="table\sparse_matrix_table.cpp" />
|
||||
<ClCompile Include="table_factory.cpp" />
|
||||
<ClCompile Include="updater\updater.cpp" />
|
||||
<ClCompile Include="timer.cpp" />
|
||||
<ClCompile Include="util\log.cpp" />
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
#include "multiverso/table_factory.h"
|
||||
|
||||
#include "multiverso/table/array_table.h"
|
||||
#include "multiverso/table/matrix_table.h"
|
||||
|
||||
namespace multiverso {
|
||||
|
||||
std::unordered_map<std::string,
|
||||
std::pair<worker_table_creater_t, server_table_creater_t> >
|
||||
TableFactory::table_creaters_;
|
||||
|
||||
std::string ele_type_str(EleType ele_type) {
|
||||
switch (ele_type) {
|
||||
case kInt:
|
||||
return "int_";
|
||||
case kFloat:
|
||||
return "float_";
|
||||
case kDouble:
|
||||
return "kdouble_";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
CHECK(table_creaters_.find(type) != table_creaters_.end());
|
||||
|
||||
if (MV_ServerId() >= 0) {
|
||||
table_creaters_[type].second(table_args);
|
||||
}
|
||||
if (MV_WorkerId() >= 0) {
|
||||
return table_creaters_[type].first(table_args);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
EleType ele_type,
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
std::string typestr = ele_type_str(ele_type) + type;
|
||||
return CreateTable(typestr, table_args);
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
EleType ele_type1,
|
||||
EleType ele_type2,
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
std::string typestr = ele_type_str(ele_type1)
|
||||
+ ele_type_str(ele_type2) + type;
|
||||
return CreateTable(typestr, table_args);
|
||||
}
|
||||
|
||||
void TableFactory::RegisterTable(
|
||||
std::string& type,
|
||||
worker_table_creater_t wt,
|
||||
server_table_creater_t st) {
|
||||
CHECK(table_creaters_.find(type) == table_creaters_.end());
|
||||
|
||||
table_creaters_[type] = std::make_pair(wt, st);
|
||||
}
|
||||
|
||||
#define MV_REGISTER_TABLE(type, worker_table_creater, \
|
||||
server_table_creater) \
|
||||
namespace table_factory { \
|
||||
TableRegister type##_table_register(#type, \
|
||||
worker_table_creater, \
|
||||
server_table_creater); \
|
||||
}
|
||||
|
||||
#define MV_REGISTER_TABLE_WITH_BASIC_TYPE(table_name, \
|
||||
worker_table_creater, server_table_creater) \
|
||||
MV_REGISTER_TABLE(int_##table_name, worker_table_creater<int>, \
|
||||
server_table_creater<int>); \
|
||||
MV_REGISTER_TABLE(float_##table_name, worker_table_creater<float>, \
|
||||
server_table_creater<float>); \
|
||||
MV_REGISTER_TABLE(double_##table_name, worker_table_creater<double>,\
|
||||
server_table_creater<double>);
|
||||
|
||||
template<typename T>
|
||||
WorkerTable* create_array_worker(void **args) {
|
||||
return new ArrayWorker<T>(*(size_t*)(*args));
|
||||
}
|
||||
template<typename T>
|
||||
ServerTable* create_array_server(void **args) {
|
||||
return new ArrayServer<T>(*(size_t*)(*args));
|
||||
}
|
||||
MV_REGISTER_TABLE_WITH_BASIC_TYPE(array, create_array_worker, create_array_server);
|
||||
|
||||
template<typename T>
|
||||
WorkerTable* create_matrix_worker(void **args) {
|
||||
return new MatrixWorkerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args+1)));
|
||||
}
|
||||
template<typename T>
|
||||
ServerTable* create_matrix_server(void **args) {
|
||||
return new MatrixServerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args + 1)));
|
||||
}
|
||||
MV_REGISTER_TABLE_WITH_BASIC_TYPE(matrix, create_matrix_worker, create_matrix_server);
|
||||
|
||||
} // namespace multiverso
|
Загрузка…
Ссылка в новой задаче