From 68c039442e40ce060bd6f391788876464f1f5714 Mon Sep 17 00:00:00 2001 From: Siva Date: Sun, 18 Feb 2018 23:45:25 +0530 Subject: [PATCH] [NHWC] InferShape Layout conversion fix. (#372) --- nnvm/src/top/nn/convolution.cc | 4 ++-- nnvm/src/top/nn/nn_common.h | 28 +++++++++++++++++++++------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index 8664fa0a..d517e7e5 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, param.kernel_size[0], param.kernel_size[1]}); - wshape = ConvertLayout(wshape, kNCHW, param.layout); + wshape = ConvertLayout(wshape, kNCHW, param.layout, true); wshape[0] *= param.groups; NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); @@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, param.channels / param.groups, param.kernel_size[0], param.kernel_size[1]}); - wshape = ConvertLayout(wshape, kNCHW, param.layout); + wshape = ConvertLayout(wshape, kNCHW, param.layout, true); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape); if (param.use_bias) { diff --git a/nnvm/src/top/nn/nn_common.h b/nnvm/src/top/nn/nn_common.h index a75077b8..e9176d17 100644 --- a/nnvm/src/top/nn/nn_common.h +++ b/nnvm/src/top/nn/nn_common.h @@ -40,7 +40,7 @@ inline std::vector UseBiasListInputNames(const NodeAttrs& attrs) { * \param dst_layout target layout * \return shape in target layout */ -inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { +inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) { if (src_layout == dst_layout) return src; TShape dst = src; if (src.ndim() == 3) { @@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { switch (src_layout) { case kNCHW: break; case kNHWC: { - dst[2] = src[1]; - dst[3] = src[2]; - dst[1] = src[3]; + if (is_weight) { + dst[2] = src[0]; + dst[3] = src[1]; + dst[1] = src[2]; + dst[0] = src[3]; + } else { + dst[2] = src[1]; + dst[3] = src[2]; + dst[1] = src[3]; + } break; } default: { @@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { switch (dst_layout) { case kNCHW: break; case kNHWC: { - dst[1] = src[2]; - dst[2] = src[3]; - dst[3] = src[1]; + if (is_weight) { + dst[0] = src[2]; + dst[1] = src[3]; + dst[2] = src[1]; + dst[3] = src[0]; + } else { + dst[1] = src[2]; + dst[2] = src[3]; + dst[3] = src[1]; + } break; } default: {