Update Graph Support for Batching, Fix Swapping (#37)
* fix graph transform for batch dimension * fix * fix
This commit is contained in:
Родитель
a96a4a9bcc
Коммит
8c9758b606
|
@ -3,6 +3,7 @@ import nnvm
|
|||
import tvm
|
||||
from nnvm.compiler import graph_attr
|
||||
import vta
|
||||
import vta.testing
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -12,7 +13,8 @@ import logging
|
|||
import wget
|
||||
from tvm.contrib import graph_runtime, rpc, util
|
||||
|
||||
factor = 16
|
||||
bfactor = 1
|
||||
cfactor = 16
|
||||
host = "pynq"
|
||||
port = 9091
|
||||
verbose = False
|
||||
|
@ -38,6 +40,10 @@ if verbose:
|
|||
target = tvm.target.create("llvm -device=vta")
|
||||
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
|
||||
|
||||
if vta.get_env().TARGET == "sim":
|
||||
target_host = "llvm"
|
||||
|
||||
|
||||
synset = eval(open(os.path.join(CATEG_FILE)).read())
|
||||
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
|
||||
|
||||
|
@ -105,7 +111,7 @@ sym = vta.graph.remove_stochastic(sym)
|
|||
sym = vta.graph.clean_cast(sym)
|
||||
sym = vta.graph.clean_conv_fuse(sym)
|
||||
if target.device_name == "vta":
|
||||
sym = vta.graph.pack(sym, shape_dict, factor)
|
||||
sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
|
||||
|
||||
graph_attr.set_shape_inputs(sym, shape_dict)
|
||||
sym = sym.apply("InferShape")
|
||||
|
@ -127,7 +133,13 @@ with nnvm.compiler.build_config(opt_level=3):
|
|||
assert tvm.module.enabled("rpc")
|
||||
temp = util.tempdir()
|
||||
lib.save(temp.relpath("graphlib.o"))
|
||||
remote = rpc.connect(host, port)
|
||||
|
||||
if vta.get_env().TARGET == "sim":
|
||||
remote = rpc.LocalSession()
|
||||
print("local session")
|
||||
else:
|
||||
remote = rpc.connect(host, port)
|
||||
|
||||
remote.upload(temp.relpath("graphlib.o"))
|
||||
lib = remote.load_module("graphlib.o")
|
||||
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
|
||||
|
@ -154,16 +166,17 @@ def run_e2e(graph):
|
|||
print("t-cost=%g" % tcost.mean)
|
||||
|
||||
|
||||
def run_layer(old_graph):
|
||||
def run_layer(old_graph, layer_begin, layer_end):
|
||||
"""Run a certain layer."""
|
||||
for layer_id in range(1, 2):
|
||||
for layer_id in range(layer_begin, layer_end):
|
||||
print("run resnet[%d]..."% (layer_id))
|
||||
graph = mark_nop(old_graph, layer_id)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
m.set_input('data', tvm.nd.array(x.astype("float32")))
|
||||
m.set_input(**params)
|
||||
# execute
|
||||
timer = m.module.time_evaluator("run", ctx, number=10)
|
||||
timer = m.module.time_evaluator("run", ctx, number=1)
|
||||
tcost = timer()
|
||||
print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
|
||||
|
||||
|
|
|
@ -10,51 +10,58 @@ import nnvm
|
|||
from nnvm.compiler import graph_attr, graph_util
|
||||
|
||||
|
||||
def _pack_channel(data, dshape, factor):
|
||||
def _pack_batch_channel(data, dshape, bfactor, cfactor):
|
||||
"""Pack the data channel dimension.
|
||||
"""
|
||||
assert dshape[1] % factor == 0
|
||||
assert dshape[0] % bfactor == 0
|
||||
assert dshape[1] % cfactor == 0
|
||||
data = nnvm.sym.reshape(data,
|
||||
shape=(dshape[0], dshape[1] // factor,
|
||||
factor, dshape[2], dshape[3]))
|
||||
data = nnvm.sym.transpose(
|
||||
data, axes=(0, 1, 3, 4, 2))
|
||||
return data
|
||||
|
||||
|
||||
def _unpack_channel(data, old_shape):
|
||||
"""Unpack the data channel dimension.
|
||||
"""
|
||||
data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3))
|
||||
data = nnvm.sym.reshape(data, shape=old_shape)
|
||||
return data
|
||||
|
||||
|
||||
def _pack_weight(data, dshape, factor):
|
||||
"""Pack the weight into packed format.
|
||||
"""
|
||||
assert len(dshape) == 4
|
||||
assert dshape[0] % factor == 0
|
||||
assert dshape[1] % factor == 0
|
||||
data = nnvm.sym.reshape(data,
|
||||
shape=(dshape[0] // factor, factor,
|
||||
dshape[1] // factor, factor,
|
||||
shape=(dshape[0] // bfactor, bfactor,
|
||||
dshape[1] // cfactor, cfactor,
|
||||
dshape[2], dshape[3]))
|
||||
data = nnvm.sym.transpose(
|
||||
data, axes=(0, 2, 4, 5, 1, 3))
|
||||
return data
|
||||
|
||||
|
||||
def _pack_bias(data, dshape, factor):
|
||||
def _unpack_batch_channel(data, old_shape):
|
||||
"""Unpack the data channel dimension.
|
||||
"""
|
||||
data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
|
||||
data = nnvm.sym.reshape(data, shape=old_shape)
|
||||
return data
|
||||
|
||||
|
||||
def _pack_weight(data, dshape, cfactor):
|
||||
"""Pack the weight into packed format.
|
||||
"""
|
||||
assert len(dshape) == 4
|
||||
assert dshape[0] % cfactor == 0
|
||||
assert dshape[1] % cfactor == 0
|
||||
data = nnvm.sym.reshape(data,
|
||||
shape=(dshape[0] // cfactor, cfactor,
|
||||
dshape[1] // cfactor, cfactor,
|
||||
dshape[2], dshape[3]))
|
||||
data = nnvm.sym.transpose(
|
||||
data, axes=(0, 2, 4, 5, 1, 3))
|
||||
return data
|
||||
|
||||
|
||||
def _pack_bias(data, dshape, bfactor, cfactor):
|
||||
"""Pack the bias parameter.
|
||||
"""
|
||||
assert len(dshape) == 3
|
||||
assert dshape[0] % factor == 0
|
||||
assert dshape[0] % cfactor == 0
|
||||
data = nnvm.sym.reshape(data,
|
||||
shape=(dshape[0] // factor,
|
||||
factor, dshape[1], dshape[2]))
|
||||
shape=(dshape[0] // cfactor,
|
||||
cfactor, dshape[1],
|
||||
dshape[2], 1))
|
||||
data = nnvm.sym.transpose(
|
||||
data, axes=(0, 2, 3, 1))
|
||||
data, axes=(0, 2, 3, 4, 1))
|
||||
# broadcast batch dimension to bfactor
|
||||
data = nnvm.sym.broadcast_to(
|
||||
data,
|
||||
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
|
||||
return data
|
||||
|
||||
|
||||
|
@ -245,8 +252,8 @@ def clean_cast(graph):
|
|||
return ret
|
||||
|
||||
|
||||
def pack(graph, shape_dict, factor, start_name=None):
|
||||
"""Pack the graph into channel packed format.
|
||||
def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
|
||||
"""Pack the graph into batch&channel packed format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None):
|
|||
shape_dict : dict of str to shapex
|
||||
The input shape.
|
||||
|
||||
factor : int
|
||||
The packing factor
|
||||
bfactor : int
|
||||
The packing factor in batch
|
||||
|
||||
cfactor : int
|
||||
The packing factor in channel
|
||||
|
||||
start_name: str, optional
|
||||
Start name start packing from certain known node.
|
||||
|
@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None):
|
|||
new_node = nnvm.symbol.Variable(node_name)
|
||||
if start_name and node_name == start_name:
|
||||
start_pack = True
|
||||
new_node = _pack_channel(new_node, oshape, factor)
|
||||
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
|
||||
elif op_name == "max_pool2d":
|
||||
assert not start_pack
|
||||
start_pack = True
|
||||
new_node = get_clone(children, op_name, node_name, attrs)
|
||||
new_node = _pack_channel(new_node, oshape, factor)
|
||||
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
|
||||
elif op_name == "global_avg_pool2d":
|
||||
if start_pack:
|
||||
start_pack = False
|
||||
children[0] = _unpack_channel(children[0], ishape[0])
|
||||
children[0] = _unpack_batch_channel(children[0], ishape[0])
|
||||
new_node = getattr(nnvm.symbol, op_name)(
|
||||
*children, name=node_name, **attrs)
|
||||
else:
|
||||
new_node = get_clone(children, op_name, node_name, attrs)
|
||||
elif op_name == "quantized_conv2d":
|
||||
if start_pack:
|
||||
attrs["pack_channel"] = str(factor)
|
||||
attrs["pack_batch"] = str(bfactor)
|
||||
attrs["pack_channel"] = str(cfactor)
|
||||
data, weight = children
|
||||
weight = _pack_weight(weight, ishape[1], factor)
|
||||
weight = _pack_weight(weight, ishape[1], cfactor)
|
||||
new_node = nnvm.sym.quantized_conv2d(
|
||||
data, weight, name=node_name, **attrs)
|
||||
elif counter == 1:
|
||||
attrs["pack_channel"] = str(factor)
|
||||
attrs["pack_batch"] = str(bfactor)
|
||||
attrs["pack_channel"] = str(cfactor)
|
||||
data, weight = children
|
||||
data = _pack_channel(data, ishape[0], factor)
|
||||
weight = _pack_weight(weight, ishape[1], factor)
|
||||
data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
|
||||
weight = _pack_weight(weight, ishape[1], cfactor)
|
||||
new_node = nnvm.sym.quantized_conv2d(
|
||||
data, weight, name=node_name, **attrs)
|
||||
new_node = _unpack_channel(new_node, oshape)
|
||||
new_node = _unpack_batch_channel(new_node, oshape)
|
||||
counter = counter + 1
|
||||
else:
|
||||
new_node = get_clone(children, op_name, node_name, attrs)
|
||||
elif op_name.startswith("broadcast"):
|
||||
if start_pack:
|
||||
assert len(ishape[1]) == 3
|
||||
children[1] = _pack_bias(children[1], ishape[1], factor)
|
||||
children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
|
||||
new_node = getattr(nnvm.symbol, op_name)(
|
||||
*children, name=node_name, **attrs)
|
||||
else:
|
||||
|
@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None):
|
|||
ret = node_map[graph.index.output_entries[0][0]]
|
||||
if start_pack:
|
||||
oshape = shape[graph.index.output_entries[0][0]]
|
||||
ret = _unpack_channel(ret, oshape)
|
||||
ret = _unpack_batch_channel(ret, oshape)
|
||||
graph = nnvm.graph.create(ret)
|
||||
graph = graph_attr.set_shape_inputs(graph, shape_dict)
|
||||
graph = graph.apply("InferShape")
|
||||
|
|
|
@ -367,9 +367,10 @@ class UopQueue : public BaseQueue {
|
|||
}
|
||||
assert(num_op <= kMaxNumUop);
|
||||
uint32_t uop_begin = 0;
|
||||
if (sram_end_ + num_op > kMaxElems) {
|
||||
if (sram_end_ + num_op > kMaxNumUop) {
|
||||
// Need to evict
|
||||
cache_ptr_ = 0;
|
||||
sram_begin_ = 0;
|
||||
sram_end_ = num_op;
|
||||
} else {
|
||||
uop_begin = sram_end_;
|
||||
|
@ -388,6 +389,7 @@ class UopQueue : public BaseQueue {
|
|||
dram_end_ += num_op;
|
||||
kernel->sram_begin_ = uop_begin;
|
||||
kernel->sram_end_ = sram_end_;
|
||||
CHECK(kernel->cached());
|
||||
assert(uop_begin != sram_end_);
|
||||
cache_.insert(cache_.begin() + cache_ptr_, kernel);
|
||||
cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_);
|
||||
|
|
|
@ -162,6 +162,7 @@ class DRAM {
|
|||
*/
|
||||
void Free(void* data) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (pmap_.size() == 0) return;
|
||||
auto it = pmap_.find(data);
|
||||
CHECK(it != pmap_.end());
|
||||
Page* p = it->second.get();
|
||||
|
|
Загрузка…
Ссылка в новой задаче