[VTA][Relay] Extending Vision model coverage compilation for VTA (#3740)
* adding support for graphpack over multiply op * increasing resnet model coverage * fix indentation * lint * moving recursion limit fix into graphpack pass * moving recursionlimit to relay init * pooling on NCHWnc format * adding more models * deploy_resnet_on_vta.py * trailing line * generalizing to vision models * merge conflicts * fix, apply quantization to VTA only * improving comments * trimming models that have runtime issues for the moment * lint * lint * lint
This commit is contained in:
Родитель
dee11b4198
Коммит
028f47ce65
|
@ -17,6 +17,7 @@
|
|||
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
|
||||
"""The Relay IR namespace containing the IR definition and compiler."""
|
||||
from __future__ import absolute_import
|
||||
from sys import setrecursionlimit
|
||||
from ..api import register_func
|
||||
from . import base
|
||||
from . import ty
|
||||
|
@ -59,6 +60,9 @@ from . import qnn
|
|||
|
||||
from .scope_builder import ScopeBuilder
|
||||
|
||||
# Required to traverse large programs
|
||||
setrecursionlimit(10000)
|
||||
|
||||
# Span
|
||||
Span = base.Span
|
||||
|
||||
|
|
|
@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
|
|||
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
|
||||
<< "max_pool2d does not support input split on width";
|
||||
|
||||
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
|
||||
CHECK(inputs[0].ndim() == 4U ||
|
||||
inputs[0].ndim() == 5U ||
|
||||
inputs[0].ndim() == 6U)
|
||||
<< "Pool2D only support 4-D input (e.g., NCHW)"
|
||||
<< " or 5-D input (last dimension is a split of channel)";
|
||||
<< " or 5-D input (e.g. NCHWc on for vector instructions)"
|
||||
<< " or 6-D input (e.g. NCHWnc for tensor accelerators)";
|
||||
|
||||
if (param->padding.size() == 1) {
|
||||
padding.push_back(padding[0]);
|
||||
|
|
|
@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor):
|
|||
return data
|
||||
|
||||
|
||||
def _pack_bias(data, dshape, dtype, bfactor, cfactor):
|
||||
"""Pack the bias parameter.
|
||||
def _pack_const(data, dshape, dtype, bfactor, cfactor):
|
||||
"""Pack a constant parameter.
|
||||
"""
|
||||
dshape = _to_shape(dshape)
|
||||
assert len(dshape) == 3
|
||||
|
@ -124,6 +124,7 @@ class ExprPack(ExprMutator):
|
|||
self.conv2d = op.op.get("nn.conv2d")
|
||||
self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
|
||||
self.add = op.op.get("add")
|
||||
self.multiply = op.op.get("multiply")
|
||||
self.bias_add = op.op.get("nn.bias_add")
|
||||
self.number_of_conv2d = 0
|
||||
super().__init__()
|
||||
|
@ -203,23 +204,35 @@ class ExprPack(ExprMutator):
|
|||
output_padding=call.attrs.output_padding,
|
||||
out_dtype=call.attrs.out_dtype)
|
||||
return conv2d
|
||||
elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
|
||||
elif call.op == self.add and \
|
||||
tuple(input_types[0].shape) == tuple(input_types[1].shape):
|
||||
pass
|
||||
elif call.op == self.add and len(input_types[1].shape) == 3:
|
||||
data, bias = args
|
||||
bias = _pack_bias(bias,
|
||||
_to_shape(input_types[1].shape),
|
||||
input_types[1].dtype,
|
||||
self.bfactor,
|
||||
self.cfactor)
|
||||
return relay.Call(self.add, [data, bias])
|
||||
data, const = args
|
||||
const = _pack_const(const,
|
||||
_to_shape(input_types[1].shape),
|
||||
input_types[1].dtype,
|
||||
self.bfactor,
|
||||
self.cfactor)
|
||||
return relay.Call(self.add, [data, const])
|
||||
elif call.op == self.multiply and \
|
||||
tuple(input_types[0].shape) == tuple(input_types[1].shape):
|
||||
pass
|
||||
elif call.op == self.multiply and len(input_types[1].shape) == 3:
|
||||
data, const = args
|
||||
const = _pack_const(const,
|
||||
_to_shape(input_types[1].shape),
|
||||
input_types[1].dtype,
|
||||
self.bfactor,
|
||||
self.cfactor)
|
||||
return relay.Call(self.multiply, [data, const])
|
||||
elif self.start_pack and call.op == self.bias_add:
|
||||
data, bias = args
|
||||
bias = _pack_bias(bias,
|
||||
_to_shape(input_types[1].shape),
|
||||
input_types[1].dtype,
|
||||
self.bfactor,
|
||||
self.cfactor)
|
||||
bias = _pack_const(bias,
|
||||
_to_shape(input_types[1].shape),
|
||||
input_types[1].dtype,
|
||||
self.bfactor,
|
||||
self.cfactor)
|
||||
return relay.Call(self.add, [data, bias])
|
||||
elif self.start_pack and call.op == op.op.get('cast') and \
|
||||
input_types[0].dtype == 'int32':
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Deploy Pretrained ResNet Model from MxNet on VTA
|
||||
Deploy Pretrained Vision Model from MxNet on VTA
|
||||
================================================
|
||||
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
|
||||
|
||||
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
|
||||
onto the VTA accelerator design to perform ImageNet classification tasks.
|
||||
This tutorial provides an end-to-end demo, on how to run ImageNet classification
|
||||
inference onto the VTA accelerator design to perform ImageNet classification tasks.
|
||||
It showcases Relay as a front end compiler that can perform quantization (VTA
|
||||
only supports int8/32 inference) as well as graph packing (in order to enable
|
||||
tensorization in the core) to massage the compute graph for the hardware target.
|
||||
|
@ -40,7 +40,7 @@ tensorization in the core) to massage the compute graph for the hardware target.
|
|||
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import argparse, json, os, requests, time
|
||||
import argparse, json, os, requests, sys, time
|
||||
from io import BytesIO
|
||||
from os.path import join, isfile
|
||||
from PIL import Image
|
||||
|
@ -53,6 +53,7 @@ import tvm
|
|||
from tvm import rpc, autotvm, relay
|
||||
from tvm.contrib import graph_runtime, util, download
|
||||
from tvm.contrib.debugger import debug_runtime
|
||||
from tvm.relay import transform
|
||||
|
||||
import vta
|
||||
from vta.testing import simulator
|
||||
|
@ -61,7 +62,6 @@ from vta.top import graph_pack
|
|||
# Make sure that TVM was compiled with RPC=1
|
||||
assert tvm.module.enabled("rpc")
|
||||
|
||||
|
||||
######################################################################
|
||||
# Define the platform and model targets
|
||||
# -------------------------------------
|
||||
|
@ -75,13 +75,22 @@ env = vta.get_env()
|
|||
device = "vta"
|
||||
target = env.target if device == "vta" else env.target_vta_cpu
|
||||
|
||||
# Dictionary lookup for when to start/end bit packing
|
||||
pack_dict = {
|
||||
"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
"resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
"resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
"resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
"resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
"resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
|
||||
}
|
||||
|
||||
# Name of Gluon model to compile
|
||||
# The ``start_pack`` and ``stop_pack`` labels indicate where
|
||||
# to start and end the graph packing relay pass: in other words
|
||||
# where to start and finish offloading to VTA.
|
||||
model = "resnet18_v1"
|
||||
start_pack="nn.max_pool2d"
|
||||
stop_pack="nn.global_avg_pool2d"
|
||||
assert model in pack_dict
|
||||
|
||||
######################################################################
|
||||
# Obtain an execution remote
|
||||
|
@ -125,7 +134,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
|
|||
######################################################################
|
||||
# Build the inference graph runtime
|
||||
# ---------------------------------
|
||||
# Grab ResNet-18 model from Gluon model zoo and compile with Relay.
|
||||
# Grab vision model from Gluon model zoo and compile with Relay.
|
||||
# The compilation steps are:
|
||||
# 1) Front end translation from MxNet into Relay module.
|
||||
# 2) Apply 8-bit quantization: here we skip the first conv layer,
|
||||
|
@ -140,7 +149,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
|
|||
# Load pre-configured AutoTVM schedules
|
||||
with autotvm.tophub.context(target):
|
||||
|
||||
# Populate the shape and data type dictionary for ResNet input
|
||||
# Populate the shape and data type dictionary for ImageNet classifier input
|
||||
dtype_dict = {"data": 'float32'}
|
||||
shape_dict = {"data": (env.BATCH, 3, 224, 224)}
|
||||
|
||||
|
@ -157,21 +166,22 @@ with autotvm.tophub.context(target):
|
|||
shape_dict.update({k: v.shape for k, v in params.items()})
|
||||
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
|
||||
|
||||
# Perform quantization in Relay
|
||||
with relay.quantize.qconfig(global_scale=8.0,
|
||||
skip_conv_layers=[0]):
|
||||
relay_prog = relay.quantize.quantize(mod["main"], params=params)
|
||||
|
||||
# Perform graph packing and constant folding for VTA target
|
||||
if target.device_name == "vta":
|
||||
# Perform quantization in Relay
|
||||
with relay.quantize.qconfig(global_scale=8.0,
|
||||
skip_conv_layers=[0]):
|
||||
relay_prog = relay.quantize.quantize(mod["main"], params=params)
|
||||
# Perform graph packing and constant folding for VTA target
|
||||
assert env.BLOCK_IN == env.BLOCK_OUT
|
||||
relay_prog = graph_pack(
|
||||
relay_prog,
|
||||
env.BATCH,
|
||||
env.BLOCK_OUT,
|
||||
env.WGT_WIDTH,
|
||||
start_name=start_pack,
|
||||
stop_name=stop_pack)
|
||||
start_name=pack_dict[model][0],
|
||||
stop_name=pack_dict[model][1])
|
||||
else:
|
||||
relay_prog = mod["main"]
|
||||
|
||||
# Compile Relay program with AlterOpLayout disabled
|
||||
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
|
||||
|
@ -199,8 +209,8 @@ with autotvm.tophub.context(target):
|
|||
m = graph_runtime.create(graph, lib, ctx)
|
||||
|
||||
######################################################################
|
||||
# Perform ResNet-18 inference
|
||||
# ---------------------------
|
||||
# Perform image classification inference
|
||||
# --------------------------------------
|
||||
# We run classification on an image sample from ImageNet
|
||||
# We just need to download the categories files, `synset.txt`
|
||||
# and an input test image.
|
||||
|
@ -256,7 +266,6 @@ else:
|
|||
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
|
||||
for b in range(env.BATCH):
|
||||
top_categories = np.argsort(tvm_output.asnumpy()[b])
|
||||
|
||||
# Report top-5 classification results
|
||||
print("\n{} prediction for sample {}".format(model, b))
|
||||
print("\t#1:", synset[top_categories[-1]])
|
||||
|
@ -264,7 +273,6 @@ for b in range(env.BATCH):
|
|||
print("\t#3:", synset[top_categories[-3]])
|
||||
print("\t#4:", synset[top_categories[-4]])
|
||||
print("\t#5:", synset[top_categories[-5]])
|
||||
|
||||
# This just checks that one of the 5 top categories
|
||||
# is one variety of cat; this is by no means an accurate
|
||||
# assessment of how quantization affects classification
|
Загрузка…
Ссылка в новой задаче