This commit is contained in:
Qiwei Ye 2016-04-21 13:55:29 +08:00
Родитель 423d100daa 0c1f239a75
Коммит b856ef8d3a
9 изменённых файлов: 307 добавлений и 129 удалений

Просмотреть файл

@ -88,8 +88,6 @@ void TestArray(int argc, char* argv[]) {
int iter = 1000;
if (argc == 2) iter = atoi(argv[1]);
for (int i = 0; i < iter; ++i) {
// std::vector<float>& vec = shared_array->raw();

Просмотреть файл

@ -43,8 +43,8 @@ public:
*/
enum RecursiveHalvingNodeType {
Normal, //normal node, 1 group only have 1 machine
ReciveNeighbor, //leader of group when number of machines in this group is 2.
SendNeighbor// non-leader machines in group
GroupLeader, //leader of group when number of machines in this group is 2.
Other// non-leader machines in group
};
/*! \brief Network structure for recursive halving algorithm */

Просмотреть файл

@ -1,6 +1,7 @@
#ifndef MULTIVERSO_TABLE_INTERFACE_H_
#define MULTIVERSO_TABLE_INTERFACE_H_
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
@ -40,7 +41,8 @@ public:
private:
std::string table_name_;
int table_id_;
std::unordered_map<int, Waiter*> waitings_;
std::mutex m_;
std::vector<Waiter*> waitings_;
int msg_id_;
};

Просмотреть файл

@ -102,7 +102,12 @@ void SetCMDFlag(const std::string& name, const T& value) {
#define MV_DECLARE_bool(name) \
DECLARE_CONFIGURE(bool, name)
#define MV_DEFINE_double(name, default_value, text) \
DEFINE_CONFIGURE(double, name, default_value, text)
#define MV_DECLARE_double(name) \
DECLARE_CONFIGURE(double, name)
} // namespace multiverso
#endif // MULTIVERSO_UTIL_CONFIGURE_H_

Просмотреть файл

@ -1,12 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<ClInclude Include="..\include\multiverso\multiverso.h">
<Filter>include</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\table_interface.h">
<Filter>include</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\table\array_table.h">
<Filter>table</Filter>
</ClInclude>
@ -112,11 +106,14 @@
<ClInclude Include="..\include\multiverso\io\local_stream.h">
<Filter>io</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\multiverso.h">
<Filter>system</Filter>
</ClInclude>
<ClInclude Include="..\include\multiverso\table_interface.h">
<Filter>system</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="include">
<UniqueIdentifier>{befa27d3-15da-409a-a02e-fc1d9e676f80}</UniqueIdentifier>
</Filter>
<Filter Include="system">
<UniqueIdentifier>{f1d6488a-2e66-4610-8676-304c070d7b49}</UniqueIdentifier>
</Filter>

Просмотреть файл

@ -123,11 +123,11 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size,
bool is_powerof_2 = (num_machines_ & (num_machines_ - 1)) == 0 ? true : false;
if (!is_powerof_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::SendNeighbor) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
//send local data to neighbor first
linkers_->Send(recursive_halving_map_.neighbor, input, 0, input_size);
}
else if (recursive_halving_map_.type == RecursiveHalvingNodeType::ReciveNeighbor) {
else if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
//recieve neighbor data first
int need_recv_cnt = input_size;
linkers_->Receive(recursive_halving_map_.neighbor, output, 0, need_recv_cnt);
@ -135,7 +135,7 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size,
}
}
//start recursive halfing
if (recursive_halving_map_.type != RecursiveHalvingNodeType::SendNeighbor) {
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
int target = recursive_halving_map_.ranks[i];
@ -160,11 +160,11 @@ void AllreduceEngine::ReduceScatter(char* input, int input_size, int type_size,
int my_reduce_block_idx = rank_;
if (!is_powerof_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::ReciveNeighbor) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
//send result to neighbor
linkers_->Send(recursive_halving_map_.neighbor, input, block_start[recursive_halving_map_.neighbor], block_len[recursive_halving_map_.neighbor]);
}
else if (recursive_halving_map_.type == RecursiveHalvingNodeType::SendNeighbor) {
else if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
//receive result from neighbor
int need_recv_cnt = block_len[my_reduce_block_idx];
linkers_->Receive(recursive_halving_map_.neighbor, output, 0, need_recv_cnt);

Просмотреть файл

@ -5,13 +5,13 @@
namespace multiverso {
BruckMap::BruckMap() {
k = 0;
}
BruckMap::BruckMap(int n) {
k = n;
// default set to -1
for (int i = 0; i < n; ++i) {
in_ranks.push_back(-1);
out_ranks.push_back(-1);
@ -19,19 +19,21 @@ BruckMap::BruckMap(int n) {
}
BruckMap BruckMap::Construct(int rank, int num_machines) {
int* dist = new int[num_machines];
// distance at k-th communication, distance[k] = 2^k
std::vector<int> distance;
int k = 0;
for (k = 0; (1 << k) < num_machines; k++) {
dist[k] = 1 << k;
distance.push_back(1 << k);
}
BruckMap bruckMap(k);
for (int j = 0; j < k; ++j) {
int ni = (rank + dist[j]) % num_machines;
bruckMap.in_ranks[j] = ni;
ni = (rank - dist[j] + num_machines) % num_machines;
bruckMap.out_ranks[j] = ni;
// set incoming rank at k-th commuication
const int in_rank = (rank + distance[j]) % num_machines;
bruckMap.in_ranks[j] = in_rank;
// set outgoing rank at k-th commuication
const int out_rank = (rank - distance[j] + num_machines) % num_machines;
bruckMap.out_ranks[j] = out_rank;
}
delete[] dist;
return bruckMap;
}
@ -41,9 +43,10 @@ RecursiveHalvingMap::RecursiveHalvingMap() {
}
RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n) {
type = _type;
if (type != RecursiveHalvingNodeType::SendNeighbor) {
k = n;
k = n;
if (type != RecursiveHalvingNodeType::Other) {
for (int i = 0; i < n; ++i) {
// defalut set as -1
ranks.push_back(-1);
send_block_start.push_back(-1);
send_block_len.push_back(-1);
@ -52,107 +55,117 @@ RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n)
}
}
}
RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
std::vector<RecursiveHalvingMap> rec_maps;
for (int i = 0; i < num_machines; ++i) {
rec_maps.push_back(RecursiveHalvingMap());
}
int* distance = new int[num_machines];
RecursiveHalvingNodeType* node_type = new RecursiveHalvingNodeType[num_machines];
// construct all recursive halving map for all machines
int k = 0;
for (k = 0; (1 << k) < num_machines; k++) {
distance[k] = 1 << k;
while ((1 << k) <= num_machines) { ++k; }
// let 1 << k <= num_machines
--k;
// distance of each communication
std::vector<int> distance;
for (int i = 0; i < k; ++i) {
distance.push_back(1 << (k - 1 - i));
}
if ((1 << k) == num_machines) {
RecursiveHalvingMap rec_map(RecursiveHalvingNodeType::Normal, k);
// if num_machines = 2^k, don't need to group machines
for (int i = 0; i < k; ++i) {
distance[i] = 1 << (k - 1 - i);
}
for (int i = 0; i < num_machines; ++i) {
rec_maps[i] = RecursiveHalvingMap(RecursiveHalvingNodeType::Normal, k);
for (int j = 0; j < k; ++j) {
int dir = ((i / distance[j]) % 2 == 0) ? 1 : -1;
int ni = i + dir * distance[j];
rec_maps[i].ranks[j] = ni;
int t = i / distance[j];
rec_maps[i].recv_block_start[j] = t * distance[j];
rec_maps[i].recv_block_len[j] = distance[j];
}
// communication direction, %2 == 0 is positive
const int dir = ((rank / distance[i]) % 2 == 0) ? 1 : -1;
// neighbor at k-th communication
const int next_node_idx = rank + dir * distance[i];
rec_map.ranks[i] = next_node_idx;
// receive data block at k-th communication
const int recv_block_start = rank / distance[i];
rec_map.recv_block_start[i] = recv_block_start * distance[i];
rec_map.recv_block_len[i] = distance[i];
// send data block at k-th communication
const int send_block_start = next_node_idx / distance[i];
rec_map.send_block_start[i] = send_block_start * distance[i];
rec_map.send_block_len[i] = distance[i];
}
return rec_map;
}
else {
k--;
// if num_machines != 2^k, need to group machines
int lower_power_of_2 = 1 << k;
int rest = num_machines - (1 << k);
int rest = num_machines - lower_power_of_2;
std::vector<RecursiveHalvingNodeType> node_type;
for (int i = 0; i < num_machines; ++i) {
node_type[i] = RecursiveHalvingNodeType::Normal;
node_type.push_back(RecursiveHalvingNodeType::Normal);
}
// group, two machine in one group, total "rest" groups will have 2 machines.
for (int i = 0; i < rest; ++i) {
int r = num_machines - i * 2 - 1;
int l = num_machines - i * 2 - 2;
node_type[l] = RecursiveHalvingNodeType::ReciveNeighbor;
node_type[r] = RecursiveHalvingNodeType::SendNeighbor;
}
for (int i = 0; i < k; ++i) {
distance[i] = 1 << (k - 1 - i);
int right = num_machines - i * 2 - 1;
int left = num_machines - i * 2 - 2;
// let left machine as group leader
node_type[left] = RecursiveHalvingNodeType::GroupLeader;
node_type[right] = RecursiveHalvingNodeType::Other;
}
int group_cnt = 0;
// cache block information for groups, group with 2 machines will have double block size
std::vector<int> group_block_start(lower_power_of_2);
std::vector<int> group_block_len(lower_power_of_2, 0);
// convert from group to node leader
std::vector<int> group_to_node(lower_power_of_2);
// convert from node to group
std::vector<int> node_to_group(num_machines);
int group_idx = 0;
int* map_len = new int[lower_power_of_2];
int* map_start = new int[lower_power_of_2];
int* group_2_node = new int[lower_power_of_2];
int* node_to_group = new int[num_machines];
for (int i = 0; i < lower_power_of_2; ++i) {
map_len[i] = 0;
}
for (int i = 0; i < num_machines; ++i) {
if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::ReciveNeighbor) {
group_2_node[group_idx++] = i;
// meet new group
if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::GroupLeader) {
group_to_node[group_cnt++] = i;
}
map_len[group_idx - 1]++;
node_to_group[i] = group_idx - 1;
node_to_group[i] = group_cnt - 1;
// add block len for this group
group_block_len[group_cnt - 1]++;
}
map_start[0] = 0;
// calculate the group block start
group_block_start[0] = 0;
for (int i = 1; i < lower_power_of_2; ++i) {
map_start[i] = map_start[i - 1] + map_len[i - 1];
group_block_start[i] = group_block_start[i - 1] + group_block_len[i - 1];
}
for (int i = 0; i < num_machines; ++i) {
if (node_type[i] == RecursiveHalvingNodeType::SendNeighbor) {
rec_maps[i] = RecursiveHalvingMap(RecursiveHalvingNodeType::SendNeighbor, k);
rec_maps[i].neighbor = i - 1;
continue;
}
rec_maps[i] = RecursiveHalvingMap(node_type[i], k);
if (node_type[i] == RecursiveHalvingNodeType::ReciveNeighbor) {
rec_maps[i].neighbor = i + 1;
}
for (int j = 0; j < k; ++j) {
int dir = ((node_to_group[i] / distance[j]) % 2 == 0) ? 1 : -1;
group_idx = group_2_node[(node_to_group[i] + dir * distance[j])];
rec_maps[i].ranks[j] = group_idx;
int t = node_to_group[i] / distance[j];
rec_maps[i].recv_block_start[j] = map_start[t * distance[j]];
int tl = 0;
for (int tmp_i = 0; tmp_i < distance[j]; ++tmp_i) {
tl += map_len[t * distance[j] + tmp_i];
}
rec_maps[i].recv_block_len[j] = tl;
}
RecursiveHalvingMap rec_map(node_type[rank], k);
if (node_type[rank] == RecursiveHalvingNodeType::Other) {
rec_map.neighbor = rank - 1;
// not need to construct
return rec_map;
}
}
for (int i = 0; i < num_machines; ++i) {
if (rec_maps[i].type != RecursiveHalvingNodeType::SendNeighbor) {
for (int j = 0; j < k; ++j) {
int target = rec_maps[i].ranks[j];
rec_maps[i].send_block_start[j] = rec_maps[target].recv_block_start[j];
rec_maps[i].send_block_len[j] = rec_maps[target].recv_block_len[j];
}
if (node_type[rank] == RecursiveHalvingNodeType::GroupLeader) {
rec_map.neighbor = rank + 1;
}
const int cur_group_idx = node_to_group[rank];
for (int i = 0; i < k; ++i) {
const int dir = ((cur_group_idx / distance[i]) % 2 == 0) ? 1 : -1;
const int next_node_idx = group_to_node[(cur_group_idx + dir * distance[i])];
rec_map.ranks[i] = next_node_idx;
// get receive block informations
const int recv_block_start = cur_group_idx / distance[i];
rec_map.recv_block_start[i] = group_block_start[recv_block_start * distance[i]];
int recv_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
recv_block_len += group_block_len[recv_block_start * distance[i] + j];
}
rec_map.recv_block_len[i] = recv_block_len;
// get send block informations
const int send_block_start = (cur_group_idx + dir * distance[i]) / distance[i];
rec_map.send_block_start[i] = group_block_start[send_block_start * distance[i]];
int send_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
send_block_len += group_block_len[send_block_start * distance[i] + j];
}
rec_map.send_block_len[i] = send_block_len;
}
return rec_map;
}
return rec_maps[rank];
}
}

Просмотреть файл

@ -18,6 +18,7 @@
namespace multiverso {
MV_DEFINE_bool(sync, false, "sync or async");
MV_DEFINE_int(backup_worker_ratio, 0, "ratio% of backup workers, set 20 means 20%");
Server::Server() : Actor(actor::kServer) {
RegisterHandler(MsgType::Request_Get, std::bind(
@ -60,25 +61,149 @@ void Server::ProcessAdd(MessagePtr& msg) {
// If worker k has add delta to server j times when its i-th Get
// then the server will return the parameter after all K
// workers finished their j-th update
class SyncServer : public Server {
// TODO(feiga): to delete this, SyncServer is a special case for
// BackupWorkerSyncServer
//class SyncServer : public Server {
//public:
// SyncServer() : Server() {
// int num_worker = Zoo::Get()->num_workers();
// worker_get_clocks_.reset(new VectorClock(num_worker));
// worker_add_clocks_.reset(new VectorClock(num_worker));
// }
//
// // make some modification to suit to the sync server
// // please not use in other place, may different with the general vector clock
// class VectorClock {
// public:
// explicit VectorClock(int n) :
// local_clock_(n, 0), global_clock_(0), size_(0) {}
//
// // Return true when all clock reach a same number
// virtual bool Update(int i) {
// ++local_clock_[i];
// if (global_clock_ < *(std::min_element(std::begin(local_clock_),
// std::end(local_clock_)))) {
// ++global_clock_;
// if (global_clock_ == *(std::max_element(std::begin(local_clock_),
// std::end(local_clock_)))) {
// return true;
// }
// }
// return false;
// }
//
// std::string DebugString() {
// std::string os = "global ";
// os += std::to_string(global_clock_) + " local: ";
// for (auto i : local_clock_) os += std::to_string(i) + " ";
// return os;
// }
//
// int local_clock(int i) const { return local_clock_[i]; }
// int global_clock() const { return global_clock_; }
//
// protected:
// std::vector<int> local_clock_;
// int global_clock_;
// int size_;
// };
//protected:
// void ProcessAdd(MessagePtr& msg) override {
// // 1. Before add: cache faster worker
// int worker = Zoo::Get()->rank_to_worker_id(msg->src());
// if (worker_get_clocks_->local_clock(worker) >
// worker_get_clocks_->global_clock()) {
// msg_add_cache_.Push(msg);
// return;
// }
// // 2. Process Add
// Server::ProcessAdd(msg);
// // 3. After add: process cached process get if necessary
// if (worker_add_clocks_->Update(worker)) {
// CHECK(msg_add_cache_.Empty());
// while (!msg_get_cache_.Empty()) {
// MessagePtr get_msg;
// CHECK(msg_get_cache_.TryPop(get_msg));
// int get_worker = Zoo::Get()->rank_to_worker_id(get_msg->src());
// Server::ProcessGet(get_msg);
// worker_get_clocks_->Update(get_worker);
// }
// }
// }
//
// void ProcessGet(MessagePtr& msg) override {
// // 1. Before get: cache faster worker
// int worker = Zoo::Get()->rank_to_worker_id(msg->src());
// if (worker_add_clocks_->local_clock(worker) >
// worker_add_clocks_->global_clock()) {
// // Will wait for other worker finished Add
// msg_get_cache_.Push(msg);
// return;
// }
// // 2. Process Get
// Server::ProcessGet(msg);
// // 3. After get: process cached process add if necessary
// if (worker_get_clocks_->Update(worker)) {
// CHECK(msg_get_cache_.Empty());
// while (!msg_add_cache_.Empty()) {
// MessagePtr add_msg;
// CHECK(msg_add_cache_.TryPop(add_msg));
// int add_worker = Zoo::Get()->rank_to_worker_id(add_msg->src());
// Server::ProcessAdd(add_msg);
// worker_add_clocks_->Update(add_worker);
// }
// }
// }
//
//private:
// std::unique_ptr<VectorClock> worker_get_clocks_;
// std::unique_ptr<VectorClock> worker_add_clocks_;
//
// MtQueue<MessagePtr> msg_add_cache_;
// MtQueue<MessagePtr> msg_get_cache_;
//};
class WithBackupSyncServer : public Server {
public:
SyncServer() : Server() {
int num_worker = Zoo::Get()->num_workers();
worker_get_clocks_.reset(new VectorClock(num_worker));
worker_add_clocks_.reset(new VectorClock(num_worker));
WithBackupSyncServer() : Server() {
num_worker_ = Zoo::Get()->num_workers();
double backup_ratio = (double)MV_CONFIG_backup_worker_ratio / 100;
num_sync_worker_ = num_worker_ -
static_cast<int>(backup_ratio * num_worker_);
CHECK(num_sync_worker_ > 0 && num_sync_worker_ <= num_worker_);
if (num_sync_worker_ == num_worker_) {
Log::Info("No backup worker, using the sync mode\n");
}
Log::Info("Sync with backup worker start: num_sync_worker = %d,"
"num_total_worker = %d\n", num_sync_worker_, num_worker_);
worker_get_clocks_.reset(new VectorClock(num_worker_));
worker_add_clocks_.reset(new VectorClock(
num_worker_, num_worker_ - num_sync_worker_));
}
// make some modification to suit to the sync server
// please not use in other place, may different with the general vector clock
class VectorClock {
public:
explicit VectorClock(int n) : local_clock_(n, 0), global_clock_(0) {}
VectorClock(int num_worker, int num_backup_worker = 0) :
local_clock_(num_worker, 0), global_clock_(0), num_worker_(num_worker),
num_sync_worker_(num_worker - num_backup_worker), progress_(0) {}
bool Update(int i) {
++local_clock_[i];
if (global_clock_ < *(std::min_element(std::begin(local_clock_),
std::end(local_clock_)))) {
// Return true when global clock meet the sync condition
// sync: all worker reach the same clock
// backup-worker-sync: sync-workers reach the same clock
virtual bool Update(int i) {
if (local_clock_[i]++ == global_clock_) {
++progress_;
}
if (progress_ >= num_sync_worker_) {
++global_clock_;
progress_ = 0;
for (auto i : local_clock_) {
if (i > global_clock_) ++progress_;
}
if (global_clock_ == *(std::max_element(std::begin(local_clock_),
std::end(local_clock_)))) {
return true;
@ -97,21 +222,27 @@ public:
int local_clock(int i) const { return local_clock_[i]; }
int global_clock() const { return global_clock_; }
private:
protected:
std::vector<int> local_clock_;
int global_clock_;
int num_worker_;
int num_sync_worker_;
int progress_;
};
protected:
void ProcessAdd(MessagePtr& msg) override {
// 1. Before add: cache faster worker
int worker = Zoo::Get()->rank_to_worker_id(msg->src());
if (worker_get_clocks_->local_clock(worker) >
worker_get_clocks_->global_clock()) {
worker_get_clocks_->global_clock()) {
msg_add_cache_.Push(msg);
return;
}
// 2. Process Add
Server::ProcessAdd(msg);
if (worker_add_clocks_->local_clock(worker) >=
worker_add_clocks_->global_clock()) {
Server::ProcessAdd(msg);
}
// 3. After add: process cached process get if necessary
if (worker_add_clocks_->Update(worker)) {
CHECK(msg_add_cache_.Empty());
@ -129,7 +260,7 @@ protected:
// 1. Before get: cache faster worker
int worker = Zoo::Get()->rank_to_worker_id(msg->src());
if (worker_add_clocks_->local_clock(worker) >
worker_add_clocks_->global_clock()) {
worker_add_clocks_->global_clock()) {
// Will wait for other worker finished Add
msg_get_cache_.Push(msg);
return;
@ -143,7 +274,10 @@ protected:
MessagePtr add_msg;
CHECK(msg_add_cache_.TryPop(add_msg));
int add_worker = Zoo::Get()->rank_to_worker_id(add_msg->src());
Server::ProcessAdd(add_msg);
if (worker_add_clocks_->local_clock(add_worker) >=
worker_add_clocks_->global_clock()) {
Server::ProcessAdd(msg);
};
worker_add_clocks_->Update(add_worker);
}
}
@ -155,11 +289,21 @@ private:
MtQueue<MessagePtr> msg_add_cache_;
MtQueue<MessagePtr> msg_get_cache_;
// num_worker_ - num_sync_worker_ = num_backup_worker_
int num_sync_worker_;
int num_worker_;
};
Server* Server::GetServer() {
if (MV_CONFIG_sync) return new SyncServer();
return new Server();
if (!MV_CONFIG_sync) {
Log::Info("Create a async server\n");
return new Server();
}
// if (MV_CONFIG_backup_worker_ratio > 0.0) {
Log::Info("Create a sync server\n");
return new WithBackupSyncServer();
// }
}
} // namespace multiverso

Просмотреть файл

@ -1,13 +1,17 @@
#include "multiverso/table_interface.h"
#include <mutex>
#include "multiverso/dashboard.h"
#include "multiverso/updater/updater.h"
#include "multiverso/util/log.h"
#include "multiverso/util/waiter.h"
#include "multiverso/zoo.h"
#include "multiverso/dashboard.h"
#include "multiverso/updater/updater.h"
namespace multiverso {
WorkerTable::WorkerTable() {
msg_id_ = 0;
table_id_ = Zoo::Get()->RegisterTable(this);
}
@ -30,8 +34,10 @@ void WorkerTable::Add(Blob keys, Blob values,
}
int WorkerTable::GetAsync(Blob keys) {
m_.lock();
int id = msg_id_++;
waitings_[id] = new Waiter();
waitings_.push_back(new Waiter());
m_.unlock();
MessagePtr msg(new Message());
msg->set_src(Zoo::Get()->rank());
msg->set_type(MsgType::Request_Get);
@ -44,8 +50,10 @@ int WorkerTable::GetAsync(Blob keys) {
int WorkerTable::AddAsync(Blob keys, Blob values,
const UpdateOption* option) {
m_.lock();
int id = msg_id_++;
waitings_[id] = new Waiter();
waitings_.push_back(new Waiter());
m_.unlock();
MessagePtr msg(new Message());
msg->set_src(Zoo::Get()->rank());
msg->set_type(MsgType::Request_Add);
@ -63,21 +71,32 @@ int WorkerTable::AddAsync(Blob keys, Blob values,
}
void WorkerTable::Wait(int id) {
CHECK(waitings_.find(id) != waitings_.end());
// CHECK(waitings_.find(id) != waitings_.end());
m_.lock();
CHECK(waitings_[id] != nullptr);
waitings_[id]->Wait();
Waiter* w = waitings_[id];
m_.unlock();
w->Wait();
m_.lock();
delete waitings_[id];
waitings_[id] = nullptr;
m_.unlock();
}
void WorkerTable::Reset(int msg_id, int num_wait) {
m_.lock();
CHECK_NOTNULL(waitings_[msg_id]);
waitings_[msg_id]->Reset(num_wait);
m_.unlock();
}
void WorkerTable::Notify(int id) {
m_.lock();
CHECK_NOTNULL(waitings_[id]);
waitings_[id]->Notify();
m_.unlock();
}
WorkerTable* TableHelper::CreateTable() {