diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc index 893a0d29..9fdc71fa 100644 --- a/nnvm/src/compiler/alter_op_layout.cc +++ b/nnvm/src/compiler/alter_op_layout.cc @@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) { if (new_nodes.count(inode.source)) { const std::vector& in_layouts = in_layouts_of_node[new_nodes[inode.source]]; - for (const auto& e : inode.inputs) { - ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index]; + for (uint32_t i = 0; i < inode.inputs.size(); ++i) { + const auto& e = inode.inputs[i]; + ret_layouts[ret_idx.entry_id(e)] = in_layouts[i]; } const std::vector& out_layouts = out_layouts_of_node[new_nodes[inode.source]]; diff --git a/nnvm/tests/python/compiler/test_alter_op_layout.py b/nnvm/tests/python/compiler/test_alter_op_layout.py index d921d4a6..bfda1807 100644 --- a/nnvm/tests/python/compiler/test_alter_op_layout.py +++ b/nnvm/tests/python/compiler/test_alter_op_layout.py @@ -19,10 +19,14 @@ def test_alter_conv2d_layout(): conv = sym.conv2d(data, name="conv", channels=16, kernel_size=(3,3), padding=(1,1), use_bias=False, layout="NCHW") - relu = sym.relu(conv, name="relu") + # split here + convs = sym.split(conv, indices_or_sections=2) + relus = [sym.relu(x, name="relu") for x in convs] + relu = sym.concatenate(*relus) flatten = sym.flatten(relu, name="flatten") softmax = sym.softmax(flatten, name="softmax") g = graph.create(softmax) + g = g.apply("CorrectLayout") g = graph_attr.set_dtype_inputs(g, "float32") g = g.apply(["InferShape", "InferType"])