diff --git a/Test/main.cpp b/Test/main.cpp index 693a7d9..1142144 100644 --- a/Test/main.cpp +++ b/Test/main.cpp @@ -276,13 +276,12 @@ void TestCheckPoint(int argc, char* argv[], bool restore){ MatrixWorkerTable* worker_table = static_cast*>((new MatrixTableHelper(num_row, num_col))->CreateTable()); - //MatrixWorkerTable* worker_table = new MatrixWorkerTable(num_row, num_col); - //MatrixServerTable* server_table = new MatrixServerTable(num_row, num_col); - //if restore = true, will restore server data and return the next iter number of last dump file - //else do nothing and return 0 if (worker_table == nullptr) { //no worker in this node } + + //if restore = true, will restore server data and return the next iter number of last dump file + //else do nothing and return 0 // int begin_iter = MV_LoadTable("serverTable_"); MV_Barrier();//won't dump data without parameters @@ -319,7 +318,6 @@ void TestAllreduce(int argc, char* argv[]) { MV_ShutDown(); } - template void TestmatrixPerformance(int argc, char* argv[], std::function(int num_row, int num_col)>CreateWorkerTable, diff --git a/include/multiverso/table/matrix_table.h b/include/multiverso/table/matrix_table.h index b251bd7..591c55f 100644 --- a/include/multiverso/table/matrix_table.h +++ b/include/multiverso/table/matrix_table.h @@ -92,29 +92,6 @@ protected: integer_t num_col_; }; -////new implementation -//template -//class MatrixTableFactory : public TableFactory { -//public: -// /* -// * args[0] : num_row -// * args[1] : num_col -// */ -// MatrixTableFactory(const std::vector&args) { -// CHECK(args.size() == 2); -// num_row_ = *(int*)args[0]; -// num_col_ = *(int*)args[1]; -// } -//protected: -// WorkerTable* CreateWorkerTable() override{ -// return new MatrixWorkerTable(num_row_, num_col_); -// } -// ServerTable* CreateServerTable() override{ -// return new MatrixServerTable(num_row_, num_col_); -// } -// int num_row_; -// int num_col_; -//}; } #endif // MULTIVERSO_MATRIX_TABLE_H_ diff --git a/include/multiverso/table_factory.h b/include/multiverso/table_factory.h index 0133a64..f564efa 100644 --- a/include/multiverso/table_factory.h +++ b/include/multiverso/table_factory.h @@ -1,113 +1,168 @@ -#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ -#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ +#ifndef MULTIVERSO_TABLE_FACTORY_H_ +#define MULTIVERSO_TABLE_FACTORY_H_ + +#include "multiverso/table_interface.h" -#include #include -#include - -#include "multiverso/util/log.h" -#include "multiverso/table/array_table.h" namespace multiverso { +enum EleType { + kInt = 0, kFloat, kDouble +}; -class WorkerTable; +typedef WorkerTable* (*worker_table_creater_t)(void**table_args); +typedef ServerTable* (*server_table_creater_t)(void**table_args); class TableFactory { public: - virtual WorkerTable* CreateTable(const std::string& args) = 0; - static TableFactory* GetFactory(const std::string& type); + static WorkerTable* CreateTable( + EleType ele_type, + std::string& type, + void**table_args); + static WorkerTable* CreateTable( + EleType ele_type1, + EleType ele_type2, + std::string& type, + void**table_args); + static void RegisterTable( + std::string& type, + worker_table_creater_t wt, + server_table_creater_t st); +private: + static WorkerTable* CreateTable( + std::string& type, + void**table_args); + TableFactory() = default; + static std::unordered_map< + std::string, + std::pair > table_creaters_; }; - -void MV_CreateTable(const std::string& type, const std::string& args, void** out) { - TableFactory* tf = TableFactory::GetFactory(type); - *out = static_cast(tf->CreateTable(args)); -} - -class ArrayTableFactory : public TableFactory { -public: - WorkerTable* CreateTable(const std::string& args) override { - // new ArrayServer() +namespace table_factory { +struct TableRegister { + TableRegister(std::string type, worker_table_creater_t wt, + server_table_creater_t st) { + TableFactory::RegisterTable(type, wt, st); } -} -// TODO(feiga): Refine +}; +} // namespace table_factory -// TODO(feiga): provide better table creator method -// Abstract Factory to create server and worker -//class TableFactory { -// // static TableFactory* GetTableFactory(); -// virtual WorkerTable* CreateWorker() = 0; -// virtual ServerTable* CreateServer() = 0; -// static TableFactory* fatory_; -//}; +} // namespace multiverso -//namespace table { +#endif // MULTIVERSO_TABLE_FACTORY_H_ -//} -//class TableBuilder { -//public: -// TableBuilder& SetArribute(const std::string& name, const std::string& val); -// WorkerTable* WorkerTableBuild(); -// ServerTable* ServerTableBuild(); -//private: -// std::string Get(const std::string& name) const; -// std::unordered_map params_; -//}; - -//class Context { -//public: -// Context& SetArribute(const std::string& name, const std::string& val); -// -// int get_int(const std::string& name) { -// CHECK(params_.find(name) != params_.end()); -// return atoi(params_[name].c_str()); -// } +//#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ +//#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ // -//private: -// std::string get(const std::string& name) const; -// std::map params_; -//}; +//#include +//#include +//#include +// +//#include "multiverso/util/log.h" +//#include "multiverso/table/array_table.h" +// +//namespace multiverso { // //class WorkerTable; // -//class TableRegistry { +//class TableFactory { //public: -// typedef WorkerTable* (*Creator)(const Context&); -// typedef std::map Registry; -// static TableRegistry* Global(); -// -// static void AddCreator(const std::string& type, Creator creator) { -// Registry& r = registry(); -// r[type] = creator; -// } -// -// static Registry& registry() { -// static Registry instance; -// return instance; -// } -// -// static WorkerTable* CreateTable(const std::string& type, const Context& context) { -// Registry& r = registry(); -// return r[type](context); -// -// } -// -//private: -// TableRegistry() {} +// virtual WorkerTable* CreateTable(const std::string& args) = 0; +// static TableFactory* GetFactory(const std::string& type); //}; // -//class TableRegisterer { -//public: -// TableRegisterer(const std::string& type, -// WorkerTable* (*creator)(const Context&)) { -// TableRegistry::AddCreator(type, creator); -// } -//}; // -//#define REGISTER_TABLE_CREATOR(type, creator) \ -// static TableRegisterer(type, creator) g_creator_##type(#type, creator); -// -} - -#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ \ No newline at end of file +//void MV_CreateTable(const std::string& type, const std::string& args, void** out) { +// TableFactory* tf = TableFactory::GetFactory(type); +// *out = static_cast(tf->CreateTable(args)); +//} +// +//class ArrayTableFactory : public TableFactory { +//public: +// WorkerTable* CreateTable(const std::string& args) override { +// // new ArrayServer() +// } +//} +//// TODO(feiga): Refine +// +//// TODO(feiga): provide better table creator method +//// Abstract Factory to create server and worker +////class TableFactory { +//// // static TableFactory* GetTableFactory(); +//// virtual WorkerTable* CreateWorker() = 0; +//// virtual ServerTable* CreateServer() = 0; +//// static TableFactory* fatory_; +////}; +// +////namespace table { +// +////} +// +////class TableBuilder { +////public: +//// TableBuilder& SetArribute(const std::string& name, const std::string& val); +//// WorkerTable* WorkerTableBuild(); +//// ServerTable* ServerTableBuild(); +////private: +//// std::string Get(const std::string& name) const; +//// std::unordered_map params_; +////}; +// +////class Context { +////public: +//// Context& SetArribute(const std::string& name, const std::string& val); +//// +//// int get_int(const std::string& name) { +//// CHECK(params_.find(name) != params_.end()); +//// return atoi(params_[name].c_str()); +//// } +//// +////private: +//// std::string get(const std::string& name) const; +//// std::map params_; +////}; +//// +////class WorkerTable; +//// +////class TableRegistry { +////public: +//// typedef WorkerTable* (*Creator)(const Context&); +//// typedef std::map Registry; +//// static TableRegistry* Global(); +//// +//// static void AddCreator(const std::string& type, Creator creator) { +//// Registry& r = registry(); +//// r[type] = creator; +//// } +//// +//// static Registry& registry() { +//// static Registry instance; +//// return instance; +//// } +//// +//// static WorkerTable* CreateTable(const std::string& type, const Context& context) { +//// Registry& r = registry(); +//// return r[type](context); +//// +//// } +//// +////private: +//// TableRegistry() {} +////}; +//// +////class TableRegisterer { +////public: +//// TableRegisterer(const std::string& type, +//// WorkerTable* (*creator)(const Context&)) { +//// TableRegistry::AddCreator(type, creator); +//// } +////}; +//// +////#define REGISTER_TABLE_CREATOR(type, creator) \ +//// static TableRegisterer(type, creator) g_creator_##type(#type, creator); +//// +//} +// +//#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_ \ No newline at end of file diff --git a/include/multiverso/table_interface.h b/include/multiverso/table_interface.h index e576a35..e8e5071 100644 --- a/include/multiverso/table_interface.h +++ b/include/multiverso/table_interface.h @@ -70,23 +70,7 @@ public: virtual void ProcessGet(const std::vector& data, std::vector* result) = 0; }; - -// TODO(feiga): provide better table creator method -// Abstract Factory to create server and worker -// my new implementation -class TableFactory { -public: - template - static WorkerTable* CreateTable(const std::string& table_type, - const std::vector& table_args, - const std::string& dump_file_path = ""); - virtual ~TableFactory() {} -protected: - virtual WorkerTable* CreateWorkerTable() = 0; - virtual ServerTable* CreateServerTable() = 0; -}; - -// older one + class TableHelper { public: TableHelper() {} @@ -98,27 +82,6 @@ protected: virtual ServerTable* CreateServerTable() = 0; }; -// template -// class MatrixTableFactory; -// template function should be defined in the same file with declaration -// template -// WorkerTable* TableFactory::CreateTable(const std::string& table_type, -// const std::vector& table_args, const std::string& dump_file_path) { -// bool worker = (MV_WorkerId() >= 0); -// bool server = (MV_ServerId() >= 0); -// TableFactory* factory; -// if (table_type == "matrix") { -// factory = new MatrixTableFactory(table_args); -// } -// else if (table_type == "array") { -// } -// else CHECK(false); -// -// if (server) factory->CreateServerTable(); -// if (worker) return factory->CreateWorkerTable(); -// return nullptr; -// } - } // namespace multiverso #endif // MULTIVERSO_TABLE_INTERFACE_H_ diff --git a/src/Multiverso.vcxproj b/src/Multiverso.vcxproj index 304cba8..7b6ac00 100644 --- a/src/Multiverso.vcxproj +++ b/src/Multiverso.vcxproj @@ -191,6 +191,7 @@ + @@ -226,6 +227,7 @@ + diff --git a/src/table_factory.cpp b/src/table_factory.cpp new file mode 100644 index 0000000..94342da --- /dev/null +++ b/src/table_factory.cpp @@ -0,0 +1,102 @@ +#include "multiverso/table_factory.h" + +#include "multiverso/table/array_table.h" +#include "multiverso/table/matrix_table.h" + +namespace multiverso { + +std::unordered_map > + TableFactory::table_creaters_; + +std::string ele_type_str(EleType ele_type) { + switch (ele_type) { + case kInt: + return "int_"; + case kFloat: + return "float_"; + case kDouble: + return "kdouble_"; + } + return "unknown"; +} + +WorkerTable* TableFactory::CreateTable( + std::string& type, + void**table_args) { + CHECK(table_creaters_.find(type) != table_creaters_.end()); + + if (MV_ServerId() >= 0) { + table_creaters_[type].second(table_args); + } + if (MV_WorkerId() >= 0) { + return table_creaters_[type].first(table_args); + } + return nullptr; +} + +WorkerTable* TableFactory::CreateTable( + EleType ele_type, + std::string& type, + void**table_args) { + std::string typestr = ele_type_str(ele_type) + type; + return CreateTable(typestr, table_args); +} + +WorkerTable* TableFactory::CreateTable( + EleType ele_type1, + EleType ele_type2, + std::string& type, + void**table_args) { + std::string typestr = ele_type_str(ele_type1) + + ele_type_str(ele_type2) + type; + return CreateTable(typestr, table_args); +} + +void TableFactory::RegisterTable( + std::string& type, + worker_table_creater_t wt, + server_table_creater_t st) { + CHECK(table_creaters_.find(type) == table_creaters_.end()); + + table_creaters_[type] = std::make_pair(wt, st); +} + +#define MV_REGISTER_TABLE(type, worker_table_creater, \ + server_table_creater) \ + namespace table_factory { \ + TableRegister type##_table_register(#type, \ + worker_table_creater, \ + server_table_creater); \ + } + +#define MV_REGISTER_TABLE_WITH_BASIC_TYPE(table_name, \ + worker_table_creater, server_table_creater) \ + MV_REGISTER_TABLE(int_##table_name, worker_table_creater, \ + server_table_creater); \ + MV_REGISTER_TABLE(float_##table_name, worker_table_creater, \ + server_table_creater); \ + MV_REGISTER_TABLE(double_##table_name, worker_table_creater,\ + server_table_creater); + +template +WorkerTable* create_array_worker(void **args) { + return new ArrayWorker(*(size_t*)(*args)); +} +template +ServerTable* create_array_server(void **args) { + return new ArrayServer(*(size_t*)(*args)); +} +MV_REGISTER_TABLE_WITH_BASIC_TYPE(array, create_array_worker, create_array_server); + +template +WorkerTable* create_matrix_worker(void **args) { + return new MatrixWorkerTable(*(integer_t*)(*args), *(integer_t*)(*(args+1))); +} +template +ServerTable* create_matrix_server(void **args) { + return new MatrixServerTable(*(integer_t*)(*args), *(integer_t*)(*(args + 1))); +} +MV_REGISTER_TABLE_WITH_BASIC_TYPE(matrix, create_matrix_worker, create_matrix_server); + +} // namespace multiverso \ No newline at end of file