From a96a4a9bcc771e8252a2dbe9cc01bf0357d09926 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 11 May 2018 17:17:33 -0700 Subject: [PATCH] [TOPI] Automated schedule in conv2d TOPI lib, moving to GEMM intrinsic (#35) * removing programming out of end to end example for now * updating TOPI library to use gemm tensor intrinsic * bug fix, autoschedule in TOPI conv lib * removing the deprecated GEVM intrinsic * refactoring, fixed lint test * fix for integer division bug * python3 bug fix for non matching types due to float division * comment --- .../resnet18/pynq/imagenet_predict.py | 7 +- vta/python/vta/environment.py | 6 - vta/python/vta/intrin.py | 82 ------ vta/python/vta/top/vta_conv2d.py | 272 +++++++++++++----- vta/tests/hardware/common/test_lib.cc | 6 +- .../python/integration/test_benchmark_gemm.py | 4 +- .../integration/test_benchmark_topi_conv2d.py | 56 ++-- 7 files changed, 229 insertions(+), 204 deletions(-) diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py index ae8d4901..e4f82b17 100644 --- a/vta/examples/resnet18/pynq/imagenet_predict.py +++ b/vta/examples/resnet18/pynq/imagenet_predict.py @@ -31,11 +31,6 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST print ("Downloading {}".format(file)) wget.download(url+file) -# Program the FPGA remotely -assert tvm.module.enabled("rpc") -remote = rpc.connect(host, port) -vta.program_fpga(remote, BITSTREAM_FILE) - if verbose: logging.basicConfig(level=logging.DEBUG) @@ -129,8 +124,10 @@ with nnvm.compiler.build_config(opt_level=3): params=params, target_host=target_host) +assert tvm.module.enabled("rpc") temp = util.tempdir() lib.save(temp.relpath("graphlib.o")) +remote = rpc.connect(host, port) remote.upload(temp.relpath("graphlib.o")) lib = remote.load_module("graphlib.o") ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index feaacbe4..41ec38ae 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -55,7 +55,6 @@ class DevContext(object): self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode) - self.gevm = intrin.gevm(env, env.mock_mode) def get_task_qid(self, qid): """Get transformed queue index.""" @@ -204,11 +203,6 @@ class Environment(object): """GEMM intrinsic""" return self.dev.gemm - @property - def gevm(self): - """GEVM intrinsic""" - return self.dev.gevm - @property def target_host(self): """The target host""" diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py index 8373a2c0..b3662875 100644 --- a/vta/python/vta/intrin.py +++ b/vta/python/vta/intrin.py @@ -3,88 +3,6 @@ from __future__ import absolute_import as _abs import tvm -def gevm(env, mock=False): - """Vector-matrix multiply intrinsic - - Parameters - ---------- - env : Environment - The Environment - - mock : bool - Whether create a mock version. - """ - wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH - assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN - wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) - assert wgt_shape[0] * wgt_shape[1] == wgt_lanes - inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH - out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH - wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), - dtype="int%d" % env.WGT_WIDTH, - name=env.wgt_scope) - inp = tvm.placeholder((wgt_shape[1], ), - dtype="int%d" % env.INP_WIDTH, - name=env.inp_scope) - k = tvm.reduce_axis((0, wgt_shape[1]), name="k") - out_dtype = "int%d" % env.ACC_WIDTH - out = tvm.compute((wgt_shape[0],), - lambda i: tvm.sum(inp[k].astype(out_dtype) * - wgt[i, k].astype(out_dtype), - axis=[k]), - name="out") - wgt_layout = tvm.decl_buffer( - wgt.shape, wgt.dtype, env.wgt_scope, - scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) - inp_layout = tvm.decl_buffer( - inp.shape, inp.dtype, env.inp_scope, - scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) - out_layout = tvm.decl_buffer( - out.shape, out.dtype, env.acc_scope, - scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) - - def intrin_func(ins, outs): - """Vector-matrix multiply intrinsic function""" - dinp, dwgt = ins - dout = outs[0] - def instr(index): - """Generate vector-matrix multiply VTA instruction""" - irb = tvm.ir_builder.create() - dev = env.dev - irb.scope_attr(dev.vta_axis, "coproc_scope", - dev.get_task_qid(dev.QID_COMPUTE)) - irb.scope_attr(dev.vta_axis, "coproc_uop_scope", - dev.vta_push_uop) - if index == 0 or index == 2: - irb.emit(tvm.call_extern( - "int32", "VTAUopPush", - 0, 0, - dout.access_ptr("rw", "int32"), - dinp.access_ptr("r", "int32"), - dwgt.access_ptr("r", "int32"), - 0, 0, 0)) - else: - irb.emit(tvm.call_extern( - "int32", "VTAUopPush", - 0, 1, - dout.access_ptr("rw", "int32"), - 0, - 0, - 0, 0, 0)) - return irb.get() - # return a triple of normal-set, reset, update - nop = tvm.make.Evaluate(0) - if mock: - return (nop, nop, nop) - return (instr(0), instr(1), instr(2)) - - return tvm.decl_tensor_intrin(out.op, intrin_func, - name="GEVM", - binds={inp: inp_layout, - wgt: wgt_layout, - out: out_layout}) - - def gemm(env, mock=False): """Matrix-matrix multiply intrinsic diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 577eac8e..489acb1c 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -12,9 +12,146 @@ from ..environment import get_env Workload = namedtuple("Conv2DWorkload", - ['height', 'width', 'in_filter', 'out_filter', + ['batch', 'height', 'width', 'in_filter', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +def find_schedules(layer, vt_only=False, best_only=False): + """ Returns a schedule for a given a layer. + + Parameters + ---------- + layer : Workload + Convolutional layer description. + vt_only : Boolean + Produce a schedule plan with virtual threading. + best_only : Boolean + Return the "best" schedule plan. + + Returns + ------- + fil_sched : list + List of valid schedules. + + """ + # pylint: disable=too-many-nested-blocks + env = get_env() + + # Helper function to get factors + def _find_factors(n): + factors = [] + for f in range(1, n + 1): + if n % f == 0: + factors.append(f) + return factors + + def _get_data_movement_byte(schedule, layer): + """ Estimate data movement in bytes for the schedule plan + """ + env = get_env() + b_f = schedule.b_factor + h_f = schedule.h_factor + w_f = schedule.w_factor + ci_f = schedule.ic_factor + co_f = schedule.oc_factor + # Derive data movement + inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH + wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH + out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH + input_tile_elems = b_f * \ + ((h_f - 1) * layer.hstride + layer.hkernel) * \ + ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f + weight_tile_elems = layer.hkernel * layer.wkernel * ci_f + output_tile_elems = b_f * h_f * w_f * co_f + # Derive tiling factors + b_factor = layer.batch // (b_f * env.BATCH) + h_factor = (layer.height // layer.hstride) // h_f + w_factor = (layer.width // layer.wstride) // w_f + ci_factor = layer.in_filter // (ci_f * env.BLOCK_IN) + co_factor = layer.out_filter // (co_f * env.BLOCK_OUT) + # Compute input transaction count + input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor + weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor + output_xfers = b_factor * h_factor * w_factor * co_factor + # Compute total transfer sizes + input_xfer_byte = input_tile_elems * input_xfers * inp_elem_sizeb // 8 + weight_xfer_byte = weight_tile_elems * weight_xfers * wgt_elem_sizeb // 8 + output_xfer_byte = output_tile_elems * output_xfers * out_elem_sizeb // 8 + total_xfer_byte = input_xfer_byte + weight_xfer_byte + output_xfer_byte + return total_xfer_byte + + # Scheduling exploration + batch_factors = _find_factors(layer.batch // env.BATCH) + height_factors = _find_factors(layer.height // layer.hstride) + width_factors = _find_factors(layer.width // layer.wstride) + cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN) + cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT) + ht_factors = [1, 2] + cot_factors = [1, 2] + + # Explore schedules + schedules = [] + for b_f in batch_factors: + for h_f in height_factors: + for w_f in width_factors: + for ci_f in cin_factors: + for co_f in cout_factors: + # FIXME: 2D load pattern matching imposes restrictions on schedule + valid = (w_f == layer.width // layer.wstride) or \ + (w_f != layer.width // layer.wstride and co_f == 1) and \ + ci_f == 1 + if valid: + schedules.append([b_f, h_f, w_f, ci_f, co_f]) + + # Filter the schedules that wouldn't work in the available BRAM sizes + inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH + wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH + out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH + inp_brams_sizeb = env.INP_BUFF_SIZE * 8 + wgt_brams_sizeb = env.WGT_BUFF_SIZE * 8 + out_brams_sizeb = env.OUT_BUFF_SIZE * 8 + fil_sched = [] + xfer_size = [] + for sched in schedules: + b_f, h_f, w_f, ci_f, co_f = sched + for h_t in ht_factors: + for co_t in cot_factors: + # Make sure to filter cases where we apply threading on two axes + # or cases where the threading factors for h and co are not + # factors of h and co + if (h_t == 2 and co_t == 2) or (h_f % h_t != 0) or (co_f % co_t != 0): + continue + # Adjust tile sizes if threading is applied + h_f //= h_t + co_f //= co_t + # Derive tile sizes + input_tile_elems = b_f * \ + ((h_f - 1) * layer.hstride + layer.hkernel) * \ + ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f + weight_tile_elems = layer.hkernel * layer.wkernel * ci_f * co_f + output_tile_elems = b_f * h_f * w_f * co_f + + # Derive valid schedule filter + valid = True + # If in vitrual-threaded mode, only allow for threaded plans + valid &= (vt_only and (h_t == 2 or co_t == 2)) or not vt_only + # Check that we don't exceed input/weight/output capacity + valid &= input_tile_elems * inp_elem_sizeb <= inp_brams_sizeb // (co_t * h_t) + valid &= weight_tile_elems * wgt_elem_sizeb <= wgt_brams_sizeb + valid &= output_tile_elems * out_elem_sizeb <= out_brams_sizeb // (co_t * h_t) + # Make sure that we don't write to the same acc location within 2 consecutive cycles + valid &= h_f > 2 and w_f > 2 + # TODO: check that we don't exceed instruction or micro-op count + + if valid: + schedule = Schedule(b_factor=b_f, oc_factor=co_f, ic_factor=ci_f, h_factor=h_f, + w_factor=w_f, oc_nthread=co_t, h_nthread=h_t) + fil_sched.append(schedule) + xfer_size.append(_get_data_movement_byte(schedule, layer)) + + if best_only: + return [fil_sched[xfer_size.index(min(xfer_size))]] + return fil_sched + def packed_conv2d(data, kernel, padding, @@ -23,14 +160,14 @@ def packed_conv2d(data, """ Packed conv2d function. """ if padding[0]: - pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0], name="pad_data") + pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") else: pad_data = data - assert len(data.shape) == 5 + assert len(data.shape) == 6 assert len(kernel.shape) == 6 oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) - oshape = (data.shape[0], kernel.shape[0], oheight, owidth, kernel.shape[4]) + oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4]) ishape = topi.util.get_const_tuple(data.shape) kshape = topi.util.get_const_tuple(kernel.shape) @@ -43,14 +180,13 @@ def packed_conv2d(data, hstride, wstride = strides res = tvm.compute( oshape, - lambda b, co, i, j, ci: tvm.sum( - pad_data[b, k_o, i*hstride+d_i, j*wstride+d_j, k_i].astype(out_dtype) * - kernel[co, k_o, d_i, d_j, ci, k_i].astype(out_dtype), + lambda b_o, c_o, i, j, b_i, c_i: tvm.sum( + pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) * + kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype), axis=[k_o, d_i, d_j, k_i]), name="res", tag="packed_conv2d") return res - @tvm.register_func("nnvm.compiler.build_target", override=True) def _build(funcs, target, target_host): tvm_t = tvm.target.create(target) @@ -155,12 +291,13 @@ def _get_workload(data, pad_data, kernel, output): o_shape = topi.util.get_const_tuple(output.shape) d_shape = topi.util.get_const_tuple(data.shape) k_shape = topi.util.get_const_tuple(kernel.shape) - o_b, o_c, o_h, o_w, o_blk = o_shape - i_b, i_c, i_h, i_w, i_blk = d_shape + o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape + i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape # For now we need to assume that input channel blocking is the same # as the output channel blocking assert o_blk == i_blk + assert ob_blk == ib_blk # Make sure that dimensions match assert o_b == i_b assert o_blk == ko_blk @@ -178,7 +315,7 @@ def _get_workload(data, pad_data, kernel, output): h_pad, w_pad = 0, 0 h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) - return Workload(i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) + return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) _WL2PLAN = {} @@ -224,7 +361,7 @@ def schedule_packed_conv2d(outs): load_inp = load_wgt = load_out = store_out = env.dma_copy alu = env.alu - gevm = env.gevm + gemm = env.gemm # schedule1 oshape = topi.util.get_const_tuple(output.shape) @@ -256,11 +393,12 @@ def schedule_packed_conv2d(outs): h_factor = (plan.h_factor if plan.h_factor else oshape[2]) w_factor = (plan.w_factor if plan.w_factor else oshape[3]) - x_b, x_oc, x_i, x_j, x_ic = s[output].op.axis - x_oc0, x_oc1 = s[output].split(x_oc, factor=oc_factor) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) x_i0, x_i1 = s[output].split(x_i, factor=h_factor) x_j0, x_j1 = s[output].split(x_j, factor=w_factor) - s[output].reorder(x_b, x_oc0, x_i0, x_j0, x_oc1, x_i1, x_j1, x_ic) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) store_pt = x_j0 # set all compute scopes @@ -273,100 +411,78 @@ def schedule_packed_conv2d(outs): s[tensor].pragma(s[tensor].op.axis[0], load_out) # virtual threading along output channel axes - if plan.oc_nthread: - _, v_t = s[output].split(x_oc0, factor=plan.oc_nthread) - s[output].reorder(v_t, x_b) + if plan.oc_nthread > 1: + _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) # virtual threading along spatial rows - if plan.h_nthread: + if plan.h_nthread > 1: _, v_t = s[output].split(x_i0, factor=plan.h_nthread) - s[output].reorder(v_t, x_b) + s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) - x_b, x_oc, x_i, x_j, x_ic = s[conv2d_stage].op.axis + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis - s[conv2d_stage].reorder(k_o, x_j, d_j, d_i, x_oc, x_i, x_ic, k_i) + s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) - if plan.ko_factor: - k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ko_factor) + if plan.ic_factor: + k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) s[cdata].compute_at(s[conv2d_stage], k_o) s[ckernel].compute_at(s[conv2d_stage], k_o) # Use VTA instructions s[cdata].pragma(s[cdata].op.axis[0], load_inp) s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) - s[conv2d_stage].tensorize(x_ic, gevm) - s[output].pragma(x_oc1, store_out) + s[conv2d_stage].tensorize(x_bi, gemm) + s[output].pragma(x_co1, store_out) return s - class Conv2DSchedule(object): """ 2D convolution schedule object. """ def __init__(self, - oc_factor, - ko_factor=1, + b_factor=1, + oc_factor=1, + ic_factor=1, h_factor=1, w_factor=0, oc_nthread=0, - h_nthread=0): + h_nthread=0, + debug_sync=False): + self.b_factor = b_factor self.oc_factor = oc_factor - self.ko_factor = ko_factor + self.ic_factor = ic_factor self.h_factor = h_factor self.w_factor = w_factor self.oc_nthread = oc_nthread self.h_nthread = h_nthread + self.debug_sync = debug_sync + def __str__(self): + return "{}.{}.{}.{}.{}.{}.{}".format( + self.b_factor, self.oc_factor, self.ic_factor, + self.h_factor, self.w_factor, + self.oc_nthread, self.h_nthread) Schedule = Conv2DSchedule -# ResNet18 workloads +# Layer description of the ResNet18 RESNET = { - # Workloads of resnet18 on imagenet - 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), } -# Serial schedule -RESNET_SERIAL = { - RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56), - RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0), - RESNET[2]: Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0), - RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0), - RESNET[4]: Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0), - RESNET[5]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), - RESNET[6]: Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0), - RESNET[7]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0), - RESNET[8]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), - RESNET[9]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), - RESNET[10]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0), - RESNET[11]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), -} - -# Latency hiding schedule -RESNET_OPT = { - RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56), - RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2), - RESNET[2]: Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2), - RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2), - RESNET[4]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2), - RESNET[5]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2), - RESNET[6]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), - RESNET[7]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), - RESNET[8]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), - RESNET[9]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), - RESNET[10]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), - RESNET[11]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), -} - -_WL2PLAN = RESNET_OPT +_WL2PLAN = {} +for idx in RESNET: + scheds = find_schedules(RESNET[idx], vt_only=True, best_only=True)[0] + _WL2PLAN[RESNET[idx]] = scheds diff --git a/vta/tests/hardware/common/test_lib.cc b/vta/tests/hardware/common/test_lib.cc index 1cb84576..6c6d28ec 100644 --- a/vta/tests/hardware/common/test_lib.cc +++ b/vta/tests/hardware/common/test_lib.cc @@ -315,7 +315,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat, int push_next_dep) { // Converter union VTAInsn converter; - // GEVM instruction initialization + // GEMM instruction initialization VTAGemInsn insn; insn.opcode = VTA_OPCODE_GEMM; insn.pop_prev_dep = pop_prev_dep; @@ -394,7 +394,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) { // Converter union VTAInsn converter; - // GEVM instruction initialization + // GEMM instruction initialization VTAGemInsn insn; insn.opcode = VTA_OPCODE_FINISH; insn.pop_prev_dep = pop_prev; @@ -649,7 +649,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) { } } else if (c.mem.opcode == VTA_OPCODE_GEMM) { // Print instruction field information - printf("GEVM\n"); + printf("GEMM\n"); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), diff --git a/vta/tests/python/integration/test_benchmark_gemm.py b/vta/tests/python/integration/test_benchmark_gemm.py index b651fad0..4668acff 100644 --- a/vta/tests/python/integration/test_benchmark_gemm.py +++ b/vta/tests/python/integration/test_benchmark_gemm.py @@ -168,7 +168,7 @@ def test_gemm(): with vta.build_config(): run_test("NORMAL", print_ir, True) - def gevm_unittest(print_ir): + def gemm_unittest(print_ir): mock = env.mock print("----- GEMM Unit Test-------") def run_test(header, print_ir): @@ -244,7 +244,7 @@ def test_gemm(): gemm_normal(False) - gevm_unittest(False) + gemm_unittest(False) alu_unittest(False) def _run(env, remote): diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index 0a5edfdc..9721487a 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -23,14 +23,11 @@ def my_clip(x, a_min, a_max): def test_vta_conv2d(): def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): - data_shape = (batch_size, wl.in_filter // env.BLOCK_IN, - wl.height, wl.width, env.BLOCK_IN) - kernel_shape = (wl.out_filter // env.BLOCK_OUT, - wl.in_filter // env.BLOCK_IN, - wl.hkernel, wl.wkernel, - env.BLOCK_OUT, env.BLOCK_IN) - bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) - + data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (1, wl.out_filter//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 @@ -45,12 +42,13 @@ def test_vta_conv2d(): res = my_clip(res, 0, 127) res = topi.cast(res, "int8") - num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter a_shape = (batch_size, wl.in_filter, wl.height, wl.width) w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) stride = (wl.hstride, wl.wstride) data_dtype = data.dtype + kernel_dtype = kernel.dtype acc_dtype = env.acc_dtype assert wl.hpad == wl.wpad padding = wl.hpad @@ -58,7 +56,7 @@ def test_vta_conv2d(): @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc") def get_ref_data(): a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) - w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) a_np = np.abs(a_np) w_np = np.abs(w_np) b_np = topi.testing.conv2d_nchw_python( @@ -82,14 +80,15 @@ def test_vta_conv2d(): bias_orig = np.abs(bias_orig) data_packed = data_orig.reshape( - batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, - wl.height, wl.width).transpose((0, 1, 3, 4, 2)) + batch_size//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) kernel_packed = kernel_orig.reshape( - wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, - wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) bias_packed = bias_orig.reshape( - wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) + 1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) res_shape = topi.util.get_const_tuple(res.shape) res_np = np.zeros(res_shape).astype(res.dtype) @@ -100,7 +99,7 @@ def test_vta_conv2d(): time_f = f.time_evaluator("conv2d", ctx, number=5) cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) res_unpack = res_arr.asnumpy().transpose( - (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) + (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) if check_correctness: assert wl.hpad == wl.wpad stride = (wl.hstride, wl.wstride) @@ -127,18 +126,18 @@ def test_vta_conv2d(): # ResNet18 workloads resnet = { # Workloads of resnet18 on imagenet - 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), } batch_size = 1 @@ -148,6 +147,7 @@ def test_vta_conv2d(): print("key=%s" % key) print(wl) run_vta_conv2d(env, remote, key, batch_size, wl) + vta.testing.run(_run)