[TOPI] Automated schedule in conv2d TOPI lib, moving to GEMM intrinsic (#35)
* removing programming out of end to end example for now * updating TOPI library to use gemm tensor intrinsic * bug fix, autoschedule in TOPI conv lib * removing the deprecated GEVM intrinsic * refactoring, fixed lint test * fix for integer division bug * python3 bug fix for non matching types due to float division * comment
This commit is contained in:
Родитель
dae77cdb4c
Коммит
a96a4a9bcc
|
@ -31,11 +31,6 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST
|
|||
print ("Downloading {}".format(file))
|
||||
wget.download(url+file)
|
||||
|
||||
# Program the FPGA remotely
|
||||
assert tvm.module.enabled("rpc")
|
||||
remote = rpc.connect(host, port)
|
||||
vta.program_fpga(remote, BITSTREAM_FILE)
|
||||
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
@ -129,8 +124,10 @@ with nnvm.compiler.build_config(opt_level=3):
|
|||
params=params, target_host=target_host)
|
||||
|
||||
|
||||
assert tvm.module.enabled("rpc")
|
||||
temp = util.tempdir()
|
||||
lib.save(temp.relpath("graphlib.o"))
|
||||
remote = rpc.connect(host, port)
|
||||
remote.upload(temp.relpath("graphlib.o"))
|
||||
lib = remote.load_module("graphlib.o")
|
||||
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
|
||||
|
|
|
@ -55,7 +55,6 @@ class DevContext(object):
|
|||
self.DEBUG_NO_SYNC = False
|
||||
env._dev_ctx = self
|
||||
self.gemm = intrin.gemm(env, env.mock_mode)
|
||||
self.gevm = intrin.gevm(env, env.mock_mode)
|
||||
|
||||
def get_task_qid(self, qid):
|
||||
"""Get transformed queue index."""
|
||||
|
@ -204,11 +203,6 @@ class Environment(object):
|
|||
"""GEMM intrinsic"""
|
||||
return self.dev.gemm
|
||||
|
||||
@property
|
||||
def gevm(self):
|
||||
"""GEVM intrinsic"""
|
||||
return self.dev.gevm
|
||||
|
||||
@property
|
||||
def target_host(self):
|
||||
"""The target host"""
|
||||
|
|
|
@ -3,88 +3,6 @@ from __future__ import absolute_import as _abs
|
|||
|
||||
import tvm
|
||||
|
||||
def gevm(env, mock=False):
|
||||
"""Vector-matrix multiply intrinsic
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env : Environment
|
||||
The Environment
|
||||
|
||||
mock : bool
|
||||
Whether create a mock version.
|
||||
"""
|
||||
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
|
||||
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
|
||||
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
|
||||
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
|
||||
inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
|
||||
out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
|
||||
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
|
||||
dtype="int%d" % env.WGT_WIDTH,
|
||||
name=env.wgt_scope)
|
||||
inp = tvm.placeholder((wgt_shape[1], ),
|
||||
dtype="int%d" % env.INP_WIDTH,
|
||||
name=env.inp_scope)
|
||||
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
|
||||
out_dtype = "int%d" % env.ACC_WIDTH
|
||||
out = tvm.compute((wgt_shape[0],),
|
||||
lambda i: tvm.sum(inp[k].astype(out_dtype) *
|
||||
wgt[i, k].astype(out_dtype),
|
||||
axis=[k]),
|
||||
name="out")
|
||||
wgt_layout = tvm.decl_buffer(
|
||||
wgt.shape, wgt.dtype, env.wgt_scope,
|
||||
scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
|
||||
inp_layout = tvm.decl_buffer(
|
||||
inp.shape, inp.dtype, env.inp_scope,
|
||||
scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
|
||||
out_layout = tvm.decl_buffer(
|
||||
out.shape, out.dtype, env.acc_scope,
|
||||
scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
|
||||
|
||||
def intrin_func(ins, outs):
|
||||
"""Vector-matrix multiply intrinsic function"""
|
||||
dinp, dwgt = ins
|
||||
dout = outs[0]
|
||||
def instr(index):
|
||||
"""Generate vector-matrix multiply VTA instruction"""
|
||||
irb = tvm.ir_builder.create()
|
||||
dev = env.dev
|
||||
irb.scope_attr(dev.vta_axis, "coproc_scope",
|
||||
dev.get_task_qid(dev.QID_COMPUTE))
|
||||
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
|
||||
dev.vta_push_uop)
|
||||
if index == 0 or index == 2:
|
||||
irb.emit(tvm.call_extern(
|
||||
"int32", "VTAUopPush",
|
||||
0, 0,
|
||||
dout.access_ptr("rw", "int32"),
|
||||
dinp.access_ptr("r", "int32"),
|
||||
dwgt.access_ptr("r", "int32"),
|
||||
0, 0, 0))
|
||||
else:
|
||||
irb.emit(tvm.call_extern(
|
||||
"int32", "VTAUopPush",
|
||||
0, 1,
|
||||
dout.access_ptr("rw", "int32"),
|
||||
0,
|
||||
0,
|
||||
0, 0, 0))
|
||||
return irb.get()
|
||||
# return a triple of normal-set, reset, update
|
||||
nop = tvm.make.Evaluate(0)
|
||||
if mock:
|
||||
return (nop, nop, nop)
|
||||
return (instr(0), instr(1), instr(2))
|
||||
|
||||
return tvm.decl_tensor_intrin(out.op, intrin_func,
|
||||
name="GEVM",
|
||||
binds={inp: inp_layout,
|
||||
wgt: wgt_layout,
|
||||
out: out_layout})
|
||||
|
||||
|
||||
def gemm(env, mock=False):
|
||||
"""Matrix-matrix multiply intrinsic
|
||||
|
||||
|
|
|
@ -12,9 +12,146 @@ from ..environment import get_env
|
|||
|
||||
|
||||
Workload = namedtuple("Conv2DWorkload",
|
||||
['height', 'width', 'in_filter', 'out_filter',
|
||||
['batch', 'height', 'width', 'in_filter', 'out_filter',
|
||||
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
|
||||
|
||||
def find_schedules(layer, vt_only=False, best_only=False):
|
||||
""" Returns a schedule for a given a layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layer : Workload
|
||||
Convolutional layer description.
|
||||
vt_only : Boolean
|
||||
Produce a schedule plan with virtual threading.
|
||||
best_only : Boolean
|
||||
Return the "best" schedule plan.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fil_sched : list
|
||||
List of valid schedules.
|
||||
|
||||
"""
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
env = get_env()
|
||||
|
||||
# Helper function to get factors
|
||||
def _find_factors(n):
|
||||
factors = []
|
||||
for f in range(1, n + 1):
|
||||
if n % f == 0:
|
||||
factors.append(f)
|
||||
return factors
|
||||
|
||||
def _get_data_movement_byte(schedule, layer):
|
||||
""" Estimate data movement in bytes for the schedule plan
|
||||
"""
|
||||
env = get_env()
|
||||
b_f = schedule.b_factor
|
||||
h_f = schedule.h_factor
|
||||
w_f = schedule.w_factor
|
||||
ci_f = schedule.ic_factor
|
||||
co_f = schedule.oc_factor
|
||||
# Derive data movement
|
||||
inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
|
||||
wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
|
||||
out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
|
||||
input_tile_elems = b_f * \
|
||||
((h_f - 1) * layer.hstride + layer.hkernel) * \
|
||||
((w_f - 1) * layer.wstride + layer.wkernel) * ci_f
|
||||
weight_tile_elems = layer.hkernel * layer.wkernel * ci_f
|
||||
output_tile_elems = b_f * h_f * w_f * co_f
|
||||
# Derive tiling factors
|
||||
b_factor = layer.batch // (b_f * env.BATCH)
|
||||
h_factor = (layer.height // layer.hstride) // h_f
|
||||
w_factor = (layer.width // layer.wstride) // w_f
|
||||
ci_factor = layer.in_filter // (ci_f * env.BLOCK_IN)
|
||||
co_factor = layer.out_filter // (co_f * env.BLOCK_OUT)
|
||||
# Compute input transaction count
|
||||
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
|
||||
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
|
||||
output_xfers = b_factor * h_factor * w_factor * co_factor
|
||||
# Compute total transfer sizes
|
||||
input_xfer_byte = input_tile_elems * input_xfers * inp_elem_sizeb // 8
|
||||
weight_xfer_byte = weight_tile_elems * weight_xfers * wgt_elem_sizeb // 8
|
||||
output_xfer_byte = output_tile_elems * output_xfers * out_elem_sizeb // 8
|
||||
total_xfer_byte = input_xfer_byte + weight_xfer_byte + output_xfer_byte
|
||||
return total_xfer_byte
|
||||
|
||||
# Scheduling exploration
|
||||
batch_factors = _find_factors(layer.batch // env.BATCH)
|
||||
height_factors = _find_factors(layer.height // layer.hstride)
|
||||
width_factors = _find_factors(layer.width // layer.wstride)
|
||||
cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN)
|
||||
cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT)
|
||||
ht_factors = [1, 2]
|
||||
cot_factors = [1, 2]
|
||||
|
||||
# Explore schedules
|
||||
schedules = []
|
||||
for b_f in batch_factors:
|
||||
for h_f in height_factors:
|
||||
for w_f in width_factors:
|
||||
for ci_f in cin_factors:
|
||||
for co_f in cout_factors:
|
||||
# FIXME: 2D load pattern matching imposes restrictions on schedule
|
||||
valid = (w_f == layer.width // layer.wstride) or \
|
||||
(w_f != layer.width // layer.wstride and co_f == 1) and \
|
||||
ci_f == 1
|
||||
if valid:
|
||||
schedules.append([b_f, h_f, w_f, ci_f, co_f])
|
||||
|
||||
# Filter the schedules that wouldn't work in the available BRAM sizes
|
||||
inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
|
||||
wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
|
||||
out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
|
||||
inp_brams_sizeb = env.INP_BUFF_SIZE * 8
|
||||
wgt_brams_sizeb = env.WGT_BUFF_SIZE * 8
|
||||
out_brams_sizeb = env.OUT_BUFF_SIZE * 8
|
||||
fil_sched = []
|
||||
xfer_size = []
|
||||
for sched in schedules:
|
||||
b_f, h_f, w_f, ci_f, co_f = sched
|
||||
for h_t in ht_factors:
|
||||
for co_t in cot_factors:
|
||||
# Make sure to filter cases where we apply threading on two axes
|
||||
# or cases where the threading factors for h and co are not
|
||||
# factors of h and co
|
||||
if (h_t == 2 and co_t == 2) or (h_f % h_t != 0) or (co_f % co_t != 0):
|
||||
continue
|
||||
# Adjust tile sizes if threading is applied
|
||||
h_f //= h_t
|
||||
co_f //= co_t
|
||||
# Derive tile sizes
|
||||
input_tile_elems = b_f * \
|
||||
((h_f - 1) * layer.hstride + layer.hkernel) * \
|
||||
((w_f - 1) * layer.wstride + layer.wkernel) * ci_f
|
||||
weight_tile_elems = layer.hkernel * layer.wkernel * ci_f * co_f
|
||||
output_tile_elems = b_f * h_f * w_f * co_f
|
||||
|
||||
# Derive valid schedule filter
|
||||
valid = True
|
||||
# If in vitrual-threaded mode, only allow for threaded plans
|
||||
valid &= (vt_only and (h_t == 2 or co_t == 2)) or not vt_only
|
||||
# Check that we don't exceed input/weight/output capacity
|
||||
valid &= input_tile_elems * inp_elem_sizeb <= inp_brams_sizeb // (co_t * h_t)
|
||||
valid &= weight_tile_elems * wgt_elem_sizeb <= wgt_brams_sizeb
|
||||
valid &= output_tile_elems * out_elem_sizeb <= out_brams_sizeb // (co_t * h_t)
|
||||
# Make sure that we don't write to the same acc location within 2 consecutive cycles
|
||||
valid &= h_f > 2 and w_f > 2
|
||||
# TODO: check that we don't exceed instruction or micro-op count
|
||||
|
||||
if valid:
|
||||
schedule = Schedule(b_factor=b_f, oc_factor=co_f, ic_factor=ci_f, h_factor=h_f,
|
||||
w_factor=w_f, oc_nthread=co_t, h_nthread=h_t)
|
||||
fil_sched.append(schedule)
|
||||
xfer_size.append(_get_data_movement_byte(schedule, layer))
|
||||
|
||||
if best_only:
|
||||
return [fil_sched[xfer_size.index(min(xfer_size))]]
|
||||
return fil_sched
|
||||
|
||||
def packed_conv2d(data,
|
||||
kernel,
|
||||
padding,
|
||||
|
@ -23,14 +160,14 @@ def packed_conv2d(data,
|
|||
""" Packed conv2d function.
|
||||
"""
|
||||
if padding[0]:
|
||||
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0], name="pad_data")
|
||||
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data")
|
||||
else:
|
||||
pad_data = data
|
||||
assert len(data.shape) == 5
|
||||
assert len(data.shape) == 6
|
||||
assert len(kernel.shape) == 6
|
||||
oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
|
||||
owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
|
||||
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, kernel.shape[4])
|
||||
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4])
|
||||
|
||||
ishape = topi.util.get_const_tuple(data.shape)
|
||||
kshape = topi.util.get_const_tuple(kernel.shape)
|
||||
|
@ -43,14 +180,13 @@ def packed_conv2d(data,
|
|||
hstride, wstride = strides
|
||||
res = tvm.compute(
|
||||
oshape,
|
||||
lambda b, co, i, j, ci: tvm.sum(
|
||||
pad_data[b, k_o, i*hstride+d_i, j*wstride+d_j, k_i].astype(out_dtype) *
|
||||
kernel[co, k_o, d_i, d_j, ci, k_i].astype(out_dtype),
|
||||
lambda b_o, c_o, i, j, b_i, c_i: tvm.sum(
|
||||
pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) *
|
||||
kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
|
||||
axis=[k_o, d_i, d_j, k_i]),
|
||||
name="res", tag="packed_conv2d")
|
||||
return res
|
||||
|
||||
|
||||
@tvm.register_func("nnvm.compiler.build_target", override=True)
|
||||
def _build(funcs, target, target_host):
|
||||
tvm_t = tvm.target.create(target)
|
||||
|
@ -155,12 +291,13 @@ def _get_workload(data, pad_data, kernel, output):
|
|||
o_shape = topi.util.get_const_tuple(output.shape)
|
||||
d_shape = topi.util.get_const_tuple(data.shape)
|
||||
k_shape = topi.util.get_const_tuple(kernel.shape)
|
||||
o_b, o_c, o_h, o_w, o_blk = o_shape
|
||||
i_b, i_c, i_h, i_w, i_blk = d_shape
|
||||
o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape
|
||||
i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape
|
||||
k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape
|
||||
# For now we need to assume that input channel blocking is the same
|
||||
# as the output channel blocking
|
||||
assert o_blk == i_blk
|
||||
assert ob_blk == ib_blk
|
||||
# Make sure that dimensions match
|
||||
assert o_b == i_b
|
||||
assert o_blk == ko_blk
|
||||
|
@ -178,7 +315,7 @@ def _get_workload(data, pad_data, kernel, output):
|
|||
h_pad, w_pad = 0, 0
|
||||
h_str = (i_h + h_pad*2 - k_h) // (o_h - 1)
|
||||
w_str = (i_w + w_pad*2 - k_w) // (o_w - 1)
|
||||
return Workload(i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)
|
||||
return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)
|
||||
|
||||
_WL2PLAN = {}
|
||||
|
||||
|
@ -224,7 +361,7 @@ def schedule_packed_conv2d(outs):
|
|||
|
||||
load_inp = load_wgt = load_out = store_out = env.dma_copy
|
||||
alu = env.alu
|
||||
gevm = env.gevm
|
||||
gemm = env.gemm
|
||||
|
||||
# schedule1
|
||||
oshape = topi.util.get_const_tuple(output.shape)
|
||||
|
@ -256,11 +393,12 @@ def schedule_packed_conv2d(outs):
|
|||
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
|
||||
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
|
||||
|
||||
x_b, x_oc, x_i, x_j, x_ic = s[output].op.axis
|
||||
x_oc0, x_oc1 = s[output].split(x_oc, factor=oc_factor)
|
||||
|
||||
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
|
||||
x_co0, x_co1 = s[output].split(x_co, factor=oc_factor)
|
||||
x_i0, x_i1 = s[output].split(x_i, factor=h_factor)
|
||||
x_j0, x_j1 = s[output].split(x_j, factor=w_factor)
|
||||
s[output].reorder(x_b, x_oc0, x_i0, x_j0, x_oc1, x_i1, x_j1, x_ic)
|
||||
s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
|
||||
store_pt = x_j0
|
||||
|
||||
# set all compute scopes
|
||||
|
@ -273,100 +411,78 @@ def schedule_packed_conv2d(outs):
|
|||
s[tensor].pragma(s[tensor].op.axis[0], load_out)
|
||||
|
||||
# virtual threading along output channel axes
|
||||
if plan.oc_nthread:
|
||||
_, v_t = s[output].split(x_oc0, factor=plan.oc_nthread)
|
||||
s[output].reorder(v_t, x_b)
|
||||
if plan.oc_nthread > 1:
|
||||
_, v_t = s[output].split(x_co0, factor=plan.oc_nthread)
|
||||
s[output].reorder(v_t, x_bo)
|
||||
s[output].bind(v_t, tvm.thread_axis("cthread"))
|
||||
|
||||
# virtual threading along spatial rows
|
||||
if plan.h_nthread:
|
||||
if plan.h_nthread > 1:
|
||||
_, v_t = s[output].split(x_i0, factor=plan.h_nthread)
|
||||
s[output].reorder(v_t, x_b)
|
||||
s[output].reorder(v_t, x_bo)
|
||||
s[output].bind(v_t, tvm.thread_axis("cthread"))
|
||||
|
||||
x_b, x_oc, x_i, x_j, x_ic = s[conv2d_stage].op.axis
|
||||
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis
|
||||
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
|
||||
s[conv2d_stage].reorder(k_o, x_j, d_j, d_i, x_oc, x_i, x_ic, k_i)
|
||||
s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
|
||||
|
||||
if plan.ko_factor:
|
||||
k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ko_factor)
|
||||
if plan.ic_factor:
|
||||
k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor)
|
||||
s[cdata].compute_at(s[conv2d_stage], k_o)
|
||||
s[ckernel].compute_at(s[conv2d_stage], k_o)
|
||||
|
||||
# Use VTA instructions
|
||||
s[cdata].pragma(s[cdata].op.axis[0], load_inp)
|
||||
s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt)
|
||||
s[conv2d_stage].tensorize(x_ic, gevm)
|
||||
s[output].pragma(x_oc1, store_out)
|
||||
s[conv2d_stage].tensorize(x_bi, gemm)
|
||||
s[output].pragma(x_co1, store_out)
|
||||
return s
|
||||
|
||||
|
||||
class Conv2DSchedule(object):
|
||||
""" 2D convolution schedule object.
|
||||
"""
|
||||
def __init__(self,
|
||||
oc_factor,
|
||||
ko_factor=1,
|
||||
b_factor=1,
|
||||
oc_factor=1,
|
||||
ic_factor=1,
|
||||
h_factor=1,
|
||||
w_factor=0,
|
||||
oc_nthread=0,
|
||||
h_nthread=0):
|
||||
h_nthread=0,
|
||||
debug_sync=False):
|
||||
self.b_factor = b_factor
|
||||
self.oc_factor = oc_factor
|
||||
self.ko_factor = ko_factor
|
||||
self.ic_factor = ic_factor
|
||||
self.h_factor = h_factor
|
||||
self.w_factor = w_factor
|
||||
self.oc_nthread = oc_nthread
|
||||
self.h_nthread = h_nthread
|
||||
self.debug_sync = debug_sync
|
||||
def __str__(self):
|
||||
return "{}.{}.{}.{}.{}.{}.{}".format(
|
||||
self.b_factor, self.oc_factor, self.ic_factor,
|
||||
self.h_factor, self.w_factor,
|
||||
self.oc_nthread, self.h_nthread)
|
||||
|
||||
Schedule = Conv2DSchedule
|
||||
|
||||
# ResNet18 workloads
|
||||
# Layer description of the ResNet18
|
||||
RESNET = {
|
||||
# Workloads of resnet18 on imagenet
|
||||
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
|
||||
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
|
||||
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
|
||||
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
|
||||
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
|
||||
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
|
||||
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
|
||||
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
|
||||
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
|
||||
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
|
||||
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
|
||||
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
|
||||
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
|
||||
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
|
||||
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
|
||||
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
|
||||
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
|
||||
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
|
||||
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
|
||||
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
|
||||
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
|
||||
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
|
||||
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
|
||||
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
|
||||
}
|
||||
|
||||
# Serial schedule
|
||||
RESNET_SERIAL = {
|
||||
RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56),
|
||||
RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0),
|
||||
RESNET[2]: Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0),
|
||||
RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0),
|
||||
RESNET[4]: Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0),
|
||||
RESNET[5]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
|
||||
RESNET[6]: Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0),
|
||||
RESNET[7]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0),
|
||||
RESNET[8]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
|
||||
RESNET[9]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
|
||||
RESNET[10]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0),
|
||||
RESNET[11]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
|
||||
}
|
||||
|
||||
# Latency hiding schedule
|
||||
RESNET_OPT = {
|
||||
RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56),
|
||||
RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2),
|
||||
RESNET[2]: Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2),
|
||||
RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2),
|
||||
RESNET[4]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2),
|
||||
RESNET[5]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2),
|
||||
RESNET[6]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
RESNET[7]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
RESNET[8]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
RESNET[9]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
RESNET[10]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
RESNET[11]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
|
||||
}
|
||||
|
||||
_WL2PLAN = RESNET_OPT
|
||||
_WL2PLAN = {}
|
||||
for idx in RESNET:
|
||||
scheds = find_schedules(RESNET[idx], vt_only=True, best_only=True)[0]
|
||||
_WL2PLAN[RESNET[idx]] = scheds
|
||||
|
|
|
@ -315,7 +315,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat,
|
|||
int push_next_dep) {
|
||||
// Converter
|
||||
union VTAInsn converter;
|
||||
// GEVM instruction initialization
|
||||
// GEMM instruction initialization
|
||||
VTAGemInsn insn;
|
||||
insn.opcode = VTA_OPCODE_GEMM;
|
||||
insn.pop_prev_dep = pop_prev_dep;
|
||||
|
@ -394,7 +394,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo
|
|||
VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) {
|
||||
// Converter
|
||||
union VTAInsn converter;
|
||||
// GEVM instruction initialization
|
||||
// GEMM instruction initialization
|
||||
VTAGemInsn insn;
|
||||
insn.opcode = VTA_OPCODE_FINISH;
|
||||
insn.pop_prev_dep = pop_prev;
|
||||
|
@ -649,7 +649,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
|
|||
}
|
||||
} else if (c.mem.opcode == VTA_OPCODE_GEMM) {
|
||||
// Print instruction field information
|
||||
printf("GEVM\n");
|
||||
printf("GEMM\n");
|
||||
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
|
||||
static_cast<int>(c.mem.pop_prev_dep),
|
||||
static_cast<int>(c.mem.pop_next_dep),
|
||||
|
|
|
@ -168,7 +168,7 @@ def test_gemm():
|
|||
with vta.build_config():
|
||||
run_test("NORMAL", print_ir, True)
|
||||
|
||||
def gevm_unittest(print_ir):
|
||||
def gemm_unittest(print_ir):
|
||||
mock = env.mock
|
||||
print("----- GEMM Unit Test-------")
|
||||
def run_test(header, print_ir):
|
||||
|
@ -244,7 +244,7 @@ def test_gemm():
|
|||
|
||||
|
||||
gemm_normal(False)
|
||||
gevm_unittest(False)
|
||||
gemm_unittest(False)
|
||||
alu_unittest(False)
|
||||
|
||||
def _run(env, remote):
|
||||
|
|
|
@ -23,14 +23,11 @@ def my_clip(x, a_min, a_max):
|
|||
|
||||
def test_vta_conv2d():
|
||||
def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
|
||||
data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
|
||||
wl.height, wl.width, env.BLOCK_IN)
|
||||
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
|
||||
wl.in_filter // env.BLOCK_IN,
|
||||
wl.hkernel, wl.wkernel,
|
||||
env.BLOCK_OUT, env.BLOCK_IN)
|
||||
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
|
||||
|
||||
data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
|
||||
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
|
||||
kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
|
||||
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
|
||||
bias_shape = (1, wl.out_filter//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
|
||||
|
||||
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
|
||||
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
|
||||
|
@ -45,12 +42,13 @@ def test_vta_conv2d():
|
|||
res = my_clip(res, 0, 127)
|
||||
res = topi.cast(res, "int8")
|
||||
|
||||
num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
|
||||
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
|
||||
|
||||
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
|
||||
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
|
||||
stride = (wl.hstride, wl.wstride)
|
||||
data_dtype = data.dtype
|
||||
kernel_dtype = kernel.dtype
|
||||
acc_dtype = env.acc_dtype
|
||||
assert wl.hpad == wl.wpad
|
||||
padding = wl.hpad
|
||||
|
@ -58,7 +56,7 @@ def test_vta_conv2d():
|
|||
@memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
|
||||
def get_ref_data():
|
||||
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
|
||||
w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype)
|
||||
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
|
||||
a_np = np.abs(a_np)
|
||||
w_np = np.abs(w_np)
|
||||
b_np = topi.testing.conv2d_nchw_python(
|
||||
|
@ -82,14 +80,15 @@ def test_vta_conv2d():
|
|||
bias_orig = np.abs(bias_orig)
|
||||
|
||||
data_packed = data_orig.reshape(
|
||||
batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
|
||||
wl.height, wl.width).transpose((0, 1, 3, 4, 2))
|
||||
batch_size//env.BATCH, env.BATCH,
|
||||
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
|
||||
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
|
||||
kernel_packed = kernel_orig.reshape(
|
||||
wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
|
||||
wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
|
||||
wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
|
||||
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
|
||||
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
|
||||
bias_packed = bias_orig.reshape(
|
||||
wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
|
||||
1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
|
||||
res_shape = topi.util.get_const_tuple(res.shape)
|
||||
|
||||
res_np = np.zeros(res_shape).astype(res.dtype)
|
||||
|
@ -100,7 +99,7 @@ def test_vta_conv2d():
|
|||
time_f = f.time_evaluator("conv2d", ctx, number=5)
|
||||
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
|
||||
res_unpack = res_arr.asnumpy().transpose(
|
||||
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
|
||||
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
|
||||
if check_correctness:
|
||||
assert wl.hpad == wl.wpad
|
||||
stride = (wl.hstride, wl.wstride)
|
||||
|
@ -127,18 +126,18 @@ def test_vta_conv2d():
|
|||
# ResNet18 workloads
|
||||
resnet = {
|
||||
# Workloads of resnet18 on imagenet
|
||||
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
|
||||
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
|
||||
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
|
||||
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
|
||||
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
|
||||
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
|
||||
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
|
||||
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
|
||||
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
|
||||
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
|
||||
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
|
||||
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
|
||||
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
|
||||
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
|
||||
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
|
||||
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
|
||||
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
|
||||
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
|
||||
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
|
||||
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
|
||||
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
|
||||
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
|
||||
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
|
||||
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
|
||||
}
|
||||
|
||||
batch_size = 1
|
||||
|
@ -148,6 +147,7 @@ def test_vta_conv2d():
|
|||
print("key=%s" % key)
|
||||
print(wl)
|
||||
run_vta_conv2d(env, remote, key, batch_size, wl)
|
||||
|
||||
vta.testing.run(_run)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче