do hint insertion after aggregation (#81)

This commit is contained in:
Tianqi Chen 2016-11-20 22:54:58 -08:00
Родитель 11883db571
Коммит da78c4c5a3
3 изменённых файлов: 19 добавлений и 18 удалений

Просмотреть файл

@ -10,7 +10,7 @@ endif
include $(config)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -fPIC
ifneq ($(ADD_CFLAGS), NONE)

Просмотреть файл

@ -20,15 +20,13 @@
# choice of compiler
#--------------------
export CC = gcc
export CXX = g++
export NVCC = nvcc
# the additional link flags you want to add
ADD_LDFLAGS =
ADD_LDFLAGS=
# the additional compile flags you want to add
ADD_CFLAGS =
ADD_CFLAGS=
#----------------------------
# plugins

Просмотреть файл

@ -38,6 +38,7 @@ struct GradEntry {
NodeEntry sum{nullptr, 0, 0};
#endif
std::vector<NodeEntry> grads;
bool need_attr_hint{true};
};
Graph Gradient(Graph src) {
@ -85,9 +86,6 @@ Graph Gradient(Graph src) {
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i];
if (attr_hint_fun != nullptr) {
ograd = attr_hint_fun(ograd, ys[i]);
}
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
}
@ -121,27 +119,29 @@ Graph Gradient(Graph src) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
for (GradEntry& e : output_grads.at(ptr.get())) {
auto& out_grad_vec = output_grads.at(ptr.get());
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
GradEntry& e = out_grad_vec[i];
e.sum = agg_fun(std::move(e.grads));
if (e.need_attr_hint && attr_hint_fun != nullptr) {
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
}
out_agg_grads.push_back(e.sum);
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
fwd_node, out_agg_grads);
if (attr_hint_fun != nullptr) {
// only insert hint when shape inference function is not available.
for (size_t i = 0; i < input_grads.size(); ++i) {
if (finfer_shape.count(input_grads[i].node->op())) continue;
input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]);
}
}
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
auto& ge = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
}
}
}
@ -153,6 +153,9 @@ Graph Gradient(Graph src) {
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
entry.sum = attr_hint_fun(entry.sum, e);
}
}
ret.outputs.emplace_back(std::move(entry.sum));
}