[PASS] Enhance scale fold axis (#424)
This commit is contained in:
Родитель
89c124bc89
Коммит
a53d8d0172
|
@ -18,12 +18,10 @@ namespace compiler {
|
||||||
enum FoldScaleKind {
|
enum FoldScaleKind {
|
||||||
// No folding is applied
|
// No folding is applied
|
||||||
kNone,
|
kNone,
|
||||||
// The folding decision is pending
|
// The folding decision is pending, we can fold on a state.
|
||||||
kPending,
|
kPending,
|
||||||
// The original operator that contains the scale.
|
// The original operator that contains the scale.
|
||||||
kProvider,
|
kProvider,
|
||||||
// Pass through the scale to parent/child to the first axis.
|
|
||||||
kPassTroughFirst,
|
|
||||||
// The final conumer of axis scale using multiply
|
// The final conumer of axis scale using multiply
|
||||||
// Likely be a conv or dense operator.
|
// Likely be a conv or dense operator.
|
||||||
kMulConsumer,
|
kMulConsumer,
|
||||||
|
@ -31,21 +29,23 @@ enum FoldScaleKind {
|
||||||
kDivConsumer
|
kDivConsumer
|
||||||
};
|
};
|
||||||
|
|
||||||
// Input fold information
|
struct FoldChainInfo {
|
||||||
struct FoldScaleInput {
|
|
||||||
uint32_t index;
|
|
||||||
int axis;
|
|
||||||
};
|
|
||||||
|
|
||||||
// The entry of folding chains on which
|
|
||||||
// we should perform folding on
|
|
||||||
struct FoldChainEntry {
|
|
||||||
// Entry kind
|
// Entry kind
|
||||||
FoldScaleKind kind{kNone};
|
FoldScaleKind kind{kNone};
|
||||||
// The output axis to be folded
|
// The output axis to be folded
|
||||||
int axis{0};
|
int axis{0};
|
||||||
// Source node in the fold chain
|
// Source node in the fold chain
|
||||||
int source{0};
|
int source{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// The entry of folding chains on which
|
||||||
|
// we should perform folding on
|
||||||
|
struct FoldChainEntry {
|
||||||
|
// Fold information
|
||||||
|
FoldChainInfo info;
|
||||||
|
// Number of outgoing fork count
|
||||||
|
// in forward propagation.
|
||||||
|
int fork_count{0};
|
||||||
// Following field only used by provider.
|
// Following field only used by provider.
|
||||||
// The input index
|
// The input index
|
||||||
int fold_input_index{1};
|
int fold_input_index{1};
|
||||||
|
@ -55,12 +55,26 @@ struct FoldChainEntry {
|
||||||
|
|
||||||
// Try to pass axis scaling to backward,
|
// Try to pass axis scaling to backward,
|
||||||
// Given that we we know the status of current fold axis.
|
// Given that we we know the status of current fold axis.
|
||||||
|
// return whether the forward signal is consumed.
|
||||||
using FScaleAxisBackward = std::function<
|
using FScaleAxisBackward = std::function<
|
||||||
FoldScaleKind(const NodeAttrs& attrs,
|
bool(const NodeAttrs& attrs,
|
||||||
int axis,
|
const std::vector<TShape>& in_shape,
|
||||||
const std::vector<TShape>& in_shape,
|
const std::vector<TShape>& out_shape,
|
||||||
const std::vector<TShape>& out_shape,
|
const FoldChainInfo& out_info,
|
||||||
std::vector<std::pair<uint32_t, int> >* in_axis)>;
|
std::vector<FoldChainInfo>* in_info)>;
|
||||||
|
|
||||||
|
|
||||||
|
// Try to pass axis scaling to forward,
|
||||||
|
// Given that we we know the status of one of its input to be pending
|
||||||
|
// also update other input info
|
||||||
|
// return whether the forward signal is consumed.
|
||||||
|
using FScaleAxisForward = std::function<
|
||||||
|
bool(const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
std::vector<FoldChainInfo>* in_info,
|
||||||
|
FoldChainInfo* out_info)>;
|
||||||
|
|
||||||
|
|
||||||
// Detect if there is a scaling axis happening
|
// Detect if there is a scaling axis happening
|
||||||
bool DetectScaleAxis(const IndexedGraph& idx,
|
bool DetectScaleAxis(const IndexedGraph& idx,
|
||||||
|
@ -99,15 +113,19 @@ bool DetectScaleAxis(const IndexedGraph& idx,
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
e.axis = axis.first;
|
e.info.axis = axis.first;
|
||||||
e.kind = kPending;
|
e.info.kind = kPending;
|
||||||
e.source = nid;
|
e.info.source = nid;
|
||||||
|
e.fork_count = 1;
|
||||||
|
// In the backward message passing
|
||||||
|
// We need to eagerly pass it to the input
|
||||||
|
// In the forward message passing
|
||||||
|
// we will "pull" the message from input.
|
||||||
if (!is_forward) {
|
if (!is_forward) {
|
||||||
// pass message to another input
|
|
||||||
FoldChainEntry& enext = (*chain)[b.node_id];
|
FoldChainEntry& enext = (*chain)[b.node_id];
|
||||||
enext.axis = e.axis;
|
enext.info.axis = e.info.axis;
|
||||||
enext.kind = kPending;
|
enext.info.kind = kPending;
|
||||||
enext.source = nid;
|
enext.info.source = nid;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -119,12 +137,16 @@ Graph FoldScaleAxis(Graph src) {
|
||||||
// Operator pattern
|
// Operator pattern
|
||||||
static auto& fbackward =
|
static auto& fbackward =
|
||||||
nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
|
nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
|
||||||
|
static auto& fforward =
|
||||||
|
nnvm::Op::GetAttr<FScaleAxisForward>("FScaleAxisForward");
|
||||||
const IndexedGraph& idx = src.indexed_graph();
|
const IndexedGraph& idx = src.indexed_graph();
|
||||||
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
|
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
|
||||||
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
|
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
|
||||||
std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
|
std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
|
||||||
|
std::vector<FoldChainEntry> fwd_chain(idx.num_nodes());
|
||||||
// shape hint for the inference.
|
// shape hint for the inference.
|
||||||
std::vector<TShape> in_shape, out_shape;
|
std::vector<TShape> in_shape, out_shape;
|
||||||
|
|
||||||
// perform backward folding.
|
// perform backward folding.
|
||||||
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
|
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
|
||||||
uint32_t nid = i - 1;
|
uint32_t nid = i - 1;
|
||||||
|
@ -132,9 +154,10 @@ Graph FoldScaleAxis(Graph src) {
|
||||||
if (inode.source->is_variable()) continue;
|
if (inode.source->is_variable()) continue;
|
||||||
if (DetectScaleAxis(idx, nid, shape_vec,
|
if (DetectScaleAxis(idx, nid, shape_vec,
|
||||||
ref_count, false, &bwd_chain)) continue;
|
ref_count, false, &bwd_chain)) continue;
|
||||||
if (bwd_chain[nid].kind != kPending) continue;
|
if (bwd_chain[nid].info.kind != kPending) continue;
|
||||||
|
// if referred by multiple node, cannot do propagation
|
||||||
if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
|
if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
|
||||||
bwd_chain[nid].kind = kNone; continue;
|
bwd_chain[nid].info.kind = kNone; continue;
|
||||||
}
|
}
|
||||||
// get input shape and output shape.
|
// get input shape and output shape.
|
||||||
in_shape.clear(); out_shape.clear();
|
in_shape.clear(); out_shape.clear();
|
||||||
|
@ -144,58 +167,151 @@ Graph FoldScaleAxis(Graph src) {
|
||||||
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
||||||
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
|
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
|
||||||
}
|
}
|
||||||
std::vector<std::pair<uint32_t, int> > in_axis;
|
std::vector<FoldChainInfo> in_info(in_shape.size(), FoldChainInfo());
|
||||||
FoldScaleKind kind =
|
bool consumed = fbackward[inode.source->op()](
|
||||||
fbackward[inode.source->op()](
|
inode.source->attrs,
|
||||||
inode.source->attrs, bwd_chain[nid].axis,
|
in_shape,
|
||||||
in_shape, out_shape, &in_axis);
|
out_shape,
|
||||||
bwd_chain[nid].kind = kind;
|
bwd_chain[nid].info,
|
||||||
if (kind == kNone) continue;
|
&in_info);
|
||||||
CHECK_GE(in_axis.size(), 1U);
|
CHECK_EQ(in_info.size(), in_shape.size());
|
||||||
CHECK(kind == kPassTroughFirst || kind == kMulConsumer);
|
|
||||||
// propagate back.
|
// propagate back.
|
||||||
bool can_prop = true;
|
bool can_prop = true;
|
||||||
for (size_t i = 0; i < in_axis.size(); ++i) {
|
for (size_t i = 0; i < in_info.size(); ++i) {
|
||||||
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first];
|
const IndexedGraph::NodeEntry& e = inode.inputs[i];
|
||||||
if (ref_count[e.node_id] != 1 ||
|
if (ref_count[e.node_id] != 1 ||
|
||||||
idx[e.node_id].source->num_outputs() != 1) {
|
idx[e.node_id].source->num_outputs() != 1) {
|
||||||
can_prop = false; break;
|
can_prop = false; break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!can_prop) continue;
|
if (!can_prop) continue;
|
||||||
for (size_t i = 0; i < in_axis.size(); ++i) {
|
for (size_t i = 0; i < in_info.size(); ++i) {
|
||||||
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first];
|
const IndexedGraph::NodeEntry& e = inode.inputs[i];
|
||||||
if (kind == kPassTroughFirst && i == 0) {
|
bwd_chain[e.node_id].info = in_info[i];
|
||||||
bwd_chain[e.node_id].kind = kPending;
|
|
||||||
} else {
|
|
||||||
bwd_chain[nid].kind = kNone;
|
|
||||||
bwd_chain[e.node_id].kind = kMulConsumer;
|
|
||||||
}
|
|
||||||
bwd_chain[e.node_id].axis = in_axis[i].second;
|
|
||||||
bwd_chain[e.node_id].source = bwd_chain[nid].source;
|
|
||||||
}
|
}
|
||||||
if (kind == kMulConsumer) {
|
// mark consumed by making the source as provider.
|
||||||
bwd_chain[bwd_chain[nid].source].kind = kProvider;
|
if (consumed) {
|
||||||
|
bwd_chain[bwd_chain[nid].info.source].info.kind = kProvider;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
|
|
||||||
const FoldChainEntry& e = bwd_chain[nid];
|
|
||||||
if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) {
|
// perform forward folding.
|
||||||
const FoldChainEntry& se = bwd_chain[e.source];
|
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
|
||||||
CHECK_EQ(n->num_outputs(), 1);
|
const auto& inode = idx[nid];
|
||||||
NodeEntry scale = ExpandBiasToMatchAxis(
|
if (inode.source->is_variable()) continue;
|
||||||
se.scale_entry,
|
// skip scales that are already folded in backward.
|
||||||
shape_vec[idx.entry_id(nid, 0)].ndim(),
|
if (bwd_chain[nid].info.kind == kProvider) continue;
|
||||||
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
|
if (DetectScaleAxis(idx, nid, shape_vec,
|
||||||
e.axis);
|
ref_count, true, &fwd_chain)) continue;
|
||||||
*ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc",
|
if (inode.source->num_outputs() != 1) continue;
|
||||||
{NodeEntry{n, 0, 0}, scale})};
|
// Do state update
|
||||||
return true;
|
// get input shape and output shape.
|
||||||
} else if (e.kind == kProvider) {
|
std::vector<FoldChainInfo> in_info;
|
||||||
*ret = {n->inputs[e.fold_input_index]};
|
FoldChainInfo out_info;
|
||||||
return true;
|
int num_inpending = 0;
|
||||||
|
in_shape.clear(); out_shape.clear();
|
||||||
|
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
|
||||||
|
in_shape.push_back(shape_vec[idx.entry_id(e)]);
|
||||||
|
// input information
|
||||||
|
in_info.push_back(fwd_chain[e.node_id].info);
|
||||||
|
if (fwd_chain[e.node_id].info.kind == kPending) {
|
||||||
|
++num_inpending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
||||||
|
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
|
||||||
|
}
|
||||||
|
if (num_inpending != 1 ||
|
||||||
|
!fforward.count(inode.source->op())) continue;
|
||||||
|
bool consumed = fforward[inode.source->op()](
|
||||||
|
inode.source->attrs,
|
||||||
|
in_shape,
|
||||||
|
out_shape,
|
||||||
|
&in_info,
|
||||||
|
&out_info);
|
||||||
|
// update input info
|
||||||
|
for (size_t i = 0; i < in_info.size(); ++i) {
|
||||||
|
fwd_chain[inode.inputs[i].node_id].info = in_info[i];
|
||||||
|
}
|
||||||
|
if (consumed) {
|
||||||
|
fwd_chain[nid].info = out_info;
|
||||||
|
for (size_t i = 0; i < in_info.size(); ++i) {
|
||||||
|
if (in_info[i].kind == kPending) {
|
||||||
|
if (--fwd_chain[in_info[i].source].fork_count == 0) {
|
||||||
|
fwd_chain[in_info[i].source].info.kind = kProvider;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// can propagate condition
|
||||||
|
if (inode.source->num_outputs() == 1) {
|
||||||
|
fwd_chain[nid].info = out_info;
|
||||||
|
if (out_info.kind == kPending) {
|
||||||
|
// When there is multiple reference to input
|
||||||
|
// every path have to be consumed
|
||||||
|
fwd_chain[out_info.source].fork_count += ref_count[nid] - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
|
||||||
|
NodeEntry rvalue = NodeEntry{n, 0, 0};
|
||||||
|
{
|
||||||
|
// Backward chain
|
||||||
|
const FoldChainEntry& e = bwd_chain[nid];
|
||||||
|
if (e.info.kind == kMulConsumer &&
|
||||||
|
bwd_chain[e.info.source].info.kind == kProvider) {
|
||||||
|
const FoldChainEntry& se = bwd_chain[e.info.source];
|
||||||
|
CHECK_EQ(n->num_outputs(), 1);
|
||||||
|
NodeEntry scale = ExpandBiasToMatchAxis(
|
||||||
|
se.scale_entry,
|
||||||
|
shape_vec[idx.entry_id(nid, 0)].ndim(),
|
||||||
|
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
|
||||||
|
e.info.axis);
|
||||||
|
rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
|
||||||
|
{rvalue, scale});
|
||||||
|
} else if (e.info.kind == kProvider) {
|
||||||
|
rvalue = n->inputs[e.fold_input_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Note that the value might get transformed twice if it
|
||||||
|
// folds value from both fwd and backward chain.
|
||||||
|
{
|
||||||
|
// forward chain
|
||||||
|
const FoldChainEntry& e = fwd_chain[nid];
|
||||||
|
if (e.info.kind == kMulConsumer &&
|
||||||
|
fwd_chain[e.info.source].info.kind == kProvider) {
|
||||||
|
const FoldChainEntry& se = fwd_chain[e.info.source];
|
||||||
|
CHECK_EQ(n->num_outputs(), 1);
|
||||||
|
NodeEntry scale = ExpandBiasToMatchAxis(
|
||||||
|
se.scale_entry,
|
||||||
|
shape_vec[idx.entry_id(nid, 0)].ndim(),
|
||||||
|
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
|
||||||
|
e.info.axis);
|
||||||
|
rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
|
||||||
|
{rvalue, scale});
|
||||||
|
} else if (e.info.kind == kDivConsumer &&
|
||||||
|
fwd_chain[e.info.source].info.kind == kProvider) {
|
||||||
|
const FoldChainEntry& se = fwd_chain[e.info.source];
|
||||||
|
CHECK_EQ(n->num_outputs(), 1);
|
||||||
|
NodeEntry scale = ExpandBiasToMatchAxis(
|
||||||
|
se.scale_entry,
|
||||||
|
shape_vec[idx.entry_id(nid, 0)].ndim(),
|
||||||
|
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
|
||||||
|
e.info.axis);
|
||||||
|
rvalue = MakeNode("broadcast_div", n->attrs.name + "_sc",
|
||||||
|
{rvalue, scale});
|
||||||
|
} else if (e.info.kind == kProvider) {
|
||||||
|
rvalue = n->inputs[e.fold_input_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rvalue.node == n) {
|
||||||
return false;
|
return false;
|
||||||
|
} else {
|
||||||
|
*ret = {rvalue};
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return GraphTransform(src, transform);
|
return GraphTransform(src, transform);
|
||||||
|
@ -205,14 +321,24 @@ NNVM_REGISTER_PASS(FoldScaleAxis)
|
||||||
.set_body(FoldScaleAxis);
|
.set_body(FoldScaleAxis);
|
||||||
|
|
||||||
// property registration.
|
// property registration.
|
||||||
FoldScaleKind ReluScaleAxisBackward(
|
bool ReluScaleAxisBackward(
|
||||||
const NodeAttrs& attrs,
|
const NodeAttrs& attrs,
|
||||||
int axis,
|
|
||||||
const std::vector<TShape>& in_shape,
|
const std::vector<TShape>& in_shape,
|
||||||
const std::vector<TShape>& out_shape,
|
const std::vector<TShape>& out_shape,
|
||||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
const FoldChainInfo& out_info,
|
||||||
in_axis->emplace_back(0, axis);
|
std::vector<FoldChainInfo>* in_axis) {
|
||||||
return kPassTroughFirst;
|
(*in_axis)[0] = out_info;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReluScaleAxisForward(
|
||||||
|
const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
std::vector<FoldChainInfo>* in_info,
|
||||||
|
FoldChainInfo* out_info) {
|
||||||
|
*out_info = (*in_info)[0];
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
NNVM_REGISTER_OP(relu)
|
NNVM_REGISTER_OP(relu)
|
||||||
|
@ -221,21 +347,102 @@ NNVM_REGISTER_OP(relu)
|
||||||
NNVM_REGISTER_OP(leaky_relu)
|
NNVM_REGISTER_OP(leaky_relu)
|
||||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
|
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
|
||||||
|
|
||||||
FoldScaleKind BroadcastAddSubScaleAxisBackward(
|
NNVM_REGISTER_OP(relu)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(leaky_relu)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);
|
||||||
|
|
||||||
|
// property registration.
|
||||||
|
bool Pool2DBackward(
|
||||||
const NodeAttrs& attrs,
|
const NodeAttrs& attrs,
|
||||||
int axis,
|
|
||||||
const std::vector<TShape>& in_shape,
|
const std::vector<TShape>& in_shape,
|
||||||
const std::vector<TShape>& out_shape,
|
const std::vector<TShape>& out_shape,
|
||||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
const FoldChainInfo& out_info,
|
||||||
|
std::vector<FoldChainInfo>* in_axis) {
|
||||||
|
using top::Pool2DParam;
|
||||||
|
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
|
||||||
|
if (out_info.axis == 1 && param.layout == top::kNCHW) {
|
||||||
|
(*in_axis)[0] = out_info;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Pool2DForward(
|
||||||
|
const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
std::vector<FoldChainInfo>* in_info,
|
||||||
|
FoldChainInfo* out_info) {
|
||||||
|
using top::Pool2DParam;
|
||||||
|
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
|
||||||
|
if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) {
|
||||||
|
*out_info = (*in_info)[0];
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(max_pool2d)
|
||||||
|
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(avg_pool2d)
|
||||||
|
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(max_pool2d)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(avg_pool2d)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool BroadcastAddSubScaleAxisBackward(
|
||||||
|
const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
const FoldChainInfo& out_info,
|
||||||
|
std::vector<FoldChainInfo>* in_axis) {
|
||||||
|
if (out_info.kind != kPending) return false;
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]);
|
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
|
||||||
if (m.second != -1 && in_shape[1 - i] == out_shape[0]) {
|
if (m.second != -1 &&
|
||||||
in_axis->emplace_back(i, axis);
|
in_shape[i] == out_shape[0] &&
|
||||||
in_axis->emplace_back(1 - i, m.second);
|
m.first == out_info.axis) {
|
||||||
return kPassTroughFirst;
|
(*in_axis)[i].kind = kPending;
|
||||||
|
(*in_axis)[i].axis = out_info.axis;
|
||||||
|
(*in_axis)[i].source = out_info.source;
|
||||||
|
(*in_axis)[1 - i].kind = kMulConsumer;
|
||||||
|
(*in_axis)[1 - i].axis = m.second;
|
||||||
|
(*in_axis)[1 - i].source = out_info.source;
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return kNone;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BroadcastAddSubScaleAxisForward(
|
||||||
|
const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
std::vector<FoldChainInfo>* in_info,
|
||||||
|
FoldChainInfo* out_info) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
if ((*in_info)[i].kind == kPending) {
|
||||||
|
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
|
||||||
|
if (m.second != -1 &&
|
||||||
|
in_shape[i] == out_shape[0] &&
|
||||||
|
m.first == (*in_info)[i].axis) {
|
||||||
|
out_info->kind = kPending;
|
||||||
|
out_info->axis = m.first;
|
||||||
|
out_info->source = (*in_info)[i].source;
|
||||||
|
(*in_info)[1 - i].kind = kDivConsumer;
|
||||||
|
(*in_info)[1 - i].axis = m.second;
|
||||||
|
(*in_info)[1 - i].source = (*in_info)[i].source;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
NNVM_REGISTER_OP(broadcast_add)
|
NNVM_REGISTER_OP(broadcast_add)
|
||||||
|
@ -244,28 +451,62 @@ NNVM_REGISTER_OP(broadcast_add)
|
||||||
NNVM_REGISTER_OP(broadcast_sub)
|
NNVM_REGISTER_OP(broadcast_sub)
|
||||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
|
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
|
||||||
|
|
||||||
FoldScaleKind Conv2DScaleAxisBackward(
|
NNVM_REGISTER_OP(broadcast_add)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(broadcast_sub)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);
|
||||||
|
|
||||||
|
bool Conv2DScaleAxisBackward(
|
||||||
const NodeAttrs& attrs,
|
const NodeAttrs& attrs,
|
||||||
int axis,
|
|
||||||
const std::vector<TShape>& in_shape,
|
const std::vector<TShape>& in_shape,
|
||||||
const std::vector<TShape>& out_shape,
|
const std::vector<TShape>& out_shape,
|
||||||
std::vector<std::pair<uint32_t, int> >* in_axis) {
|
const FoldChainInfo& out_info,
|
||||||
|
std::vector<FoldChainInfo>* in_axis) {
|
||||||
using top::Conv2DParam;
|
using top::Conv2DParam;
|
||||||
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
|
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
|
||||||
|
if (out_info.kind != kPending) return false;
|
||||||
// only optimize for nchw for now
|
// only optimize for nchw for now
|
||||||
if (param.layout == top::kNCHW) {
|
if (param.layout == top::kNCHW && out_info.axis == 1) {
|
||||||
in_axis->emplace_back(1, 0);
|
(*in_axis)[1].kind = kMulConsumer;
|
||||||
|
(*in_axis)[1].axis = 0;
|
||||||
|
(*in_axis)[1].source = out_info.source;
|
||||||
if (param.use_bias) {
|
if (param.use_bias) {
|
||||||
in_axis->emplace_back(2, 0);
|
(*in_axis)[2].kind = kMulConsumer;
|
||||||
|
(*in_axis)[2].axis = 0;
|
||||||
|
(*in_axis)[2].source = out_info.source;
|
||||||
}
|
}
|
||||||
return kMulConsumer;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return kNone;
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Conv2DScaleAxisForward(
|
||||||
|
const NodeAttrs& attrs,
|
||||||
|
const std::vector<TShape>& in_shape,
|
||||||
|
const std::vector<TShape>& out_shape,
|
||||||
|
std::vector<FoldChainInfo>* in_info,
|
||||||
|
FoldChainInfo* out_info) {
|
||||||
|
using top::Conv2DParam;
|
||||||
|
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
|
||||||
|
if ((*in_info)[0].kind != kPending) return false;
|
||||||
|
// only optimize for nchw for now
|
||||||
|
if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) {
|
||||||
|
(*in_info)[1].kind = kMulConsumer;
|
||||||
|
(*in_info)[1].axis = 1;
|
||||||
|
(*in_info)[1].source = (*in_info)[0].source;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NNVM_REGISTER_OP(conv2d)
|
NNVM_REGISTER_OP(conv2d)
|
||||||
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);
|
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);
|
||||||
|
|
||||||
|
NNVM_REGISTER_OP(conv2d)
|
||||||
|
.set_attr<FScaleAxisForward>("FScaleAxisForward", Conv2DScaleAxisForward);
|
||||||
|
|
||||||
} // namespace compiler
|
} // namespace compiler
|
||||||
} // namespace nnvm
|
} // namespace nnvm
|
||||||
|
|
|
@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
|
||||||
if (taken[kv.first] == false &&
|
if (taken[kv.first] == false &&
|
||||||
sid_out == GraphAllocator::kBadStorageID &&
|
sid_out == GraphAllocator::kBadStorageID &&
|
||||||
sid_in >= 0 &&
|
sid_in >= 0 &&
|
||||||
(storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) &&
|
((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) &&
|
||||||
entry_ref_count[eid_out] > 0 &&
|
entry_ref_count[eid_out] > 0 &&
|
||||||
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
|
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
|
||||||
dtype_vec[eid_out] == dtype_vec[eid_in]) {
|
dtype_vec[eid_out] == dtype_vec[eid_in]) {
|
||||||
|
|
|
@ -1,22 +1,26 @@
|
||||||
"""Unittest cases for fold_axis"""
|
"""Unittest cases for fold_axis"""
|
||||||
import nnvm
|
import nnvm
|
||||||
|
import nnvm.testing.resnet
|
||||||
|
import numpy as np
|
||||||
from nnvm import symbol as sym
|
from nnvm import symbol as sym
|
||||||
from nnvm.compiler import graph_util, graph_attr
|
from nnvm.compiler import graph_util, graph_attr
|
||||||
|
|
||||||
def test_fold_axis_conv():
|
def test_fold_axis_conv():
|
||||||
def before(x, conv_weight, conv_bias, scale, channels):
|
def before(x, conv_weight, conv_bias, in_scale, out_scale, channels):
|
||||||
|
x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
|
||||||
y = sym.conv2d(x, conv_weight, conv_bias,
|
y = sym.conv2d(x, conv_weight, conv_bias,
|
||||||
channels=channels,
|
channels=channels,
|
||||||
kernel_size=(3, 3),
|
kernel_size=(3, 3),
|
||||||
padding=(1, 1),
|
padding=(1, 1),
|
||||||
name="conv")
|
name="conv")
|
||||||
y = sym.relu(y)
|
y = sym.relu(y)
|
||||||
y = y * sym.expand_dims(scale, axis=1, num_newaxis=2)
|
y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def expected(x, conv_weight, conv_bias, scale, channels):
|
def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels):
|
||||||
conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3)
|
conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3)
|
||||||
conv_bias = conv_bias * scale
|
conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
|
||||||
|
conv_bias = conv_bias * out_scale
|
||||||
y = sym.conv2d(x,
|
y = sym.conv2d(x,
|
||||||
conv_weight,
|
conv_weight,
|
||||||
conv_bias,
|
conv_bias,
|
||||||
|
@ -32,10 +36,11 @@ def test_fold_axis_conv():
|
||||||
x = sym.Variable("x") + 1
|
x = sym.Variable("x") + 1
|
||||||
weight = sym.Variable("weight")
|
weight = sym.Variable("weight")
|
||||||
bias = sym.Variable("bias")
|
bias = sym.Variable("bias")
|
||||||
scale = sym.Variable("scale")
|
in_scale = sym.Variable("in_scale")
|
||||||
y1 = before(x, weight, bias, scale, channels)
|
out_scale = sym.Variable("out_scale")
|
||||||
y2 = expected(x, weight, bias, scale, channels)
|
y1 = before(x, weight, bias, in_scale, out_scale, channels)
|
||||||
ishape = {"x": shape, "scale": (channels,)}
|
y2 = expected(x, weight, bias, in_scale, out_scale, channels)
|
||||||
|
ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)}
|
||||||
g1 = nnvm.graph.create(y1)
|
g1 = nnvm.graph.create(y1)
|
||||||
g2 = nnvm.graph.create(y2)
|
g2 = nnvm.graph.create(y2)
|
||||||
graph_attr.set_shape_inputs(g1, ishape)
|
graph_attr.set_shape_inputs(g1, ishape)
|
||||||
|
@ -45,5 +50,61 @@ def test_fold_axis_conv():
|
||||||
|
|
||||||
check((2, 4, 10, 10), 2)
|
check((2, 4, 10, 10), 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fold_fail():
|
||||||
|
def before(x, scale, channels):
|
||||||
|
y = sym.conv2d(x,
|
||||||
|
channels=channels,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
name="conv")
|
||||||
|
y = y * sym.expand_dims(scale, axis=1, num_newaxis=1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
# Before simplify
|
||||||
|
def check(shape, channels):
|
||||||
|
x = sym.Variable("x")
|
||||||
|
bias = sym.Variable("bias")
|
||||||
|
scale = sym.Variable("scale")
|
||||||
|
y1 = before(x, scale, channels)
|
||||||
|
ishape = {"x": shape, "scale": (channels,), "bias": (channels,)}
|
||||||
|
g1 = nnvm.graph.create(y1)
|
||||||
|
graph_attr.set_shape_inputs(g1, ishape)
|
||||||
|
g2 = g1.apply("InferShape").apply("FoldScaleAxis")
|
||||||
|
# assert graph equals as expected
|
||||||
|
graph_util.check_graph_equal(g1, g2)
|
||||||
|
|
||||||
|
check((2, 10, 10, 10), 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fold_resnet():
|
||||||
|
batch_size = 1
|
||||||
|
num_classes = 1000
|
||||||
|
image_shape = (3, 224, 224)
|
||||||
|
data_shape = (batch_size,) +image_shape
|
||||||
|
net, params = nnvm.testing.resnet.get_workload(
|
||||||
|
batch_size=1, image_shape=image_shape)
|
||||||
|
ishape = {"data" : data_shape}
|
||||||
|
graph = nnvm.graph.create(net)
|
||||||
|
data = np.random.uniform(size=data_shape).astype("float32")
|
||||||
|
# Initial pass do shape type inference
|
||||||
|
shape, _ = graph_util.infer_shape(graph, **ishape)
|
||||||
|
ishape.update(zip(graph.index.input_names, shape))
|
||||||
|
|
||||||
|
def run_prune(graph, params, opt_level):
|
||||||
|
# Apply optimization
|
||||||
|
with nnvm.compiler.build_config(opt_level=0):
|
||||||
|
graph = nnvm.compiler.optimize(graph, ishape)
|
||||||
|
graph, params = nnvm.compiler.build_module.precompute_prune(graph, params)
|
||||||
|
params["data"] = data
|
||||||
|
return nnvm.compiler.build_module._run_graph(graph, params)
|
||||||
|
|
||||||
|
x = run_prune(graph, params, 0)
|
||||||
|
y = run_prune(graph, params, 3)
|
||||||
|
np.testing.assert_allclose(y[0].asnumpy(), x[0].asnumpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_fold_resnet()
|
||||||
test_fold_axis_conv()
|
test_fold_axis_conv()
|
||||||
|
test_fold_fail()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче