This commit is contained in:
Lianmin Zheng 2018-08-22 20:21:15 -07:00 коммит произвёл Tianqi Chen
Родитель 56c50d2d07
Коммит cfafd212c0
10 изменённых файлов: 194 добавлений и 139 удалений

Просмотреть файл

@ -1,45 +1,26 @@
"""Benchmark script for performance on ARM CPU.
"""Benchmark script for ARM CPU.
see README.md for the usage and results of this script.
"""
import argparse
import time
import numpy as np
import nnvm.testing
import nnvm.compiler
import tvm
from tvm import autotvm
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import nnvm.compiler
import nnvm.testing
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if name == 'resnet-18':
net, params = nnvm.testing.resnet.get_workload(num_layers=18,
batch_size=batch_size, image_shape=(3, 224, 224))
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'squeezenet v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size,
version='1.1')
elif name == 'vgg-16':
net, params = nnvm.testing.vgg.get_workload(batch_size=batch_size, num_layers=16)
else:
raise RuntimeError("Unsupported network: " + name)
return net, params, input_shape, output_shape
from util import get_network, print_progress
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=['resnet-18', 'mobilenet', 'squeezenet v1.1', 'vgg-16'])
parser.add_argument("--device", type=str, required=True, choices=['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
'pixel2', 'rasp3b', 'pynq'])
parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'vgg-16', 'mobilenet', 'squeezenet v1.1', ])
parser.add_argument("--device", type=str, required=True, choices=
['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
'pixel2', 'rasp3b', 'pynq'])
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True)
@ -49,7 +30,7 @@ if __name__ == "__main__":
dtype = 'float32'
if args.network is None:
networks = ['squeezenet v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
networks = ['squeezenet_v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
else:
networks = [args.network]
@ -63,8 +44,10 @@ if __name__ == "__main__":
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")
for network in networks:
print_progress(network)
net, params, input_shape, output_shape = get_network(network, batch_size=1)
print_progress("%-20s building..." % network)
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
@ -79,6 +62,7 @@ if __name__ == "__main__":
lib.export_library(tmp.relpath(filename))
# upload library and params
print_progress("%-20s uploading..." % network)
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
@ -90,7 +74,7 @@ if __name__ == "__main__":
module.set_input(**rparams)
# evaluate
print_progress("%-20s evaluating..." % network)
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))

72
apps/benchmark/util.py Normal file
Просмотреть файл

@ -0,0 +1,72 @@
"""Utility for benchmark"""
import sys
import nnvm
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network
Parameters
----------
name: str
The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ...
batch_size:
batch size
Returns
-------
net: nnvm.symbol
The NNVM symbol of network definition
params: dict
The random parameters for benchmark
input_shape: tuple
The shape of input tensor
output_shape: tuple
The shape of output tensor
"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif "squeezenet" in name:
version = name.split("_v")[1]
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version=version)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif name == 'custom':
# an example for custom network
from nnvm.testing import utils
net = nnvm.sym.Variable('data')
net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1))
net = nnvm.sym.flatten(net)
net = nnvm.sym.dense(net, units=1000)
net, params = utils.create_workload(net, batch_size, (3, 224, 224))
elif name == 'mxnet':
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = nnvm.frontend.from_mxnet(block)
net = nnvm.sym.softmax(net)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
def print_progress(msg):
"""print progress message
Parameters
----------
msg: str
The message to print
"""
sys.stdout.write(msg + "\r")
sys.stdout.flush()

Просмотреть файл

@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .config import ctx_list
from .utils import create_workload
from . import mobilenet
from . import mobilenet_v2
from . import mlp
from . import resnet
from . import vgg

Просмотреть файл

@ -0,0 +1,51 @@
"""
MobileNetV2, load model from gluon model zoo
Reference:
Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation
https://arxiv.org/abs/1801.04381
"""
from .utils import create_workload
from ..frontend.mxnet import _from_mxnet_impl
def get_workload(batch_size, num_classes=1000, multiplier=1.0, dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
multiplier : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
import mxnet as mx
from mxnet.gluon.model_zoo.vision.mobilenet import MobileNetV2
image_shape = (1, 3, 224, 224)
block = MobileNetV2(multiplier=multiplier, classes=num_classes)
data = mx.sym.Variable('data')
sym = block(data)
sym = mx.sym.SoftmaxOutput(sym)
net = _from_mxnet_impl(sym, {})
return create_workload(net, batch_size, image_shape[1:], dtype)

Просмотреть файл

@ -8,16 +8,22 @@ TVM will download these parameters for you when you create the target for the fi
import logging
import os
import json
import sys
from .task import ApplyHistoryBest
from .. import target as _target
from ..contrib.util import tempdir
from ..contrib.download import download
# root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
# the version of each package
PACKAGE_VERSION = {
'vta': "v0.01",
'arm_cpu': "v0.01",
'cuda': "v0.01",
}
logger = logging.getLogger('autotvm')
def _alias(name):
@ -30,7 +36,8 @@ def _alias(name):
def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters.
The corresponding downloaded *.log files under tophub root path will be loaded.
This function will load the corresponding *.log files in AUTOTVM_TOPHUB_ROOT_PATH.
If cannot find them, it will download them from TopHub github repo.
Users can also add their own files in argument `extra_files`.
Parameters
@ -40,21 +47,24 @@ def context(target, extra_files=None):
extra_files: list of str, optional
Extra log files to load
"""
rootpath = AUTOTVM_TOPHUB_ROOT_PATH
best_context = ApplyHistoryBest([])
if isinstance(target, str):
target = _target.create(target)
big_target = str(target).split()[0]
if os.path.isfile(os.path.join(rootpath, big_target + ".log")):
best_context.load(os.path.join(rootpath, big_target + ".log"))
possible_names = [str(target).split()[0]]
for opt in target.options:
if opt.startswith("-device"):
model = _alias(opt[8:])
if os.path.isfile(os.path.join(rootpath, model) + ".log"):
best_context.load(os.path.join(rootpath, model) + ".log")
device = _alias(opt[8:])
possible_names.append(device)
all_packages = list(PACKAGE_VERSION.keys())
for name in possible_names:
if name in all_packages:
check_backend(name)
filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename))
if extra_files:
for filename in extra_files:
@ -63,12 +73,39 @@ def context(target, extra_files=None):
return best_context
def download_package(backend):
"""Download pre-tuned parameters of operators for a backend
def check_backend(backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.
Parameters
----------
backend: str
The name of backend.
"""
backend = _alias(backend)
assert backend in PACKAGE_VERSION, 'Cannot find backend "%s" in TopHub' % backend
version = PACKAGE_VERSION[backend]
package_name = "%s_%s.log" % (backend, version)
if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)):
return
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
try:
download_package(package_name)
except urllib2.URLError as e:
logging.warning("Failed to download tophub package for %s: %s", backend, e)
def download_package(package_name):
"""Download pre-tuned parameters of operators for a backend
Parameters
----------
package_name: str
The name of package
"""
rootpath = AUTOTVM_TOPHUB_ROOT_PATH
@ -81,54 +118,6 @@ def download_package(backend):
if not os.path.isdir(path):
os.mkdir(path)
backend = _alias(backend)
logger.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0)
def check_package(backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.
Parameters
----------
backend: str
The name of package
"""
backend = _alias(backend)
if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, backend + ".log")):
return
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
try:
download_package(backend)
except urllib2.URLError:
logging.warning("Failed to download tophub package for %s", backend)
def list_packages():
"""List all available pre-tuned op parameters for targets
Returns
-------
ret: List
All available packets
"""
path = tempdir()
filename = path.relpath("info.json")
logger.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0)
with open(filename, "r") as fin:
text = "".join(fin.readlines())
info = json.loads(text)
keys = list(info.keys())
keys.sort()
return [(k, info[k]) for k in keys]
logger.info("Download pre-tuned parameters package %s", package_name)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s"
% package_name, os.path.join(rootpath, package_name), True, verbose=0)

Просмотреть файл

@ -1,37 +0,0 @@
# pylint: disable=invalid-name
"""Download pre-tuned parameters of ops"""
import argparse
import logging
from ..autotvm.tophub import list_packages, download_package
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--download", type=str, nargs='+',
help="The targets to download. Use 'all' to download for all targets")
parser.add_argument("-l", "--list", action='store_true', help="List available packages")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.list:
info = list_packages()
print("\n%-20s %-20s" % ("Target", "Size"))
print("-" * 41)
for target, info in info:
print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000)))
elif args.download:
info = list_packages()
all_targets = [x[0] for x in info]
if 'all' in args.download:
targets = all_targets
else:
targets = args.download
for t in targets:
if t not in all_targets:
print("Warning : cannot find tuned parameters of " + t + ". (ignored)")
download_package(t)
else:
parser.print_help()

Просмотреть файл

@ -313,7 +313,7 @@ class Server(object):
self.use_popen = use_popen
if silent:
logger.setLevel(logging.WARN)
logger.setLevel(logging.ERROR)
if use_popen:
cmd = [sys.executable,

Просмотреть файл

@ -425,8 +425,6 @@ def arm_cpu(model='unknown', options=None):
options : str or list of str
Additional options
"""
from . import autotvm
trans_table = {
"pixel2": ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"],
"mate10": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
@ -439,9 +437,6 @@ def arm_cpu(model='unknown', options=None):
}
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
# download pre-tuned parameters for arm_cpu if there is not any.
autotvm.tophub.check_package('arm_cpu')
opts = ["-device=arm_cpu"] + pre_defined_opt
opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("llvm", *opts)

Просмотреть файл

@ -128,7 +128,7 @@ def test_cpu_conv2d():
run_cpu_conv2d(env, remote, key, batch_size, wl)
# load pre-tuned operator parameters for ARM CPU
autotvm.tophub.check_package('vta')
autotvm.tophub.check_backend('vta')
with autotvm.tophub.context('llvm -device=vtacpu'):
vta.testing.run(_run)

Просмотреть файл

@ -154,7 +154,7 @@ for file in [categ_fn, graph_fn, params_fn]:
synset = eval(open(os.path.join(data_dir, categ_fn)).read())
# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_package('vta')
autotvm.tophub.check_backend('vta')
######################################################################