Multiverso/Test/main.cpp

519 строки
17 KiB
C++
Исходник Обычный вид История

#include <iostream>
#include <thread>
2016-02-23 09:26:05 +03:00
#include <random>
#include <chrono>
2016-03-15 01:14:20 +03:00
#include <ctime>
2016-04-20 16:52:49 +03:00
#include <algorithm>
2016-04-21 11:07:58 +03:00
#include <numeric>
2016-05-24 14:45:29 +03:00
#include <memory>
2016-08-01 09:29:12 +03:00
#include <cassert>
2016-02-02 15:40:07 +03:00
2016-02-29 09:27:59 +03:00
#include <mpi.h>
#include <multiverso/multiverso.h>
2016-02-19 05:46:07 +03:00
#include <multiverso/net.h>
2016-02-04 06:03:53 +03:00
#include <multiverso/util/log.h>
2016-02-19 05:46:07 +03:00
#include <multiverso/util/net_util.h>
2016-04-17 14:56:58 +03:00
#include <multiverso/util/configure.h>
2016-04-19 12:25:20 +03:00
#include <multiverso/util/timer.h>
2016-04-19 18:30:18 +03:00
#include <multiverso/dashboard.h>
2016-02-02 15:40:07 +03:00
#include <multiverso/table/array_table.h>
#include <multiverso/table/kv_table.h>
2016-03-30 08:40:20 +03:00
#include <multiverso/table/matrix_table.h>
2016-05-27 10:50:31 +03:00
#include <multiverso/table/matrix.h>
2016-04-19 12:25:20 +03:00
#include <multiverso/table/sparse_matrix_table.h>
2016-04-07 15:05:32 +03:00
#include <multiverso/updater/updater.h>
2016-05-08 19:09:19 +03:00
#include <multiverso/table_factory.h>
2016-03-03 09:11:04 +03:00
2016-02-02 15:40:07 +03:00
using namespace multiverso;
void TestKV(int argc, char* argv[]) {
2016-02-02 15:40:07 +03:00
Log::Info("Test KV map \n");
// ----------------------------------------------------------------------- //
// this is a demo of distributed hash table to show how to use the multiverso
// ----------------------------------------------------------------------- //
// 1. Start the Multiverso engine ---------------------------------------- //
2016-03-01 06:34:23 +03:00
MV_Init(&argc, argv);
2016-02-02 15:40:07 +03:00
// 2. To create the shared table ----------------------------------------- //
// TODO(feiga): This table must be create at both worker and server endpoint
// simultaneously, since they should share same type and same created order
// (same order means same table id). So it's better to create them with some
// specific creator, instead of current way
// TODO(feiga): should add some if statesment
// if the node is worker, then create a worker cache table
KVWorkerTable<int, int>* dht = new KVWorkerTable<int, int>();
// if the node is server, then create a server storage table
KVServerTable<int, int>* server_dht = new KVServerTable<int, int>();
2016-03-01 06:34:23 +03:00
MV_Barrier();
2016-02-02 15:40:07 +03:00
// 3. User program ------------------------------------------------------- //
// all this interface is related with the KVWorkerTable
// the data structure and the interface can be both defined by users
// We also provides several common implementations
// For specific program, user-defined table may provides better performance
// access the local cache
std::unordered_map<int, int>& kv = dht->raw();
// The Get/Add are sync operation, when the function call returns, we already
// Get from server, or the server has Added the update
// Get from the server
dht->Get(0);
// Check the result. Since no one added, this should be 0
2016-03-04 09:18:15 +03:00
Log::Info("Get 0 from kv server: result = %d\n", kv[0]);
2016-02-02 15:40:07 +03:00
// Add 1 to the server
2016-03-04 09:18:15 +03:00
dht->Add(0, 1);
2016-02-02 15:40:07 +03:00
// Check the result. Since just added one, this should be 1
dht->Get(0);
Log::Info("Get 0 from kv server after add 1: result = %d\n", kv[0]);
// 4. Shutdown the Multiverso engine. ------------------------------------ //
2016-03-01 06:34:23 +03:00
MV_ShutDown();
2016-02-02 15:40:07 +03:00
}
void TestArray(int argc, char* argv[]) {
2016-02-02 15:40:07 +03:00
Log::Info("Test Array \n");
2016-05-24 14:45:29 +03:00
multiverso::SetCMDFlag("sync", true);
2016-03-01 06:34:23 +03:00
MV_Init(&argc, argv);
size_t array_size = 500;
2016-05-08 19:26:26 +03:00
2016-05-24 14:45:29 +03:00
ArrayWorker<int>* shared_array = MV_CreateTable(ArrayTableOption<int>(array_size));
2016-02-02 15:40:07 +03:00
2016-03-01 06:34:23 +03:00
MV_Barrier();
2016-05-24 14:45:29 +03:00
Log::Info("Create tables OK. Rank = %d, worker_id = %d\n", MV_Rank(), MV_WorkerId());
2016-02-02 15:40:07 +03:00
2016-05-24 14:45:29 +03:00
std::vector<int> delta(array_size);
2016-05-08 19:26:26 +03:00
for (int i = 0; i < array_size; ++i)
2016-05-24 14:45:29 +03:00
delta[i] = static_cast<int>(i);
int* data = new int[array_size];
int iter = 10 * (MV_Rank() + 10);
2016-02-23 09:15:22 +03:00
for (int i = 0; i < iter; ++i) {
2016-05-24 14:45:29 +03:00
shared_array->Add(delta.data(), array_size);
2016-06-16 10:09:02 +03:00
shared_array->Add(delta.data(), array_size);
shared_array->Add(delta.data(), array_size);
shared_array->Get(data, array_size);
shared_array->Get(data, array_size);
2016-05-08 19:26:26 +03:00
shared_array->Get(data, array_size);
2016-05-24 14:45:29 +03:00
for (int k = 0; k < array_size; ++k) {
if (data[k] != delta[k] * (i + 1) * MV_NumWorkers()) {
// std::cout << "i + 1 = " << i + 1 << " k = " << k << std::endl;
// for (int j = 0; j < array_size; ++j) {
// std::cout << data[j] << " ";
// }
// exit(1);
2016-05-24 14:45:29 +03:00
}
}
{ printf("iter = %d\n", i); fflush(stdout); }
}
2016-03-01 06:34:23 +03:00
MV_ShutDown();
2016-02-02 15:40:07 +03:00
}
void TestNet(int argc, char* argv[]) {
2016-02-02 15:40:07 +03:00
NetInterface* net = NetInterface::Get();
net->Init(&argc, argv);
2016-02-02 15:40:07 +03:00
2016-03-02 11:02:03 +03:00
const char* chi1 = std::string("hello, world").c_str();
const char* chi2 = std::string("hello, c++").c_str();
const char* chi3 = std::string("hello, multiverso").c_str();
char* hi1 = new char[14];
strcpy(hi1, chi1);
char* hi2 = new char[12];
strcpy(hi2, chi2);
char* hi3 = new char[19];
strcpy(hi3, chi3);
2016-02-02 15:40:07 +03:00
if (net->rank() == 0) {
2016-03-01 17:51:27 +03:00
for (int rank = 1; rank < net->size(); ++rank) {
MessagePtr msg(new Message());// = std::make_unique<Message>();
msg->set_src(0);
msg->set_dst(rank);
msg->Push(Blob(hi1, 13));
msg->Push(Blob(hi2, 11));
msg->Push(Blob(hi3, 18));
2016-03-08 22:31:22 +03:00
for (int i = 0; i < msg->size(); ++i) {
Log::Info("In Send: %s\n", msg->data()[i].data());
};
2016-04-17 14:56:58 +03:00
while (net->Send(msg) == 0);
Log::Info("rank 0 send\n");
2016-03-01 17:51:27 +03:00
}
2016-02-23 09:15:22 +03:00
2016-03-01 17:51:27 +03:00
for (int i = 1; i < net->size(); ++i) {
MessagePtr msg(new Message());
msg.reset(new Message());
while (net->Recv(&msg) == 0) {
// Log::Info("recv return 0\n");
}
Log::Info("rank 0 recv\n");
// CHECK(strcmp(msg->data()[0].data(), hi) == 0);
std::vector<Blob> recv_data = msg->data();
CHECK(recv_data.size() == 3);
for (int i = 0; i < msg->size(); ++i) {
Log::Info("recv from srv %d: %s\n", msg->src(), recv_data[i].data());
};
2016-02-23 09:15:22 +03:00
}
2016-04-21 14:52:35 +03:00
} else {// other rank
2016-02-26 09:36:38 +03:00
MessagePtr msg(new Message());// = std::make_unique<Message>();
2016-02-23 09:15:22 +03:00
while (net->Recv(&msg) == 0) {
2016-03-01 17:51:27 +03:00
// Log::Info("recv return 0\n");
2016-02-23 09:15:22 +03:00
}
2016-03-01 17:51:27 +03:00
Log::Info("rank %d recv\n", net->rank());
std::vector<Blob>& recv_data = msg->data();
CHECK(recv_data.size() == 3);
for (int i = 0; i < msg->size(); ++i) {
Log::Info("%s\n", recv_data[i].data());
}
2016-02-23 09:15:22 +03:00
msg.reset(new Message());
2016-03-01 17:51:27 +03:00
msg->set_src(net->rank());
2016-02-23 09:15:22 +03:00
msg->set_dst(0);
msg->Push(Blob(hi1, 13));
msg->Push(Blob(hi2, 11));
msg->Push(Blob(hi3, 18));
2016-04-17 14:56:58 +03:00
while (net->Send(msg) == 0);
2016-03-01 17:51:27 +03:00
Log::Info("rank %d send\n", net->rank());
2016-02-02 15:40:07 +03:00
}
2016-03-01 17:51:27 +03:00
// while (!net->Test()) {
// // wait all message process finished
// }
2016-02-02 15:40:07 +03:00
net->Finalize();
}
2016-02-19 05:46:07 +03:00
void TestIP() {
std::unordered_set<std::string> ip_list;
2016-02-26 09:36:38 +03:00
// net::GetLocalIPAddress(&ip_list);
for (auto ip : ip_list) Log::Info("%s\n", ip.c_str());
2016-02-19 05:46:07 +03:00
}
2016-03-03 11:30:09 +03:00
2016-05-24 10:38:58 +03:00
void TestMatrix(int argc, char* argv[]){
2016-05-27 13:44:43 +03:00
//Log::ResetLogLevel(LogLevel::Debug);
2016-05-11 10:55:19 +03:00
multiverso::SetCMDFlag("sync", true);
2016-04-17 14:56:58 +03:00
MV_Init(&argc, argv);
2016-03-03 11:30:09 +03:00
2016-05-27 13:44:43 +03:00
int num_row = 11, num_col = 3592;
int num_tables = 2;
2016-05-24 10:38:58 +03:00
std::vector<int> num_table_size;
2016-05-27 10:50:31 +03:00
std::vector<MatrixOption<int>* > table_options;
std::vector<MatrixWorker<int>* > worker_tables;
2016-05-24 10:38:58 +03:00
2016-05-27 13:44:43 +03:00
for (auto i = 0; i < num_tables-1; i++)
2016-05-24 10:38:58 +03:00
{
2016-05-27 10:50:31 +03:00
table_options.push_back(new MatrixOption<int>());
2016-05-24 10:38:58 +03:00
table_options[i]->num_col = num_col;
table_options[i]->num_row = num_row + i;
2016-05-27 10:50:31 +03:00
table_options[i]->is_sparse = true;
2016-05-24 10:38:58 +03:00
num_table_size.push_back(num_col * (num_row + i));
worker_tables.push_back(multiverso::MV_CreateTable(*table_options[i]));
}
2016-05-27 10:50:31 +03:00
table_options.push_back(new MatrixOption<int>());
2016-05-27 13:44:43 +03:00
table_options[num_tables - 1]->num_col = num_col;
table_options[num_tables - 1]->num_row = 1;
2016-05-24 10:38:58 +03:00
num_table_size.push_back(num_col * (1));
2016-05-27 13:44:43 +03:00
worker_tables.push_back(multiverso::MV_CreateTable(*table_options[num_tables-1]));
2016-05-24 10:38:58 +03:00
2016-04-17 14:56:58 +03:00
std::thread* m_prefetchThread = nullptr;
MV_Barrier();
2016-04-27 16:33:27 +03:00
int count = 0;
2016-04-17 14:56:58 +03:00
while (true)
{
2016-04-27 16:33:27 +03:00
count++;
std::vector<int> v = { 0, 1, 3, 7 };
2016-04-17 14:56:58 +03:00
// test data
2016-05-24 10:38:58 +03:00
std::vector<std::vector<int>> delta(num_tables);
std::vector<std::vector<int>> data(num_tables);
for (auto j =0; j < num_tables; j++)
{
delta[j].resize(num_table_size[j]);
data[j].resize(num_table_size[j], 0);
for (auto i = 0; i < num_table_size[j]; ++i)
2016-05-27 13:44:43 +03:00
delta[j][i] = (int)i + 1;
2016-05-24 10:38:58 +03:00
}
2016-04-17 14:56:58 +03:00
2016-05-24 10:38:58 +03:00
for (auto j = 0; j < num_tables; j++)
{
worker_tables[j]->Add(delta[j].data(), num_table_size[j]);
worker_tables[j]->Get(data[j].data(), num_table_size[j]);
}
2016-05-11 10:55:19 +03:00
if (count % 1000 == 0)
{
printf("Dense Add/Get, #test: %d.\n", count);
fflush(stdout);
2016-04-17 14:56:58 +03:00
}
2016-05-11 10:55:19 +03:00
2016-05-24 10:38:58 +03:00
std::vector<int*> data_rows = { &data[0][0], &data[0][num_col], &data[0][3 * num_col], &data[0][7 * num_col] };
std::vector<int*> delta_rows = { &delta[0][0], &delta[0][num_col], &delta[0][3 * num_col], &delta[0][7 * num_col] };
for (auto j = 0; j < num_tables - 1; j++)
{
worker_tables[j]->Add(v, delta_rows, num_col);
worker_tables[j]->Get(v, data_rows, num_col);
}
//MV_Barrier();
//worker_table->Get(v, data_rows, num_col);
2016-04-17 14:56:58 +03:00
2016-05-11 10:55:19 +03:00
if (count % 1000 == 0)
{
printf("Sparse Add/Get, #test: %d.\n", count);
fflush(stdout);
}
for (auto i = 0; i < num_row; ++i) {
for (auto j = 0; j < num_col; ++j) {
2016-05-27 13:44:43 +03:00
int expected = (int)(i * num_col + j + 1) * count * MV_NumWorkers();
if (i == 0 || i == 1 || i == 3 || i == 7) {
2016-05-27 13:44:43 +03:00
expected += (int)(i * num_col + j + 1) * count * MV_NumWorkers();
2016-04-27 16:33:27 +03:00
}
2016-05-24 10:38:58 +03:00
int actual = data[0][i* num_col + j];
2016-08-01 09:29:12 +03:00
assert(expected == actual); // << "Should be equal after adding, row: "
// << i << ", col:" << j << ", expected: " << expected << ", actual: " << actual;
2016-04-27 16:33:27 +03:00
}
2016-04-17 14:56:58 +03:00
}
}
2016-05-24 10:38:58 +03:00
worker_tables.clear();
2016-04-17 14:56:58 +03:00
MV_ShutDown();
2016-03-03 11:30:09 +03:00
}
void TestCheckPoint(int argc, char* argv[], bool restore){
Log::Info("Test CheckPoint\n");
2016-04-17 14:56:58 +03:00
MV_Init(&argc, argv);
int num_row = 11, num_col = 10;
int size = num_row * num_col;
2016-05-08 19:09:19 +03:00
MatrixWorkerTable<int>* worker_table = new MatrixWorkerTable<int>(num_row, num_col);
MatrixServerTable<int>* server_table = new MatrixServerTable<int>(num_row, num_col);
2016-05-03 14:33:57 +03:00
//if restore = true, will restore server data and return the next iter number of last dump file
//else do nothing and return 0
// int begin_iter = MV_LoadTable("serverTable_");
MV_Barrier();//won't dump data without parameters
std::vector<int> delta(size);
for (int i = 0; i < size; ++i)
delta[i] = i;
int * data = new int[size];
// Log::Debug("rank %d start from iteration %d\n", MV_Rank(), begin_iter);
for (int i = 0 /*begin_iter*/; i < 50; ++i){
worker_table->Add(delta.data(), size);
MV_Barrier(); //dump table data with iteration i each k iterations
}
worker_table->Get(data, size);
printf("----------------------------\n");
for (int i = 0; i < num_row; ++i){
printf("rank %d, row %d: ", MV_Rank(), i);
for (int j = 0; j < num_col; ++j)
printf("%d ", data[i * num_col + j]);
printf("\n");
}
MV_ShutDown();
}
2016-04-17 14:56:58 +03:00
void TestAllreduce(int argc, char* argv[]) {
2016-04-17 15:34:23 +03:00
multiverso::SetCMDFlag("ma", true);
2016-04-17 14:56:58 +03:00
MV_Init(&argc, argv);
int a = 1;
MV_Aggregate(&a, 1);
std::cout << "a = " << a << std::endl;
MV_ShutDown();
2016-02-29 09:27:59 +03:00
}
2016-04-21 06:30:47 +03:00
template<typename WT, typename ST>
void TestmatrixPerformance(int argc, char* argv[],
2016-04-21 06:30:47 +03:00
std::function<std::shared_ptr<WT>(int num_row, int num_col)>CreateWorkerTable,
std::function<std::shared_ptr<ST>(int num_row, int num_col)>CreateServerTable,
2016-04-26 12:40:45 +03:00
std::function<void(const std::shared_ptr<WT>& worker_table, const std::vector<int>& row_ids, const std::vector<float*>& data_vec, size_t size, const AddOption* option, int worker_id)> Add,
2016-04-21 07:57:49 +03:00
std::function<void(const std::shared_ptr<WT>& worker_table, float* data, size_t size, int worker_id)> Get) {
2016-04-21 06:30:47 +03:00
Log::ResetLogLevel(LogLevel::Info);
Log::Info("Test Matrix\n");
2016-04-19 12:25:20 +03:00
Timer timmer;
2016-05-24 10:38:58 +03:00
//multiverso::SetCMDFlag("sync", true);
2016-05-11 10:55:19 +03:00
MV_Init(&argc, argv);
2016-04-22 08:19:40 +03:00
int per = 0;
2016-04-20 16:52:49 +03:00
int num_row = 1000000, num_col = 50;
if (argc == 3){
num_row = atoi(argv[2]);
}
2016-04-22 08:19:40 +03:00
2016-04-19 12:25:20 +03:00
int size = num_row * num_col;
int worker_id = MV_Rank();
2016-04-21 14:52:35 +03:00
int worker_num = MV_Size();
2016-04-19 12:25:20 +03:00
// test data
2016-04-20 16:52:49 +03:00
float* data = new float[size];
float* delta = new float[size];
2016-04-19 12:25:20 +03:00
int* keys = new int[num_row];
2016-04-20 16:45:58 +03:00
for (auto row = 0; row < num_row; ++row) {
for (auto col = 0; col < num_col; ++col) {
2016-04-21 06:30:47 +03:00
delta[row * num_col + col] = row * num_col + col;
2016-04-20 16:45:58 +03:00
}
2016-04-19 12:25:20 +03:00
}
2016-04-26 12:40:45 +03:00
AddOption option;
2016-04-19 12:25:20 +03:00
option.set_worker_id(worker_id);
2016-04-22 09:13:14 +03:00
//std::mt19937_64 eng{ std::random_device{}() };
//std::vector<int> unique_index;
//for (int i = 0; i < num_row; i++){
// unique_index.push_back(i);
//}
2016-04-22 09:00:36 +03:00
for (auto percent = 0; percent < 10; ++percent)
2016-04-22 09:13:14 +03:00
for (auto turn = 0; turn < 10; ++turn)
2016-04-22 09:00:36 +03:00
{
2016-04-22 09:13:14 +03:00
//std::shuffle(unique_index.begin(), unique_index.end(), eng);
2016-04-22 09:00:36 +03:00
if (worker_id == 0) {
std::cout << "\nTesting: Get All Rows => Add "
2016-04-22 09:13:14 +03:00
<< percent + 1 << "0% Rows to Server => Get All Rows" << std::endl;
2016-04-22 09:00:36 +03:00
}
2016-04-21 08:54:59 +03:00
2016-04-22 09:00:36 +03:00
auto worker_table = CreateWorkerTable(num_row, num_col);
auto server_table = CreateServerTable(num_row, num_col);
MV_Barrier();
2016-04-19 12:25:20 +03:00
2016-04-22 09:00:36 +03:00
timmer.Start();
Get(worker_table, data, size, worker_id);
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows first time, worker id: " << worker_id << std::endl;
MV_Barrier();
2016-04-21 15:30:34 +03:00
2016-04-22 09:00:36 +03:00
std::vector<int> row_ids;
std::vector<float*> data_vec;
for (auto i = 0; i < num_row; ++i) {
if (i % 10 <= percent && i % worker_num == worker_id) {
row_ids.push_back(i);
data_vec.push_back(delta + i * num_col);
}
}
//for (auto i = 0; i < (percent + 1) * num_row / 10; i++)
//{
// row_ids.push_back(unique_index[i]);
// data_vec.push_back(delta + unique_index[i] * num_col);
//}
if (worker_id == 0) {
std::cout << "adding " << percent + 1 << " /10 rows to matrix server" << std::endl;
2016-04-21 14:52:35 +03:00
}
2016-04-21 15:30:34 +03:00
2016-04-22 09:00:36 +03:00
if (row_ids.size() > 0) {
Add(worker_table, row_ids, data_vec, num_col, &option, worker_id);
}
Get(worker_table, data, size, -1);
MV_Barrier();
timmer.Start();
Get(worker_table, data, size, worker_id);
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows after adding to rows, worker id: " << worker_id << std::endl;
for (auto i = 0; i < num_row; ++i) {
auto row_start = data + i * num_col;
for (auto col = 0; col < num_col; ++col) {
float expected = (float) i * num_col + col;
float actual = *(row_start + col);
2016-04-22 09:00:36 +03:00
if (i % 10 <= percent) {
2016-08-01 09:57:15 +03:00
assert(expected == actual); // << "Should be updated after adding, worker_id:"
2016-08-01 09:29:12 +03:00
//<< worker_id << ",row: " << i << ",col:" << col << ",expected: " << expected << ",actual: " << actual;
2016-04-22 09:00:36 +03:00
}
else {
2016-08-01 09:57:15 +03:00
assert(0 == *(row_start + col)); // << "Should be 0 for non update row values, worker_id:"
2016-08-01 09:29:12 +03:00
// << worker_id << ",row: " << i << ",col:" << col << ",expected: " << expected << ",actual: " << actual;
2016-04-22 09:00:36 +03:00
}
2016-04-21 14:52:35 +03:00
}
}
2016-05-27 10:50:31 +03:00
//MV_Barrier();
2016-04-22 09:00:36 +03:00
}
2016-04-21 14:52:35 +03:00
2016-04-19 12:25:20 +03:00
MV_Barrier();
Log::ResetLogLevel(LogLevel::Info);
2016-04-21 06:30:47 +03:00
Dashboard::Display();
Log::ResetLogLevel(LogLevel::Error);
2016-04-19 12:25:20 +03:00
MV_ShutDown();
}
2016-04-21 06:30:47 +03:00
void TestSparsePerf(int argc, char* argv[]) {
2016-05-27 10:50:31 +03:00
TestmatrixPerformance<MatrixWorker<float>, MatrixServer<float>>(argc,
2016-04-21 06:30:47 +03:00
argv,
[](int num_row, int num_col) {
2016-05-27 10:50:31 +03:00
return std::shared_ptr<MatrixWorker<float>>(
new MatrixWorker<float>(num_row, num_col, true));
2016-04-22 09:00:36 +03:00
},
[](int num_row, int num_col) {
2016-05-27 10:50:31 +03:00
return std::shared_ptr<MatrixServer<float>>(
new MatrixServer<float>(num_row, num_col, true, false));
2016-04-22 09:00:36 +03:00
},
2016-05-27 10:50:31 +03:00
[](const std::shared_ptr<MatrixWorker<float>>& worker_table, const std::vector<int>& row_ids, const std::vector<float*>& data_vec, size_t size, const AddOption* option, const int worker_id) {
2016-04-22 09:00:36 +03:00
worker_table->Add(row_ids, data_vec, size, option);
},
2016-05-27 10:50:31 +03:00
[](const std::shared_ptr<MatrixWorker<float>>& worker_table, float* data, size_t size, int worker_id) {
2016-04-26 12:40:45 +03:00
GetOption get_option;
get_option.set_worker_id(worker_id);
worker_table->Get(data, size, &get_option);
2016-04-22 09:00:36 +03:00
});
2016-04-21 06:30:47 +03:00
}
void TestDensePerf(int argc, char* argv[]) {
2016-04-21 07:57:49 +03:00
TestmatrixPerformance<MatrixWorkerTable<float>, MatrixServerTable<float>>(argc,
2016-04-21 06:30:47 +03:00
argv,
[](int num_row, int num_col) {
2016-04-21 07:57:49 +03:00
return std::shared_ptr<MatrixWorkerTable<float>>(
new MatrixWorkerTable<float>(num_row, num_col));
2016-04-21 06:30:47 +03:00
},
[](int num_row, int num_col) {
2016-04-21 07:57:49 +03:00
return std::shared_ptr<MatrixServerTable<float>>(
new MatrixServerTable<float>(num_row, num_col));
2016-04-21 06:30:47 +03:00
},
2016-04-26 12:40:45 +03:00
[](const std::shared_ptr<MatrixWorkerTable<float>>& worker_table, const std::vector<int>& row_ids, const std::vector<float*>& data_vec, size_t size, const AddOption* option, const int worker_id) {
2016-04-21 06:30:47 +03:00
worker_table->Add(row_ids, data_vec, size, option);
},
2016-04-21 07:57:49 +03:00
[](const std::shared_ptr<MatrixWorkerTable<float>>& worker_table, float* data, size_t size, int worker_id) {
2016-04-21 06:30:47 +03:00
worker_table->Get(data, size);
});
}
2016-04-17 14:56:58 +03:00
2016-02-02 15:40:07 +03:00
int main(int argc, char* argv[]) {
Log::ResetLogLevel(LogLevel::Info);
2016-03-30 08:40:20 +03:00
if (argc == 1){
2016-04-17 14:56:58 +03:00
multiverso::MV_Init();
multiverso::MV_ShutDown();
2016-08-01 09:29:12 +03:00
return 0;
2016-04-21 14:52:35 +03:00
} else {
if (strcmp(argv[1], "kv") == 0) TestKV(argc, argv);
else if (strcmp(argv[1], "array") == 0) TestArray(argc, argv);
else if (strcmp(argv[1], "net") == 0) TestNet(argc, argv);
2016-02-19 05:46:07 +03:00
else if (strcmp(argv[1], "ip") == 0) TestIP();
else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
else if (strcmp(argv[1], "checkpoint") == 0) TestCheckPoint(argc, argv, false);
else if (strcmp(argv[1], "restore") == 0) TestCheckPoint(argc, argv, true);
2016-04-17 14:56:58 +03:00
else if (strcmp(argv[1], "allreduce") == 0) TestAllreduce(argc, argv);
2016-04-21 06:30:47 +03:00
else if (strcmp(argv[1], "TestSparsePerf") == 0) TestSparsePerf(argc, argv);
else if (strcmp(argv[1], "TestDensePerf") == 0) TestDensePerf(argc, argv);
2016-02-02 15:40:07 +03:00
else CHECK(false);
}
2016-02-02 15:40:07 +03:00
return 0;
2016-02-26 09:36:38 +03:00
}