performance improvement and fix bug
This commit is contained in:
Родитель
628b37d70b
Коммит
0a99f5e3e8
|
@ -503,8 +503,10 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
|||
int* data = new int[size];
|
||||
int* delta = new int[size];
|
||||
int* keys = new int[num_row];
|
||||
for (auto i = 0; i < size; ++i) {
|
||||
delta[i] = 1;
|
||||
for (auto row = 0; row < num_row; ++row) {
|
||||
for (auto col = 0; col < num_col; ++col) {
|
||||
delta[row * num_col + col] = row + 2;
|
||||
}
|
||||
}
|
||||
|
||||
UpdateOption option;
|
||||
|
@ -534,7 +536,9 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
|||
auto row_start = data + i * num_col;
|
||||
for (auto col = 0; col < num_col; ++col) {
|
||||
if (i % 10 <= p) {
|
||||
ASSERT_EQ(1, *(row_start + col)) << "Should be 1 after adding";
|
||||
auto expected = i + 2;
|
||||
auto actual = *(row_start + col);
|
||||
ASSERT_EQ(expected, actual) << "Should be updated after adding";
|
||||
}
|
||||
else {
|
||||
ASSERT_EQ(0, *(row_start + col)) << "Should be 0 for non update row values";
|
||||
|
@ -570,7 +574,9 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
|||
auto row_start = data + i * num_col;
|
||||
for (auto col = 0; col < num_col; ++col) {
|
||||
if (i % 10 <= p) {
|
||||
ASSERT_EQ(1, *(row_start + col)) << "Should be 1 after adding";
|
||||
auto expected = i + 2;
|
||||
auto actual = *(row_start + col);
|
||||
ASSERT_EQ(expected, actual) << "Should be updated after adding";
|
||||
}
|
||||
else {
|
||||
ASSERT_EQ(0, *(row_start + col)) << "Should be 0 for non update row values";
|
||||
|
|
|
@ -49,6 +49,7 @@ template <typename T>
|
|||
class SparseMatrixServerTable : public MatrixServerTable<T> {
|
||||
public:
|
||||
SparseMatrixServerTable(int num_row, int num_col, bool using_pipeline);
|
||||
~SparseMatrixServerTable();
|
||||
void ProcessAdd(const std::vector<Blob>& data) override;
|
||||
void ProcessGet(const std::vector<Blob>& data,
|
||||
std::vector<Blob>* result) override;
|
||||
|
@ -63,7 +64,9 @@ class SparseMatrixServerTable : public MatrixServerTable<T> {
|
|||
return global_row_id - row_offset_;
|
||||
}
|
||||
private:
|
||||
std::vector<std::vector<bool>> up_to_date_;
|
||||
bool** up_to_date_;
|
||||
int server_count_;
|
||||
// std::vector<std::vector<bool>> up_to_date_;
|
||||
};
|
||||
|
||||
} // namespace multiverso
|
||||
|
|
|
@ -262,6 +262,7 @@ void MatrixServerTable<T>::ProcessAdd(const std::vector<Blob>& data) {
|
|||
for (int i = 0; i < keys_size; ++i) {
|
||||
int offset_s = (keys[i] - row_offset_) * num_col_;
|
||||
updater_->Update(num_col_, storage_.data(), values + offset_v, option, offset_s);
|
||||
offset_v += num_col_;
|
||||
Log::Debug("[ProcessAdd] Server = %d, adding #row = %d\n",
|
||||
server_id_, keys[i]);
|
||||
}
|
||||
|
|
|
@ -147,17 +147,28 @@ void SparseMatrixWorkerTable<T>::ProcessReplyGet(
|
|||
MatrixWorkerTable<T>::ProcessReplyGet(reply_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SparseMatrixServerTable<T>::~SparseMatrixServerTable() {
|
||||
for (auto i = 0; i < server_count_; ++i) {
|
||||
delete[]up_to_date_[i];
|
||||
}
|
||||
delete[]up_to_date_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SparseMatrixServerTable<T>::SparseMatrixServerTable(int num_row, int num_col,
|
||||
bool using_pipeline) : MatrixServerTable<T>(num_row, num_col) {
|
||||
auto server_count = multiverso::MV_Size();
|
||||
server_count_ = multiverso::MV_Size();
|
||||
if (using_pipeline) {
|
||||
server_count *= 2;
|
||||
server_count_ *= 2;
|
||||
}
|
||||
|
||||
for (auto i = 0; i < my_num_row_; ++i) {
|
||||
up_to_date_.push_back(std::move(std::vector<bool>(server_count, false)));
|
||||
|
||||
up_to_date_ = new bool*[server_count_];
|
||||
for (auto i = 0; i < server_count_; ++i) {
|
||||
up_to_date_[i] = new bool[my_num_row_];
|
||||
memset(up_to_date_[i], 0, sizeof(bool) * my_num_row_);
|
||||
}
|
||||
Log::Info("SparseMatrixServerTable server_count_ %d", server_count_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -167,16 +178,16 @@ void SparseMatrixServerTable<T>::UpdateAddState(int worker_id,
|
|||
int *keys = reinterpret_cast<int*>(keys_blob.data());
|
||||
// add all values
|
||||
if (keys_size == 1 && keys[0] == -1) {
|
||||
for (auto local_row_id = 0; local_row_id < my_num_row_; ++local_row_id) {
|
||||
for (auto id = 0; id < up_to_date_[local_row_id].size(); ++id) {
|
||||
up_to_date_[local_row_id][id] = (id == worker_id);
|
||||
for (auto id = 0; id < server_count_; ++id) {
|
||||
for (auto local_row_id = 0; local_row_id < my_num_row_; ++local_row_id) {
|
||||
up_to_date_[id][local_row_id] = (id == worker_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < keys_size; ++i) {
|
||||
auto local_row_id = get_local_row_id(keys[i]);
|
||||
for (auto id = 0; id < up_to_date_[local_row_id].size(); ++id) {
|
||||
up_to_date_[local_row_id][id] = (id == worker_id);
|
||||
for (auto id = 0; id < server_count_; ++id) {
|
||||
for (int i = 0; i < keys_size; ++i) {
|
||||
auto local_row_id = get_local_row_id(keys[i]);
|
||||
up_to_date_[id][local_row_id] = (id == worker_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,26 +201,22 @@ void SparseMatrixServerTable<T>::UpdateGetState(int worker_id, int* keys,
|
|||
for (auto local_row_id = 0; local_row_id < my_num_row_; ++local_row_id) {
|
||||
out_rows->push_back(get_global_row_id(local_row_id));
|
||||
}
|
||||
}
|
||||
|
||||
// do not update flags for worker_id == -1
|
||||
if (worker_id == -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (key_size == 1 && keys[0] == -1) {
|
||||
for (auto local_row_id = 0; local_row_id < my_num_row_; ++local_row_id) {
|
||||
if (!up_to_date_[local_row_id][worker_id]) {
|
||||
if (!up_to_date_[worker_id][local_row_id]) {
|
||||
out_rows->push_back(get_global_row_id(local_row_id));
|
||||
up_to_date_[local_row_id][worker_id] = true;
|
||||
up_to_date_[worker_id][local_row_id] = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto i = 0; i < key_size; ++i) {
|
||||
auto global_row_id = keys[i];
|
||||
auto local_row_id = get_local_row_id(global_row_id);
|
||||
if (!up_to_date_[local_row_id][worker_id]) {
|
||||
up_to_date_[local_row_id][worker_id] = true;
|
||||
if (!up_to_date_[worker_id][local_row_id]) {
|
||||
up_to_date_[worker_id][local_row_id] = true;
|
||||
out_rows->push_back(global_row_id);
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче