This commit is contained in:
Qiwei Ye 2016-04-20 13:43:17 +08:00
Родитель 51735f98d7
Коммит ea6c451fb6
4 изменённых файлов: 23 добавлений и 20 удалений

Просмотреть файл

@ -42,10 +42,10 @@ TEST_F(test_filter, test_should_compress_all_zero) {
near_zero_value, near_zero_value, near_zero_value};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_EQ(compressed, true);
auto de_compressed_blob =
de_compress(compressed_blob, size * sizeof(float));
DeCompress(compressed_blob, size * sizeof(float));
ASSERT_TRUE(is_same_content_with_clipping(array, de_compressed_blob, size));
}
@ -55,10 +55,10 @@ TEST_F(test_filter, test_should_compress_most_zero_a) {
near_zero_value, none_zero_value};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_TRUE(compressed == true);
auto de_compressed_blob =
de_compress(compressed_blob, size * sizeof(float));
DeCompress(compressed_blob, size * sizeof(float));
ASSERT_TRUE(is_same_content_with_clipping(array, de_compressed_blob, size));
}
@ -67,10 +67,10 @@ TEST_F(test_filter, test_should_compress_most_zero_b) {
auto array = new float[size]{1, 2, 0, 0, 0, 0, 0, 3};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_TRUE(compressed == true);
auto de_compressed_blob =
de_compress(compressed_blob, size * sizeof(float));
DeCompress(compressed_blob, size * sizeof(float));
ASSERT_TRUE(is_same_content(array, de_compressed_blob, size));
}
@ -79,7 +79,7 @@ TEST_F(test_filter, test_should_not_compress_half_zero) {
auto array = new float[size]{1, 0, 2, 0, 3, 0, 0, 4};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_TRUE(compressed == false);
}
@ -88,7 +88,7 @@ TEST_F(test_filter, test_should_not_compress_most_none_zero) {
auto array = new float[size]{1, 2, 3, 4, 0, 0, 0, 5};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_TRUE(compressed == false);
}
@ -99,7 +99,7 @@ TEST_F(test_filter, test_should_not_compress_all_none_zero) {
none_zero_value, none_zero_value, none_zero_value};
multiverso::Blob input_blob(array, sizeof(float) * size);
multiverso::Blob compressed_blob;
auto compressed = try_compress(input_blob, &compressed_blob);
auto compressed = TryCompress(input_blob, &compressed_blob);
ASSERT_TRUE(compressed == false);
}

Просмотреть файл

@ -53,8 +53,8 @@ class SparseMatrixServerTable : public MatrixServerTable<T> {
void ProcessGet(const std::vector<Blob>& data,
std::vector<Blob>* result) override;
private:
void update_state_on_add(int worker_id, Blob keys);
void update_state_on_get(int worker_id, int* keys, int key_size,
void UpdateAddState(int worker_id, Blob keys);
void UpdateGetState(int worker_id, int* keys, int key_size,
std::vector<int>* out_rows);
int get_global_row_id(int local_row_id) {
return row_offset_ + local_row_id;

Просмотреть файл

@ -41,7 +41,7 @@ namespace multiverso {
for (auto i = 0; i < blobs.size(); ++i) {
auto& blob = blobs[i];
Blob compressed_blob;
auto compressed = try_compress(blob, &compressed_blob);
auto compressed = TryCompress(blob, &compressed_blob);
// size info (compressed ? size : -1)
#pragma warning( push )
#pragma warning( disable : 4267)
@ -66,13 +66,13 @@ namespace multiverso {
size_blob.As<int>(i - 1) : blobs[i].size();
auto& blob = blobs[i];
outputs->push_back(is_compressed ?
std::move(de_compress(blob, size)) : blob);
std::move(DeCompress(blob, size)) : blob);
}
}
protected:
bool try_compress(const Blob& in_blob,
bool TryCompress(const Blob& in_blob,
Blob* out_blob) {
CHECK_NOTNULL(out_blob);
#pragma warning( push )
@ -121,7 +121,7 @@ protected:
return true;
}
Blob de_compress(const Blob& in_blob, size_t size) {
Blob DeCompress(const Blob& in_blob, size_t size) {
#pragma warning( push )
#pragma warning( disable : 4127)
CHECK(sizeof(data_type) == sizeof(index_type));

Просмотреть файл

@ -28,6 +28,7 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
if (row_id >= 0) CHECK(size == num_col_);
row_index_[row_id] = data; // data_ = data;
Blob keys(&row_id, sizeof(int) * 2);
// TODO[qiwye]: to make worker_id an Option.
keys.As<int>(1) = worker_id;
WorkerTable::Get(keys);
Log::Debug("[Get] worker = %d, #row = %d\n", MV_Rank(), row_id);
@ -41,7 +42,7 @@ void SparseMatrixWorkerTable<T>::Get(const std::vector<int>& row_ids,
for (int i = 0; i < row_ids.size(); ++i) {
row_index_[row_ids[i]] = data_vec[i];
}
Blob keys(row_ids.data(), sizeof(int)* (row_ids.size() + 1));
Blob keys(row_ids.data(), sizeof(int) * (row_ids.size() + 1));
keys.As<int>(row_ids.size()) = worker_id;
WorkerTable::Get(keys);
Log::Debug("[Get] worker = %d, #rows_set = %d\n", MV_Rank(),
@ -82,6 +83,7 @@ int SparseMatrixWorkerTable<T>::Partition(const std::vector<Blob>& kv,
}
// space for workder_id
// TODO[qiwye]: to make worker_id an Option.
for (auto &kv : count) {
++kv.second;
}
@ -102,6 +104,7 @@ int SparseMatrixWorkerTable<T>::Partition(const std::vector<Blob>& kv,
}
// append workder_id
// TODO[qiwye]: to make worker_id an new Blob.
for (auto& kv : *out) {
kv.second[0].As<int>(kv.second[0].size<int>() - 1) = keys[keys_size];
}
@ -158,7 +161,7 @@ SparseMatrixServerTable<T>::SparseMatrixServerTable(int num_row, int num_col,
}
template <typename T>
void SparseMatrixServerTable<T>::update_state_on_add(int worker_id,
void SparseMatrixServerTable<T>::UpdateAddState(int worker_id,
Blob keys_blob) {
size_t keys_size = keys_blob.size<int>();
int *keys = reinterpret_cast<int*>(keys_blob.data());
@ -180,7 +183,7 @@ void SparseMatrixServerTable<T>::update_state_on_add(int worker_id,
}
template <typename T>
void SparseMatrixServerTable<T>::update_state_on_get(int worker_id, int* keys,
void SparseMatrixServerTable<T>::UpdateGetState(int worker_id, int* keys,
int key_size, std::vector<int>* out_rows) {
if (worker_id == -1) {
@ -241,7 +244,7 @@ void SparseMatrixServerTable<T>::ProcessAdd(
// must contain option that has worker id
CHECK(data.size() == 3);
UpdateOption option(data[2].data(), data[2].size());
update_state_on_add(option.worker_id(), data[0]);
UpdateAddState(option.worker_id(), data[0]);
MatrixServerTable<T>::ProcessAdd(data);
}
@ -263,7 +266,7 @@ void SparseMatrixServerTable<T>::ProcessGet(
std::vector<int> outdate_rows;
#pragma warning( push )
#pragma warning( disable : 4267)
update_state_on_get(workder_id, keys, keys_size, &outdate_rows);
UpdateGetState(workder_id, keys, keys_size, &outdate_rows);
#pragma warning( pop )
Blob outdate_rows_blob(sizeof(int) * outdate_rows.size());
for (auto i = 0; i < outdate_rows.size(); ++i) {