[PASS] FoldScaleAxis (#55)
* [PASS] FoldScaleAxis * Move FoldAxis to O3 * Set unroll to 0 when ready
This commit is contained in:
Родитель
3f599a60be
Коммит
d25138e6c9
|
@ -98,6 +98,8 @@ class IndexedGraph {
|
|||
array_view<NodeEntry> inputs;
|
||||
/*! \brief control flow dependencies to the node */
|
||||
array_view<uint32_t> control_deps;
|
||||
/*! \brief weak reference to node */
|
||||
std::weak_ptr<nnvm::Node> weak_ref;
|
||||
};
|
||||
/*! \return number of nodes in the graph */
|
||||
inline size_t num_nodes() const {
|
||||
|
|
|
@ -11,7 +11,8 @@ from .. import graph as _graph
|
|||
OPT_PASS_LEVEL = {
|
||||
"SimplifyInference": 0,
|
||||
"PrecomputePrune": 2,
|
||||
"OpFusion": 1
|
||||
"OpFusion": 1,
|
||||
"FoldScaleAxis": 3
|
||||
}
|
||||
|
||||
# List of optimization pass and level when switch on
|
||||
|
@ -144,6 +145,10 @@ def optimize(graph, shape, dtype="float32"):
|
|||
if cfg.pass_enabled("SimplifyInference"):
|
||||
graph = graph_attr.set_shape_inputs(graph, shape)
|
||||
graph = graph.apply(["InferShape", "SimplifyInference"])
|
||||
|
||||
if cfg.pass_enabled("FoldScaleAxis"):
|
||||
graph = graph_attr.set_shape_inputs(graph, shape)
|
||||
graph = graph.apply(["InferShape", "FoldScaleAxis"])
|
||||
return graph
|
||||
|
||||
|
||||
|
@ -291,5 +296,6 @@ def precompute_prune(graph, params):
|
|||
out_names = pre_graph.json_attr("output_names")
|
||||
if not pre_graph.symbol.list_output_names():
|
||||
return graph, params
|
||||
out_arrs = _run_graph(pre_graph, params)
|
||||
with tvm.build_config(auto_unroll_max_step=0):
|
||||
out_arrs = _run_graph(pre_graph, params)
|
||||
return graph, dict(zip(out_names, out_arrs))
|
||||
|
|
|
@ -81,7 +81,6 @@ class Xavier(Initializer):
|
|||
self.factor_type = factor_type
|
||||
self.magnitude = float(magnitude)
|
||||
|
||||
|
||||
def _init_weight(self, name, arr):
|
||||
shape = arr.shape
|
||||
hw_scale = 1.
|
||||
|
|
|
@ -30,7 +30,8 @@ def separable_conv_block(data, name, depthwise_channels,
|
|||
# depthwise convolution + bn + relu
|
||||
conv1 = sym.conv2d(data=data, channels=depthwise_channels,
|
||||
groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
|
||||
padding=padding, use_bias=False, layout="NCHW", name=name + "_depthwise_conv1")
|
||||
padding=padding, use_bias=False, layout="NCHW",
|
||||
name=name + "_depthwise_conv1")
|
||||
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
|
||||
act1 = sym.relu(data=bn1, name=name + "_relu1")
|
||||
# pointwise convolution + bn + relu
|
||||
|
|
|
@ -46,7 +46,7 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224),
|
|||
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
|
||||
shape_dict = dict(zip(g.index.input_names, input_shapes))
|
||||
np.random.seed(seed)
|
||||
initializer = initializer if initializer else Xavier(magnitude=3)
|
||||
initializer = initializer if initializer else Xavier()
|
||||
for k, v in shape_dict.items():
|
||||
if k == "data":
|
||||
continue
|
||||
|
|
|
@ -7,8 +7,7 @@ The following components are operator invariant.
|
|||
- core: NNVM core data structure
|
||||
- pass: NNVM pass
|
||||
|
||||
The following components are generic graph compiler for NNVM-TOP
|
||||
The following components are generic NNVM compiler and defines tensor operator set
|
||||
|
||||
- top: NNVM-TOP core operator defs
|
||||
- tvm: NNVM-TOP to TVM compiler toolchain
|
||||
- runtime: NNVM-TOP runtime
|
||||
- top: NNVM core tensor operators
|
||||
- compiler: NNVM compiler toolchain
|
||||
|
|
|
@ -58,7 +58,7 @@ class CompileEngine {
|
|||
return it->second->graph_func;
|
||||
}
|
||||
GraphFunc f = DoLower(key->graph, key->inputs, key->target,
|
||||
schedule_op_key, schedule_op_attr);
|
||||
schedule_op_key, schedule_op_attr);
|
||||
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
|
||||
n->graph_func = f;
|
||||
n->use_count = 1;
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file fold_scale_axis.cc
|
||||
* \author Fold scaling parameter of axis into weight of conv/dense
|
||||
*/
|
||||
#include <nnvm/graph.h>
|
||||
#include <nnvm/op_attr_types.h>
|
||||
#include <nnvm/graph_attr_types.h>
|
||||
#include <nnvm/pass.h>
|
||||
#include <nnvm/compiler/op_attr_types.h>
|
||||
#include <nnvm/top/nn.h>
|
||||
#include "./pattern_util.h"
|
||||
#include "./graph_transform.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace compiler {
|
||||
|
||||
enum FoldScaleKind {
|
||||
// No folding is applied
|
||||
kNone,
|
||||
// The folding decision is pending
|
||||
kPending,
|
||||
// The original operator that contains the scale.
|
||||
kProvider,
|
||||
// Pass through the scale to parent/child to the first axis.
|
||||
kPassTroughFirst,
|
||||
// The final conumer of axis scale using multiply
|
||||
// Likely be a conv or dense operator.
|
||||
kMulConsumer,
|
||||
// The final conumer of axis scale using division
|
||||
kDivConsumer
|
||||
};
|
||||
|
||||
// Input fold information
|
||||
struct FoldScaleInput {
|
||||
uint32_t index;
|
||||
int axis;
|
||||
};
|
||||
|
||||
// The entry of folding chains on which
|
||||
// we should perform folding on
|
||||
struct FoldChainEntry {
|
||||
// Entry kind
|
||||
FoldScaleKind kind{kNone};
|
||||
// The output axis to be folded
|
||||
int axis{0};
|
||||
// Source node in the fold chain
|
||||
int source{0};
|
||||
// Following field only used by provider.
|
||||
// The input index
|
||||
int fold_input_index{1};
|
||||
// The scale entry
|
||||
NodeEntry scale_entry;
|
||||
};
|
||||
|
||||
// Try to pass axis scaling to backward,
|
||||
// Given that we we know the status of current fold axis.
|
||||
using FScaleAxisBackward = std::function<
|
||||
FoldScaleKind(const NodeAttrs& attrs,
|
||||
int axis,
|
||||
const std::vector<TShape>& in_shape,
|
||||
const std::vector<TShape>& out_shape,
|
||||
std::vector<std::pair<uint32_t, int> >* in_axis)>;
|
||||
|
||||
// Detect if there is a scaling axis happening
|
||||
bool DetectScaleAxis(const IndexedGraph& idx,
|
||||
uint32_t nid,
|
||||
const ShapeVector& shape_vec,
|
||||
const std::vector<uint32_t>& ref_count,
|
||||
bool is_forward,
|
||||
std::vector<FoldChainEntry>* chain) {
|
||||
const IndexedGraph::Node& inode = idx[nid];
|
||||
static const Op* bcast_mul = Op::Get("broadcast_mul");
|
||||
static const Op* expand_dims = Op::Get("expand_dims");
|
||||
if (inode.source->op() != bcast_mul) return false;
|
||||
const TShape& oshape = shape_vec[idx.entry_id(nid, 0)];
|
||||
CHECK_NE(oshape.ndim(), 0);
|
||||
if (oshape.ndim() <= 1) return false;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
const IndexedGraph::NodeEntry& a = inode.inputs[i];
|
||||
const IndexedGraph::NodeEntry& b = inode.inputs[1 - i];
|
||||
std::pair<int, int> axis =
|
||||
MatchBroadcast1DAxis(oshape, shape_vec[idx.entry_id(a)]);
|
||||
if (axis.first != -1 &&
|
||||
shape_vec[idx.entry_id(b)] == oshape) {
|
||||
if (ref_count[a.node_id] != 1) return false;
|
||||
if (is_forward && ref_count[nid] != 1) return false;
|
||||
if (!is_forward && ref_count[b.node_id] != 1) return false;
|
||||
const IndexedGraph::Node& anode = idx[a.node_id];
|
||||
// mark the current entry.
|
||||
FoldChainEntry& e = (*chain)[nid];
|
||||
if (anode.source->is_variable()) {
|
||||
e.fold_input_index = 1 - i;
|
||||
e.scale_entry = inode.source->inputs[1 - i];
|
||||
} else if (anode.source->op() == expand_dims &&
|
||||
shape_vec[idx.entry_id(anode.source->inputs[0])].ndim() == 1) {
|
||||
e.fold_input_index = 1 - i;
|
||||
e.scale_entry = anode.source->inputs[0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
e.axis = axis.first;
|
||||
e.kind = kPending;
|
||||
e.source = nid;
|
||||
if (!is_forward) {
|
||||
// pass message to another input
|
||||
FoldChainEntry& enext = (*chain)[b.node_id];
|
||||
enext.axis = e.axis;
|
||||
enext.kind = kPending;
|
||||
enext.source = nid;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Graph FoldScaleAxis(Graph src) {
|
||||
// Operator pattern
|
||||
static auto& fbackward =
|
||||
nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
|
||||
const IndexedGraph& idx = src.indexed_graph();
|
||||
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
|
||||
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
|
||||
std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
|
||||
// shape hint for the inference.
|
||||
std::vector<TShape> in_shape, out_shape;
|
||||
// perform backward folding.
|
||||
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
|
||||
uint32_t nid = i - 1;
|
||||
const auto& inode = idx[nid];
|
||||
if (inode.source->is_variable()) continue;
|
||||
if (DetectScaleAxis(idx, nid, shape_vec,
|
||||
ref_count, false, &bwd_chain)) continue;
|
||||
if (bwd_chain[nid].kind != kPending) continue;
|
||||
if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
|
||||
bwd_chain[nid].kind = kNone; continue;
|
||||
}
|
||||
// get input shape and output shape.
|
||||
in_shape.clear(); out_shape.clear();
|
||||
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
|
||||
in_shape.push_back(shape_vec[idx.entry_id(e)]);
|
||||
}
|
||||
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
||||
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
|
||||
}
|
||||
std::vector<std::pair<uint32_t, int> > in_axis;
|
||||
FoldScaleKind kind =
|
||||
fbackward[inode.source->op()](
|
||||
inode.source->attrs, bwd_chain[nid].axis,
|
||||
in_shape, out_shape, &in_axis);
|
||||
bwd_chain[nid].kind = kind;
|
||||
if (kind == kNone) continue;
|
||||
CHECK_GE(in_axis.size(), 1U);
|
||||
CHECK(kind == kPassTroughFirst || kMulConsumer);
|
||||
// propagate back.
|
||||
bool can_prop = true;
|
||||
for (size_t i = 0; i < in_axis.size(); ++i) {
|
||||
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first];
|
||||
if (ref_count[e.node_id] != 1 ||
|
||||
idx[e.node_id].source->num_outputs() != 1) {
|
||||
can_prop = false; break;
|
||||
}
|
||||
}
|
||||
if (!can_prop) continue;
|
||||
for (size_t i = 0; i < in_axis.size(); ++i) {
|
||||
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first];
|
||||
if (kind == kPassTroughFirst && i == 0) {
|
||||
bwd_chain[e.node_id].kind = kPending;
|
||||
} else {
|
||||
bwd_chain[nid].kind = kNone;
|
||||
bwd_chain[e.node_id].kind = kMulConsumer;
|
||||
}
|
||||
bwd_chain[e.node_id].axis = in_axis[i].second;
|
||||
bwd_chain[e.node_id].source = bwd_chain[nid].source;
|
||||
}
|
||||
if (kind == kMulConsumer) {
|
||||
bwd_chain[bwd_chain[nid].source].kind = kProvider;
|
||||
}
|
||||
}
|
||||
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
|
||||
const FoldChainEntry& e = bwd_chain[nid];
|
||||
if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) {
|
||||
const FoldChainEntry& se = bwd_chain[e.source];
|
||||
CHECK_EQ(n->num_outputs(), 1);
|
||||
NodeEntry scale = ExpandBiasToMatchAxis(
|
||||
se.scale_entry,
|
||||
shape_vec[idx.entry_id(nid, 0)].ndim(),
|
||||
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
|
||||
e.axis);
|
||||
*ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc",
|
||||
{NodeEntry{n, 0, 0}, scale})};
|
||||
return true;
|
||||
} else if (e.kind == kProvider) {
|
||||
*ret = {n->inputs[e.fold_input_index]};
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
return GraphTransform(src, transform);
|
||||
}
|
||||
|
||||
NNVM_REGISTER_PASS(FoldScaleAxis)
|
||||
.set_body(FoldScaleAxis);
|
||||
|
||||
// property registration.
|
||||
FoldScaleKind ReluScaleAxisBackward(
|
||||
const NodeAttrs& attrs,
|
||||
int axis,
|
||||
const std::vector<TShape>& in_shape,
|
||||
const std::vector<TShape>& out_shape,
|
||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
||||
in_axis->emplace_back(0, axis);
|
||||
return kPassTroughFirst;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(relu)
|
||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
|
||||
|
||||
NNVM_REGISTER_OP(leaky_relu)
|
||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
|
||||
|
||||
FoldScaleKind BroadcastAddSubScaleAxisBackward(
|
||||
const NodeAttrs& attrs,
|
||||
int axis,
|
||||
const std::vector<TShape>& in_shape,
|
||||
const std::vector<TShape>& out_shape,
|
||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]);
|
||||
if (m.second != -1 && in_shape[1 - i] == out_shape[0]) {
|
||||
in_axis->emplace_back(i, axis);
|
||||
in_axis->emplace_back(1 - i, m.second);
|
||||
return kPassTroughFirst;
|
||||
}
|
||||
}
|
||||
return kNone;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(broadcast_add)
|
||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
|
||||
|
||||
NNVM_REGISTER_OP(broadcast_sub)
|
||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
|
||||
|
||||
FoldScaleKind Conv2DScaleAxisBackward(
|
||||
const NodeAttrs& attrs,
|
||||
int axis,
|
||||
const std::vector<TShape>& in_shape,
|
||||
const std::vector<TShape>& out_shape,
|
||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
||||
using top::Conv2DParam;
|
||||
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
|
||||
// only optimize for nchw for now
|
||||
if (param.layout == top::kNCHW) {
|
||||
in_axis->emplace_back(1, 0);
|
||||
if (param.use_bias) {
|
||||
in_axis->emplace_back(2, 0);
|
||||
}
|
||||
return kMulConsumer;
|
||||
} else {
|
||||
return kNone;
|
||||
}
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(conv2d)
|
||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);
|
||||
|
||||
} // namespace compiler
|
||||
} // namespace nnvm
|
|
@ -16,6 +16,7 @@
|
|||
#include <dmlc/parameter.h>
|
||||
#include "./compile_engine.h"
|
||||
#include "./graph_runtime.h"
|
||||
#include "./pattern_util.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace compiler {
|
||||
|
@ -56,17 +57,10 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
|
|||
|
||||
// Reference counter of each op node
|
||||
// For now, always store result when an op is referred more than once.
|
||||
std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
|
||||
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
|
||||
const auto& inode = idx[nid];
|
||||
if (inode.source->is_variable()) continue;
|
||||
for (const auto& e : inode.inputs) {
|
||||
++ref_count[e.node_id];
|
||||
}
|
||||
}
|
||||
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
|
||||
for (const auto& e : idx.outputs()) {
|
||||
// this line will realize all the outputs
|
||||
ref_count[e.node_id] += 2;
|
||||
ref_count[e.node_id] += 1;
|
||||
}
|
||||
// Pattern for the subgraph
|
||||
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace compiler {
|
|||
*
|
||||
* \param graph The original graph
|
||||
*
|
||||
* \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool
|
||||
* \param ftransform Function of (int nid, const NodePtr& node, std::vector<NodeEntry>* out) -> bool
|
||||
*
|
||||
* If empty vector is returned, it means original entries should be kept.
|
||||
*
|
||||
|
@ -36,7 +36,6 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
|
|||
// setup inputs and placeholder.
|
||||
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
|
||||
const auto& inode = idx[nid];
|
||||
if (inode.source->is_variable()) continue;
|
||||
bool need_copy = false;
|
||||
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
|
||||
if (updated[idx.entry_id(e)]) {
|
||||
|
@ -57,7 +56,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
|
|||
|
||||
if (!need_copy) {
|
||||
std::vector<NodeEntry> ret;
|
||||
if (ftransform(nid, inode.source, &ret)) {
|
||||
if (ftransform(nid, inode.weak_ref.lock(), &ret)) {
|
||||
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
|
||||
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
|
||||
updated[idx.entry_id(nid, i)] = true;
|
||||
|
@ -93,7 +92,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
|
|||
node->control_deps.push_back(selected_ptr);
|
||||
}
|
||||
std::vector<NodeEntry> ret;
|
||||
if (ftransform(nid, node.get(), &ret)) {
|
||||
if (ftransform(nid, node, &ret)) {
|
||||
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
|
||||
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
|
||||
updated[idx.entry_id(nid, i)] = true;
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file pattern_util.h
|
||||
* \brief Utilities for doing various pattern matching in graph.
|
||||
*/
|
||||
#ifndef NNVM_COMPILER_PATTERN_UTIL_H_
|
||||
#define NNVM_COMPILER_PATTERN_UTIL_H_
|
||||
|
||||
#include <nnvm/graph.h>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
namespace nnvm {
|
||||
namespace compiler {
|
||||
|
||||
/*!
|
||||
* \brief find axis in oshape, such that:
|
||||
* bias_shape = [1,1, ... oshape[axis], 1,1,]
|
||||
*
|
||||
* This is used to detect bias or scaling factor on channel dimension.
|
||||
* \param oshape The output shape
|
||||
* \param bias_shape The shape of bias or scaling factor.
|
||||
* \return Pair of matched axis in o shape and bias_shape if found.
|
||||
*/
|
||||
inline std::pair<int, int> MatchBroadcast1DAxis(
|
||||
const TShape& oshape, const TShape& bias_shape) {
|
||||
dim_t axis_dim = bias_shape.ndim();
|
||||
for (dim_t i = bias_shape.ndim(); i != 0; --i, --axis_dim) {
|
||||
if (bias_shape[i - 1] != 1) break;
|
||||
}
|
||||
// everything is 1
|
||||
if (axis_dim == 0) {
|
||||
return {oshape.ndim() - bias_shape.ndim(), 0};
|
||||
}
|
||||
axis_dim = axis_dim - 1;
|
||||
// The bias shape is not 1D
|
||||
for (dim_t i = 0; i < axis_dim; ++i) {
|
||||
if (bias_shape[i] != 1) return {-1, -1};
|
||||
}
|
||||
int axis = static_cast<int>(
|
||||
oshape.ndim() - bias_shape.ndim() + axis_dim);
|
||||
if (oshape[axis] != bias_shape[axis_dim]) return {-1, -1};
|
||||
return {axis, axis_dim};
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Expand bias dimension to match needed axis.
|
||||
*
|
||||
* \param bias The bias NodeEntry
|
||||
* \param out_dim output dimension.
|
||||
* \param bias_dim The current bias dimension.
|
||||
* \param axis The axis we want to match on.
|
||||
*/
|
||||
inline NodeEntry
|
||||
ExpandBiasToMatchAxis(NodeEntry bias,
|
||||
int out_dim,
|
||||
int bias_dim,
|
||||
int axis) {
|
||||
if (bias_dim != 1) {
|
||||
bias = MakeNode("squeeze", bias.node->attrs.name + "_sqz", {bias});
|
||||
}
|
||||
int num_pad_axis = out_dim - axis - 1;
|
||||
if (num_pad_axis > 0) {
|
||||
std::unordered_map<std::string, std::string> kwargs{
|
||||
{"axis", "1"},
|
||||
{"num_newaxis", std::to_string(num_pad_axis)}};
|
||||
return MakeNode("expand_dims", bias.node->attrs.name + "_expand",
|
||||
{bias}, kwargs);
|
||||
|
||||
} else {
|
||||
return bias;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get the reference count of each node.
|
||||
* \param idx The IndexedGraph
|
||||
* \return ref_count vector of length number nodes.
|
||||
*/
|
||||
inline std::vector<uint32_t>
|
||||
GetNodeRefCounts(const IndexedGraph& idx) {
|
||||
std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
|
||||
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
|
||||
const auto& inode = idx[nid];
|
||||
if (inode.source->is_variable()) continue;
|
||||
for (const auto& e : inode.inputs) {
|
||||
++ref_count[e.node_id];
|
||||
}
|
||||
}
|
||||
for (const auto& e : idx.outputs()) {
|
||||
// this line will realize all the outputs
|
||||
ref_count[e.node_id] += 1;
|
||||
}
|
||||
return ref_count;
|
||||
}
|
||||
} // namespace compiler
|
||||
} // namespace nnvm
|
||||
#endif // NNVM_COMPILER_PATTERN_UTIL_H_
|
|
@ -1,6 +1,6 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file simplify_batch_norm.cc
|
||||
* \file simplify_inference.cc
|
||||
* \author Ziheng Jiang
|
||||
*/
|
||||
#include <nnvm/graph.h>
|
||||
|
@ -10,6 +10,7 @@
|
|||
#include <nnvm/compiler/op_attr_types.h>
|
||||
#include <nnvm/top/nn.h>
|
||||
#include "./graph_transform.h"
|
||||
#include "./pattern_util.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace compiler {
|
||||
|
@ -58,15 +59,9 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
|
|||
shift = MakeNode(
|
||||
"elemwise_add", bn_name + "_add_beta", {shift, beta});
|
||||
}
|
||||
// use expand dims to pad lower dims to 1
|
||||
int num_pad_axis = static_cast<int>(dshape.ndim() - param.axis) - 1;
|
||||
if (num_pad_axis != 0) {
|
||||
std::unordered_map<std::string, std::string> kwargs{
|
||||
{"axis", std::to_string(param.axis)},
|
||||
{"num_newaxis", std::to_string(num_pad_axis)}};
|
||||
scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs);
|
||||
shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs);
|
||||
}
|
||||
int axis = param.axis;
|
||||
scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis);
|
||||
shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis);
|
||||
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
|
||||
{data, scale});
|
||||
out = MakeNode("broadcast_add", bn_name + "_out",
|
||||
|
@ -80,7 +75,7 @@ Graph SimplifyInference(nnvm::Graph src) {
|
|||
// Get attributes from the graph
|
||||
const IndexedGraph& idx = src.indexed_graph();
|
||||
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
|
||||
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
|
||||
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
|
||||
if (n->is_variable()) return false;
|
||||
static const Op* bn_op = Op::Get("batch_norm");
|
||||
static const Op* dropout_op = Op::Get("dropout");
|
||||
|
|
|
@ -28,6 +28,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
|
|||
// nodes_
|
||||
IndexedGraph::Node new_node;
|
||||
new_node.source = n.get();
|
||||
new_node.weak_ref = n;
|
||||
nodes_.emplace_back(std::move(new_node));
|
||||
// arg_nodes_
|
||||
if (n->is_variable()) {
|
||||
|
|
|
@ -460,21 +460,21 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
|
|||
std::vector<int64_t> oshape;
|
||||
if (param.axis.ndim() == 0) {
|
||||
for (dim_t i = 0; i < shp.ndim(); ++i) {
|
||||
if(shp[i] != 1) {
|
||||
if (shp[i] != 1) {
|
||||
oshape.emplace_back(shp[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::unordered_set<dim_t> axis_checker;
|
||||
for (size_t i = 0; i < param.axis.ndim(); ++i) {
|
||||
if(param.axis[i] < 0) {
|
||||
if (param.axis[i] < 0) {
|
||||
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
|
||||
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
|
||||
axis_checker.insert(real_axis);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < shp.ndim(); ++i) {
|
||||
if(axis_checker.find(i) == axis_checker.end()) {
|
||||
if (axis_checker.find(i) == axis_checker.end()) {
|
||||
oshape.emplace_back(shp[i]);
|
||||
} else {
|
||||
CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!"
|
||||
|
@ -483,7 +483,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
|
|||
}
|
||||
}
|
||||
}
|
||||
if(oshape.size() == 0) {
|
||||
if (oshape.size() == 0) {
|
||||
// Handles the case where all axes are squeezed.
|
||||
oshape.push_back(1);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
"""Unittest cases for fold_axis"""
|
||||
import nnvm
|
||||
from nnvm import symbol as sym
|
||||
from nnvm.compiler import graph_util, graph_attr
|
||||
|
||||
def test_fold_axis_conv():
|
||||
def before(x, conv_weight, conv_bias, scale, channels):
|
||||
y = sym.conv2d(x, conv_weight, conv_bias,
|
||||
channels=channels,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
name="conv")
|
||||
y = sym.relu(y)
|
||||
y = y * sym.expand_dims(scale, axis=1, num_newaxis=2)
|
||||
return y
|
||||
|
||||
def expected(x, conv_weight, conv_bias, scale, channels):
|
||||
conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3)
|
||||
conv_bias = conv_bias * scale
|
||||
y = sym.conv2d(x,
|
||||
conv_weight,
|
||||
conv_bias,
|
||||
channels=channels,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
name="conv")
|
||||
y = sym.relu(y)
|
||||
return y
|
||||
|
||||
# Before simplify
|
||||
def check(shape, channels):
|
||||
x = sym.Variable("x") + 1
|
||||
weight = sym.Variable("weight")
|
||||
bias = sym.Variable("bias")
|
||||
scale = sym.Variable("scale")
|
||||
y1 = before(x, weight, bias, scale, channels)
|
||||
y2 = expected(x, weight, bias, scale, channels)
|
||||
ishape = {"x": shape, "scale": (channels,)}
|
||||
g1 = nnvm.graph.create(y1)
|
||||
g2 = nnvm.graph.create(y2)
|
||||
graph_attr.set_shape_inputs(g1, ishape)
|
||||
g1 = g1.apply("InferShape").apply("FoldScaleAxis")
|
||||
# assert graph equals as expected
|
||||
graph_util.check_graph_equal(g1, g2)
|
||||
|
||||
check((2, 4, 10, 10), 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fold_axis_conv()
|
|
@ -14,8 +14,8 @@ def test_simplify_batchnorm():
|
|||
# for 2D
|
||||
num_newaxis=len(shape) - axis - 1
|
||||
if num_newaxis:
|
||||
scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis)
|
||||
shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis)
|
||||
scale = sym.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
|
||||
shift = sym.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
|
||||
return x * scale + shift
|
||||
|
||||
|
||||
|
@ -39,8 +39,6 @@ def test_simplify_batchnorm():
|
|||
g2 = nnvm.graph.create(y2)
|
||||
graph_attr.set_shape_inputs(g, ishape)
|
||||
g1 = g.apply("InferShape").apply("SimplifyInference")
|
||||
# Some prints for debug
|
||||
# print(g1.ir())
|
||||
# assert graph equals as expected
|
||||
graph_util.check_graph_equal(g1, g2)
|
||||
|
||||
|
|
|
@ -28,7 +28,8 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
|
|||
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
|
||||
dshape = x.shape
|
||||
shape_dict = {'data': dshape}
|
||||
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
|
||||
with nnvm.compiler.build_config(opt_level=3):
|
||||
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
m.set_input("data", tvm.nd.array(x.astype(dtype)))
|
||||
|
|
Загрузка…
Ссылка в новой задаче