fix restore layout in AlterOpLayout (#460)
* fix restore layout in AlterOpLayout * lint test case
This commit is contained in:
Родитель
1e4bb2f887
Коммит
343eb82ca2
|
@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) {
|
|||
if (new_nodes.count(inode.source)) {
|
||||
const std::vector<Layout>& in_layouts =
|
||||
in_layouts_of_node[new_nodes[inode.source]];
|
||||
for (const auto& e : inode.inputs) {
|
||||
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
|
||||
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
|
||||
const auto& e = inode.inputs[i];
|
||||
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
|
||||
}
|
||||
const std::vector<Layout>& out_layouts =
|
||||
out_layouts_of_node[new_nodes[inode.source]];
|
||||
|
|
|
@ -19,10 +19,14 @@ def test_alter_conv2d_layout():
|
|||
conv = sym.conv2d(data, name="conv", channels=16,
|
||||
kernel_size=(3,3), padding=(1,1),
|
||||
use_bias=False, layout="NCHW")
|
||||
relu = sym.relu(conv, name="relu")
|
||||
# split here
|
||||
convs = sym.split(conv, indices_or_sections=2)
|
||||
relus = [sym.relu(x, name="relu") for x in convs]
|
||||
relu = sym.concatenate(*relus)
|
||||
flatten = sym.flatten(relu, name="flatten")
|
||||
softmax = sym.softmax(flatten, name="softmax")
|
||||
g = graph.create(softmax)
|
||||
|
||||
g = g.apply("CorrectLayout")
|
||||
g = graph_attr.set_dtype_inputs(g, "float32")
|
||||
g = g.apply(["InferShape", "InferType"])
|
||||
|
|
Загрузка…
Ссылка в новой задаче