do hint insertion after aggregation (#81)
This commit is contained in:
Родитель
11883db571
Коммит
da78c4c5a3
|
@ -10,7 +10,7 @@ endif
|
|||
include $(config)
|
||||
|
||||
export LDFLAGS = -pthread -lm
|
||||
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
|
||||
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
|
||||
-Iinclude -fPIC
|
||||
|
||||
ifneq ($(ADD_CFLAGS), NONE)
|
||||
|
|
|
@ -20,15 +20,13 @@
|
|||
# choice of compiler
|
||||
#--------------------
|
||||
|
||||
export CC = gcc
|
||||
export CXX = g++
|
||||
export NVCC = nvcc
|
||||
|
||||
# the additional link flags you want to add
|
||||
ADD_LDFLAGS =
|
||||
ADD_LDFLAGS=
|
||||
|
||||
# the additional compile flags you want to add
|
||||
ADD_CFLAGS =
|
||||
ADD_CFLAGS=
|
||||
|
||||
#----------------------------
|
||||
# plugins
|
||||
|
|
|
@ -38,6 +38,7 @@ struct GradEntry {
|
|||
NodeEntry sum{nullptr, 0, 0};
|
||||
#endif
|
||||
std::vector<NodeEntry> grads;
|
||||
bool need_attr_hint{true};
|
||||
};
|
||||
|
||||
Graph Gradient(Graph src) {
|
||||
|
@ -85,9 +86,6 @@ Graph Gradient(Graph src) {
|
|||
CHECK_EQ(ys.size(), ys_out_grad.size());
|
||||
for (size_t i = 0; i < ys.size(); ++i) {
|
||||
NodeEntry ograd = ys_out_grad[i];
|
||||
if (attr_hint_fun != nullptr) {
|
||||
ograd = attr_hint_fun(ograd, ys[i]);
|
||||
}
|
||||
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
|
||||
}
|
||||
|
||||
|
@ -121,27 +119,29 @@ Graph Gradient(Graph src) {
|
|||
const NodePtr& ptr = *rit;
|
||||
if (ptr->is_variable()) continue;
|
||||
out_agg_grads.clear();
|
||||
for (GradEntry& e : output_grads.at(ptr.get())) {
|
||||
auto& out_grad_vec = output_grads.at(ptr.get());
|
||||
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
|
||||
GradEntry& e = out_grad_vec[i];
|
||||
e.sum = agg_fun(std::move(e.grads));
|
||||
if (e.need_attr_hint && attr_hint_fun != nullptr) {
|
||||
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
|
||||
}
|
||||
out_agg_grads.push_back(e.sum);
|
||||
}
|
||||
if ((*rit)->inputs.size() != 0) {
|
||||
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
|
||||
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
|
||||
fwd_node, out_agg_grads);
|
||||
|
||||
if (attr_hint_fun != nullptr) {
|
||||
// only insert hint when shape inference function is not available.
|
||||
for (size_t i = 0; i < input_grads.size(); ++i) {
|
||||
if (finfer_shape.count(input_grads[i].node->op())) continue;
|
||||
input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]);
|
||||
}
|
||||
}
|
||||
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
|
||||
<< "Gradient function not returning enough gradient";
|
||||
auto git = input_grads.begin();
|
||||
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
|
||||
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
|
||||
auto& ge = output_grads[it->node.get()][it->index];
|
||||
// if any of the backward op can do shape inference, the hint is not necessary.
|
||||
if (finfer_shape.count(git->node->op())) {
|
||||
ge.need_attr_hint = false;
|
||||
}
|
||||
ge.grads.emplace_back(std::move(*git));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -153,6 +153,9 @@ Graph Gradient(Graph src) {
|
|||
// aggregate sum if there haven't been
|
||||
if (entry.sum.node.get() == nullptr) {
|
||||
entry.sum = agg_fun(std::move(entry.grads));
|
||||
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
|
||||
entry.sum = attr_hint_fun(entry.sum, e);
|
||||
}
|
||||
}
|
||||
ret.outputs.emplace_back(std::move(entry.sum));
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче