diff --git a/nnvm/Makefile b/nnvm/Makefile index 48e03e0f..e73cba04 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -10,7 +10,7 @@ endif include $(config) export LDFLAGS = -pthread -lm -export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\ +export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\ -Iinclude -fPIC ifneq ($(ADD_CFLAGS), NONE) diff --git a/nnvm/make/config.mk b/nnvm/make/config.mk index 92dc67c8..b0989b42 100644 --- a/nnvm/make/config.mk +++ b/nnvm/make/config.mk @@ -20,15 +20,13 @@ # choice of compiler #-------------------- -export CC = gcc -export CXX = g++ export NVCC = nvcc # the additional link flags you want to add -ADD_LDFLAGS = +ADD_LDFLAGS= # the additional compile flags you want to add -ADD_CFLAGS = +ADD_CFLAGS= #---------------------------- # plugins diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 2318763d..80df9038 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -38,6 +38,7 @@ struct GradEntry { NodeEntry sum{nullptr, 0, 0}; #endif std::vector grads; + bool need_attr_hint{true}; }; Graph Gradient(Graph src) { @@ -85,9 +86,6 @@ Graph Gradient(Graph src) { CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { NodeEntry ograd = ys_out_grad[i]; - if (attr_hint_fun != nullptr) { - ograd = attr_hint_fun(ograd, ys[i]); - } output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; } @@ -121,27 +119,29 @@ Graph Gradient(Graph src) { const NodePtr& ptr = *rit; if (ptr->is_variable()) continue; out_agg_grads.clear(); - for (GradEntry& e : output_grads.at(ptr.get())) { + auto& out_grad_vec = output_grads.at(ptr.get()); + for (uint32_t i = 0; i < out_grad_vec.size(); ++i) { + GradEntry& e = out_grad_vec[i]; e.sum = agg_fun(std::move(e.grads)); + if (e.need_attr_hint && attr_hint_fun != nullptr) { + e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i}); + } out_agg_grads.push_back(e.sum); } if ((*rit)->inputs.size() != 0) { NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); std::vector input_grads = grad_fun_map[ptr->op()]( fwd_node, out_agg_grads); - - if (attr_hint_fun != nullptr) { - // only insert hint when shape inference function is not available. - for (size_t i = 0; i < input_grads.size(); ++i) { - if (finfer_shape.count(input_grads[i].node->op())) continue; - input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]); - } - } CHECK_EQ((*rit)->inputs.size(), input_grads.size()) << "Gradient function not returning enough gradient"; auto git = input_grads.begin(); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { - output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); + auto& ge = output_grads[it->node.get()][it->index]; + // if any of the backward op can do shape inference, the hint is not necessary. + if (finfer_shape.count(git->node->op())) { + ge.need_attr_hint = false; + } + ge.grads.emplace_back(std::move(*git)); } } } @@ -153,6 +153,9 @@ Graph Gradient(Graph src) { // aggregate sum if there haven't been if (entry.sum.node.get() == nullptr) { entry.sum = agg_fun(std::move(entry.grads)); + if (entry.need_attr_hint && attr_hint_fun != nullptr) { + entry.sum = attr_hint_fun(entry.sum, e); + } } ret.outputs.emplace_back(std::move(entry.sum)); }