Use template trait for table creater
This commit is contained in:
Родитель
85c9eaba4b
Коммит
d71e04451f
|
@ -21,6 +21,7 @@
|
|||
#include <multiverso/table/matrix_table.h>
|
||||
#include <multiverso/table/sparse_matrix_table.h>
|
||||
#include <multiverso/updater/updater.h>
|
||||
#include <multiverso/table_factory.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
@ -86,8 +87,10 @@ void TestArray(int argc, char* argv[]) {
|
|||
|
||||
MV_Init(&argc, argv);
|
||||
|
||||
ArrayWorker<float>* shared_array = new ArrayWorker<float>(50000000);
|
||||
ArrayServer<float>* server_array = new ArrayServer<float>(50000000);
|
||||
ArrayTableInitOption option{ 50000000 };
|
||||
ArrayWorker<float>* shared_array = TableFactory::CreateTable<float>(option);
|
||||
//ArrayWorker<float>* shared_array = new ArrayWorker<float>(50000000);
|
||||
// ArrayServer<float>* server_array = new ArrayServer<float>(50000000);
|
||||
|
||||
MV_Barrier();
|
||||
Log::Info("Create tables OK\n");
|
||||
|
@ -274,11 +277,8 @@ void TestCheckPoint(int argc, char* argv[], bool restore){
|
|||
int num_row = 11, num_col = 10;
|
||||
int size = num_row * num_col;
|
||||
|
||||
MatrixWorkerTable<int>* worker_table =
|
||||
static_cast<MatrixWorkerTable<int>*>((new MatrixTableHelper<int>(num_row, num_col))->CreateTable());
|
||||
if (worker_table == nullptr) {
|
||||
//no worker in this node
|
||||
}
|
||||
MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
|
||||
MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
|
||||
|
||||
//if restore = true, will restore server data and return the next iter number of last dump file
|
||||
//else do nothing and return 0
|
||||
|
|
|
@ -69,6 +69,20 @@ protected:
|
|||
}
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
struct ArrayTableInitOption {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
namespace trait {
|
||||
template<typename EleType>
|
||||
struct OptionTrait<EleType, ArrayTableInitOption> {
|
||||
static std::string type;
|
||||
typedef ArrayWorker<EleType> worker_table_type;
|
||||
typedef ArrayServer<EleType> server_table_type;
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif // MULTIVERSO_ARRAY_TABLE_H_
|
||||
|
|
|
@ -2,36 +2,44 @@
|
|||
#define MULTIVERSO_TABLE_FACTORY_H_
|
||||
|
||||
#include "multiverso/table_interface.h"
|
||||
#include "multiverso.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace multiverso {
|
||||
enum EleType {
|
||||
kInt = 0, kFloat, kDouble
|
||||
};
|
||||
|
||||
typedef WorkerTable* (*worker_table_creater_t)(void**table_args);
|
||||
typedef ServerTable* (*server_table_creater_t)(void**table_args);
|
||||
typedef WorkerTable* (*worker_table_creater_t)(void*table_args);
|
||||
typedef ServerTable* (*server_table_creater_t)(void*table_args);
|
||||
|
||||
class TableFactory {
|
||||
public:
|
||||
static WorkerTable* CreateTable(
|
||||
EleType ele_type,
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
static WorkerTable* CreateTable(
|
||||
EleType ele_type1,
|
||||
EleType ele_type2,
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
template <typename EleType, typename OptionType>
|
||||
static typename trait::OptionTrait<EleType, OptionType>::worker_table_type*
|
||||
CreateTable(const OptionType& option) {
|
||||
std::string typestr = typeid(EleType).name() + '_' +
|
||||
trait::OptionTrait<EleType, OptionType>::type;
|
||||
return InnerCreateTable<EleType, OptionType>(typestr, option);
|
||||
}
|
||||
static void RegisterTable(
|
||||
std::string& type,
|
||||
worker_table_creater_t wt,
|
||||
server_table_creater_t st);
|
||||
private:
|
||||
static WorkerTable* CreateTable(
|
||||
std::string& type,
|
||||
void**table_args);
|
||||
template <typename EleType, typename OptionType>
|
||||
static typename trait::OptionTrait<EleType, OptionType>::worker_table_type*
|
||||
InnerCreateTable(const std::string& type,
|
||||
const OptionType& table_args) {
|
||||
CHECK(table_creaters_.find(type) != table_creaters_.end());
|
||||
|
||||
if (MV_ServerId() >= 0) {
|
||||
table_creaters_[type].second((void*)&table_args);
|
||||
}
|
||||
if (MV_WorkerId() >= 0) {
|
||||
return reinterpret_cast<trait::OptionTrait<EleType, OptionType>::worker_table_type*>
|
||||
(table_creaters_[type].first((void*)&table_args));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
TableFactory() = default;
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
|
@ -50,119 +58,4 @@ struct TableRegister {
|
|||
|
||||
} // namespace multiverso
|
||||
|
||||
#endif // MULTIVERSO_TABLE_FACTORY_H_
|
||||
|
||||
|
||||
//#ifndef MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
//#define MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
//
|
||||
//#include <functional>
|
||||
//#include <string>
|
||||
//#include <map>
|
||||
//
|
||||
//#include "multiverso/util/log.h"
|
||||
//#include "multiverso/table/array_table.h"
|
||||
//
|
||||
//namespace multiverso {
|
||||
//
|
||||
//class WorkerTable;
|
||||
//
|
||||
//class TableFactory {
|
||||
//public:
|
||||
// virtual WorkerTable* CreateTable(const std::string& args) = 0;
|
||||
// static TableFactory* GetFactory(const std::string& type);
|
||||
//};
|
||||
//
|
||||
//
|
||||
//void MV_CreateTable(const std::string& type, const std::string& args, void** out) {
|
||||
// TableFactory* tf = TableFactory::GetFactory(type);
|
||||
// *out = static_cast<void*>(tf->CreateTable(args));
|
||||
//}
|
||||
//
|
||||
//class ArrayTableFactory : public TableFactory {
|
||||
//public:
|
||||
// WorkerTable* CreateTable(const std::string& args) override {
|
||||
// // new ArrayServer()
|
||||
// }
|
||||
//}
|
||||
//// TODO(feiga): Refine
|
||||
//
|
||||
//// TODO(feiga): provide better table creator method
|
||||
//// Abstract Factory to create server and worker
|
||||
////class TableFactory {
|
||||
//// // static TableFactory* GetTableFactory();
|
||||
//// virtual WorkerTable* CreateWorker() = 0;
|
||||
//// virtual ServerTable* CreateServer() = 0;
|
||||
//// static TableFactory* fatory_;
|
||||
////};
|
||||
//
|
||||
////namespace table {
|
||||
//
|
||||
////}
|
||||
//
|
||||
////class TableBuilder {
|
||||
////public:
|
||||
//// TableBuilder& SetArribute(const std::string& name, const std::string& val);
|
||||
//// WorkerTable* WorkerTableBuild();
|
||||
//// ServerTable* ServerTableBuild();
|
||||
////private:
|
||||
//// std::string Get(const std::string& name) const;
|
||||
//// std::unordered_map<std::string, std::string> params_;
|
||||
////};
|
||||
//
|
||||
////class Context {
|
||||
////public:
|
||||
//// Context& SetArribute(const std::string& name, const std::string& val);
|
||||
////
|
||||
//// int get_int(const std::string& name) {
|
||||
//// CHECK(params_.find(name) != params_.end());
|
||||
//// return atoi(params_[name].c_str());
|
||||
//// }
|
||||
////
|
||||
////private:
|
||||
//// std::string get(const std::string& name) const;
|
||||
//// std::map<std::string, std::string> params_;
|
||||
////};
|
||||
////
|
||||
////class WorkerTable;
|
||||
////
|
||||
////class TableRegistry {
|
||||
////public:
|
||||
//// typedef WorkerTable* (*Creator)(const Context&);
|
||||
//// typedef std::map<std::string, Creator> Registry;
|
||||
//// static TableRegistry* Global();
|
||||
////
|
||||
//// static void AddCreator(const std::string& type, Creator creator) {
|
||||
//// Registry& r = registry();
|
||||
//// r[type] = creator;
|
||||
//// }
|
||||
////
|
||||
//// static Registry& registry() {
|
||||
//// static Registry instance;
|
||||
//// return instance;
|
||||
//// }
|
||||
////
|
||||
//// static WorkerTable* CreateTable(const std::string& type, const Context& context) {
|
||||
//// Registry& r = registry();
|
||||
//// return r[type](context);
|
||||
////
|
||||
//// }
|
||||
////
|
||||
////private:
|
||||
//// TableRegistry() {}
|
||||
////};
|
||||
////
|
||||
////class TableRegisterer {
|
||||
////public:
|
||||
//// TableRegisterer(const std::string& type,
|
||||
//// WorkerTable* (*creator)(const Context&)) {
|
||||
//// TableRegistry::AddCreator(type, creator);
|
||||
//// }
|
||||
////};
|
||||
////
|
||||
////#define REGISTER_TABLE_CREATOR(type, creator) \
|
||||
//// static TableRegisterer(type, creator) g_creator_##type(#type, creator);
|
||||
////
|
||||
//}
|
||||
//
|
||||
//#endif // MULTIVERSO_INCLUDE_TABLE_FACTORY_H_
|
||||
#endif // MULTIVERSO_TABLE_FACTORY_H_
|
|
@ -70,17 +70,11 @@ public:
|
|||
virtual void ProcessGet(const std::vector<Blob>& data,
|
||||
std::vector<Blob>* result) = 0;
|
||||
};
|
||||
|
||||
class TableHelper {
|
||||
public:
|
||||
TableHelper() {}
|
||||
WorkerTable* CreateTable();
|
||||
virtual ~TableHelper() {}
|
||||
|
||||
protected:
|
||||
virtual WorkerTable* CreateWorkerTable() = 0;
|
||||
virtual ServerTable* CreateServerTable() = 0;
|
||||
};
|
||||
namespace trait {
|
||||
template<typename EleType, typename OptionType>
|
||||
struct OptionTrait;
|
||||
}
|
||||
|
||||
} // namespace multiverso
|
||||
|
||||
|
|
|
@ -106,14 +106,4 @@ void WorkerTable::Notify(int id) {
|
|||
m_.unlock();
|
||||
}
|
||||
|
||||
WorkerTable* TableHelper::CreateTable() {
|
||||
if (MV_ServerId() >= 0) {
|
||||
CreateServerTable();
|
||||
}
|
||||
if (MV_WorkerId() >= 0) {
|
||||
return CreateWorkerTable();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace multiverso
|
||||
|
|
|
@ -9,50 +9,6 @@ std::unordered_map<std::string,
|
|||
std::pair<worker_table_creater_t, server_table_creater_t> >
|
||||
TableFactory::table_creaters_;
|
||||
|
||||
std::string ele_type_str(EleType ele_type) {
|
||||
switch (ele_type) {
|
||||
case kInt:
|
||||
return "int_";
|
||||
case kFloat:
|
||||
return "float_";
|
||||
case kDouble:
|
||||
return "double_";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
CHECK(table_creaters_.find(type) != table_creaters_.end());
|
||||
|
||||
if (MV_ServerId() >= 0) {
|
||||
table_creaters_[type].second(table_args);
|
||||
}
|
||||
if (MV_WorkerId() >= 0) {
|
||||
return table_creaters_[type].first(table_args);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
EleType ele_type,
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
std::string typestr = ele_type_str(ele_type) + type;
|
||||
return CreateTable(typestr, table_args);
|
||||
}
|
||||
|
||||
WorkerTable* TableFactory::CreateTable(
|
||||
EleType ele_type1,
|
||||
EleType ele_type2,
|
||||
std::string& type,
|
||||
void**table_args) {
|
||||
std::string typestr = ele_type_str(ele_type1)
|
||||
+ ele_type_str(ele_type2) + type;
|
||||
return CreateTable(typestr, table_args);
|
||||
}
|
||||
|
||||
void TableFactory::RegisterTable(
|
||||
std::string& type,
|
||||
worker_table_creater_t wt,
|
||||
|
@ -80,23 +36,26 @@ void TableFactory::RegisterTable(
|
|||
server_table_creater<double>);
|
||||
|
||||
template<typename T>
|
||||
WorkerTable* create_array_worker(void **args) {
|
||||
return new ArrayWorker<T>(*(size_t*)(*args));
|
||||
WorkerTable* create_array_worker(void *args) {
|
||||
return new ArrayWorker<T>(((ArrayTableInitOption*)args)->size);
|
||||
}
|
||||
template<typename T>
|
||||
ServerTable* create_array_server(void **args) {
|
||||
return new ArrayServer<T>(*(size_t*)(*args));
|
||||
ServerTable* create_array_server(void *args) {
|
||||
return new ArrayServer<T>(((ArrayTableInitOption*)args)->size);
|
||||
}
|
||||
MV_REGISTER_TABLE_WITH_BASIC_TYPE(array, create_array_worker, create_array_server);
|
||||
|
||||
template<typename T>
|
||||
WorkerTable* create_matrix_worker(void **args) {
|
||||
return new MatrixWorkerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args+1)));
|
||||
namespace trait {
|
||||
#define DECLARE_OPTION_TRAIT_TYPE(eletype, optiontype, str) \
|
||||
template<> \
|
||||
std::string OptionTrait<eletype, optiontype>::type = #str;
|
||||
|
||||
#define DECLARE_OPTION_TRAIT_WITH_BASIC_TYPE(optiontype, str) \
|
||||
DECLARE_OPTION_TRAIT_TYPE(int, optiontype, str) \
|
||||
DECLARE_OPTION_TRAIT_TYPE(float, optiontype, str) \
|
||||
DECLARE_OPTION_TRAIT_TYPE(double, optiontype, str) \
|
||||
|
||||
DECLARE_OPTION_TRAIT_WITH_BASIC_TYPE(ArrayTableInitOption, array);
|
||||
}
|
||||
template<typename T>
|
||||
ServerTable* create_matrix_server(void **args) {
|
||||
return new MatrixServerTable<T>(*(integer_t*)(*args), *(integer_t*)(*(args + 1)));
|
||||
}
|
||||
MV_REGISTER_TABLE_WITH_BASIC_TYPE(matrix, create_matrix_worker, create_matrix_server);
|
||||
|
||||
} // namespace multiverso
|
Загрузка…
Ссылка в новой задаче