fix restore layout in AlterOpLayout (#460)

* fix restore layout in AlterOpLayout

* lint test case
This commit is contained in:
Yizhi Liu 2018-04-28 20:14:55 -07:00 коммит произвёл Tianqi Chen
Родитель 1e4bb2f887
Коммит 343eb82ca2
2 изменённых файлов: 8 добавлений и 3 удалений

Просмотреть файл

@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) {
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];

Просмотреть файл

@ -19,10 +19,14 @@ def test_alter_conv2d_layout():
conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu")
# split here
convs = sym.split(conv, indices_or_sections=2)
relus = [sym.relu(x, name="relu") for x in convs]
relu = sym.concatenate(*relus)
flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])