[NHWC] InferShape Layout conversion fix. (#372)

This commit is contained in:
Siva 2018-02-18 23:45:25 +05:30 коммит произвёл Tianqi Chen
Родитель 50c20b76d3
Коммит 68c039442e
2 изменённых файлов: 23 добавлений и 9 удалений

Просмотреть файл

@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
wshape[0] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
param.channels / param.groups,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
if (param.use_bias) {

Просмотреть файл

@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout
* \return shape in target layout
*/
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
if (src_layout == dst_layout) return src;
TShape dst = src;
if (src.ndim() == 3) {
@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (src_layout) {
case kNCHW: break;
case kNHWC: {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
if (is_weight) {
dst[2] = src[0];
dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break;
}
default: {
@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (dst_layout) {
case kNCHW: break;
case kNHWC: {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
if (is_weight) {
dst[0] = src[2];
dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break;
}
default: {