[TOPI] Depth wise convolution backward methods for NHWC (#434)
* rename the nchw and pass the unit test; going to do it for nhwc depthwise * bug with fusion * nchw works fine; nhwc float32 problem remains * still cannot bind them together * fusion works * syntax fix * all bugs fixed; test cases pass * minor fix on nn.h * back wrt input * backward wrt input nhwc; only test case in recipe * test case for depthwise back wrt input * test case for depthwise backward wrt weight * tags * minor fixes * pylint test; add arch=3.7 * modify scheduler * better backward depthwise w.r.t weight scheduler * updated scheduler * test_topi_depthwise_conv2d_back_input.py and test_topi_depthwise_conv2d_back_weight.py success * all test cases wrt input pass * update * new test cases and scheduler * not working 1 and 2 * good wrt weight, bad wrt input * test cases added * remove tf lines * minor fix * compute arch changed * remove compile hook * minor change * pylint * fix the float for python case * fix cases for python3 case * except for memoize * fix most; memoize still wrong * memoize added * unexpected layout cases added for scheduler * error message layout other than NHWC added * improve padding * fix as pr requests * remove dilate in backward wrt weight
This commit is contained in:
Родитель
f2ab736b61
Коммит
ffff1e4932
|
@ -15,6 +15,8 @@ constexpr auto kConv2dNCHW = "conv2d_nchw";
|
|||
constexpr auto kConv2dHWCN = "conv2d_hwcn";
|
||||
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
|
||||
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
|
||||
constexpr auto kDepthwiseConv2dBackInputNHWC = "depthwise_conv2d_back_input_nhwc";
|
||||
constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nhwc";
|
||||
constexpr auto kGroupConv2d = "group_conv2d";
|
||||
|
||||
} // namespace topi
|
||||
|
|
|
@ -5,6 +5,8 @@ from __future__ import absolute_import as _abs
|
|||
from .conv2d_nchw import schedule_conv2d_nchw
|
||||
from .conv2d_hwcn import schedule_conv2d_hwcn
|
||||
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
|
||||
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
|
||||
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
|
||||
from .reduction import schedule_reduce
|
||||
from .broadcast import schedule_broadcast_to
|
||||
from .softmax import schedule_softmax
|
||||
|
|
|
@ -186,3 +186,101 @@ def schedule_depthwise_conv2d_nhwc(outs):
|
|||
|
||||
traverse(outs[0].op)
|
||||
return s
|
||||
|
||||
|
||||
def schedule_depthwise_conv2d_backward_input_nhwc(outs):
|
||||
"""Schedule for depthwise_conv2d nhwc backward wrt input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of depthwise_conv2d
|
||||
backward wrt input in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
The computation schedule for depthwise_conv2d backward
|
||||
wrt input with layout nhwc.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
|
||||
def _schedule(Padded_out_grad, In_grad):
|
||||
s[Padded_out_grad].compute_inline()
|
||||
|
||||
block_x = tvm.thread_axis("blockIdx.x")
|
||||
thread_x = tvm.thread_axis("threadIdx.x")
|
||||
_, h, w, c = In_grad.op.axis
|
||||
|
||||
fused_hwc = s[In_grad].fuse(h, w, c)
|
||||
xoc, xic = s[In_grad].split(fused_hwc, factor=128)
|
||||
|
||||
s[In_grad].bind(xoc, block_x)
|
||||
s[In_grad].bind(xic, thread_x)
|
||||
|
||||
def traverse(OP):
|
||||
# inline all one-to-one-mapping operators except the last stage (output)
|
||||
if OP.tag == 'depthwise_conv2d_backward_input_nhwc':
|
||||
Padded_out_grad = OP.input_tensors[0]
|
||||
Dilated_out_grad = Padded_out_grad.op.input_tensors[0]
|
||||
s[Dilated_out_grad].compute_inline()
|
||||
In_grad = OP.output(0)
|
||||
_schedule(Padded_out_grad, In_grad)
|
||||
else:
|
||||
raise ValueError("Depthwise conv backward wrt input for non-NHWC is not supported.")
|
||||
|
||||
traverse(outs[0].op)
|
||||
return s
|
||||
|
||||
def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
|
||||
"""Schedule for depthwise_conv2d nhwc backward wrt weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of depthwise_conv2d
|
||||
backward wrt weight in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
The computation schedule for depthwise_conv2d backward
|
||||
wrt weight with layout nhwc.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
|
||||
def _schedule(Weight_grad):
|
||||
block_x = tvm.thread_axis("blockIdx.x")
|
||||
thread_y = tvm.thread_axis("threadIdx.y")
|
||||
thread_x = tvm.thread_axis("threadIdx.x")
|
||||
|
||||
db, dh, dw = Weight_grad.op.reduce_axis
|
||||
|
||||
fused_dbdhdw = s[Weight_grad].fuse(db, dh, dw)
|
||||
_, ki = s[Weight_grad].split(fused_dbdhdw, factor=8)
|
||||
BF = s.rfactor(Weight_grad, ki)
|
||||
|
||||
fused_fwcm = s[Weight_grad].fuse(*s[Weight_grad].op.axis)
|
||||
|
||||
xo, xi = s[Weight_grad].split(fused_fwcm, factor=32)
|
||||
|
||||
s[Weight_grad].bind(xi, thread_x)
|
||||
s[Weight_grad].bind(xo, block_x)
|
||||
|
||||
s[Weight_grad].bind(s[Weight_grad].op.reduce_axis[0], thread_y)
|
||||
s[BF].compute_at(s[Weight_grad], s[Weight_grad].op.reduce_axis[0])
|
||||
|
||||
def traverse(OP):
|
||||
# inline all one-to-one-mapping operators except the last stage (output)
|
||||
if OP.tag == 'depthwise_conv2d_backward_weight_nhwc':
|
||||
Padded_in = OP.input_tensors[1]
|
||||
s[Padded_in].compute_inline()
|
||||
Weight_grad = OP.output(0)
|
||||
_schedule(Weight_grad)
|
||||
else:
|
||||
raise ValueError("Depthwise conv backward wrt weight for non-NHWC is not supported.")
|
||||
|
||||
traverse(outs[0].op)
|
||||
return s
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# pylint: disable=invalid-name, unused-variable, too-many-locals
|
||||
"""Depthwise Convolution operators"""
|
||||
"""Depthwise convolution operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
|
||||
from .dilate import dilate
|
||||
from .pad import pad
|
||||
from .util import get_pad_tuple
|
||||
from ..util import simplify
|
||||
|
@ -55,6 +57,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
|
|||
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
|
||||
return Output
|
||||
|
||||
|
||||
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
|
||||
"""Depthwise convolution nhwc forward operator.
|
||||
|
||||
|
@ -66,8 +69,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
|
|||
Filter : tvm.Tensor
|
||||
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
|
||||
|
||||
Stride : tvm.Tensor
|
||||
1-D of size 2
|
||||
stride : tuple of two ints
|
||||
The spatial stride along height and width
|
||||
|
||||
padding : int or str
|
||||
Padding size, or ['VALID', 'SAME']
|
||||
|
@ -102,3 +105,105 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
|
|||
axis=[di, dj]),
|
||||
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
|
||||
return Output
|
||||
|
||||
def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, stride, padding):
|
||||
"""Depthwise convolution nhwc backward wrt input operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Filter : tvm.Tensor
|
||||
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
|
||||
|
||||
Out_grad : tvm.Tensor
|
||||
4-D with shape [batch, out_height, out_width, out_channel]
|
||||
|
||||
stride : tuple of two ints
|
||||
The spatial stride along height and width
|
||||
|
||||
padding : int or str
|
||||
Padding size, or ['VALID', 'SAME']
|
||||
|
||||
Returns
|
||||
-------
|
||||
Output : tvm.Tensor
|
||||
4-D with shape [batch, in_height, in_width, in_channel]
|
||||
"""
|
||||
batch, in_h, in_w, in_c = ishape
|
||||
_, out_h, out_w, out_c = oshape
|
||||
filter_h, filter_w, _, channel_multiplier = Filter.shape
|
||||
stride_h, stride_w = stride
|
||||
|
||||
dilated_out_grad = dilate(Out_grad, [1, stride_h, stride_w, 1], name='dilated_out_grad')
|
||||
|
||||
# padding params in forward propagation
|
||||
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
|
||||
# padding params in backward propagation
|
||||
bpad_top = filter_h - 1 - fpad_top
|
||||
bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
|
||||
bpad_left = filter_w - 1 - fpad_left
|
||||
bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
|
||||
|
||||
padded_out_grad = pad(dilated_out_grad, \
|
||||
[0, bpad_top, bpad_left, 0], \
|
||||
[0, bpad_bottom, bpad_right, 0], \
|
||||
name='padded_out_grad')
|
||||
|
||||
dh = tvm.reduce_axis((0, filter_h), name='dh')
|
||||
dw = tvm.reduce_axis((0, filter_w), name='dw')
|
||||
dc = tvm.reduce_axis((0, channel_multiplier), name='dc')
|
||||
|
||||
In_grad = tvm.compute(
|
||||
(batch, in_h, in_w, in_c),
|
||||
lambda b, h, w, c: tvm.sum(padded_out_grad[b, h+dh, w+dw, c*channel_multiplier + dc] * \
|
||||
Filter[filter_h-1-dh, filter_w-1-dw, c, dc],
|
||||
axis=[dh, dw, dc]), tag='depthwise_conv2d_backward_input_nhwc')
|
||||
|
||||
return In_grad
|
||||
|
||||
|
||||
def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, stride, padding):
|
||||
"""Depthwise convolution nhwc backward wrt weight operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Input : tvm.Tensor
|
||||
4-D with shape [batch, in_height, in_width, in_channel]
|
||||
|
||||
Out_grad : tvm.Tensor
|
||||
4-D with shape [batch, out_height, out_width, out_channel]
|
||||
|
||||
stride : tuple of two ints
|
||||
The spatial stride along height and width
|
||||
|
||||
padding : int or str
|
||||
Padding size, or ['VALID', 'SAME']
|
||||
|
||||
Returns
|
||||
-------
|
||||
Output : tvm.Tensor
|
||||
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
|
||||
"""
|
||||
batch, out_h, out_w, out_c = oshape
|
||||
filter_h, filter_w, _, channel_multiplier = fshape
|
||||
in_c = Input.shape[3].value
|
||||
stride_h, stride_w = stride
|
||||
|
||||
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_h, filter_w))
|
||||
|
||||
padded_in = pad(Input, \
|
||||
[0, pad_top, pad_left, 0], \
|
||||
[0, pad_bottom, pad_right, 0], \
|
||||
name='padded_in')
|
||||
|
||||
dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh')
|
||||
dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw')
|
||||
db = tvm.reduce_axis((0, batch), name='db')
|
||||
|
||||
Weight_grad = tvm.compute(
|
||||
(filter_h, filter_w, in_c, channel_multiplier),
|
||||
lambda fh, fw, c, m: tvm.sum(
|
||||
Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] *
|
||||
padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
|
||||
tag='depthwise_conv2d_backward_weight_nhwc')
|
||||
|
||||
return Weight_grad
|
||||
|
|
|
@ -15,11 +15,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
# placeholder
|
||||
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
|
||||
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
|
||||
Stride = [stride_h, stride_w]
|
||||
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
|
||||
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
|
@ -97,11 +96,10 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
# placeholder
|
||||
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
|
||||
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
|
||||
Stride = [stride_h, stride_w]
|
||||
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
|
||||
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, stride=[stride_h, stride_w], padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
import tvm
|
||||
import topi
|
||||
import numpy as np
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from scipy import signal
|
||||
from topi.util import get_const_tuple
|
||||
from topi.nn.util import get_pad_tuple
|
||||
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
|
||||
|
||||
|
||||
def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
|
||||
in_w = in_h
|
||||
filter_channel = in_channel
|
||||
filter_w = filter_h
|
||||
stride_w = stride_h
|
||||
padding_w = padding_h
|
||||
|
||||
out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
|
||||
out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
|
||||
out_channel = in_channel * channel_multiplier
|
||||
|
||||
ishape = [batch, in_h, in_w, in_channel]
|
||||
oshape = [batch, out_h, out_w, out_channel]
|
||||
|
||||
# placeholder
|
||||
Out_grad = tvm.placeholder(oshape, name='Out_grad')
|
||||
Filter = tvm.placeholder((filter_h, filter_w, filter_channel, channel_multiplier))
|
||||
# declare
|
||||
In_grad = topi.nn.depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape,
|
||||
stride=[stride_h, stride_w], padding=[padding_h, padding_w])
|
||||
# schedule
|
||||
schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernel
|
||||
f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
|
||||
# prepare pod type for test data closure
|
||||
dtype = Out_grad.dtype
|
||||
out_grad_shape = get_const_tuple(Out_grad.shape)
|
||||
filter_shape = get_const_tuple(Filter.shape)
|
||||
|
||||
# use memoize to pickle the test data for next time use
|
||||
@memoize("topi.tests.test_topi_depthwise_conv2d_backward_input.nhwc")
|
||||
def get_ref_data():
|
||||
out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
|
||||
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
|
||||
dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
|
||||
# padding params in forward propagation
|
||||
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
|
||||
# padding params in backward propagation
|
||||
bpad_top = filter_h - 1 - fpad_top
|
||||
bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
|
||||
bpad_left = filter_w - 1 - fpad_left
|
||||
bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
|
||||
|
||||
padded_out_grad = np.zeros((batch, dilated_out_grad_np.shape[1]+bpad_top+bpad_bottom,
|
||||
dilated_out_grad_np.shape[2]+bpad_left+bpad_right, out_channel))
|
||||
padded_out_grad[:, bpad_top:dilated_out_grad_np.shape[1]+bpad_top,
|
||||
bpad_left:dilated_out_grad_np.shape[2]+bpad_left, :] = dilated_out_grad_np
|
||||
|
||||
in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
|
||||
for b in range(batch):
|
||||
for c in range(in_channel):
|
||||
for m in range(channel_multiplier):
|
||||
in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
|
||||
filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
|
||||
return (out_grad_np, filter_np, in_grad_np)
|
||||
|
||||
(out_grad_np, filter_np, in_grad_np) = get_ref_data()
|
||||
|
||||
out_grad_tvm = tvm.nd.array(out_grad_np, ctx)
|
||||
filter_tvm = tvm.nd.array(filter_np, ctx)
|
||||
in_grad_tvm = tvm.nd.array(np.zeros(shape=ishape, dtype=dtype), ctx)
|
||||
# launch the kernel
|
||||
timer = f.time_evaluator(f.entry_name, ctx, number=1)
|
||||
tcost = timer(filter_tvm, out_grad_tvm, in_grad_tvm).mean
|
||||
np.testing.assert_allclose(in_grad_np, in_grad_tvm.asnumpy(), rtol=1e-5)
|
||||
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
|
||||
|
||||
def test_topi_depthwise_conv2d_backward_input_nhwc():
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 1)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 2)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 2)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 1)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 1)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 2)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 2)
|
||||
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 0)
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_topi_depthwise_conv2d_backward_input_nhwc()
|
|
@ -0,0 +1,101 @@
|
|||
import tvm
|
||||
import topi
|
||||
import numpy as np
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from scipy import signal
|
||||
from topi.util import get_const_tuple
|
||||
from topi.nn.util import get_pad_tuple
|
||||
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
|
||||
|
||||
|
||||
def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
|
||||
in_w = in_h
|
||||
filter_channel = in_channel
|
||||
filter_w = filter_h
|
||||
stride_w = stride_h
|
||||
padding_w = padding_h
|
||||
|
||||
out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
|
||||
out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
|
||||
out_channel = in_channel * channel_multiplier
|
||||
|
||||
oshape = [batch, out_h, out_w, out_channel]
|
||||
fshape = [filter_h, filter_w, in_channel, channel_multiplier]
|
||||
|
||||
# placeholder
|
||||
Out_grad = tvm.placeholder(oshape, name='Out_grad')
|
||||
Input = tvm.placeholder((batch, in_h, in_w, in_channel), name='In_grad')
|
||||
# declare
|
||||
Weight_grad = topi.nn.depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape,
|
||||
stride=[stride_h, stride_w], padding=[padding_h, padding_w])
|
||||
# schedule
|
||||
schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernel
|
||||
f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device)
|
||||
# prepare pod type for test data closure
|
||||
dtype = Out_grad.dtype
|
||||
out_grad_shape = get_const_tuple(Out_grad.shape)
|
||||
in_shape = get_const_tuple(Input.shape)
|
||||
|
||||
# use memoize to pickle the test data for next time use
|
||||
@memoize("topi.tests.test_topi_depthwise_conv2d_backward_weight.nhwc")
|
||||
def get_ref_data():
|
||||
out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
|
||||
input_np = np.random.uniform(size=in_shape).astype(dtype)
|
||||
dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
|
||||
|
||||
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
|
||||
padded_input_np = np.zeros((batch, in_h+pad_top+pad_bottom, in_w+pad_left+pad_right, in_channel))
|
||||
padded_input_np[:, pad_top:in_h+pad_top, pad_left:in_w+pad_left, :] = input_np
|
||||
|
||||
weight_grad_np = np.zeros((filter_h, filter_w, in_channel, channel_multiplier))
|
||||
for c in range(in_channel):
|
||||
for m in range(channel_multiplier):
|
||||
for b in range(batch):
|
||||
weight_grad_np[:, :, c, m] += signal.convolve2d(padded_input_np[b, :, :, c], \
|
||||
np.rot90(dilated_out_grad_np[b, :, :, c*channel_multiplier+m%channel_multiplier], 2), \
|
||||
mode='valid')[0:filter_h, 0:filter_w]
|
||||
return (out_grad_np, input_np, weight_grad_np)
|
||||
|
||||
(out_grad_np, input_np, weight_grad_np) = get_ref_data()
|
||||
|
||||
out_grad_tvm = tvm.nd.array(out_grad_np, ctx)
|
||||
input_tvm = tvm.nd.array(input_np, ctx)
|
||||
weight_grad_tvm = tvm.nd.array(np.zeros(shape=fshape, dtype=dtype), ctx)
|
||||
# launch the kernel
|
||||
timer = f.time_evaluator(f.entry_name, ctx, number=1)
|
||||
tcost = timer(input_tvm, out_grad_tvm, weight_grad_tvm).mean
|
||||
np.testing.assert_allclose(weight_grad_np, weight_grad_tvm.asnumpy(), rtol=1e-4)
|
||||
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
|
||||
|
||||
def test_topi_depthwise_conv2d_backward_weight_nhwc():
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 1, 1)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 1, 2)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 1, 2)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 2, 1)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 2, 1)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 2, 2)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 2, 2)
|
||||
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 1, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 1, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 1, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 2, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 2, 0)
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 2, 0)
|
||||
verify_depthwise_conv2d_back_weight(15, 256, 56, 2, 5, 2, 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_topi_depthwise_conv2d_backward_weight_nhwc()
|
Загрузка…
Ссылка в новой задаче