This commit is contained in:
Junjie Li 2016-04-20 18:10:21 +08:00
Родитель 0aa21f4dd4
Коммит 9619c084e5
3 изменённых файлов: 17 добавлений и 5 удалений

Просмотреть файл

@ -489,7 +489,7 @@ void TestSparseMatrixTable(int argc, char* argv[]) {
void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
Log::ResetLogLevel(LogLevel::Debug);
Log::ResetLogLevel(LogLevel::Error);
Log::Info("Test Sparse Matrix\n");
Timer timmer;
@ -543,7 +543,7 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
}
timmer.Start();
worker_table->Get(data, size, worker_id);
std::cout << " " << timmer.elapse() << "s:\t" << "get all rows after adding to rows" << std::endl;
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows after adding to rows" << std::endl;
}
}
else {
@ -579,7 +579,7 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
}
timmer.Start();
worker_table->Get(data, size);
std::cout << " " << timmer.elapse() << "s:\t" << "get all rows after adding to rows" << std::endl;
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows after adding to rows" << std::endl;
}
}
Log::ResetLogLevel(LogLevel::Info);

Просмотреть файл

@ -41,6 +41,8 @@ class SparseMatrixWorkerTable : public MatrixWorkerTable<T> {
void Get(const std::vector<int>& row_ids,
const std::vector<T*>& data_vec, size_t size) = delete;
T* base_buf;
};
template <typename T>

Просмотреть файл

@ -16,6 +16,7 @@ bool split_rows = true;
template <typename T>
void SparseMatrixWorkerTable<T>::Get(T* data, size_t size,
int worker_id) {
base_buf == data;
CHECK(size == num_col_ * num_row_);
int whole_table = -1;
Get(whole_table, data, size, worker_id);
@ -27,6 +28,11 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
int worker_id) {
if (row_id >= 0) CHECK(size == num_col_);
row_index_[row_id] = data; // data_ = data;
if (row_id == -1) {
base_buf = data;
} else {
base_buf = nullptr;
}
Blob keys(&row_id, sizeof(int) * 2);
// TODO[qiwye]: to make worker_id an Option.
keys.As<int>(1) = worker_id;
@ -37,6 +43,11 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
template <typename T>
void SparseMatrixWorkerTable<T>::Get(const std::vector<int>& row_ids,
const std::vector<T*>& data_vec, size_t size, int worker_id) {
if (row_ids.size() == 1 && row_ids[0] == -1) {
base_buf = data_vec[0];
} else {
base_buf = nullptr;
}
CHECK(size == num_col_);
CHECK(row_ids.size() == data_vec.size());
for (int i = 0; i < row_ids.size(); ++i) {
@ -133,8 +144,7 @@ void SparseMatrixWorkerTable<T>::ProcessReplyGet(
std::vector<Blob>& reply_data) {
if (split_rows) {
// replace row_index when original key == -1
if (row_index_.size() == 1 && row_index_.find(-1) != row_index_.end()) {
T* base_buf = row_index_[-1];
if (base_buf != nullptr) {
size_t keys_size = reply_data[0].size<int>();
int *keys = reinterpret_cast<int*>(reply_data[0].data());
row_index_.clear();