diff --git a/next/IMultiverso.vcxproj b/next/IMultiverso.vcxproj
index ce648cc..18bdadd 100644
--- a/next/IMultiverso.vcxproj
+++ b/next/IMultiverso.vcxproj
@@ -181,6 +181,7 @@
+
diff --git a/next/Test/main.cpp b/next/Test/main.cpp
index a0247aa..c25aea2 100644
--- a/next/Test/main.cpp
+++ b/next/Test/main.cpp
@@ -13,9 +13,46 @@
#include
#include
#include
+#include
#include
+
using namespace multiverso;
+void TestMatrix(int argc, char* argv[]){
+ Log::Info("Test Matrix\n");
+ MV_Init(&argc, argv);
+ MatrixWorkerTable* worker_table = new MatrixWorkerTable(10, 10);
+ MatrixServerTable* server_table = new MatrixServerTable(10, 10);
+ MV_Barrier();
+
+ std::vector delta(10 * 10);
+
+ for (int i = 0; i < 100; ++i)
+ delta[i] = 0;
+
+ int * data = new int[10 * 10];
+
+ for (int i = 0; i < 100; ++i)
+ delta[i] = 1.0;
+
+ std::vector v = { 0, 1, 5 };
+
+ worker_table->Add(v, delta.data());
+ worker_table->Add(-1, delta.data());
+
+ worker_table->Get(-1, data);
+
+ printf("-----------rank %d begin-----------\n", Zoo::Get()->rank());
+ for (int i = 0; i < 10; ++i){
+ for (int j = 0; j < 10; ++j)
+ printf("%d ", data[i*10+j]);
+ printf("\n");
+ }
+ MV_Barrier();
+
+ MV_ShutDown();
+}
+
void TestKV(int argc, char* argv[]) {
Log::Info("Test KV map \n");
// ----------------------------------------------------------------------- //
@@ -56,10 +93,10 @@ void TestKV(int argc, char* argv[]) {
// Get from the server
dht->Get(0);
// Check the result. Since no one added, this should be 0
- Log::Info("Get 0 from kv server: result = %d\n", kv[0]);
+ Log::Info("Get 0 from kv server: result = %d,%d\n", kv[0]);
// Add 1 to the server
- dht->Add(0, 1);
+ dht->Add(0,1);
// Check the result. Since just added one, this should be 1
dht->Get(0);
@@ -329,6 +366,7 @@ int main(int argc, char* argv[]) {
else if (strcmp(argv[1], "ip") == 0) TestIP();
else if (strcmp(argv[1], "momentum") == 0) TestMomentum(argc, argv);
else if (strcmp(argv[1], "threads") == 0) TestMultipleThread(argc, argv);
+ else if (strcmp(argv[1], "matrix") == 0) TestMatrix(argc, argv);
else if (strcmp(argv[1], "nonet") == 0) TestNoNet(argc, argv);
else CHECK(false);
}
diff --git a/next/include/multiverso/table/kv_table.h b/next/include/multiverso/table/kv_table.h
index 334fd31..e974983 100644
--- a/next/include/multiverso/table/kv_table.h
+++ b/next/include/multiverso/table/kv_table.h
@@ -57,6 +57,7 @@ public:
(*out)[dst][0].As(counts[dst]) = keys.As(i);
if (kv.size() == 2)
(*out)[dst][1].As(counts[dst]) = kv[1].As(i);
+ ++counts[dst];
}
return static_cast(out->size());
}
diff --git a/next/include/multiverso/table/matrix_table.h b/next/include/multiverso/table/matrix_table.h
new file mode 100644
index 0000000..66ce3c8
--- /dev/null
+++ b/next/include/multiverso/table/matrix_table.h
@@ -0,0 +1,213 @@
+#ifndef MULTIVERSO_MATRIX_TABLE_H_
+#define MULTIVERSO_MATRIX_TABLE_H_
+
+#include "multiverso/table_interface.h"
+#include "multiverso/util/log.h"
+#include "multiverso/zoo.h"
+
+#include
+
+namespace multiverso {
+
+ template
+ class MatrixWorkerTable : public WorkerTable {
+ public:
+ explicit MatrixWorkerTable(int num_row, int num_col) : WorkerTable(), num_row_(num_row), num_col_(num_col) {
+ num_server_ = Zoo::Get()->num_servers();
+ server_offsets_.push_back(0);
+ int length = num_row / num_server_;
+ for (int i = 1; i < num_server_; ++i) {
+ server_offsets_.push_back(i * length); // may not balance
+ }
+ server_offsets_.push_back(num_row);
+
+ Log::Debug("worker %d create matrixTable with %d rows %d colums.\n", Zoo::Get()->rank(), num_row, num_col);
+ }
+
+ T* raw() { return data_; }
+
+ // data is user-allocated memory
+ void Get(int row_id, T* data){
+ data_ = data;
+ WorkerTable::Get(Blob(&row_id, sizeof(int)));
+ Log::Debug("worker %d getting row with id %d.\n", Zoo::Get()->rank(), row_id);
+ }
+
+ void Add(int row_id, T* data) {
+ if (row_id == -1){
+ WorkerTable::Add(Blob(&row_id, sizeof(int)), Blob(data, sizeof(T) * num_col_ * num_row_));
+ }
+ else{
+ WorkerTable::Add(Blob(&row_id, sizeof(int)), Blob(data, sizeof(T)* num_col_));
+ }
+ Log::Debug("worker %d adding row with id %d.\n", Zoo::Get()->rank(), row_id);
+ }
+
+ // data is user-allocated memory
+ void Get(std::vector row_ids, T* data) {
+ data_ = data;
+ WorkerTable::Get(Blob(&row_ids[0], sizeof(int)* row_ids.size()));
+ Log::Debug("worker %d getting rows\n", Zoo::Get()->rank());
+ }
+
+ // Add some rows
+ void Add(std::vector row_ids, T* data) {
+ Blob ids_blob(&row_ids[0], sizeof(int)* row_ids.size());
+ Blob data_blob(data, sizeof(T)* row_ids.size() * num_col_);
+ WorkerTable::Add(ids_blob, data_blob);
+ Log::Debug("worker %d adding rows\n", Zoo::Get()->rank());
+ }
+
+ int Partition(const std::vector& kv,
+ std::unordered_map >* out) override {
+ CHECK(kv.size() == 1 || kv.size() == 2);
+ CHECK_NOTNULL(out);
+
+ //get all elements
+ if (kv[0].size() == 1 && kv[0].As(0) == -1){
+ for (int i = 0; i < num_server_; ++i){
+ (*out)[i].push_back(kv[0]);
+ }
+ if (kv.size() == 2){
+ for (int i = 0; i < num_server_; ++i){
+ Blob blob(kv[1].data() + server_offsets_[i] * num_col_ * sizeof(T),
+ (server_offsets_[i + 1] - server_offsets_[i]) * num_col_ * sizeof(T));
+ (*out)[i].push_back(blob);
+ }
+ }
+ return static_cast(out->size());
+ }
+
+ Blob row_ids = kv[0];
+ std::unordered_map count;
+ for (int i = 0; i < row_ids.size(); ++i){
+ int dst = row_ids.As(i) / (num_row_ / num_server_);
+ dst = (dst == num_server_ ? dst - 1 : dst);
+ ++count[dst];
+ }
+ for (auto& it : count) { // Allocate memory
+ std::vector& vec = (*out)[it.first];
+ vec.push_back(Blob(it.second * sizeof(int)));
+ if (kv.size() == 2) vec.push_back(Blob(it.second * sizeof(T)* num_col_));
+ }
+ count.clear();
+ for (int i = 0; i < row_ids.size(); ++i) {
+ int dst = row_ids.As(i) / (num_row_ / num_server_);
+ dst = (dst == num_server_ ? dst - 1 : dst);
+ (*out)[dst][0].As(count[dst]) = row_ids.As(i);
+ if (kv.size() == 2){
+ memcpy(&((*out)[dst][1].As(count[dst] * num_col_)), &(kv[1].As(i * num_col_)), num_col_ * sizeof(T));
+ }
+ ++count[dst];
+ }
+ return static_cast(out->size());
+ }
+
+ void ProcessReplyGet(std::vector& reply_data) override {
+ CHECK(reply_data.size() == 2 || reply_data.size() == 3);
+ Blob keys = reply_data[0], data = reply_data[1];
+
+ if (keys.size() == 1 && keys.As(0) == -1) {
+ int row_offset = reply_data[2].As();
+ memcpy(data_ + row_offset * num_col_, data.data(), data.size());
+ return;
+ }
+
+ CHECK(data.size() == keys.size() * sizeof(T)* num_col_);
+ for (int i = 0; i < keys.size(); ++i) {
+ int row_id = keys.As(i);
+ memcpy(data_ + row_id * num_col_, data.data() + i * num_col_ * sizeof(T), num_col_ * sizeof(T));
+ }
+ }
+
+ private:
+ T* data_; // not owned
+ int num_row_;
+ int num_col_;
+ int num_server_;
+ std::vector server_offsets_;
+ };
+
+ // TODO(feiga): rename. The name static is inherited from last version
+ // The storage is a continuous large chunk of memory
+ template
+ class MatrixServerTable : public ServerTable {
+ public:
+ explicit MatrixServerTable(int num_row, int num_col) : ServerTable(), num_col_(num_col) {
+ server_id_ = Zoo::Get()->rank();
+
+ int size = num_row / Zoo::Get()->num_servers();
+ row_offset_ = size * Zoo::Get()->rank();
+ if (server_id_ == Zoo::Get()->num_servers() - 1){
+ size = num_row - row_offset_;
+ }
+ storage_.resize(size * num_col);
+
+ Log::Debug("server %d create matrixTable with %d row %d colums of %d rows.\n", server_id_, num_row, num_col, size);
+ }
+
+ void ProcessAdd(const std::vector& data) override {
+#ifdef MULTIVERSO_USE_BLAS
+ // MKL update
+ Log::Fatal("Not implemented yet\n");
+#else
+ CHECK(data.size() == 2);
+ Blob values = data[1], keys = data[0];
+
+ if (keys.size() == 1 && keys.As() == -1){
+ CHECK(storage_.size() == values.size());
+ for (int i = 0; i < storage_.size(); ++i){
+ storage_[i] += values.As(i);
+ }
+ Log::Debug("server %d adding all rows with row offset %d with %d rows\n", server_id_, row_offset_, storage_.size() / num_col_);
+ return;
+ }
+
+ CHECK(values.size() == keys.size() * sizeof(T)* num_col_);
+ for (int i = 0; i < keys.size(); ++i) {
+ int offset_v = i * num_col_;
+ int offset_s = (keys.As(i) -row_offset_) * num_col_;
+ for (int j = 0; j < num_col_; ++j){
+ storage_[j + offset_s] += values.As(offset_v + j);
+ }
+ Log::Debug("server %d adding row with id %d\n", server_id_, keys.As(i));
+ }
+#endif
+ }
+
+ void ProcessGet(const std::vector& data,
+ std::vector* result) override {
+ CHECK(data.size() == 1);
+ CHECK_NOTNULL(result);
+
+ Blob keys = data[0];
+ result->push_back(keys); // also push the key
+
+ //get all rows
+ if (keys.size() == 1 && keys.As() == -1){
+ result->push_back(Blob(storage_.data(), sizeof(T)* storage_.size()));
+ result->push_back(Blob(&row_offset_, sizeof(int)));
+ Log::Debug("server %d getting all rows with row offset %d with %d rows\n", server_id_, row_offset_, storage_.size() / num_col_);
+ return;
+ }
+
+
+ result->push_back(Blob(keys.size() * sizeof(T)* num_col_));
+ Blob& vals = (*result)[1];
+ for (int i = 0; i < keys.size(); ++i) {
+ int offset_v = i * num_col_;
+ int offset_s = (keys.As(i) -row_offset_) * num_col_;
+ memcpy(&(vals.As(offset_v)), &storage_[offset_s], sizeof(T)*num_col_);
+ Log::Debug("server %d getting row with id %d\n", server_id_, keys.As(i));
+ }
+ }
+
+ private:
+ int server_id_;
+ int num_col_;
+ int row_offset_;
+ std::vector storage_;
+ };
+}
+
+#endif // MULTIVERSO_ARRAY_TABLE_H_