diff --git a/Test/main.cpp b/Test/main.cpp index 23710b3..ec1ff3c 100644 --- a/Test/main.cpp +++ b/Test/main.cpp @@ -88,8 +88,6 @@ void TestArray(int argc, char* argv[]) { int iter = 1000; - if (argc == 2) iter = atoi(argv[1]); - for (int i = 0; i < iter; ++i) { // std::vector& vec = shared_array->raw(); diff --git a/include/multiverso/net/allreduce_engine.h b/include/multiverso/net/allreduce_engine.h index f76bb3e..b5e5e0f 100644 --- a/include/multiverso/net/allreduce_engine.h +++ b/include/multiverso/net/allreduce_engine.h @@ -43,8 +43,8 @@ public: */ enum RecursiveHalvingNodeType { Normal, //normal node, 1 group only have 1 machine - ReciveNeighbor, //leader of group when number of machines in this group is 2. - SendNeighbor// non-leader machines in group + GroupLeader, //leader of group when number of machines in this group is 2. + Other// non-leader machines in group }; /*! \brief Network structure for recursive halving algorithm */ diff --git a/include/multiverso/table_interface.h b/include/multiverso/table_interface.h index a9a9470..36e4d59 100644 --- a/include/multiverso/table_interface.h +++ b/include/multiverso/table_interface.h @@ -1,6 +1,7 @@ #ifndef MULTIVERSO_TABLE_INTERFACE_H_ #define MULTIVERSO_TABLE_INTERFACE_H_ +#include #include #include #include @@ -40,7 +41,8 @@ public: private: std::string table_name_; int table_id_; - std::unordered_map waitings_; + std::mutex m_; + std::vector waitings_; int msg_id_; }; diff --git a/include/multiverso/util/configure.h b/include/multiverso/util/configure.h index 59780e6..a3b3702 100644 --- a/include/multiverso/util/configure.h +++ b/include/multiverso/util/configure.h @@ -102,7 +102,12 @@ void SetCMDFlag(const std::string& name, const T& value) { #define MV_DECLARE_bool(name) \ DECLARE_CONFIGURE(bool, name) + +#define MV_DEFINE_double(name, default_value, text) \ + DEFINE_CONFIGURE(double, name, default_value, text) +#define MV_DECLARE_double(name) \ + DECLARE_CONFIGURE(double, name) } // namespace multiverso #endif // MULTIVERSO_UTIL_CONFIGURE_H_ diff --git a/src/Multiverso.vcxproj.filters b/src/Multiverso.vcxproj.filters index 75165dd..035facb 100644 --- a/src/Multiverso.vcxproj.filters +++ b/src/Multiverso.vcxproj.filters @@ -1,12 +1,6 @@  - - include - - - include - table @@ -112,11 +106,14 @@ io + + system + + + system + - - {befa27d3-15da-409a-a02e-fc1d9e676f80} - {f1d6488a-2e66-4610-8676-304c070d7b49} diff --git a/src/net/allreduce_engine.cpp b/src/net/allreduce_engine.cpp index 73ff6e4..aff1f2d 100644 --- a/src/net/allreduce_engine.cpp +++ b/src/net/allreduce_engine.cpp @@ -123,11 +123,11 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size, bool is_powerof_2 = (num_machines_ & (num_machines_ - 1)) == 0 ? true : false; if (!is_powerof_2) { - if (recursive_halving_map_.type == RecursiveHalvingNodeType::SendNeighbor) { + if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) { //send local data to neighbor first linkers_->Send(recursive_halving_map_.neighbor, input, 0, input_size); } - else if (recursive_halving_map_.type == RecursiveHalvingNodeType::ReciveNeighbor) { + else if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) { //recieve neighbor data first int need_recv_cnt = input_size; linkers_->Receive(recursive_halving_map_.neighbor, output, 0, need_recv_cnt); @@ -135,7 +135,7 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size, } } //start recursive halfing - if (recursive_halving_map_.type != RecursiveHalvingNodeType::SendNeighbor) { + if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) { for (int i = 0; i < recursive_halving_map_.k; ++i) { int target = recursive_halving_map_.ranks[i]; @@ -160,11 +160,11 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size, int my_reduce_block_idx = rank_; if (!is_powerof_2) { - if (recursive_halving_map_.type == RecursiveHalvingNodeType::ReciveNeighbor) { + if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) { //send result to neighbor linkers_->Send(recursive_halving_map_.neighbor, input, block_start[recursive_halving_map_.neighbor], block_len[recursive_halving_map_.neighbor]); } - else if (recursive_halving_map_.type == RecursiveHalvingNodeType::SendNeighbor) { + else if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) { //receive result from neighbor int need_recv_cnt = block_len[my_reduce_block_idx]; linkers_->Receive(recursive_halving_map_.neighbor, output, 0, need_recv_cnt); diff --git a/src/net/allreduce_topo.cpp b/src/net/allreduce_topo.cpp index 31600eb..232fe23 100644 --- a/src/net/allreduce_topo.cpp +++ b/src/net/allreduce_topo.cpp @@ -5,13 +5,13 @@ namespace multiverso { - BruckMap::BruckMap() { k = 0; } BruckMap::BruckMap(int n) { k = n; + // default set to -1 for (int i = 0; i < n; ++i) { in_ranks.push_back(-1); out_ranks.push_back(-1); @@ -19,19 +19,21 @@ BruckMap::BruckMap(int n) { } BruckMap BruckMap::Construct(int rank, int num_machines) { - int* dist = new int[num_machines]; + // distance at k-th communication, distance[k] = 2^k + std::vector distance; int k = 0; for (k = 0; (1 << k) < num_machines; k++) { - dist[k] = 1 << k; + distance.push_back(1 << k); } BruckMap bruckMap(k); for (int j = 0; j < k; ++j) { - int ni = (rank + dist[j]) % num_machines; - bruckMap.in_ranks[j] = ni; - ni = (rank - dist[j] + num_machines) % num_machines; - bruckMap.out_ranks[j] = ni; + // set incoming rank at k-th commuication + const int in_rank = (rank + distance[j]) % num_machines; + bruckMap.in_ranks[j] = in_rank; + // set outgoing rank at k-th commuication + const int out_rank = (rank - distance[j] + num_machines) % num_machines; + bruckMap.out_ranks[j] = out_rank; } - delete[] dist; return bruckMap; } @@ -41,9 +43,10 @@ RecursiveHalvingMap::RecursiveHalvingMap() { } RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n) { type = _type; - if (type != RecursiveHalvingNodeType::SendNeighbor) { - k = n; + k = n; + if (type != RecursiveHalvingNodeType::Other) { for (int i = 0; i < n; ++i) { + // defalut set as -1 ranks.push_back(-1); send_block_start.push_back(-1); send_block_len.push_back(-1); @@ -52,107 +55,117 @@ RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n) } } } + RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { - std::vector rec_maps; - for (int i = 0; i < num_machines; ++i) { - rec_maps.push_back(RecursiveHalvingMap()); - } - int* distance = new int[num_machines]; - RecursiveHalvingNodeType* node_type = new RecursiveHalvingNodeType[num_machines]; + // construct all recursive halving map for all machines int k = 0; - for (k = 0; (1 << k) < num_machines; k++) { - distance[k] = 1 << k; + while ((1 << k) <= num_machines) { ++k; } + // let 1 << k <= num_machines + --k; + // distance of each communication + std::vector distance; + for (int i = 0; i < k; ++i) { + distance.push_back(1 << (k - 1 - i)); } + if ((1 << k) == num_machines) { + RecursiveHalvingMap rec_map(RecursiveHalvingNodeType::Normal, k); + // if num_machines = 2^k, don't need to group machines for (int i = 0; i < k; ++i) { - distance[i] = 1 << (k - 1 - i); - } - for (int i = 0; i < num_machines; ++i) { - rec_maps[i] = RecursiveHalvingMap(RecursiveHalvingNodeType::Normal, k); - for (int j = 0; j < k; ++j) { - int dir = ((i / distance[j]) % 2 == 0) ? 1 : -1; - int ni = i + dir * distance[j]; - rec_maps[i].ranks[j] = ni; - int t = i / distance[j]; - rec_maps[i].recv_block_start[j] = t * distance[j]; - rec_maps[i].recv_block_len[j] = distance[j]; - } + // communication direction, %2 == 0 is positive + const int dir = ((rank / distance[i]) % 2 == 0) ? 1 : -1; + // neighbor at k-th communication + const int next_node_idx = rank + dir * distance[i]; + rec_map.ranks[i] = next_node_idx; + // receive data block at k-th communication + const int recv_block_start = rank / distance[i]; + rec_map.recv_block_start[i] = recv_block_start * distance[i]; + rec_map.recv_block_len[i] = distance[i]; + // send data block at k-th communication + const int send_block_start = next_node_idx / distance[i]; + rec_map.send_block_start[i] = send_block_start * distance[i]; + rec_map.send_block_len[i] = distance[i]; } + return rec_map; } else { - k--; + // if num_machines != 2^k, need to group machines + int lower_power_of_2 = 1 << k; - int rest = num_machines - (1 << k); + int rest = num_machines - lower_power_of_2; + + std::vector node_type; for (int i = 0; i < num_machines; ++i) { - node_type[i] = RecursiveHalvingNodeType::Normal; + node_type.push_back(RecursiveHalvingNodeType::Normal); } + // group, two machine in one group, total "rest" groups will have 2 machines. for (int i = 0; i < rest; ++i) { - int r = num_machines - i * 2 - 1; - int l = num_machines - i * 2 - 2; - node_type[l] = RecursiveHalvingNodeType::ReciveNeighbor; - node_type[r] = RecursiveHalvingNodeType::SendNeighbor; - } - for (int i = 0; i < k; ++i) { - distance[i] = 1 << (k - 1 - i); + int right = num_machines - i * 2 - 1; + int left = num_machines - i * 2 - 2; + // let left machine as group leader + node_type[left] = RecursiveHalvingNodeType::GroupLeader; + node_type[right] = RecursiveHalvingNodeType::Other; } + int group_cnt = 0; + // cache block information for groups, group with 2 machines will have double block size + std::vector group_block_start(lower_power_of_2); + std::vector group_block_len(lower_power_of_2, 0); + // convert from group to node leader + std::vector group_to_node(lower_power_of_2); + // convert from node to group + std::vector node_to_group(num_machines); - int group_idx = 0; - int* map_len = new int[lower_power_of_2]; - int* map_start = new int[lower_power_of_2]; - int* group_2_node = new int[lower_power_of_2]; - int* node_to_group = new int[num_machines]; - for (int i = 0; i < lower_power_of_2; ++i) { - map_len[i] = 0; - } for (int i = 0; i < num_machines; ++i) { - if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::ReciveNeighbor) { - group_2_node[group_idx++] = i; - + // meet new group + if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::GroupLeader) { + group_to_node[group_cnt++] = i; } - map_len[group_idx - 1]++; - node_to_group[i] = group_idx - 1; + node_to_group[i] = group_cnt - 1; + // add block len for this group + group_block_len[group_cnt - 1]++; } - map_start[0] = 0; + // calculate the group block start + group_block_start[0] = 0; for (int i = 1; i < lower_power_of_2; ++i) { - map_start[i] = map_start[i - 1] + map_len[i - 1]; + group_block_start[i] = group_block_start[i - 1] + group_block_len[i - 1]; } - for (int i = 0; i < num_machines; ++i) { - - if (node_type[i] == RecursiveHalvingNodeType::SendNeighbor) { - rec_maps[i] = RecursiveHalvingMap(RecursiveHalvingNodeType::SendNeighbor, k); - rec_maps[i].neighbor = i - 1; - continue; - } - rec_maps[i] = RecursiveHalvingMap(node_type[i], k); - if (node_type[i] == RecursiveHalvingNodeType::ReciveNeighbor) { - rec_maps[i].neighbor = i + 1; - } - for (int j = 0; j < k; ++j) { - int dir = ((node_to_group[i] / distance[j]) % 2 == 0) ? 1 : -1; - group_idx = group_2_node[(node_to_group[i] + dir * distance[j])]; - rec_maps[i].ranks[j] = group_idx; - int t = node_to_group[i] / distance[j]; - rec_maps[i].recv_block_start[j] = map_start[t * distance[j]]; - int tl = 0; - for (int tmp_i = 0; tmp_i < distance[j]; ++tmp_i) { - tl += map_len[t * distance[j] + tmp_i]; - } - rec_maps[i].recv_block_len[j] = tl; - } + RecursiveHalvingMap rec_map(node_type[rank], k); + if (node_type[rank] == RecursiveHalvingNodeType::Other) { + rec_map.neighbor = rank - 1; + // not need to construct + return rec_map; } - } - for (int i = 0; i < num_machines; ++i) { - if (rec_maps[i].type != RecursiveHalvingNodeType::SendNeighbor) { - for (int j = 0; j < k; ++j) { - int target = rec_maps[i].ranks[j]; - rec_maps[i].send_block_start[j] = rec_maps[target].recv_block_start[j]; - rec_maps[i].send_block_len[j] = rec_maps[target].recv_block_len[j]; - } + if (node_type[rank] == RecursiveHalvingNodeType::GroupLeader) { + rec_map.neighbor = rank + 1; } + const int cur_group_idx = node_to_group[rank]; + for (int i = 0; i < k; ++i) { + const int dir = ((cur_group_idx / distance[i]) % 2 == 0) ? 1 : -1; + const int next_node_idx = group_to_node[(cur_group_idx + dir * distance[i])]; + rec_map.ranks[i] = next_node_idx; + // get receive block informations + const int recv_block_start = cur_group_idx / distance[i]; + rec_map.recv_block_start[i] = group_block_start[recv_block_start * distance[i]]; + int recv_block_len = 0; + // accumulate block len + for (int j = 0; j < distance[i]; ++j) { + recv_block_len += group_block_len[recv_block_start * distance[i] + j]; + } + rec_map.recv_block_len[i] = recv_block_len; + // get send block informations + const int send_block_start = (cur_group_idx + dir * distance[i]) / distance[i]; + rec_map.send_block_start[i] = group_block_start[send_block_start * distance[i]]; + int send_block_len = 0; + // accumulate block len + for (int j = 0; j < distance[i]; ++j) { + send_block_len += group_block_len[send_block_start * distance[i] + j]; + } + rec_map.send_block_len[i] = send_block_len; + } + return rec_map; } - return rec_maps[rank]; } } diff --git a/src/server.cpp b/src/server.cpp index a6fe926..f123613 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -18,6 +18,7 @@ namespace multiverso { MV_DEFINE_bool(sync, false, "sync or async"); +MV_DEFINE_int(backup_worker_ratio, 0, "ratio% of backup workers, set 20 means 20%"); Server::Server() : Actor(actor::kServer) { RegisterHandler(MsgType::Request_Get, std::bind( @@ -60,25 +61,149 @@ void Server::ProcessAdd(MessagePtr& msg) { // If worker k has add delta to server j times when its i-th Get // then the server will return the parameter after all K // workers finished their j-th update -class SyncServer : public Server { + +// TODO(feiga): to delete this, SyncServer is a special case for +// BackupWorkerSyncServer + +//class SyncServer : public Server { +//public: +// SyncServer() : Server() { +// int num_worker = Zoo::Get()->num_workers(); +// worker_get_clocks_.reset(new VectorClock(num_worker)); +// worker_add_clocks_.reset(new VectorClock(num_worker)); +// } +// +// // make some modification to suit to the sync server +// // please not use in other place, may different with the general vector clock +// class VectorClock { +// public: +// explicit VectorClock(int n) : +// local_clock_(n, 0), global_clock_(0), size_(0) {} +// +// // Return true when all clock reach a same number +// virtual bool Update(int i) { +// ++local_clock_[i]; +// if (global_clock_ < *(std::min_element(std::begin(local_clock_), +// std::end(local_clock_)))) { +// ++global_clock_; +// if (global_clock_ == *(std::max_element(std::begin(local_clock_), +// std::end(local_clock_)))) { +// return true; +// } +// } +// return false; +// } +// +// std::string DebugString() { +// std::string os = "global "; +// os += std::to_string(global_clock_) + " local: "; +// for (auto i : local_clock_) os += std::to_string(i) + " "; +// return os; +// } +// +// int local_clock(int i) const { return local_clock_[i]; } +// int global_clock() const { return global_clock_; } +// +// protected: +// std::vector local_clock_; +// int global_clock_; +// int size_; +// }; +//protected: +// void ProcessAdd(MessagePtr& msg) override { +// // 1. Before add: cache faster worker +// int worker = Zoo::Get()->rank_to_worker_id(msg->src()); +// if (worker_get_clocks_->local_clock(worker) > +// worker_get_clocks_->global_clock()) { +// msg_add_cache_.Push(msg); +// return; +// } +// // 2. Process Add +// Server::ProcessAdd(msg); +// // 3. After add: process cached process get if necessary +// if (worker_add_clocks_->Update(worker)) { +// CHECK(msg_add_cache_.Empty()); +// while (!msg_get_cache_.Empty()) { +// MessagePtr get_msg; +// CHECK(msg_get_cache_.TryPop(get_msg)); +// int get_worker = Zoo::Get()->rank_to_worker_id(get_msg->src()); +// Server::ProcessGet(get_msg); +// worker_get_clocks_->Update(get_worker); +// } +// } +// } +// +// void ProcessGet(MessagePtr& msg) override { +// // 1. Before get: cache faster worker +// int worker = Zoo::Get()->rank_to_worker_id(msg->src()); +// if (worker_add_clocks_->local_clock(worker) > +// worker_add_clocks_->global_clock()) { +// // Will wait for other worker finished Add +// msg_get_cache_.Push(msg); +// return; +// } +// // 2. Process Get +// Server::ProcessGet(msg); +// // 3. After get: process cached process add if necessary +// if (worker_get_clocks_->Update(worker)) { +// CHECK(msg_get_cache_.Empty()); +// while (!msg_add_cache_.Empty()) { +// MessagePtr add_msg; +// CHECK(msg_add_cache_.TryPop(add_msg)); +// int add_worker = Zoo::Get()->rank_to_worker_id(add_msg->src()); +// Server::ProcessAdd(add_msg); +// worker_add_clocks_->Update(add_worker); +// } +// } +// } +// +//private: +// std::unique_ptr worker_get_clocks_; +// std::unique_ptr worker_add_clocks_; +// +// MtQueue msg_add_cache_; +// MtQueue msg_get_cache_; +//}; + +class WithBackupSyncServer : public Server { public: - SyncServer() : Server() { - int num_worker = Zoo::Get()->num_workers(); - worker_get_clocks_.reset(new VectorClock(num_worker)); - worker_add_clocks_.reset(new VectorClock(num_worker)); + WithBackupSyncServer() : Server() { + num_worker_ = Zoo::Get()->num_workers(); + double backup_ratio = (double)MV_CONFIG_backup_worker_ratio / 100; + num_sync_worker_ = num_worker_ - + static_cast(backup_ratio * num_worker_); + CHECK(num_sync_worker_ > 0 && num_sync_worker_ <= num_worker_); + if (num_sync_worker_ == num_worker_) { + Log::Info("No backup worker, using the sync mode\n"); + } + Log::Info("Sync with backup worker start: num_sync_worker = %d," + "num_total_worker = %d\n", num_sync_worker_, num_worker_); + worker_get_clocks_.reset(new VectorClock(num_worker_)); + worker_add_clocks_.reset(new VectorClock( + num_worker_, num_worker_ - num_sync_worker_)); } // make some modification to suit to the sync server // please not use in other place, may different with the general vector clock class VectorClock { public: - explicit VectorClock(int n) : local_clock_(n, 0), global_clock_(0) {} + VectorClock(int num_worker, int num_backup_worker = 0) : + local_clock_(num_worker, 0), global_clock_(0), num_worker_(num_worker), + num_sync_worker_(num_worker - num_backup_worker), progress_(0) {} - bool Update(int i) { - ++local_clock_[i]; - if (global_clock_ < *(std::min_element(std::begin(local_clock_), - std::end(local_clock_)))) { + // Return true when global clock meet the sync condition + // sync: all worker reach the same clock + // backup-worker-sync: sync-workers reach the same clock + virtual bool Update(int i) { + if (local_clock_[i]++ == global_clock_) { + ++progress_; + } + if (progress_ >= num_sync_worker_) { ++global_clock_; + progress_ = 0; + for (auto i : local_clock_) { + if (i > global_clock_) ++progress_; + } if (global_clock_ == *(std::max_element(std::begin(local_clock_), std::end(local_clock_)))) { return true; @@ -97,21 +222,27 @@ public: int local_clock(int i) const { return local_clock_[i]; } int global_clock() const { return global_clock_; } - private: + protected: std::vector local_clock_; int global_clock_; + int num_worker_; + int num_sync_worker_; + int progress_; }; protected: void ProcessAdd(MessagePtr& msg) override { // 1. Before add: cache faster worker int worker = Zoo::Get()->rank_to_worker_id(msg->src()); if (worker_get_clocks_->local_clock(worker) > - worker_get_clocks_->global_clock()) { + worker_get_clocks_->global_clock()) { msg_add_cache_.Push(msg); return; } // 2. Process Add - Server::ProcessAdd(msg); + if (worker_add_clocks_->local_clock(worker) >= + worker_add_clocks_->global_clock()) { + Server::ProcessAdd(msg); + } // 3. After add: process cached process get if necessary if (worker_add_clocks_->Update(worker)) { CHECK(msg_add_cache_.Empty()); @@ -129,7 +260,7 @@ protected: // 1. Before get: cache faster worker int worker = Zoo::Get()->rank_to_worker_id(msg->src()); if (worker_add_clocks_->local_clock(worker) > - worker_add_clocks_->global_clock()) { + worker_add_clocks_->global_clock()) { // Will wait for other worker finished Add msg_get_cache_.Push(msg); return; @@ -143,7 +274,10 @@ protected: MessagePtr add_msg; CHECK(msg_add_cache_.TryPop(add_msg)); int add_worker = Zoo::Get()->rank_to_worker_id(add_msg->src()); - Server::ProcessAdd(add_msg); + if (worker_add_clocks_->local_clock(add_worker) >= + worker_add_clocks_->global_clock()) { + Server::ProcessAdd(msg); + }; worker_add_clocks_->Update(add_worker); } } @@ -155,11 +289,21 @@ private: MtQueue msg_add_cache_; MtQueue msg_get_cache_; + + // num_worker_ - num_sync_worker_ = num_backup_worker_ + int num_sync_worker_; + int num_worker_; }; Server* Server::GetServer() { - if (MV_CONFIG_sync) return new SyncServer(); - return new Server(); + if (!MV_CONFIG_sync) { + Log::Info("Create a async server\n"); + return new Server(); + } + // if (MV_CONFIG_backup_worker_ratio > 0.0) { + Log::Info("Create a sync server\n"); + return new WithBackupSyncServer(); + // } } } // namespace multiverso diff --git a/src/table.cpp b/src/table.cpp index 1301213..b85ba1d 100644 --- a/src/table.cpp +++ b/src/table.cpp @@ -1,13 +1,17 @@ #include "multiverso/table_interface.h" + +#include + +#include "multiverso/dashboard.h" +#include "multiverso/updater/updater.h" #include "multiverso/util/log.h" #include "multiverso/util/waiter.h" #include "multiverso/zoo.h" -#include "multiverso/dashboard.h" -#include "multiverso/updater/updater.h" namespace multiverso { WorkerTable::WorkerTable() { + msg_id_ = 0; table_id_ = Zoo::Get()->RegisterTable(this); } @@ -30,8 +34,10 @@ void WorkerTable::Add(Blob keys, Blob values, } int WorkerTable::GetAsync(Blob keys) { + m_.lock(); int id = msg_id_++; - waitings_[id] = new Waiter(); + waitings_.push_back(new Waiter()); + m_.unlock(); MessagePtr msg(new Message()); msg->set_src(Zoo::Get()->rank()); msg->set_type(MsgType::Request_Get); @@ -44,8 +50,10 @@ int WorkerTable::GetAsync(Blob keys) { int WorkerTable::AddAsync(Blob keys, Blob values, const UpdateOption* option) { + m_.lock(); int id = msg_id_++; - waitings_[id] = new Waiter(); + waitings_.push_back(new Waiter()); + m_.unlock(); MessagePtr msg(new Message()); msg->set_src(Zoo::Get()->rank()); msg->set_type(MsgType::Request_Add); @@ -63,21 +71,32 @@ int WorkerTable::AddAsync(Blob keys, Blob values, } void WorkerTable::Wait(int id) { - CHECK(waitings_.find(id) != waitings_.end()); + // CHECK(waitings_.find(id) != waitings_.end()); + m_.lock(); CHECK(waitings_[id] != nullptr); - waitings_[id]->Wait(); + Waiter* w = waitings_[id]; + m_.unlock(); + + w->Wait(); + + m_.lock(); delete waitings_[id]; waitings_[id] = nullptr; + m_.unlock(); } void WorkerTable::Reset(int msg_id, int num_wait) { + m_.lock(); CHECK_NOTNULL(waitings_[msg_id]); waitings_[msg_id]->Reset(num_wait); + m_.unlock(); } void WorkerTable::Notify(int id) { + m_.lock(); CHECK_NOTNULL(waitings_[id]); waitings_[id]->Notify(); + m_.unlock(); } WorkerTable* TableHelper::CreateTable() {