[NHWC] InferShape Layout conversion fix. (#372)
This commit is contained in:
Родитель
50c20b76d3
Коммит
68c039442e
|
@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
|
|||
param.kernel_size[0],
|
||||
param.kernel_size[1]});
|
||||
|
||||
wshape = ConvertLayout(wshape, kNCHW, param.layout);
|
||||
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
|
||||
wshape[0] *= param.groups;
|
||||
|
||||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
|
||||
|
@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
|
|||
param.channels / param.groups,
|
||||
param.kernel_size[0],
|
||||
param.kernel_size[1]});
|
||||
wshape = ConvertLayout(wshape, kNCHW, param.layout);
|
||||
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
|
||||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
|
||||
|
||||
if (param.use_bias) {
|
||||
|
|
|
@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
|
|||
* \param dst_layout target layout
|
||||
* \return shape in target layout
|
||||
*/
|
||||
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
||||
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
|
||||
if (src_layout == dst_layout) return src;
|
||||
TShape dst = src;
|
||||
if (src.ndim() == 3) {
|
||||
|
@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
|||
switch (src_layout) {
|
||||
case kNCHW: break;
|
||||
case kNHWC: {
|
||||
if (is_weight) {
|
||||
dst[2] = src[0];
|
||||
dst[3] = src[1];
|
||||
dst[1] = src[2];
|
||||
dst[0] = src[3];
|
||||
} else {
|
||||
dst[2] = src[1];
|
||||
dst[3] = src[2];
|
||||
dst[1] = src[3];
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
|||
switch (dst_layout) {
|
||||
case kNCHW: break;
|
||||
case kNHWC: {
|
||||
if (is_weight) {
|
||||
dst[0] = src[2];
|
||||
dst[1] = src[3];
|
||||
dst[2] = src[1];
|
||||
dst[3] = src[0];
|
||||
} else {
|
||||
dst[1] = src[2];
|
||||
dst[2] = src[3];
|
||||
dst[3] = src[1];
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
|
Загрузка…
Ссылка в новой задаче