* [PASS] FoldScaleAxis

* Move FoldAxis to O3

* Set unroll to 0 when ready
This commit is contained in:
Tianqi Chen 2017-09-28 22:33:19 -07:00
Родитель 3f599a60be
Коммит d25138e6c9
17 изменённых файлов: 457 добавлений и 43 удалений

Просмотреть файл

@ -98,6 +98,8 @@ class IndexedGraph {
array_view<NodeEntry> inputs; array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */ /*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps; array_view<uint32_t> control_deps;
/*! \brief weak reference to node */
std::weak_ptr<nnvm::Node> weak_ref;
}; };
/*! \return number of nodes in the graph */ /*! \return number of nodes in the graph */
inline size_t num_nodes() const { inline size_t num_nodes() const {

Просмотреть файл

@ -11,7 +11,8 @@ from .. import graph as _graph
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
"PrecomputePrune": 2, "PrecomputePrune": 2,
"OpFusion": 1 "OpFusion": 1,
"FoldScaleAxis": 3
} }
# List of optimization pass and level when switch on # List of optimization pass and level when switch on
@ -144,6 +145,10 @@ def optimize(graph, shape, dtype="float32"):
if cfg.pass_enabled("SimplifyInference"): if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"]) graph = graph.apply(["InferShape", "SimplifyInference"])
if cfg.pass_enabled("FoldScaleAxis"):
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "FoldScaleAxis"])
return graph return graph
@ -291,5 +296,6 @@ def precompute_prune(graph, params):
out_names = pre_graph.json_attr("output_names") out_names = pre_graph.json_attr("output_names")
if not pre_graph.symbol.list_output_names(): if not pre_graph.symbol.list_output_names():
return graph, params return graph, params
out_arrs = _run_graph(pre_graph, params) with tvm.build_config(auto_unroll_max_step=0):
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs)) return graph, dict(zip(out_names, out_arrs))

Просмотреть файл

@ -81,7 +81,6 @@ class Xavier(Initializer):
self.factor_type = factor_type self.factor_type = factor_type
self.magnitude = float(magnitude) self.magnitude = float(magnitude)
def _init_weight(self, name, arr): def _init_weight(self, name, arr):
shape = arr.shape shape = arr.shape
hw_scale = 1. hw_scale = 1.

Просмотреть файл

@ -30,7 +30,8 @@ def separable_conv_block(data, name, depthwise_channels,
# depthwise convolution + bn + relu # depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels, conv1 = sym.conv2d(data=data, channels=depthwise_channels,
groups=depthwise_channels, kernel_size=kernel_size, strides=strides, groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout="NCHW", name=name + "_depthwise_conv1") padding=padding, use_bias=False, layout="NCHW",
name=name + "_depthwise_conv1")
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1") bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
act1 = sym.relu(data=bn1, name=name + "_relu1") act1 = sym.relu(data=bn1, name=name + "_relu1")
# pointwise convolution + bn + relu # pointwise convolution + bn + relu

Просмотреть файл

@ -46,7 +46,7 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224),
input_shapes, _ = graph_util.infer_shape(g, data=data_shape) input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes)) shape_dict = dict(zip(g.index.input_names, input_shapes))
np.random.seed(seed) np.random.seed(seed)
initializer = initializer if initializer else Xavier(magnitude=3) initializer = initializer if initializer else Xavier()
for k, v in shape_dict.items(): for k, v in shape_dict.items():
if k == "data": if k == "data":
continue continue

Просмотреть файл

@ -7,8 +7,7 @@ The following components are operator invariant.
- core: NNVM core data structure - core: NNVM core data structure
- pass: NNVM pass - pass: NNVM pass
The following components are generic graph compiler for NNVM-TOP The following components are generic NNVM compiler and defines tensor operator set
- top: NNVM-TOP core operator defs - top: NNVM core tensor operators
- tvm: NNVM-TOP to TVM compiler toolchain - compiler: NNVM compiler toolchain
- runtime: NNVM-TOP runtime

Просмотреть файл

@ -58,7 +58,7 @@ class CompileEngine {
return it->second->graph_func; return it->second->graph_func;
} }
GraphFunc f = DoLower(key->graph, key->inputs, key->target, GraphFunc f = DoLower(key->graph, key->inputs, key->target,
schedule_op_key, schedule_op_attr); schedule_op_key, schedule_op_attr);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
n->graph_func = f; n->graph_func = f;
n->use_count = 1; n->use_count = 1;

Просмотреть файл

@ -0,0 +1,271 @@
/*!
* Copyright (c) 2017 by Contributors
* \file fold_scale_axis.cc
* \author Fold scaling parameter of axis into weight of conv/dense
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./pattern_util.h"
#include "./graph_transform.h"
namespace nnvm {
namespace compiler {
enum FoldScaleKind {
// No folding is applied
kNone,
// The folding decision is pending
kPending,
// The original operator that contains the scale.
kProvider,
// Pass through the scale to parent/child to the first axis.
kPassTroughFirst,
// The final conumer of axis scale using multiply
// Likely be a conv or dense operator.
kMulConsumer,
// The final conumer of axis scale using division
kDivConsumer
};
// Input fold information
struct FoldScaleInput {
uint32_t index;
int axis;
};
// The entry of folding chains on which
// we should perform folding on
struct FoldChainEntry {
// Entry kind
FoldScaleKind kind{kNone};
// The output axis to be folded
int axis{0};
// Source node in the fold chain
int source{0};
// Following field only used by provider.
// The input index
int fold_input_index{1};
// The scale entry
NodeEntry scale_entry;
};
// Try to pass axis scaling to backward,
// Given that we we know the status of current fold axis.
using FScaleAxisBackward = std::function<
FoldScaleKind(const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis)>;
// Detect if there is a scaling axis happening
bool DetectScaleAxis(const IndexedGraph& idx,
uint32_t nid,
const ShapeVector& shape_vec,
const std::vector<uint32_t>& ref_count,
bool is_forward,
std::vector<FoldChainEntry>* chain) {
const IndexedGraph::Node& inode = idx[nid];
static const Op* bcast_mul = Op::Get("broadcast_mul");
static const Op* expand_dims = Op::Get("expand_dims");
if (inode.source->op() != bcast_mul) return false;
const TShape& oshape = shape_vec[idx.entry_id(nid, 0)];
CHECK_NE(oshape.ndim(), 0);
if (oshape.ndim() <= 1) return false;
for (int i = 0; i < 2; ++i) {
const IndexedGraph::NodeEntry& a = inode.inputs[i];
const IndexedGraph::NodeEntry& b = inode.inputs[1 - i];
std::pair<int, int> axis =
MatchBroadcast1DAxis(oshape, shape_vec[idx.entry_id(a)]);
if (axis.first != -1 &&
shape_vec[idx.entry_id(b)] == oshape) {
if (ref_count[a.node_id] != 1) return false;
if (is_forward && ref_count[nid] != 1) return false;
if (!is_forward && ref_count[b.node_id] != 1) return false;
const IndexedGraph::Node& anode = idx[a.node_id];
// mark the current entry.
FoldChainEntry& e = (*chain)[nid];
if (anode.source->is_variable()) {
e.fold_input_index = 1 - i;
e.scale_entry = inode.source->inputs[1 - i];
} else if (anode.source->op() == expand_dims &&
shape_vec[idx.entry_id(anode.source->inputs[0])].ndim() == 1) {
e.fold_input_index = 1 - i;
e.scale_entry = anode.source->inputs[0];
} else {
return false;
}
e.axis = axis.first;
e.kind = kPending;
e.source = nid;
if (!is_forward) {
// pass message to another input
FoldChainEntry& enext = (*chain)[b.node_id];
enext.axis = e.axis;
enext.kind = kPending;
enext.source = nid;
}
return true;
}
}
return false;
}
Graph FoldScaleAxis(Graph src) {
// Operator pattern
static auto& fbackward =
nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
// shape hint for the inference.
std::vector<TShape> in_shape, out_shape;
// perform backward folding.
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1;
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (DetectScaleAxis(idx, nid, shape_vec,
ref_count, false, &bwd_chain)) continue;
if (bwd_chain[nid].kind != kPending) continue;
if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
bwd_chain[nid].kind = kNone; continue;
}
// get input shape and output shape.
in_shape.clear(); out_shape.clear();
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
in_shape.push_back(shape_vec[idx.entry_id(e)]);
}
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
}
std::vector<std::pair<uint32_t, int> > in_axis;
FoldScaleKind kind =
fbackward[inode.source->op()](
inode.source->attrs, bwd_chain[nid].axis,
in_shape, out_shape, &in_axis);
bwd_chain[nid].kind = kind;
if (kind == kNone) continue;
CHECK_GE(in_axis.size(), 1U);
CHECK(kind == kPassTroughFirst || kMulConsumer);
// propagate back.
bool can_prop = true;
for (size_t i = 0; i < in_axis.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first];
if (ref_count[e.node_id] != 1 ||
idx[e.node_id].source->num_outputs() != 1) {
can_prop = false; break;
}
}
if (!can_prop) continue;
for (size_t i = 0; i < in_axis.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first];
if (kind == kPassTroughFirst && i == 0) {
bwd_chain[e.node_id].kind = kPending;
} else {
bwd_chain[nid].kind = kNone;
bwd_chain[e.node_id].kind = kMulConsumer;
}
bwd_chain[e.node_id].axis = in_axis[i].second;
bwd_chain[e.node_id].source = bwd_chain[nid].source;
}
if (kind == kMulConsumer) {
bwd_chain[bwd_chain[nid].source].kind = kProvider;
}
}
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
const FoldChainEntry& e = bwd_chain[nid];
if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) {
const FoldChainEntry& se = bwd_chain[e.source];
CHECK_EQ(n->num_outputs(), 1);
NodeEntry scale = ExpandBiasToMatchAxis(
se.scale_entry,
shape_vec[idx.entry_id(nid, 0)].ndim(),
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
e.axis);
*ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc",
{NodeEntry{n, 0, 0}, scale})};
return true;
} else if (e.kind == kProvider) {
*ret = {n->inputs[e.fold_input_index]};
return true;
} else {
return false;
}
};
return GraphTransform(src, transform);
}
NNVM_REGISTER_PASS(FoldScaleAxis)
.set_body(FoldScaleAxis);
// property registration.
FoldScaleKind ReluScaleAxisBackward(
const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) {
in_axis->emplace_back(0, axis);
return kPassTroughFirst;
}
NNVM_REGISTER_OP(relu)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
NNVM_REGISTER_OP(leaky_relu)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
FoldScaleKind BroadcastAddSubScaleAxisBackward(
const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) {
for (int i = 0; i < 2; ++i) {
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]);
if (m.second != -1 && in_shape[1 - i] == out_shape[0]) {
in_axis->emplace_back(i, axis);
in_axis->emplace_back(1 - i, m.second);
return kPassTroughFirst;
}
}
return kNone;
}
NNVM_REGISTER_OP(broadcast_add)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
NNVM_REGISTER_OP(broadcast_sub)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
FoldScaleKind Conv2DScaleAxisBackward(
const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) {
using top::Conv2DParam;
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
// only optimize for nchw for now
if (param.layout == top::kNCHW) {
in_axis->emplace_back(1, 0);
if (param.use_bias) {
in_axis->emplace_back(2, 0);
}
return kMulConsumer;
} else {
return kNone;
}
}
NNVM_REGISTER_OP(conv2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);
} // namespace compiler
} // namespace nnvm

Просмотреть файл

@ -16,6 +16,7 @@
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include "./compile_engine.h" #include "./compile_engine.h"
#include "./graph_runtime.h" #include "./graph_runtime.h"
#include "./pattern_util.h"
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
@ -56,17 +57,10 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
// Reference counter of each op node // Reference counter of each op node
// For now, always store result when an op is referred more than once. // For now, always store result when an op is referred more than once.
std::vector<uint32_t> ref_count(idx.num_nodes(), 0); std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
for (const auto& e : inode.inputs) {
++ref_count[e.node_id];
}
}
for (const auto& e : idx.outputs()) { for (const auto& e : idx.outputs()) {
// this line will realize all the outputs // this line will realize all the outputs
ref_count[e.node_id] += 2; ref_count[e.node_id] += 1;
} }
// Pattern for the subgraph // Pattern for the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque); std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);

Просмотреть файл

@ -20,7 +20,7 @@ namespace compiler {
* *
* \param graph The original graph * \param graph The original graph
* *
* \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool * \param ftransform Function of (int nid, const NodePtr& node, std::vector<NodeEntry>* out) -> bool
* *
* If empty vector is returned, it means original entries should be kept. * If empty vector is returned, it means original entries should be kept.
* *
@ -36,7 +36,6 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
// setup inputs and placeholder. // setup inputs and placeholder.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
bool need_copy = false; bool need_copy = false;
for (const IndexedGraph::NodeEntry& e : inode.inputs) { for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (updated[idx.entry_id(e)]) { if (updated[idx.entry_id(e)]) {
@ -57,7 +56,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
if (!need_copy) { if (!need_copy) {
std::vector<NodeEntry> ret; std::vector<NodeEntry> ret;
if (ftransform(nid, inode.source, &ret)) { if (ftransform(nid, inode.weak_ref.lock(), &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true; updated[idx.entry_id(nid, i)] = true;
@ -93,7 +92,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
node->control_deps.push_back(selected_ptr); node->control_deps.push_back(selected_ptr);
} }
std::vector<NodeEntry> ret; std::vector<NodeEntry> ret;
if (ftransform(nid, node.get(), &ret)) { if (ftransform(nid, node, &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true; updated[idx.entry_id(nid, i)] = true;

Просмотреть файл

@ -0,0 +1,99 @@
/*!
* Copyright (c) 2017 by Contributors
* \file pattern_util.h
* \brief Utilities for doing various pattern matching in graph.
*/
#ifndef NNVM_COMPILER_PATTERN_UTIL_H_
#define NNVM_COMPILER_PATTERN_UTIL_H_
#include <nnvm/graph.h>
#include <vector>
#include <utility>
#include <string>
namespace nnvm {
namespace compiler {
/*!
* \brief find axis in oshape, such that:
* bias_shape = [1,1, ... oshape[axis], 1,1,]
*
* This is used to detect bias or scaling factor on channel dimension.
* \param oshape The output shape
* \param bias_shape The shape of bias or scaling factor.
* \return Pair of matched axis in o shape and bias_shape if found.
*/
inline std::pair<int, int> MatchBroadcast1DAxis(
const TShape& oshape, const TShape& bias_shape) {
dim_t axis_dim = bias_shape.ndim();
for (dim_t i = bias_shape.ndim(); i != 0; --i, --axis_dim) {
if (bias_shape[i - 1] != 1) break;
}
// everything is 1
if (axis_dim == 0) {
return {oshape.ndim() - bias_shape.ndim(), 0};
}
axis_dim = axis_dim - 1;
// The bias shape is not 1D
for (dim_t i = 0; i < axis_dim; ++i) {
if (bias_shape[i] != 1) return {-1, -1};
}
int axis = static_cast<int>(
oshape.ndim() - bias_shape.ndim() + axis_dim);
if (oshape[axis] != bias_shape[axis_dim]) return {-1, -1};
return {axis, axis_dim};
}
/*!
* \brief Expand bias dimension to match needed axis.
*
* \param bias The bias NodeEntry
* \param out_dim output dimension.
* \param bias_dim The current bias dimension.
* \param axis The axis we want to match on.
*/
inline NodeEntry
ExpandBiasToMatchAxis(NodeEntry bias,
int out_dim,
int bias_dim,
int axis) {
if (bias_dim != 1) {
bias = MakeNode("squeeze", bias.node->attrs.name + "_sqz", {bias});
}
int num_pad_axis = out_dim - axis - 1;
if (num_pad_axis > 0) {
std::unordered_map<std::string, std::string> kwargs{
{"axis", "1"},
{"num_newaxis", std::to_string(num_pad_axis)}};
return MakeNode("expand_dims", bias.node->attrs.name + "_expand",
{bias}, kwargs);
} else {
return bias;
}
}
/*!
* \brief Get the reference count of each node.
* \param idx The IndexedGraph
* \return ref_count vector of length number nodes.
*/
inline std::vector<uint32_t>
GetNodeRefCounts(const IndexedGraph& idx) {
std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
for (const auto& e : inode.inputs) {
++ref_count[e.node_id];
}
}
for (const auto& e : idx.outputs()) {
// this line will realize all the outputs
ref_count[e.node_id] += 1;
}
return ref_count;
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_PATTERN_UTIL_H_

Просмотреть файл

@ -1,6 +1,6 @@
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file simplify_batch_norm.cc * \file simplify_inference.cc
* \author Ziheng Jiang * \author Ziheng Jiang
*/ */
#include <nnvm/graph.h> #include <nnvm/graph.h>
@ -10,6 +10,7 @@
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include "./graph_transform.h" #include "./graph_transform.h"
#include "./pattern_util.h"
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
@ -58,15 +59,9 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
shift = MakeNode( shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta}); "elemwise_add", bn_name + "_add_beta", {shift, beta});
} }
// use expand dims to pad lower dims to 1 int axis = param.axis;
int num_pad_axis = static_cast<int>(dshape.ndim() - param.axis) - 1; scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis);
if (num_pad_axis != 0) { shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis);
std::unordered_map<std::string, std::string> kwargs{
{"axis", std::to_string(param.axis)},
{"num_newaxis", std::to_string(num_pad_axis)}};
scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs);
shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs);
}
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale}); {data, scale});
out = MakeNode("broadcast_add", bn_name + "_out", out = MakeNode("broadcast_add", bn_name + "_out",
@ -80,7 +75,7 @@ Graph SimplifyInference(nnvm::Graph src) {
// Get attributes from the graph // Get attributes from the graph
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) { auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
if (n->is_variable()) return false; if (n->is_variable()) return false;
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
static const Op* dropout_op = Op::Get("dropout"); static const Op* dropout_op = Op::Get("dropout");

Просмотреть файл

@ -28,6 +28,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
// nodes_ // nodes_
IndexedGraph::Node new_node; IndexedGraph::Node new_node;
new_node.source = n.get(); new_node.source = n.get();
new_node.weak_ref = n;
nodes_.emplace_back(std::move(new_node)); nodes_.emplace_back(std::move(new_node));
// arg_nodes_ // arg_nodes_
if (n->is_variable()) { if (n->is_variable()) {

Просмотреть файл

@ -460,21 +460,21 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
std::vector<int64_t> oshape; std::vector<int64_t> oshape;
if (param.axis.ndim() == 0) { if (param.axis.ndim() == 0) {
for (dim_t i = 0; i < shp.ndim(); ++i) { for (dim_t i = 0; i < shp.ndim(); ++i) {
if(shp[i] != 1) { if (shp[i] != 1) {
oshape.emplace_back(shp[i]); oshape.emplace_back(shp[i]);
} }
} }
} else { } else {
std::unordered_set<dim_t> axis_checker; std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) { for (size_t i = 0; i < param.axis.ndim(); ++i) {
if(param.axis[i] < 0) { if (param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim()); int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0); CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis); axis_checker.insert(real_axis);
} }
} }
for (size_t i = 0; i < shp.ndim(); ++i) { for (size_t i = 0; i < shp.ndim(); ++i) {
if(axis_checker.find(i) == axis_checker.end()) { if (axis_checker.find(i) == axis_checker.end()) {
oshape.emplace_back(shp[i]); oshape.emplace_back(shp[i]);
} else { } else {
CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!" CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!"
@ -483,7 +483,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
} }
} }
} }
if(oshape.size() == 0) { if (oshape.size() == 0) {
// Handles the case where all axes are squeezed. // Handles the case where all axes are squeezed.
oshape.push_back(1); oshape.push_back(1);
} }

Просмотреть файл

@ -0,0 +1,49 @@
"""Unittest cases for fold_axis"""
import nnvm
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
def test_fold_axis_conv():
def before(x, conv_weight, conv_bias, scale, channels):
y = sym.conv2d(x, conv_weight, conv_bias,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
name="conv")
y = sym.relu(y)
y = y * sym.expand_dims(scale, axis=1, num_newaxis=2)
return y
def expected(x, conv_weight, conv_bias, scale, channels):
conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3)
conv_bias = conv_bias * scale
y = sym.conv2d(x,
conv_weight,
conv_bias,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
name="conv")
y = sym.relu(y)
return y
# Before simplify
def check(shape, channels):
x = sym.Variable("x") + 1
weight = sym.Variable("weight")
bias = sym.Variable("bias")
scale = sym.Variable("scale")
y1 = before(x, weight, bias, scale, channels)
y2 = expected(x, weight, bias, scale, channels)
ishape = {"x": shape, "scale": (channels,)}
g1 = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g1, ishape)
g1 = g1.apply("InferShape").apply("FoldScaleAxis")
# assert graph equals as expected
graph_util.check_graph_equal(g1, g2)
check((2, 4, 10, 10), 2)
if __name__ == "__main__":
test_fold_axis_conv()

Просмотреть файл

@ -14,8 +14,8 @@ def test_simplify_batchnorm():
# for 2D # for 2D
num_newaxis=len(shape) - axis - 1 num_newaxis=len(shape) - axis - 1
if num_newaxis: if num_newaxis:
scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis) scale = sym.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis) shift = sym.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
return x * scale + shift return x * scale + shift
@ -39,8 +39,6 @@ def test_simplify_batchnorm():
g2 = nnvm.graph.create(y2) g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g, ishape) graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyInference") g1 = g.apply("InferShape").apply("SimplifyInference")
# Some prints for debug
# print(g1.ir())
# assert graph equals as expected # assert graph equals as expected
graph_util.check_graph_equal(g1, g2) graph_util.check_graph_equal(g1, g2)

Просмотреть файл

@ -28,7 +28,8 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
new_sym, params = frontend.from_mxnet(symbol, args, auxs) new_sym, params = frontend.from_mxnet(symbol, args, auxs)
dshape = x.shape dshape = x.shape
shape_dict = {'data': dshape} shape_dict = {'data': dshape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
m.set_input("data", tvm.nd.array(x.astype(dtype))) m.set_input("data", tvm.nd.array(x.astype(dtype)))