fix bug
This commit is contained in:
Родитель
0aa21f4dd4
Коммит
9619c084e5
|
@ -489,7 +489,7 @@ void TestSparseMatrixTable(int argc, char* argv[]) {
|
|||
|
||||
|
||||
void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
||||
Log::ResetLogLevel(LogLevel::Debug);
|
||||
Log::ResetLogLevel(LogLevel::Error);
|
||||
Log::Info("Test Sparse Matrix\n");
|
||||
Timer timmer;
|
||||
|
||||
|
@ -543,7 +543,7 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
|||
}
|
||||
timmer.Start();
|
||||
worker_table->Get(data, size, worker_id);
|
||||
std::cout << " " << timmer.elapse() << "s:\t" << "get all rows after adding to rows" << std::endl;
|
||||
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows after adding to rows" << std::endl;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -579,7 +579,7 @@ void TestMatrixPerformance(int argc, char* argv[], bool sparse) {
|
|||
}
|
||||
timmer.Start();
|
||||
worker_table->Get(data, size);
|
||||
std::cout << " " << timmer.elapse() << "s:\t" << "get all rows after adding to rows" << std::endl;
|
||||
std::cout << " " << 1.0 * timmer.elapse() / 1000 << "s:\t" << "get all rows after adding to rows" << std::endl;
|
||||
}
|
||||
}
|
||||
Log::ResetLogLevel(LogLevel::Info);
|
||||
|
|
|
@ -41,6 +41,8 @@ class SparseMatrixWorkerTable : public MatrixWorkerTable<T> {
|
|||
|
||||
void Get(const std::vector<int>& row_ids,
|
||||
const std::vector<T*>& data_vec, size_t size) = delete;
|
||||
|
||||
T* base_buf;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -16,6 +16,7 @@ bool split_rows = true;
|
|||
template <typename T>
|
||||
void SparseMatrixWorkerTable<T>::Get(T* data, size_t size,
|
||||
int worker_id) {
|
||||
base_buf == data;
|
||||
CHECK(size == num_col_ * num_row_);
|
||||
int whole_table = -1;
|
||||
Get(whole_table, data, size, worker_id);
|
||||
|
@ -27,6 +28,11 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
|
|||
int worker_id) {
|
||||
if (row_id >= 0) CHECK(size == num_col_);
|
||||
row_index_[row_id] = data; // data_ = data;
|
||||
if (row_id == -1) {
|
||||
base_buf = data;
|
||||
} else {
|
||||
base_buf = nullptr;
|
||||
}
|
||||
Blob keys(&row_id, sizeof(int) * 2);
|
||||
// TODO[qiwye]: to make worker_id an Option.
|
||||
keys.As<int>(1) = worker_id;
|
||||
|
@ -37,6 +43,11 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
|
|||
template <typename T>
|
||||
void SparseMatrixWorkerTable<T>::Get(const std::vector<int>& row_ids,
|
||||
const std::vector<T*>& data_vec, size_t size, int worker_id) {
|
||||
if (row_ids.size() == 1 && row_ids[0] == -1) {
|
||||
base_buf = data_vec[0];
|
||||
} else {
|
||||
base_buf = nullptr;
|
||||
}
|
||||
CHECK(size == num_col_);
|
||||
CHECK(row_ids.size() == data_vec.size());
|
||||
for (int i = 0; i < row_ids.size(); ++i) {
|
||||
|
@ -133,8 +144,7 @@ void SparseMatrixWorkerTable<T>::ProcessReplyGet(
|
|||
std::vector<Blob>& reply_data) {
|
||||
if (split_rows) {
|
||||
// replace row_index when original key == -1
|
||||
if (row_index_.size() == 1 && row_index_.find(-1) != row_index_.end()) {
|
||||
T* base_buf = row_index_[-1];
|
||||
if (base_buf != nullptr) {
|
||||
size_t keys_size = reply_data[0].size<int>();
|
||||
int *keys = reinterpret_cast<int*>(reply_data[0].data());
|
||||
row_index_.clear();
|
||||
|
|
Загрузка…
Ссылка в новой задаче