From 8c9758b6064bc2ebb2ea82a86b31701f03e1067f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 21 May 2018 12:23:31 -0700 Subject: [PATCH] Update Graph Support for Batching, Fix Swapping (#37) * fix graph transform for batch dimension * fix * fix --- .../resnet18/pynq/imagenet_predict.py | 25 ++++- vta/python/vta/graph.py | 104 ++++++++++-------- vta/src/runtime.cc | 4 +- vta/src/sim/sim_driver.cc | 1 + 4 files changed, 81 insertions(+), 53 deletions(-) diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py index e4f82b17..554cceab 100644 --- a/vta/examples/resnet18/pynq/imagenet_predict.py +++ b/vta/examples/resnet18/pynq/imagenet_predict.py @@ -3,6 +3,7 @@ import nnvm import tvm from nnvm.compiler import graph_attr import vta +import vta.testing import os import numpy as np from PIL import Image @@ -12,7 +13,8 @@ import logging import wget from tvm.contrib import graph_runtime, rpc, util -factor = 16 +bfactor = 1 +cfactor = 16 host = "pynq" port = 9091 verbose = False @@ -38,6 +40,10 @@ if verbose: target = tvm.target.create("llvm -device=vta") target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" +if vta.get_env().TARGET == "sim": + target_host = "llvm" + + synset = eval(open(os.path.join(CATEG_FILE)).read()) image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) @@ -105,7 +111,7 @@ sym = vta.graph.remove_stochastic(sym) sym = vta.graph.clean_cast(sym) sym = vta.graph.clean_conv_fuse(sym) if target.device_name == "vta": - sym = vta.graph.pack(sym, shape_dict, factor) + sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor) graph_attr.set_shape_inputs(sym, shape_dict) sym = sym.apply("InferShape") @@ -127,7 +133,13 @@ with nnvm.compiler.build_config(opt_level=3): assert tvm.module.enabled("rpc") temp = util.tempdir() lib.save(temp.relpath("graphlib.o")) -remote = rpc.connect(host, port) + +if vta.get_env().TARGET == "sim": + remote = rpc.LocalSession() + print("local session") +else: + remote = rpc.connect(host, port) + remote.upload(temp.relpath("graphlib.o")) lib = remote.load_module("graphlib.o") ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) @@ -154,16 +166,17 @@ def run_e2e(graph): print("t-cost=%g" % tcost.mean) -def run_layer(old_graph): +def run_layer(old_graph, layer_begin, layer_end): """Run a certain layer.""" - for layer_id in range(1, 2): + for layer_id in range(layer_begin, layer_end): + print("run resnet[%d]..."% (layer_id)) graph = mark_nop(old_graph, layer_id) m = graph_runtime.create(graph, lib, ctx) # set inputs m.set_input('data', tvm.nd.array(x.astype("float32"))) m.set_input(**params) # execute - timer = m.module.time_evaluator("run", ctx, number=10) + timer = m.module.time_evaluator("run", ctx, number=1) tcost = timer() print("resnet[%d]: %g\n"% (layer_id, tcost.mean)) diff --git a/vta/python/vta/graph.py b/vta/python/vta/graph.py index b8237980..41a38c2b 100644 --- a/vta/python/vta/graph.py +++ b/vta/python/vta/graph.py @@ -10,51 +10,58 @@ import nnvm from nnvm.compiler import graph_attr, graph_util -def _pack_channel(data, dshape, factor): +def _pack_batch_channel(data, dshape, bfactor, cfactor): """Pack the data channel dimension. """ - assert dshape[1] % factor == 0 + assert dshape[0] % bfactor == 0 + assert dshape[1] % cfactor == 0 data = nnvm.sym.reshape(data, - shape=(dshape[0], dshape[1] // factor, - factor, dshape[2], dshape[3])) - data = nnvm.sym.transpose( - data, axes=(0, 1, 3, 4, 2)) - return data - - -def _unpack_channel(data, old_shape): - """Unpack the data channel dimension. - """ - data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3)) - data = nnvm.sym.reshape(data, shape=old_shape) - return data - - -def _pack_weight(data, dshape, factor): - """Pack the weight into packed format. - """ - assert len(dshape) == 4 - assert dshape[0] % factor == 0 - assert dshape[1] % factor == 0 - data = nnvm.sym.reshape(data, - shape=(dshape[0] // factor, factor, - dshape[1] // factor, factor, + shape=(dshape[0] // bfactor, bfactor, + dshape[1] // cfactor, cfactor, dshape[2], dshape[3])) data = nnvm.sym.transpose( data, axes=(0, 2, 4, 5, 1, 3)) return data -def _pack_bias(data, dshape, factor): +def _unpack_batch_channel(data, old_shape): + """Unpack the data channel dimension. + """ + data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3)) + data = nnvm.sym.reshape(data, shape=old_shape) + return data + + +def _pack_weight(data, dshape, cfactor): + """Pack the weight into packed format. + """ + assert len(dshape) == 4 + assert dshape[0] % cfactor == 0 + assert dshape[1] % cfactor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // cfactor, cfactor, + dshape[1] // cfactor, cfactor, + dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _pack_bias(data, dshape, bfactor, cfactor): """Pack the bias parameter. """ assert len(dshape) == 3 - assert dshape[0] % factor == 0 + assert dshape[0] % cfactor == 0 data = nnvm.sym.reshape(data, - shape=(dshape[0] // factor, - factor, dshape[1], dshape[2])) + shape=(dshape[0] // cfactor, + cfactor, dshape[1], + dshape[2], 1)) data = nnvm.sym.transpose( - data, axes=(0, 2, 3, 1)) + data, axes=(0, 2, 3, 4, 1)) + # broadcast batch dimension to bfactor + data = nnvm.sym.broadcast_to( + data, + shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)) return data @@ -245,8 +252,8 @@ def clean_cast(graph): return ret -def pack(graph, shape_dict, factor, start_name=None): - """Pack the graph into channel packed format. +def pack(graph, shape_dict, bfactor, cfactor, start_name=None): + """Pack the graph into batch&channel packed format. Parameters ---------- @@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None): shape_dict : dict of str to shapex The input shape. - factor : int - The packing factor + bfactor : int + The packing factor in batch + + cfactor : int + The packing factor in channel start_name: str, optional Start name start packing from certain known node. @@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None): new_node = nnvm.symbol.Variable(node_name) if start_name and node_name == start_name: start_pack = True - new_node = _pack_channel(new_node, oshape, factor) + new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor) elif op_name == "max_pool2d": assert not start_pack start_pack = True new_node = get_clone(children, op_name, node_name, attrs) - new_node = _pack_channel(new_node, oshape, factor) + new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor) elif op_name == "global_avg_pool2d": if start_pack: start_pack = False - children[0] = _unpack_channel(children[0], ishape[0]) + children[0] = _unpack_batch_channel(children[0], ishape[0]) new_node = getattr(nnvm.symbol, op_name)( *children, name=node_name, **attrs) else: new_node = get_clone(children, op_name, node_name, attrs) elif op_name == "quantized_conv2d": if start_pack: - attrs["pack_channel"] = str(factor) + attrs["pack_batch"] = str(bfactor) + attrs["pack_channel"] = str(cfactor) data, weight = children - weight = _pack_weight(weight, ishape[1], factor) + weight = _pack_weight(weight, ishape[1], cfactor) new_node = nnvm.sym.quantized_conv2d( data, weight, name=node_name, **attrs) elif counter == 1: - attrs["pack_channel"] = str(factor) + attrs["pack_batch"] = str(bfactor) + attrs["pack_channel"] = str(cfactor) data, weight = children - data = _pack_channel(data, ishape[0], factor) - weight = _pack_weight(weight, ishape[1], factor) + data = _pack_batch_channel(data, ishape[0], bfactor, cfactor) + weight = _pack_weight(weight, ishape[1], cfactor) new_node = nnvm.sym.quantized_conv2d( data, weight, name=node_name, **attrs) - new_node = _unpack_channel(new_node, oshape) + new_node = _unpack_batch_channel(new_node, oshape) counter = counter + 1 else: new_node = get_clone(children, op_name, node_name, attrs) elif op_name.startswith("broadcast"): if start_pack: assert len(ishape[1]) == 3 - children[1] = _pack_bias(children[1], ishape[1], factor) + children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor) new_node = getattr(nnvm.symbol, op_name)( *children, name=node_name, **attrs) else: @@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None): ret = node_map[graph.index.output_entries[0][0]] if start_pack: oshape = shape[graph.index.output_entries[0][0]] - ret = _unpack_channel(ret, oshape) + ret = _unpack_batch_channel(ret, oshape) graph = nnvm.graph.create(ret) graph = graph_attr.set_shape_inputs(graph, shape_dict) graph = graph.apply("InferShape") diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index c0de87fa..9e84acfc 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -367,9 +367,10 @@ class UopQueue : public BaseQueue { } assert(num_op <= kMaxNumUop); uint32_t uop_begin = 0; - if (sram_end_ + num_op > kMaxElems) { + if (sram_end_ + num_op > kMaxNumUop) { // Need to evict cache_ptr_ = 0; + sram_begin_ = 0; sram_end_ = num_op; } else { uop_begin = sram_end_; @@ -388,6 +389,7 @@ class UopQueue : public BaseQueue { dram_end_ += num_op; kernel->sram_begin_ = uop_begin; kernel->sram_end_ = sram_end_; + CHECK(kernel->cached()); assert(uop_begin != sram_end_); cache_.insert(cache_.begin() + cache_ptr_, kernel); cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_); diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc index 57bc21c9..9a953e7a 100644 --- a/vta/src/sim/sim_driver.cc +++ b/vta/src/sim/sim_driver.cc @@ -162,6 +162,7 @@ class DRAM { */ void Free(void* data) { std::lock_guard lock(mutex_); + if (pmap_.size() == 0) return; auto it = pmap_.find(data); CHECK(it != pmap_.end()); Page* p = it->second.get();