From a53d8d01725cc25c855c00e32ec99510479f4c41 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 1 Apr 2018 10:43:29 -0700 Subject: [PATCH] [PASS] Enhance scale fold axis (#424) --- nnvm/src/compiler/fold_scale_axis.cc | 415 +++++++++++++++---- nnvm/src/pass/plan_memory.cc | 2 +- nnvm/tests/python/compiler/test_fold_axis.py | 79 +++- 3 files changed, 399 insertions(+), 97 deletions(-) diff --git a/nnvm/src/compiler/fold_scale_axis.cc b/nnvm/src/compiler/fold_scale_axis.cc index 34383c4e..7b05153b 100644 --- a/nnvm/src/compiler/fold_scale_axis.cc +++ b/nnvm/src/compiler/fold_scale_axis.cc @@ -18,12 +18,10 @@ namespace compiler { enum FoldScaleKind { // No folding is applied kNone, - // The folding decision is pending + // The folding decision is pending, we can fold on a state. kPending, // The original operator that contains the scale. kProvider, - // Pass through the scale to parent/child to the first axis. - kPassTroughFirst, // The final conumer of axis scale using multiply // Likely be a conv or dense operator. kMulConsumer, @@ -31,21 +29,23 @@ enum FoldScaleKind { kDivConsumer }; -// Input fold information -struct FoldScaleInput { - uint32_t index; - int axis; -}; - -// The entry of folding chains on which -// we should perform folding on -struct FoldChainEntry { +struct FoldChainInfo { // Entry kind FoldScaleKind kind{kNone}; // The output axis to be folded int axis{0}; // Source node in the fold chain int source{0}; +}; + +// The entry of folding chains on which +// we should perform folding on +struct FoldChainEntry { + // Fold information + FoldChainInfo info; + // Number of outgoing fork count + // in forward propagation. + int fork_count{0}; // Following field only used by provider. // The input index int fold_input_index{1}; @@ -55,12 +55,26 @@ struct FoldChainEntry { // Try to pass axis scaling to backward, // Given that we we know the status of current fold axis. +// return whether the forward signal is consumed. using FScaleAxisBackward = std::function< - FoldScaleKind(const NodeAttrs& attrs, - int axis, - const std::vector& in_shape, - const std::vector& out_shape, - std::vector >* in_axis)>; + bool(const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + const FoldChainInfo& out_info, + std::vector* in_info)>; + + +// Try to pass axis scaling to forward, +// Given that we we know the status of one of its input to be pending +// also update other input info +// return whether the forward signal is consumed. +using FScaleAxisForward = std::function< + bool(const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector* in_info, + FoldChainInfo* out_info)>; + // Detect if there is a scaling axis happening bool DetectScaleAxis(const IndexedGraph& idx, @@ -99,15 +113,19 @@ bool DetectScaleAxis(const IndexedGraph& idx, } else { return false; } - e.axis = axis.first; - e.kind = kPending; - e.source = nid; + e.info.axis = axis.first; + e.info.kind = kPending; + e.info.source = nid; + e.fork_count = 1; + // In the backward message passing + // We need to eagerly pass it to the input + // In the forward message passing + // we will "pull" the message from input. if (!is_forward) { - // pass message to another input FoldChainEntry& enext = (*chain)[b.node_id]; - enext.axis = e.axis; - enext.kind = kPending; - enext.source = nid; + enext.info.axis = e.info.axis; + enext.info.kind = kPending; + enext.info.source = nid; } return true; } @@ -119,12 +137,16 @@ Graph FoldScaleAxis(Graph src) { // Operator pattern static auto& fbackward = nnvm::Op::GetAttr("FScaleAxisBackward"); + static auto& fforward = + nnvm::Op::GetAttr("FScaleAxisForward"); const IndexedGraph& idx = src.indexed_graph(); const ShapeVector& shape_vec = src.GetAttr("shape"); std::vector ref_count = GetNodeRefCounts(idx); std::vector bwd_chain(idx.num_nodes()); + std::vector fwd_chain(idx.num_nodes()); // shape hint for the inference. std::vector in_shape, out_shape; + // perform backward folding. for (uint32_t i = idx.num_nodes(); i != 0; --i) { uint32_t nid = i - 1; @@ -132,9 +154,10 @@ Graph FoldScaleAxis(Graph src) { if (inode.source->is_variable()) continue; if (DetectScaleAxis(idx, nid, shape_vec, ref_count, false, &bwd_chain)) continue; - if (bwd_chain[nid].kind != kPending) continue; + if (bwd_chain[nid].info.kind != kPending) continue; + // if referred by multiple node, cannot do propagation if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) { - bwd_chain[nid].kind = kNone; continue; + bwd_chain[nid].info.kind = kNone; continue; } // get input shape and output shape. in_shape.clear(); out_shape.clear(); @@ -144,58 +167,151 @@ Graph FoldScaleAxis(Graph src) { for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { out_shape.push_back(shape_vec[idx.entry_id(nid, i)]); } - std::vector > in_axis; - FoldScaleKind kind = - fbackward[inode.source->op()]( - inode.source->attrs, bwd_chain[nid].axis, - in_shape, out_shape, &in_axis); - bwd_chain[nid].kind = kind; - if (kind == kNone) continue; - CHECK_GE(in_axis.size(), 1U); - CHECK(kind == kPassTroughFirst || kind == kMulConsumer); + std::vector in_info(in_shape.size(), FoldChainInfo()); + bool consumed = fbackward[inode.source->op()]( + inode.source->attrs, + in_shape, + out_shape, + bwd_chain[nid].info, + &in_info); + CHECK_EQ(in_info.size(), in_shape.size()); // propagate back. bool can_prop = true; - for (size_t i = 0; i < in_axis.size(); ++i) { - const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first]; + for (size_t i = 0; i < in_info.size(); ++i) { + const IndexedGraph::NodeEntry& e = inode.inputs[i]; if (ref_count[e.node_id] != 1 || idx[e.node_id].source->num_outputs() != 1) { can_prop = false; break; } } if (!can_prop) continue; - for (size_t i = 0; i < in_axis.size(); ++i) { - const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first]; - if (kind == kPassTroughFirst && i == 0) { - bwd_chain[e.node_id].kind = kPending; - } else { - bwd_chain[nid].kind = kNone; - bwd_chain[e.node_id].kind = kMulConsumer; - } - bwd_chain[e.node_id].axis = in_axis[i].second; - bwd_chain[e.node_id].source = bwd_chain[nid].source; + for (size_t i = 0; i < in_info.size(); ++i) { + const IndexedGraph::NodeEntry& e = inode.inputs[i]; + bwd_chain[e.node_id].info = in_info[i]; } - if (kind == kMulConsumer) { - bwd_chain[bwd_chain[nid].source].kind = kProvider; + // mark consumed by making the source as provider. + if (consumed) { + bwd_chain[bwd_chain[nid].info.source].info.kind = kProvider; } } - auto transform = [&](uint32_t nid, const NodePtr& n, std::vector* ret) { - const FoldChainEntry& e = bwd_chain[nid]; - if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) { - const FoldChainEntry& se = bwd_chain[e.source]; - CHECK_EQ(n->num_outputs(), 1); - NodeEntry scale = ExpandBiasToMatchAxis( - se.scale_entry, - shape_vec[idx.entry_id(nid, 0)].ndim(), - shape_vec[idx.entry_id(se.scale_entry)].ndim(), - e.axis); - *ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc", - {NodeEntry{n, 0, 0}, scale})}; - return true; - } else if (e.kind == kProvider) { - *ret = {n->inputs[e.fold_input_index]}; - return true; + + + // perform forward folding. + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + // skip scales that are already folded in backward. + if (bwd_chain[nid].info.kind == kProvider) continue; + if (DetectScaleAxis(idx, nid, shape_vec, + ref_count, true, &fwd_chain)) continue; + if (inode.source->num_outputs() != 1) continue; + // Do state update + // get input shape and output shape. + std::vector in_info; + FoldChainInfo out_info; + int num_inpending = 0; + in_shape.clear(); out_shape.clear(); + for (const IndexedGraph::NodeEntry& e : inode.inputs) { + in_shape.push_back(shape_vec[idx.entry_id(e)]); + // input information + in_info.push_back(fwd_chain[e.node_id].info); + if (fwd_chain[e.node_id].info.kind == kPending) { + ++num_inpending; + } + } + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + out_shape.push_back(shape_vec[idx.entry_id(nid, i)]); + } + if (num_inpending != 1 || + !fforward.count(inode.source->op())) continue; + bool consumed = fforward[inode.source->op()]( + inode.source->attrs, + in_shape, + out_shape, + &in_info, + &out_info); + // update input info + for (size_t i = 0; i < in_info.size(); ++i) { + fwd_chain[inode.inputs[i].node_id].info = in_info[i]; + } + if (consumed) { + fwd_chain[nid].info = out_info; + for (size_t i = 0; i < in_info.size(); ++i) { + if (in_info[i].kind == kPending) { + if (--fwd_chain[in_info[i].source].fork_count == 0) { + fwd_chain[in_info[i].source].info.kind = kProvider; + } + } + } } else { + // can propagate condition + if (inode.source->num_outputs() == 1) { + fwd_chain[nid].info = out_info; + if (out_info.kind == kPending) { + // When there is multiple reference to input + // every path have to be consumed + fwd_chain[out_info.source].fork_count += ref_count[nid] - 1; + } + } + } + } + + auto transform = [&](uint32_t nid, const NodePtr& n, std::vector* ret) { + NodeEntry rvalue = NodeEntry{n, 0, 0}; + { + // Backward chain + const FoldChainEntry& e = bwd_chain[nid]; + if (e.info.kind == kMulConsumer && + bwd_chain[e.info.source].info.kind == kProvider) { + const FoldChainEntry& se = bwd_chain[e.info.source]; + CHECK_EQ(n->num_outputs(), 1); + NodeEntry scale = ExpandBiasToMatchAxis( + se.scale_entry, + shape_vec[idx.entry_id(nid, 0)].ndim(), + shape_vec[idx.entry_id(se.scale_entry)].ndim(), + e.info.axis); + rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc", + {rvalue, scale}); + } else if (e.info.kind == kProvider) { + rvalue = n->inputs[e.fold_input_index]; + } + } + // Note that the value might get transformed twice if it + // folds value from both fwd and backward chain. + { + // forward chain + const FoldChainEntry& e = fwd_chain[nid]; + if (e.info.kind == kMulConsumer && + fwd_chain[e.info.source].info.kind == kProvider) { + const FoldChainEntry& se = fwd_chain[e.info.source]; + CHECK_EQ(n->num_outputs(), 1); + NodeEntry scale = ExpandBiasToMatchAxis( + se.scale_entry, + shape_vec[idx.entry_id(nid, 0)].ndim(), + shape_vec[idx.entry_id(se.scale_entry)].ndim(), + e.info.axis); + rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc", + {rvalue, scale}); + } else if (e.info.kind == kDivConsumer && + fwd_chain[e.info.source].info.kind == kProvider) { + const FoldChainEntry& se = fwd_chain[e.info.source]; + CHECK_EQ(n->num_outputs(), 1); + NodeEntry scale = ExpandBiasToMatchAxis( + se.scale_entry, + shape_vec[idx.entry_id(nid, 0)].ndim(), + shape_vec[idx.entry_id(se.scale_entry)].ndim(), + e.info.axis); + rvalue = MakeNode("broadcast_div", n->attrs.name + "_sc", + {rvalue, scale}); + } else if (e.info.kind == kProvider) { + rvalue = n->inputs[e.fold_input_index]; + } + } + if (rvalue.node == n) { return false; + } else { + *ret = {rvalue}; + return true; } }; return GraphTransform(src, transform); @@ -205,14 +321,24 @@ NNVM_REGISTER_PASS(FoldScaleAxis) .set_body(FoldScaleAxis); // property registration. -FoldScaleKind ReluScaleAxisBackward( +bool ReluScaleAxisBackward( const NodeAttrs& attrs, - int axis, const std::vector& in_shape, const std::vector& out_shape, - std::vector >* in_axis) { - in_axis->emplace_back(0, axis); - return kPassTroughFirst; + const FoldChainInfo& out_info, + std::vector* in_axis) { + (*in_axis)[0] = out_info; + return false; +} + +bool ReluScaleAxisForward( + const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector* in_info, + FoldChainInfo* out_info) { + *out_info = (*in_info)[0]; + return false; } NNVM_REGISTER_OP(relu) @@ -221,21 +347,102 @@ NNVM_REGISTER_OP(relu) NNVM_REGISTER_OP(leaky_relu) .set_attr("FScaleAxisBackward", ReluScaleAxisBackward); -FoldScaleKind BroadcastAddSubScaleAxisBackward( +NNVM_REGISTER_OP(relu) +.set_attr("FScaleAxisForward", ReluScaleAxisForward); + +NNVM_REGISTER_OP(leaky_relu) +.set_attr("FScaleAxisForward", ReluScaleAxisForward); + +// property registration. +bool Pool2DBackward( const NodeAttrs& attrs, - int axis, const std::vector& in_shape, const std::vector& out_shape, - std::vector >* in_axis) { + const FoldChainInfo& out_info, + std::vector* in_axis) { + using top::Pool2DParam; + const Pool2DParam& param = nnvm::get(attrs.parsed); + if (out_info.axis == 1 && param.layout == top::kNCHW) { + (*in_axis)[0] = out_info; + } + return false; +} + +bool Pool2DForward( + const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector* in_info, + FoldChainInfo* out_info) { + using top::Pool2DParam; + const Pool2DParam& param = nnvm::get(attrs.parsed); + if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) { + *out_info = (*in_info)[0]; + } + return false; +} + +NNVM_REGISTER_OP(max_pool2d) +.set_attr("FScaleAxisBackward", Pool2DBackward); + +NNVM_REGISTER_OP(avg_pool2d) +.set_attr("FScaleAxisBackward", Pool2DBackward); + +NNVM_REGISTER_OP(max_pool2d) +.set_attr("FScaleAxisForward", Pool2DForward); + +NNVM_REGISTER_OP(avg_pool2d) +.set_attr("FScaleAxisForward", Pool2DForward); + + + +bool BroadcastAddSubScaleAxisBackward( + const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + const FoldChainInfo& out_info, + std::vector* in_axis) { + if (out_info.kind != kPending) return false; for (int i = 0; i < 2; ++i) { - std::pair m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]); - if (m.second != -1 && in_shape[1 - i] == out_shape[0]) { - in_axis->emplace_back(i, axis); - in_axis->emplace_back(1 - i, m.second); - return kPassTroughFirst; + std::pair m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]); + if (m.second != -1 && + in_shape[i] == out_shape[0] && + m.first == out_info.axis) { + (*in_axis)[i].kind = kPending; + (*in_axis)[i].axis = out_info.axis; + (*in_axis)[i].source = out_info.source; + (*in_axis)[1 - i].kind = kMulConsumer; + (*in_axis)[1 - i].axis = m.second; + (*in_axis)[1 - i].source = out_info.source; + return false; } } - return kNone; + return false; +} + +bool BroadcastAddSubScaleAxisForward( + const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector* in_info, + FoldChainInfo* out_info) { + for (int i = 0; i < 2; ++i) { + if ((*in_info)[i].kind == kPending) { + std::pair m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]); + if (m.second != -1 && + in_shape[i] == out_shape[0] && + m.first == (*in_info)[i].axis) { + out_info->kind = kPending; + out_info->axis = m.first; + out_info->source = (*in_info)[i].source; + (*in_info)[1 - i].kind = kDivConsumer; + (*in_info)[1 - i].axis = m.second; + (*in_info)[1 - i].source = (*in_info)[i].source; + return false; + } + } + } + return false; } NNVM_REGISTER_OP(broadcast_add) @@ -244,28 +451,62 @@ NNVM_REGISTER_OP(broadcast_add) NNVM_REGISTER_OP(broadcast_sub) .set_attr("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward); -FoldScaleKind Conv2DScaleAxisBackward( +NNVM_REGISTER_OP(broadcast_add) +.set_attr("FScaleAxisForward", BroadcastAddSubScaleAxisForward); + +NNVM_REGISTER_OP(broadcast_sub) +.set_attr("FScaleAxisForward", BroadcastAddSubScaleAxisForward); + +bool Conv2DScaleAxisBackward( const NodeAttrs& attrs, - int axis, const std::vector& in_shape, const std::vector& out_shape, - std::vector >* in_axis) { + const FoldChainInfo& out_info, + std::vector* in_axis) { using top::Conv2DParam; const Conv2DParam& param = nnvm::get(attrs.parsed); + if (out_info.kind != kPending) return false; // only optimize for nchw for now - if (param.layout == top::kNCHW) { - in_axis->emplace_back(1, 0); + if (param.layout == top::kNCHW && out_info.axis == 1) { + (*in_axis)[1].kind = kMulConsumer; + (*in_axis)[1].axis = 0; + (*in_axis)[1].source = out_info.source; if (param.use_bias) { - in_axis->emplace_back(2, 0); + (*in_axis)[2].kind = kMulConsumer; + (*in_axis)[2].axis = 0; + (*in_axis)[2].source = out_info.source; } - return kMulConsumer; + return true; } else { - return kNone; + return false; + } +} + +bool Conv2DScaleAxisForward( + const NodeAttrs& attrs, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector* in_info, + FoldChainInfo* out_info) { + using top::Conv2DParam; + const Conv2DParam& param = nnvm::get(attrs.parsed); + if ((*in_info)[0].kind != kPending) return false; + // only optimize for nchw for now + if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) { + (*in_info)[1].kind = kMulConsumer; + (*in_info)[1].axis = 1; + (*in_info)[1].source = (*in_info)[0].source; + return true; + } else { + return false; } } NNVM_REGISTER_OP(conv2d) .set_attr("FScaleAxisBackward", Conv2DScaleAxisBackward); +NNVM_REGISTER_OP(conv2d) +.set_attr("FScaleAxisForward", Conv2DScaleAxisForward); + } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index f96f061b..51448bcf 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && - (storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) && + ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) && entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]) { diff --git a/nnvm/tests/python/compiler/test_fold_axis.py b/nnvm/tests/python/compiler/test_fold_axis.py index f306b3a2..bbd50193 100644 --- a/nnvm/tests/python/compiler/test_fold_axis.py +++ b/nnvm/tests/python/compiler/test_fold_axis.py @@ -1,22 +1,26 @@ """Unittest cases for fold_axis""" import nnvm +import nnvm.testing.resnet +import numpy as np from nnvm import symbol as sym from nnvm.compiler import graph_util, graph_attr def test_fold_axis_conv(): - def before(x, conv_weight, conv_bias, scale, channels): + def before(x, conv_weight, conv_bias, in_scale, out_scale, channels): + x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2) y = sym.conv2d(x, conv_weight, conv_bias, channels=channels, kernel_size=(3, 3), padding=(1, 1), name="conv") y = sym.relu(y) - y = y * sym.expand_dims(scale, axis=1, num_newaxis=2) + y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2) return y - def expected(x, conv_weight, conv_bias, scale, channels): - conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3) - conv_bias = conv_bias * scale + def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels): + conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3) + conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=2) + conv_bias = conv_bias * out_scale y = sym.conv2d(x, conv_weight, conv_bias, @@ -32,10 +36,11 @@ def test_fold_axis_conv(): x = sym.Variable("x") + 1 weight = sym.Variable("weight") bias = sym.Variable("bias") - scale = sym.Variable("scale") - y1 = before(x, weight, bias, scale, channels) - y2 = expected(x, weight, bias, scale, channels) - ishape = {"x": shape, "scale": (channels,)} + in_scale = sym.Variable("in_scale") + out_scale = sym.Variable("out_scale") + y1 = before(x, weight, bias, in_scale, out_scale, channels) + y2 = expected(x, weight, bias, in_scale, out_scale, channels) + ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)} g1 = nnvm.graph.create(y1) g2 = nnvm.graph.create(y2) graph_attr.set_shape_inputs(g1, ishape) @@ -45,5 +50,61 @@ def test_fold_axis_conv(): check((2, 4, 10, 10), 2) + +def test_fold_fail(): + def before(x, scale, channels): + y = sym.conv2d(x, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1), + name="conv") + y = y * sym.expand_dims(scale, axis=1, num_newaxis=1) + return y + + # Before simplify + def check(shape, channels): + x = sym.Variable("x") + bias = sym.Variable("bias") + scale = sym.Variable("scale") + y1 = before(x, scale, channels) + ishape = {"x": shape, "scale": (channels,), "bias": (channels,)} + g1 = nnvm.graph.create(y1) + graph_attr.set_shape_inputs(g1, ishape) + g2 = g1.apply("InferShape").apply("FoldScaleAxis") + # assert graph equals as expected + graph_util.check_graph_equal(g1, g2) + + check((2, 10, 10, 10), 10) + + +def test_fold_resnet(): + batch_size = 1 + num_classes = 1000 + image_shape = (3, 224, 224) + data_shape = (batch_size,) +image_shape + net, params = nnvm.testing.resnet.get_workload( + batch_size=1, image_shape=image_shape) + ishape = {"data" : data_shape} + graph = nnvm.graph.create(net) + data = np.random.uniform(size=data_shape).astype("float32") + # Initial pass do shape type inference + shape, _ = graph_util.infer_shape(graph, **ishape) + ishape.update(zip(graph.index.input_names, shape)) + + def run_prune(graph, params, opt_level): + # Apply optimization + with nnvm.compiler.build_config(opt_level=0): + graph = nnvm.compiler.optimize(graph, ishape) + graph, params = nnvm.compiler.build_module.precompute_prune(graph, params) + params["data"] = data + return nnvm.compiler.build_module._run_graph(graph, params) + + x = run_prune(graph, params, 0) + y = run_prune(graph, params, 3) + np.testing.assert_allclose(y[0].asnumpy(), x[0].asnumpy()) + + if __name__ == "__main__": + test_fold_resnet() test_fold_axis_conv() + test_fold_fail()