unifying the interface
This commit is contained in:
Родитель
51735f98d7
Коммит
ea6c451fb6
|
@ -42,10 +42,10 @@ TEST_F(test_filter, test_should_compress_all_zero) {
|
|||
near_zero_value, near_zero_value, near_zero_value};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_EQ(compressed, true);
|
||||
auto de_compressed_blob =
|
||||
de_compress(compressed_blob, size * sizeof(float));
|
||||
DeCompress(compressed_blob, size * sizeof(float));
|
||||
ASSERT_TRUE(is_same_content_with_clipping(array, de_compressed_blob, size));
|
||||
}
|
||||
|
||||
|
@ -55,10 +55,10 @@ TEST_F(test_filter, test_should_compress_most_zero_a) {
|
|||
near_zero_value, none_zero_value};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_TRUE(compressed == true);
|
||||
auto de_compressed_blob =
|
||||
de_compress(compressed_blob, size * sizeof(float));
|
||||
DeCompress(compressed_blob, size * sizeof(float));
|
||||
ASSERT_TRUE(is_same_content_with_clipping(array, de_compressed_blob, size));
|
||||
}
|
||||
|
||||
|
@ -67,10 +67,10 @@ TEST_F(test_filter, test_should_compress_most_zero_b) {
|
|||
auto array = new float[size]{1, 2, 0, 0, 0, 0, 0, 3};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_TRUE(compressed == true);
|
||||
auto de_compressed_blob =
|
||||
de_compress(compressed_blob, size * sizeof(float));
|
||||
DeCompress(compressed_blob, size * sizeof(float));
|
||||
ASSERT_TRUE(is_same_content(array, de_compressed_blob, size));
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ TEST_F(test_filter, test_should_not_compress_half_zero) {
|
|||
auto array = new float[size]{1, 0, 2, 0, 3, 0, 0, 4};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_TRUE(compressed == false);
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ TEST_F(test_filter, test_should_not_compress_most_none_zero) {
|
|||
auto array = new float[size]{1, 2, 3, 4, 0, 0, 0, 5};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_TRUE(compressed == false);
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ TEST_F(test_filter, test_should_not_compress_all_none_zero) {
|
|||
none_zero_value, none_zero_value, none_zero_value};
|
||||
multiverso::Blob input_blob(array, sizeof(float) * size);
|
||||
multiverso::Blob compressed_blob;
|
||||
auto compressed = try_compress(input_blob, &compressed_blob);
|
||||
auto compressed = TryCompress(input_blob, &compressed_blob);
|
||||
ASSERT_TRUE(compressed == false);
|
||||
}
|
||||
|
||||
|
|
|
@ -53,8 +53,8 @@ class SparseMatrixServerTable : public MatrixServerTable<T> {
|
|||
void ProcessGet(const std::vector<Blob>& data,
|
||||
std::vector<Blob>* result) override;
|
||||
private:
|
||||
void update_state_on_add(int worker_id, Blob keys);
|
||||
void update_state_on_get(int worker_id, int* keys, int key_size,
|
||||
void UpdateAddState(int worker_id, Blob keys);
|
||||
void UpdateGetState(int worker_id, int* keys, int key_size,
|
||||
std::vector<int>* out_rows);
|
||||
int get_global_row_id(int local_row_id) {
|
||||
return row_offset_ + local_row_id;
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace multiverso {
|
|||
for (auto i = 0; i < blobs.size(); ++i) {
|
||||
auto& blob = blobs[i];
|
||||
Blob compressed_blob;
|
||||
auto compressed = try_compress(blob, &compressed_blob);
|
||||
auto compressed = TryCompress(blob, &compressed_blob);
|
||||
// size info (compressed ? size : -1)
|
||||
#pragma warning( push )
|
||||
#pragma warning( disable : 4267)
|
||||
|
@ -66,13 +66,13 @@ namespace multiverso {
|
|||
size_blob.As<int>(i - 1) : blobs[i].size();
|
||||
auto& blob = blobs[i];
|
||||
outputs->push_back(is_compressed ?
|
||||
std::move(de_compress(blob, size)) : blob);
|
||||
std::move(DeCompress(blob, size)) : blob);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
bool try_compress(const Blob& in_blob,
|
||||
bool TryCompress(const Blob& in_blob,
|
||||
Blob* out_blob) {
|
||||
CHECK_NOTNULL(out_blob);
|
||||
#pragma warning( push )
|
||||
|
@ -121,7 +121,7 @@ protected:
|
|||
return true;
|
||||
}
|
||||
|
||||
Blob de_compress(const Blob& in_blob, size_t size) {
|
||||
Blob DeCompress(const Blob& in_blob, size_t size) {
|
||||
#pragma warning( push )
|
||||
#pragma warning( disable : 4127)
|
||||
CHECK(sizeof(data_type) == sizeof(index_type));
|
||||
|
|
|
@ -28,6 +28,7 @@ void SparseMatrixWorkerTable<T>::Get(int row_id, T* data, size_t size,
|
|||
if (row_id >= 0) CHECK(size == num_col_);
|
||||
row_index_[row_id] = data; // data_ = data;
|
||||
Blob keys(&row_id, sizeof(int) * 2);
|
||||
// TODO[qiwye]: to make worker_id an Option.
|
||||
keys.As<int>(1) = worker_id;
|
||||
WorkerTable::Get(keys);
|
||||
Log::Debug("[Get] worker = %d, #row = %d\n", MV_Rank(), row_id);
|
||||
|
@ -41,7 +42,7 @@ void SparseMatrixWorkerTable<T>::Get(const std::vector<int>& row_ids,
|
|||
for (int i = 0; i < row_ids.size(); ++i) {
|
||||
row_index_[row_ids[i]] = data_vec[i];
|
||||
}
|
||||
Blob keys(row_ids.data(), sizeof(int)* (row_ids.size() + 1));
|
||||
Blob keys(row_ids.data(), sizeof(int) * (row_ids.size() + 1));
|
||||
keys.As<int>(row_ids.size()) = worker_id;
|
||||
WorkerTable::Get(keys);
|
||||
Log::Debug("[Get] worker = %d, #rows_set = %d\n", MV_Rank(),
|
||||
|
@ -82,6 +83,7 @@ int SparseMatrixWorkerTable<T>::Partition(const std::vector<Blob>& kv,
|
|||
}
|
||||
|
||||
// space for workder_id
|
||||
// TODO[qiwye]: to make worker_id an Option.
|
||||
for (auto &kv : count) {
|
||||
++kv.second;
|
||||
}
|
||||
|
@ -102,6 +104,7 @@ int SparseMatrixWorkerTable<T>::Partition(const std::vector<Blob>& kv,
|
|||
}
|
||||
|
||||
// append workder_id
|
||||
// TODO[qiwye]: to make worker_id an new Blob.
|
||||
for (auto& kv : *out) {
|
||||
kv.second[0].As<int>(kv.second[0].size<int>() - 1) = keys[keys_size];
|
||||
}
|
||||
|
@ -158,7 +161,7 @@ SparseMatrixServerTable<T>::SparseMatrixServerTable(int num_row, int num_col,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SparseMatrixServerTable<T>::update_state_on_add(int worker_id,
|
||||
void SparseMatrixServerTable<T>::UpdateAddState(int worker_id,
|
||||
Blob keys_blob) {
|
||||
size_t keys_size = keys_blob.size<int>();
|
||||
int *keys = reinterpret_cast<int*>(keys_blob.data());
|
||||
|
@ -180,7 +183,7 @@ void SparseMatrixServerTable<T>::update_state_on_add(int worker_id,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SparseMatrixServerTable<T>::update_state_on_get(int worker_id, int* keys,
|
||||
void SparseMatrixServerTable<T>::UpdateGetState(int worker_id, int* keys,
|
||||
int key_size, std::vector<int>* out_rows) {
|
||||
|
||||
if (worker_id == -1) {
|
||||
|
@ -241,7 +244,7 @@ void SparseMatrixServerTable<T>::ProcessAdd(
|
|||
// must contain option that has worker id
|
||||
CHECK(data.size() == 3);
|
||||
UpdateOption option(data[2].data(), data[2].size());
|
||||
update_state_on_add(option.worker_id(), data[0]);
|
||||
UpdateAddState(option.worker_id(), data[0]);
|
||||
MatrixServerTable<T>::ProcessAdd(data);
|
||||
}
|
||||
|
||||
|
@ -263,7 +266,7 @@ void SparseMatrixServerTable<T>::ProcessGet(
|
|||
std::vector<int> outdate_rows;
|
||||
#pragma warning( push )
|
||||
#pragma warning( disable : 4267)
|
||||
update_state_on_get(workder_id, keys, keys_size, &outdate_rows);
|
||||
UpdateGetState(workder_id, keys, keys_size, &outdate_rows);
|
||||
#pragma warning( pop )
|
||||
Blob outdate_rows_blob(sizeof(int) * outdate_rows.size());
|
||||
for (auto i = 0; i < outdate_rows.size(); ++i) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче