[NHWC] InferShape Layout conversion fix. (#372)
This commit is contained in:
Родитель
50c20b76d3
Коммит
68c039442e
|
@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
|
||||||
param.kernel_size[0],
|
param.kernel_size[0],
|
||||||
param.kernel_size[1]});
|
param.kernel_size[1]});
|
||||||
|
|
||||||
wshape = ConvertLayout(wshape, kNCHW, param.layout);
|
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
|
||||||
wshape[0] *= param.groups;
|
wshape[0] *= param.groups;
|
||||||
|
|
||||||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
|
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
|
||||||
|
@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
|
||||||
param.channels / param.groups,
|
param.channels / param.groups,
|
||||||
param.kernel_size[0],
|
param.kernel_size[0],
|
||||||
param.kernel_size[1]});
|
param.kernel_size[1]});
|
||||||
wshape = ConvertLayout(wshape, kNCHW, param.layout);
|
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
|
||||||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
|
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
|
||||||
|
|
||||||
if (param.use_bias) {
|
if (param.use_bias) {
|
||||||
|
|
|
@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
|
||||||
* \param dst_layout target layout
|
* \param dst_layout target layout
|
||||||
* \return shape in target layout
|
* \return shape in target layout
|
||||||
*/
|
*/
|
||||||
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
|
||||||
if (src_layout == dst_layout) return src;
|
if (src_layout == dst_layout) return src;
|
||||||
TShape dst = src;
|
TShape dst = src;
|
||||||
if (src.ndim() == 3) {
|
if (src.ndim() == 3) {
|
||||||
|
@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
||||||
switch (src_layout) {
|
switch (src_layout) {
|
||||||
case kNCHW: break;
|
case kNCHW: break;
|
||||||
case kNHWC: {
|
case kNHWC: {
|
||||||
dst[2] = src[1];
|
if (is_weight) {
|
||||||
dst[3] = src[2];
|
dst[2] = src[0];
|
||||||
dst[1] = src[3];
|
dst[3] = src[1];
|
||||||
|
dst[1] = src[2];
|
||||||
|
dst[0] = src[3];
|
||||||
|
} else {
|
||||||
|
dst[2] = src[1];
|
||||||
|
dst[3] = src[2];
|
||||||
|
dst[1] = src[3];
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
|
@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
|
||||||
switch (dst_layout) {
|
switch (dst_layout) {
|
||||||
case kNCHW: break;
|
case kNCHW: break;
|
||||||
case kNHWC: {
|
case kNHWC: {
|
||||||
dst[1] = src[2];
|
if (is_weight) {
|
||||||
dst[2] = src[3];
|
dst[0] = src[2];
|
||||||
dst[3] = src[1];
|
dst[1] = src[3];
|
||||||
|
dst[2] = src[1];
|
||||||
|
dst[3] = src[0];
|
||||||
|
} else {
|
||||||
|
dst[1] = src[2];
|
||||||
|
dst[2] = src[3];
|
||||||
|
dst[3] = src[1];
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче