[AUTOTVM] Allow fallback for template & Fix bugs in tuners (#1615)
* support fallback & fix bugs in tuners & clean topi test * update task extraction * update task extraction * fix arm tutorial * Update tune_nnvm_arm.py
This commit is contained in:
Родитель
729224b17f
Коммит
b7beb1ebef
|
@ -239,8 +239,9 @@ def build(graph, target=None, shape=None, dtype="float32",
|
|||
raise ValueError("Target is not set in env or passed as argument.")
|
||||
target = tvm.target.create(target)
|
||||
|
||||
# if not inside an autotvm config dispatch context, load pre-tuned parameters from TopHub
|
||||
if autotvm.task.DispatchContext.current is None:
|
||||
# If current dispatch context is fallback context (the default root context),
|
||||
# then load pre-tuned parameters from TopHub
|
||||
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
|
||||
tophub_context = autotvm.tophub.context(target)
|
||||
else:
|
||||
tophub_context = autotvm.util.EmptyContext()
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
"""Test task extraction for autotvm"""
|
||||
|
||||
import nnvm.testing
|
||||
import nnvm.compiler
|
||||
from tvm import autotvm
|
||||
|
||||
def get_network(name, batch_size):
|
||||
"""Get the symbol definition and random weight of a network"""
|
||||
input_shape = (batch_size, 3, 224, 224)
|
||||
output_shape = (batch_size, 1000)
|
||||
|
||||
if name == 'resnet-18':
|
||||
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
|
||||
elif name == 'mobilenet':
|
||||
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
|
||||
elif name == 'squeezenet v1.1':
|
||||
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
|
||||
elif name == 'vgg-16':
|
||||
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size)
|
||||
elif name == 'dcgan':
|
||||
net, params = nnvm.testing.dcgan.get_workload(batch_size=batch_size)
|
||||
input_shape = (batch_size, 100)
|
||||
else:
|
||||
raise ValueError("Unsupported network: " + name)
|
||||
|
||||
return net, params, input_shape, output_shape
|
||||
|
||||
def test_task_extraction():
|
||||
target = 'llvm'
|
||||
dtype = 'float32'
|
||||
|
||||
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, target=target,
|
||||
shape={'data': input_shape}, dtype=dtype,
|
||||
symbols=(nnvm.sym.conv2d,))
|
||||
assert len(tasks) == 12
|
||||
|
||||
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, target=target,
|
||||
shape={'data': input_shape}, dtype=dtype,
|
||||
symbols=(nnvm.sym.dense,))
|
||||
assert len(tasks) == 1
|
||||
|
||||
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, target=target,
|
||||
shape={'data': input_shape}, dtype=dtype,
|
||||
symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
|
||||
assert len(tasks) == 13
|
||||
|
||||
net, params, input_shape, out_shape = get_network('mobilenet', batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, target=target,
|
||||
shape={'data': input_shape}, dtype=dtype,
|
||||
symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
|
||||
assert len(tasks) == 20
|
||||
|
||||
net, params, input_shape, out_shape = get_network('dcgan', batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, target=target,
|
||||
shape={'data': input_shape}, dtype=dtype,
|
||||
symbols=(nnvm.sym.conv2d_transpose,))
|
||||
assert len(tasks) == 4
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_extraction()
|
|
@ -25,5 +25,6 @@ from . import tophub
|
|||
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
|
||||
from .tuner import callback
|
||||
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
|
||||
ApplyHistoryBest as apply_history_best
|
||||
register_topi_compute, register_topi_schedule, \
|
||||
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best
|
||||
from .env import GLOBAL_SCOPE
|
||||
|
|
|
@ -89,8 +89,9 @@ def measure_option(measure_func,
|
|||
|
||||
callable: customized build function for other backends (e.g. VTA).
|
||||
See measure/measure_methods.py::default_build_func for example.
|
||||
check_correctness: bool
|
||||
Whether check correctness after measurement. This will use llvm cpu as reference.
|
||||
check_correctness: bool, optional
|
||||
Whether check correctness after measurement. This will use llvm cpu target to generate
|
||||
reference output.
|
||||
replay_db : Database, optional
|
||||
The database that we retrieve saved MeasureResult from.
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
|
|||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this check (units: seconds).
|
||||
If time is out, a RuntimerError will be raised.
|
||||
If time is out, a RuntimeError will be raised.
|
||||
"""
|
||||
def _check():
|
||||
remote = request_remote(device_key, tracker_addr, priority)
|
||||
|
@ -281,11 +281,11 @@ def rpc(key,
|
|||
results: List of MeasureResult
|
||||
The results for input_pack
|
||||
"""
|
||||
remote = request_remote(key, (host, port), priority, session_timeout)
|
||||
remote_args = (key, (host, port), priority, session_timeout)
|
||||
|
||||
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
||||
ref_input, ref_output,
|
||||
remote)
|
||||
remote_args)
|
||||
return res
|
||||
|
||||
fmeasure.pack_size = pack_size
|
||||
|
@ -294,7 +294,7 @@ def rpc(key,
|
|||
|
||||
|
||||
def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
||||
ref_input=None, ref_output=None, remote=None):
|
||||
ref_input=None, ref_output=None, remote_args=None):
|
||||
"""Measure the time cost for a pack of inputs.
|
||||
|
||||
(Note: A pack is a list of inputs which will be measured inside a same RPC session)
|
||||
|
@ -318,8 +318,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
|||
Reference input for checking correctness
|
||||
ref_output: Array of np.ndarray, optional
|
||||
Reference output for checking correctness
|
||||
remote: RPCSession, optional
|
||||
The remote RPC session
|
||||
remote_args: Tuple, optional
|
||||
The arguments to request_remote. If is not None, will use remote rpc devices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -327,7 +327,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
|||
The list of results of measurement.
|
||||
"""
|
||||
res_pack = []
|
||||
tmp_dir = util.tempdir() if remote else None
|
||||
tmp_dir = util.tempdir() if remote_args else None
|
||||
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
|
||||
|
||||
for inp in input_pack:
|
||||
tic = time.time()
|
||||
|
@ -360,31 +361,36 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
|||
tstamp - tic, tstamp))
|
||||
continue
|
||||
|
||||
# upload built module
|
||||
if remote:
|
||||
remote.upload(tmp_dir.relpath(filename))
|
||||
func = remote.load_module(filename)
|
||||
ctx = remote.context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
else:
|
||||
ctx = context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
|
||||
# measure time
|
||||
errno = MeasureErrorNo.NO_ERROR
|
||||
try:
|
||||
# upload built module
|
||||
if remote_args:
|
||||
remote = request_remote(*remote_args)
|
||||
remote.upload(tmp_dir.relpath(filename))
|
||||
func = remote.load_module(filename)
|
||||
ctx = remote.context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
else:
|
||||
ctx = context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
|
||||
# set input
|
||||
if ref_input:
|
||||
args = [nd.array(x, ctx=ctx) for x in ref_input]
|
||||
else:
|
||||
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
|
||||
for x in arg_bufs]
|
||||
|
||||
costs = time_f(*args).results
|
||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||
costs = list(costs)
|
||||
costs.sort()
|
||||
costs = tuple(costs[1:-1])
|
||||
|
||||
# check correctness of output
|
||||
if ref_output:
|
||||
for expected, real in zip(ref_output, args):
|
||||
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
|
||||
|
|
|
@ -9,7 +9,7 @@ of typical tasks of interest.
|
|||
from .task import Task, create, register, template, get_config, args_to_workload
|
||||
from .space import ConfigSpace, ConfigEntity
|
||||
from .code_hash import attach_code_hash, attach_code_hash_to_arg
|
||||
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, dispatcher
|
||||
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, FallbackContext, dispatcher
|
||||
|
||||
from .topi_integration import register_topi_compute, register_topi_schedule
|
||||
from .nnvm_integration import extract_from_graph
|
||||
|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
|
||||
from tvm import target as _target
|
||||
|
||||
from .space import ConfigSpace
|
||||
from .space import FallbackConfigEntity
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
|
@ -34,9 +34,36 @@ class DispatchContext(object):
|
|||
"""
|
||||
current = None
|
||||
|
||||
def __init__(self):
|
||||
self._old_ctx = DispatchContext.current
|
||||
|
||||
def query(self, target, workload):
|
||||
"""
|
||||
Query the context to get the specific implementation.
|
||||
Query the context to get the specific config for a template.
|
||||
If cannot find the result inside this context, this function will query it
|
||||
from the upper contexts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The current target
|
||||
workload : Workload
|
||||
The current workload.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cfg : ConfigSpace
|
||||
The specific configuration.
|
||||
"""
|
||||
ret = self._query_inside(target, workload)
|
||||
if ret is None:
|
||||
ret = self._old_ctx.query(target, workload)
|
||||
return ret
|
||||
|
||||
def _query_inside(self, target, workload):
|
||||
"""
|
||||
Query the context to get the specific config for a template.
|
||||
This function only query config inside this context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -117,17 +144,17 @@ def dispatcher(fworkload):
|
|||
def dispatch_func(func, *args, **kwargs):
|
||||
"""The wrapped dispatch function"""
|
||||
tgt = _target.current_target()
|
||||
context = DispatchContext.current
|
||||
if context is None:
|
||||
raise RuntimeError("DispatchContext is not initialized")
|
||||
workload = func(*args, **kwargs)
|
||||
cfg = context.query(tgt, workload)
|
||||
if cfg.template_key:
|
||||
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
|
||||
else:
|
||||
assert dispatch_dict, "No func registered for this dispatcher"
|
||||
cfg = DispatchContext.current.query(tgt, workload)
|
||||
if cfg.is_fallback and not cfg.template_key:
|
||||
# first try 'direct' template
|
||||
if 'direct' in dispatch_dict:
|
||||
return dispatch_dict['direct'](cfg, *args, **kwargs)
|
||||
# otherwise pick a random template
|
||||
for v in dispatch_dict.values():
|
||||
return v(cfg, *args, **kwargs)
|
||||
else:
|
||||
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
|
||||
|
||||
fdecorate = decorate(fworkload, dispatch_func)
|
||||
fdecorate.register = register
|
||||
|
@ -135,7 +162,7 @@ def dispatcher(fworkload):
|
|||
|
||||
|
||||
class ApplyConfig(DispatchContext):
|
||||
"""Apply a specific config entity during query.
|
||||
"""Apply a deterministic config entity for all queries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -147,7 +174,7 @@ class ApplyConfig(DispatchContext):
|
|||
self._config = config
|
||||
self.workload = None
|
||||
|
||||
def query(self, target, workload):
|
||||
def _query_inside(self, target, workload):
|
||||
"""Override query"""
|
||||
self.workload = workload
|
||||
return self._config
|
||||
|
@ -164,20 +191,12 @@ class ApplyHistoryBest(DispatchContext):
|
|||
If is str, then it should be the filename of a records log file.
|
||||
Each row of this file is an encoded record pair.
|
||||
Otherwise, it is an iterator.
|
||||
default: ConfigEntity, optional
|
||||
The default config to return when no history records
|
||||
allow_fallback: bool
|
||||
Whether allow to use a fallback configuration if cannot find
|
||||
tuned result.
|
||||
"""
|
||||
def __init__(self, records, default=None, allow_fallback=False):
|
||||
def __init__(self, records):
|
||||
super(ApplyHistoryBest, self).__init__()
|
||||
|
||||
self.best_by_targetkey = {}
|
||||
self.best_by_model = {}
|
||||
self._default = default
|
||||
self._allow_fallback = allow_fallback
|
||||
self.fallback = {}
|
||||
|
||||
if records:
|
||||
self.load(records)
|
||||
|
@ -234,7 +253,7 @@ class ApplyHistoryBest(DispatchContext):
|
|||
|
||||
logger.debug("Finish loading %d records", counter)
|
||||
|
||||
def query(self, target, workload):
|
||||
def _query_inside(self, target, workload):
|
||||
if target is None:
|
||||
raise RuntimeError("Need a target context to find the history best. "
|
||||
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
|
||||
|
@ -254,20 +273,50 @@ class ApplyHistoryBest(DispatchContext):
|
|||
if key in self.best_by_targetkey:
|
||||
return self.best_by_targetkey[key][0].config
|
||||
|
||||
if self._default:
|
||||
return self._default
|
||||
return None
|
||||
|
||||
if self._allow_fallback:
|
||||
key = (target, workload)
|
||||
if key in self.fallback:
|
||||
return self.fallback[key]
|
||||
|
||||
class FallbackContext(DispatchContext):
|
||||
"""
|
||||
A fallback dispatch context.
|
||||
|
||||
Any tunable template can be called under this context.
|
||||
This is the root context.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FallbackContext, self).__init__()
|
||||
self.memory = {}
|
||||
self.silent = False
|
||||
|
||||
def _query_inside(self, target, workload):
|
||||
key = (str(target), workload)
|
||||
if key in self.memory:
|
||||
return self.memory[key]
|
||||
|
||||
if not self.silent:
|
||||
logger.warning(
|
||||
"Cannot find config for target=%s, workload=%s. A fallback configuration "
|
||||
"is used, which may bring great performance regression.", target, workload)
|
||||
cfg = ConfigSpace()
|
||||
self.fallback[key] = cfg
|
||||
return cfg
|
||||
cfg = FallbackConfigEntity()
|
||||
|
||||
raise RuntimeError(
|
||||
"Cannot find config for target=%s, workload=%s. You need to do tuning "
|
||||
"for this workload to get the config." % (target, workload))
|
||||
# cache this config
|
||||
self.memory[key] = cfg
|
||||
return cfg
|
||||
|
||||
def clear_cache(self, target, workload):
|
||||
"""Clear fallback cache. Pass the same argument as _query_inside to this function
|
||||
to clean the cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The current target
|
||||
workload : Workload
|
||||
The current workload.
|
||||
"""
|
||||
key = (str(target), workload)
|
||||
if key in self.memory:
|
||||
del self.memory[key]
|
||||
|
||||
DispatchContext.current = FallbackContext()
|
||||
|
|
|
@ -7,11 +7,10 @@ import warnings
|
|||
import logging
|
||||
|
||||
|
||||
from ... import tensor, placeholder, target as _target
|
||||
from ... import tensor, placeholder, create_schedule, target as _target
|
||||
|
||||
from ..util import get_const_tuple
|
||||
from .task import create, register
|
||||
from .dispatcher import ApplyHistoryBest
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
|
@ -56,40 +55,68 @@ class TaskExtractEnv:
|
|||
import topi
|
||||
import nnvm
|
||||
|
||||
# NOTE: To add more symbols, you only need to change the following lists
|
||||
# nnvm symbol -> topi compute
|
||||
self.symbol2topi = {
|
||||
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
|
||||
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose],
|
||||
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
|
||||
nnvm.sym.dense: [topi.nn.dense],
|
||||
}
|
||||
|
||||
# topi compute -> autotvm task name
|
||||
self.topi_to_task = {
|
||||
topi.nn.conv2d: "topi_nn_conv2d",
|
||||
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
|
||||
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
|
||||
topi.nn.dense: "topi_nn_dense",
|
||||
}
|
||||
|
||||
self._register_dummy()
|
||||
self.topi_to_schedule = {
|
||||
topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
|
||||
topi.generic.schedule_conv2d_nhwc],
|
||||
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
|
||||
topi.generic.schedule_depthwise_conv2d_nhwc],
|
||||
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
|
||||
topi.nn.dense: [topi.generic.schedule_dense],
|
||||
}
|
||||
|
||||
self._register_tracing()
|
||||
self._register_topi_task()
|
||||
self.task_collection = []
|
||||
self.wanted_topi_funcs = list(self.topi_to_task.keys())
|
||||
|
||||
def _register_dummy(self):
|
||||
"""Register dummy function to track the topi function call"""
|
||||
for func in self.topi_to_task:
|
||||
def _local_scope(local_func):
|
||||
"""build a scope to holds the function"""
|
||||
@local_func.register("dummy", )
|
||||
def _dummy_func(*args, **kwargs):
|
||||
def _register_tracing(self):
|
||||
"""Register tracing function to track the topi function call"""
|
||||
# register topi compute for "tracing" target
|
||||
for topi_compute in self.topi_to_task:
|
||||
def _local_scope(compute_func):
|
||||
"""start a scope to hold the local function in for loop"""
|
||||
|
||||
@compute_func.register("tracing", )
|
||||
def _tracing_topi_compute(*args, **kwargs):
|
||||
assert not kwargs, "Do not support extracting tuning tasks when" \
|
||||
"kwargs is used in TOPI function call." \
|
||||
"Please modify it to use only positional args."
|
||||
|
||||
if (self.topi_to_task[local_func], serialize_args(args)) \
|
||||
not in self.task_collection:
|
||||
self.task_collection.append((self.topi_to_task[local_func],
|
||||
serialize_args(args)))
|
||||
with _target.create("opencl"):
|
||||
return local_func(*args)
|
||||
if compute_func in self.wanted_topi_funcs: # record this call
|
||||
key = (self.topi_to_task[compute_func], serialize_args(args))
|
||||
if key not in self.task_collection:
|
||||
self.task_collection.append(key)
|
||||
|
||||
_local_scope(func)
|
||||
return compute_func.fdefault(*args)
|
||||
_local_scope(topi_compute)
|
||||
|
||||
# register topi schedule for "tracing" target
|
||||
for topi_compute in self.topi_to_task:
|
||||
for topi_schedule in self.topi_to_schedule[topi_compute]:
|
||||
def _local_scope_(schedule_func):
|
||||
"""start a scope to hold the local function in for loop"""
|
||||
|
||||
@schedule_func.register("tracing", )
|
||||
def _tracing_topi_compute(outs):
|
||||
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
|
||||
return create_schedule([x.op for x in outs])
|
||||
_local_scope_(topi_schedule)
|
||||
|
||||
def _register_topi_task(self):
|
||||
"""register tuning wrapper for topi function"""
|
||||
|
@ -125,17 +152,47 @@ class TaskExtractEnv:
|
|||
s = topi.generic.schedule_conv2d_transpose_nchw([C])
|
||||
return s, [A, W, C]
|
||||
|
||||
def reset(self):
|
||||
"""Reset task collections"""
|
||||
@register("topi_nn_dense")
|
||||
def _topi_nn_dense(*args, **kwargs):
|
||||
assert not kwargs, "Do not support kwargs in template function call"
|
||||
args = deserialize_args(args)
|
||||
data, weight, bias = args
|
||||
C = topi.nn.dense(*args, **kwargs)
|
||||
s = topi.generic.schedule_dense([C])
|
||||
if bias is not None:
|
||||
return s, [data, weight, bias, C]
|
||||
return s, [data, weight, C]
|
||||
|
||||
def reset(self, wanted_topi_funcs):
|
||||
"""Reset task collections
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wanted_topi_funcs: List of function
|
||||
The topi function to be extracted
|
||||
"""
|
||||
self.task_collection = []
|
||||
self.wanted_topi_funcs = wanted_topi_funcs
|
||||
|
||||
def get_tasks(self):
|
||||
"""Get collected tasks"""
|
||||
"""Get collected tasks
|
||||
|
||||
Returns
|
||||
-------
|
||||
tasks: List of tuple(name, args)
|
||||
A list of tasks extracted from the nnvm graph
|
||||
"""
|
||||
return self.task_collection
|
||||
|
||||
@staticmethod
|
||||
def get():
|
||||
"""Get the single instance of TaskExtractEnv"""
|
||||
"""Get the single instance of TaskExtractEnv
|
||||
|
||||
Returns
|
||||
-------
|
||||
env: TaskExtractEnv
|
||||
The single instance of TaskExtractEnv
|
||||
"""
|
||||
if not TaskExtractEnv.current:
|
||||
TaskExtractEnv.current = TaskExtractEnv()
|
||||
return TaskExtractEnv.current
|
||||
|
@ -144,8 +201,8 @@ class TaskExtractEnv:
|
|||
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
|
||||
""" Extract tuning tasks from a nnvm graph.
|
||||
|
||||
This function collects tunning tasks by building the graph
|
||||
with a "dummy" target and tracing all the calls to topi.
|
||||
This function collects tuning tasks by building the graph
|
||||
with a "tracing" target and tracing all the calls to topi.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -158,7 +215,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
|
|||
target: tvm.target.Target
|
||||
The compilation target
|
||||
symbols : Array of nnvm.symbol
|
||||
Array of nnvm symbols
|
||||
Array of nnvm symbols want to be tuned
|
||||
target_host: tvm.target.Target
|
||||
The host compilation target
|
||||
|
||||
|
@ -179,16 +236,16 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
|
|||
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
|
||||
|
||||
# run compiler to collect all TOPI calls during compilation
|
||||
env.reset()
|
||||
env.reset(topi_funcs)
|
||||
|
||||
# disable logger temporarily
|
||||
old_state = logger.disabled
|
||||
logger.disabled = True
|
||||
|
||||
# use a dummy target to do a fake compile for collecting topi calls
|
||||
dummy_target = _target.create("opencl -device=dummy")
|
||||
with ApplyHistoryBest([], allow_fallback=True):
|
||||
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
|
||||
# use a "tracing" target to do a fake compile for collecting topi calls
|
||||
tracing_target = _target.create("llvm -device=tracing")
|
||||
nnvm.compiler.engine.clear_cache()
|
||||
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
|
||||
|
||||
logger.disabled = old_state
|
||||
|
||||
|
|
|
@ -567,15 +567,16 @@ class ConfigSpace(object):
|
|||
"""
|
||||
def __init__(self):
|
||||
# private dict to provide sugar
|
||||
self.space_map = OrderedDict() # name -> space
|
||||
self.space_map = OrderedDict() # name -> space
|
||||
self._collect = True
|
||||
self._length = None
|
||||
self._entity_map = OrderedDict()
|
||||
self._entity_map = OrderedDict() # name -> entity
|
||||
self._constraints = []
|
||||
self.errors = []
|
||||
self.template_key = None
|
||||
self.code_hash = None
|
||||
self.flop = 0
|
||||
self.is_fallback = False
|
||||
|
||||
@staticmethod
|
||||
def axis(var):
|
||||
|
@ -607,6 +608,15 @@ class ConfigSpace(object):
|
|||
If is 'candidate', try listed candidate.
|
||||
kwargs: dict
|
||||
extra arguments for policy
|
||||
see examples below for how to use filter
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # use custom candidates
|
||||
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
|
||||
|
||||
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
|
||||
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
|
||||
"""
|
||||
axes = [axis]
|
||||
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
|
||||
|
@ -889,3 +899,45 @@ class ConfigEntity(ConfigSpace):
|
|||
def __repr__(self):
|
||||
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
|
||||
self.code_hash, self.index)
|
||||
|
||||
class FallbackConfigEntity(ConfigSpace):
|
||||
"""The config entity created to support fallback"""
|
||||
|
||||
def __init__(self):
|
||||
super(FallbackConfigEntity, self).__init__()
|
||||
self.is_fallback = True
|
||||
|
||||
def fallback_split(self, name, constraints):
|
||||
"""Fallback a split knob
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name of the knob
|
||||
constraints: List of int
|
||||
The maximum tile size for every dimension. Value `-1` means no constraint.
|
||||
|
||||
Examples
|
||||
--------
|
||||
If you use cfg.define_split('tile_0', 128, num_outputs=3),
|
||||
Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4]
|
||||
|
||||
If you use cfg.define_split('tile_0', 49, num_outputs=3),
|
||||
Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1]
|
||||
"""
|
||||
space = self.space_map[name]
|
||||
assert len(constraints) == space.num_outputs
|
||||
indices = np.arange(space.num_outputs)
|
||||
|
||||
# '-1' means no constraint
|
||||
constraints = [x if x != -1 else 1e10 for x in constraints]
|
||||
|
||||
for entity in reversed(space.entities):
|
||||
if all([entity.size[i] <= constraints[i] for i in indices]):
|
||||
self._entity_map[name] = entity
|
||||
return
|
||||
|
||||
raise RuntimeError("Cannot find feasible fallback split entity for node: " + name)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
|
||||
|
|
|
@ -206,7 +206,7 @@ def args_to_workload(x):
|
|||
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
|
||||
return x.value
|
||||
elif x is None:
|
||||
return None
|
||||
return 0
|
||||
else:
|
||||
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
|
||||
'primitive types only' % type(x))
|
||||
|
|
|
@ -28,7 +28,7 @@ def _alias(name):
|
|||
return table.get(name, name)
|
||||
|
||||
|
||||
def context(target, extra_files=None, allow_fallback=False):
|
||||
def context(target, extra_files=None):
|
||||
"""Return the dispatch context with pre-tuned parameters.
|
||||
The corresponding downloaded *.log files under tophub root path will be loaded.
|
||||
Users can also add their own files in argument `extra_files`.
|
||||
|
@ -39,12 +39,9 @@ def context(target, extra_files=None, allow_fallback=False):
|
|||
The compilation target
|
||||
extra_files: list of str, optional
|
||||
Extra log files to load
|
||||
allow_fallback: bool
|
||||
Whether allow to use a fallback configuration if cannot find
|
||||
tuned result.
|
||||
"""
|
||||
rootpath = AUTOTVM_TOPHUB_ROOT_PATH
|
||||
best_context = ApplyHistoryBest([], allow_fallback=allow_fallback)
|
||||
best_context = ApplyHistoryBest([])
|
||||
|
||||
if isinstance(target, str):
|
||||
target = _target.create(target)
|
||||
|
|
|
@ -86,13 +86,9 @@ class GATuner(Tuner):
|
|||
|
||||
# cross over
|
||||
indices = np.arange(len(genes))
|
||||
max_score = np.max(scores)
|
||||
if max_score < 1e-8:
|
||||
probs = np.empty_like(scores)
|
||||
probs[:] = 1.0 / len(scores)
|
||||
else:
|
||||
scores /= max_score
|
||||
probs = scores / np.sum(scores)
|
||||
scores += 1e-8
|
||||
scores /= np.max(scores)
|
||||
probs = scores / np.sum(scores)
|
||||
tmp_genes = []
|
||||
for _ in range(self.pop_size):
|
||||
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
|
||||
|
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import numpy as np
|
||||
|
||||
from .tuner import Tuner
|
||||
|
||||
from ..env import GLOBAL_SCOPE
|
||||
|
||||
class FeatureCache(object):
|
||||
"""Feature cache manager for cache sharing between different cost models"""
|
||||
|
@ -119,11 +119,9 @@ class CostModel(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def clone_new(self):
|
||||
"""Clone a new model with the same parameters.
|
||||
This function will only copy hyperparameters of the tuner, not all the trained model
|
||||
|
||||
This is used for deriving a base model conveniently
|
||||
def spawn_base_model(self):
|
||||
"""Clone a base model with the same parameters.
|
||||
The base model is used to fit history data in transfer learning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -221,7 +219,9 @@ class ModelBasedTuner(Tuner):
|
|||
break
|
||||
self.trial_pt += 1
|
||||
|
||||
if self.trial_pt >= len(self.trials): # trial list is empty, choose randomly
|
||||
if self.trial_pt >= len(self.trials) - int(0.05 * self.plan_size):
|
||||
# if the trial list is empty or
|
||||
# the tuner is doing the last 5% trials (e-greedy), choose randomly
|
||||
index = np.random.randint(len(self.space))
|
||||
while index in self.visited:
|
||||
index = np.random.randint(len(self.space))
|
||||
|
@ -264,18 +264,16 @@ class ModelBasedTuner(Tuner):
|
|||
self.train_ct += 1
|
||||
|
||||
def load_history(self, data_set):
|
||||
# filter data, only pick the data with a same task
|
||||
data = []
|
||||
for inp, res in data_set:
|
||||
if inp.task.name == self.task.name and \
|
||||
inp.config.template_key == self.task.config_space.template_key:
|
||||
data.append((inp, res))
|
||||
if not data:
|
||||
return
|
||||
# set in_tuning as True to make the feature extraction consistent
|
||||
GLOBAL_SCOPE.in_tuning = True
|
||||
|
||||
# fit base model
|
||||
base_model = self.cost_model.clone_new()
|
||||
base_model.fit_log(data, self.plan_size)
|
||||
base_model = self.cost_model.spawn_base_model()
|
||||
success = base_model.fit_log(data_set, self.plan_size)
|
||||
|
||||
if not success:
|
||||
GLOBAL_SCOPE.in_tuning = False
|
||||
return
|
||||
|
||||
# use base model to select initial points
|
||||
if not self.trials:
|
||||
|
@ -285,6 +283,7 @@ class ModelBasedTuner(Tuner):
|
|||
self.trial_pt = 0
|
||||
|
||||
self.cost_model.load_basemodel(base_model)
|
||||
GLOBAL_SCOPE.in_tuning = False
|
||||
|
||||
def has_next(self):
|
||||
return len(self.visited) < len(self.space)
|
||||
|
|
|
@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
|||
|
||||
new_scores = model.predict(new_points)
|
||||
|
||||
ac_prob = np.exp((new_scores - scores) / t)
|
||||
ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
|
||||
ac_index = np.random.random(len(ac_prob)) < ac_prob
|
||||
|
||||
points[ac_index] = new_points[ac_index]
|
||||
|
|
|
@ -31,6 +31,10 @@ class Tuner(object):
|
|||
self.best_measure_pair = None
|
||||
self.best_iter = 0
|
||||
|
||||
# time to leave
|
||||
self.ttl = None
|
||||
self.n_trial = None
|
||||
|
||||
def has_next(self):
|
||||
"""Whether has next untried config in the space
|
||||
|
||||
|
@ -76,7 +80,7 @@ class Tuner(object):
|
|||
measure_option: dict
|
||||
The options for how to measure generated code.
|
||||
You should use the return value ot autotvm.measure_option for this argument.
|
||||
early_stopping: int
|
||||
early_stopping: int, optional
|
||||
Early stop the tuning when not finding better configs in this number of trials
|
||||
callbacks: List of callable
|
||||
A list of callback functions. The signature of callback function is
|
||||
|
@ -87,6 +91,8 @@ class Tuner(object):
|
|||
measure_batch = create_measure_batch(self.task, measure_option)
|
||||
n_parallel = getattr(measure_batch, 'n_parallel', 1)
|
||||
early_stopping = early_stopping or 1e9
|
||||
self.n_trial = n_trial
|
||||
|
||||
old_level = logger.level
|
||||
|
||||
GLOBAL_SCOPE.in_tuning = True
|
||||
|
@ -127,11 +133,12 @@ class Tuner(object):
|
|||
for callback in callbacks:
|
||||
callback(self, inputs, results)
|
||||
|
||||
if i > self.best_iter + early_stopping:
|
||||
self.ttl = min(early_stopping + self.best_iter, n_trial) - i
|
||||
if i >= self.best_iter + early_stopping:
|
||||
logger.debug("Early stopped. Best iter: %d.", self.best_iter)
|
||||
break
|
||||
|
||||
if error_ct > 50:
|
||||
if error_ct > 150:
|
||||
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
else:
|
||||
|
|
|
@ -31,8 +31,12 @@ class XGBoostCostModel(CostModel):
|
|||
If is 'curve', use sampled curve feature (relation feature).
|
||||
|
||||
Note on choosing feature type:
|
||||
For single task tuning, 'itervar' and 'knob' is good.
|
||||
For single task tuning, 'itervar' and 'knob' are good.
|
||||
'itervar' is more accurate but 'knob' is much faster.
|
||||
There are some constraints on 'itervar', if you meet
|
||||
problems with feature extraction when using 'itervar',
|
||||
you can swith to 'knob'.
|
||||
|
||||
For cross-shape tuning (e.g. many convolutions with different shapes),
|
||||
'itervar' and 'curve' has better transferability,
|
||||
'knob' is faster.
|
||||
|
@ -46,8 +50,11 @@ class XGBoostCostModel(CostModel):
|
|||
The number of threads.
|
||||
log_interval: int, optional
|
||||
If is not none, the cost model will print training log every `log_interval` iterations.
|
||||
upper_model: XGBoostCostModel, optional
|
||||
The upper model used in transfer learning
|
||||
"""
|
||||
def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25):
|
||||
def __init__(self, task, feature_type, loss_type, num_threads=4, log_interval=25,
|
||||
upper_model=None):
|
||||
super(XGBoostCostModel, self).__init__()
|
||||
|
||||
if xgb is None:
|
||||
|
@ -109,35 +116,51 @@ class XGBoostCostModel(CostModel):
|
|||
else:
|
||||
raise RuntimeError("Invalid feature type " + feature_type)
|
||||
|
||||
self.feature_cache = FeatureCache()
|
||||
if upper_model: # share a same feature cache with upper model
|
||||
self.feature_cache = upper_model.feature_cache
|
||||
else:
|
||||
self.feature_cache = FeatureCache()
|
||||
self.upper_model = upper_model
|
||||
self.feature_extra_ct = 0
|
||||
self.pool = None
|
||||
self.base_model = None
|
||||
self.upper_model = None
|
||||
|
||||
self._sample_size = 0
|
||||
self._reset_pool(self.space, self.target, self.task)
|
||||
|
||||
self._reset_pool()
|
||||
def _reset_pool(self, space, target, task):
|
||||
"""reset processing pool for feature extraction"""
|
||||
|
||||
def _reset_pool(self):
|
||||
# reset processing pool for feature extraction
|
||||
if self.upper_model: # base model will reuse upper model's pool,
|
||||
self.upper_model._reset_pool(space, target, task)
|
||||
return
|
||||
|
||||
self._close_pool()
|
||||
|
||||
# use global variable to pass common arguments
|
||||
global _extract_space, _extract_target, _extract_task
|
||||
_extract_space = space
|
||||
_extract_target = target
|
||||
_extract_task = task
|
||||
self.pool = multiprocessing.Pool(self.num_threads)
|
||||
|
||||
def _close_pool(self):
|
||||
if self.pool:
|
||||
self.pool.terminate()
|
||||
self.pool.join()
|
||||
del self.pool
|
||||
# use global variable to pass common arguments
|
||||
global _extract_space, _extract_target, _extract_task
|
||||
_extract_space = self.space
|
||||
_extract_target = self.target
|
||||
_extract_task = self.task
|
||||
self.pool = multiprocessing.Pool(self.num_threads)
|
||||
self.pool = None
|
||||
|
||||
def _get_pool(self):
|
||||
if self.upper_model:
|
||||
return self.upper_model._get_pool()
|
||||
return self.pool
|
||||
|
||||
def _base_model_discount(self):
|
||||
return 1.0 / (2 ** (self._sample_size / 50.0))
|
||||
return 1.0 / (2 ** (self._sample_size / 64.0))
|
||||
|
||||
def fit(self, xs, ys, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
self._reset_pool(self.space, self.target, self.task)
|
||||
|
||||
x_train = self._get_feature(xs)
|
||||
y_train = np.array(ys)
|
||||
|
@ -150,8 +173,12 @@ class XGBoostCostModel(CostModel):
|
|||
self._sample_size = len(x_train)
|
||||
|
||||
if self.base_model:
|
||||
dtrain.set_base_margin(self._base_model_discount() *
|
||||
self.base_model.predict(xs, output_margin=True))
|
||||
discount = self._base_model_discount()
|
||||
if discount < 0.05: # discard base model
|
||||
self.base_model.upper_model = None
|
||||
self.base_model = None
|
||||
else:
|
||||
dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
|
||||
|
||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||
num_boost_round=8000,
|
||||
|
@ -172,11 +199,19 @@ class XGBoostCostModel(CostModel):
|
|||
|
||||
def fit_log(self, records, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
|
||||
args = list(records)
|
||||
logger.debug("XGB load %d entries from history log file", len(args))
|
||||
# filter data, only pick the data with a same task
|
||||
data = []
|
||||
for inp, res in records:
|
||||
if inp.task.name == self.task.name and \
|
||||
inp.config.template_key == self.task.config_space.template_key:
|
||||
data.append((inp, res))
|
||||
|
||||
logger.debug("XGB load %d entries from history log file", len(data))
|
||||
|
||||
# extract feature
|
||||
self._reset_pool(self.space, self.target, self.task)
|
||||
pool = self._get_pool()
|
||||
if self.fea_type == 'itervar':
|
||||
feature_extract_func = _extract_itervar_feature_log
|
||||
elif self.fea_type == 'knob':
|
||||
|
@ -185,10 +220,21 @@ class XGBoostCostModel(CostModel):
|
|||
feature_extract_func = _extract_curve_feature_log
|
||||
else:
|
||||
raise RuntimeError("Invalid feature type: " + self.fea_type)
|
||||
res = self.pool.map(feature_extract_func, args)
|
||||
xs, ys = zip(*res)
|
||||
xs, ys = np.array(xs), np.array(ys)
|
||||
res = pool.map(feature_extract_func, data)
|
||||
|
||||
# filter out feature with different shapes
|
||||
fea_len = len(self._get_feature([0])[0])
|
||||
|
||||
xs, ys = [], []
|
||||
for x, y in res:
|
||||
if len(x) == fea_len:
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
|
||||
if len(xs) < 500: # no enough samples
|
||||
return False
|
||||
|
||||
xs, ys = np.array(xs), np.array(ys)
|
||||
x_train = xs
|
||||
y_train = ys
|
||||
y_max = np.max(y_train)
|
||||
|
@ -212,6 +258,8 @@ class XGBoostCostModel(CostModel):
|
|||
|
||||
logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
|
||||
|
||||
return True
|
||||
|
||||
def predict(self, xs, output_margin=False):
|
||||
feas = self._get_feature(xs)
|
||||
dtest = xgb.DMatrix(feas)
|
||||
|
@ -224,20 +272,12 @@ class XGBoostCostModel(CostModel):
|
|||
|
||||
def load_basemodel(self, base_model):
|
||||
self.base_model = base_model
|
||||
if isinstance(base_model, XGBoostCostModel):
|
||||
# share feature cache
|
||||
base_model.feature_cache = self.feature_cache
|
||||
self.base_model._close_pool()
|
||||
self.base_model.upper_model = self
|
||||
|
||||
# close thread pool
|
||||
if base_model.pool:
|
||||
base_model.pool.terminate()
|
||||
base_model.pool.join()
|
||||
del base_model.pool
|
||||
self.base_model.upper_model = self
|
||||
|
||||
def clone_new(self):
|
||||
def spawn_base_model(self):
|
||||
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
|
||||
self.num_threads, self.log_interval)
|
||||
self.num_threads, self.log_interval, self)
|
||||
|
||||
def _get_feature(self, indexes):
|
||||
"""get features for indexes, run extraction if we do not have cache for them"""
|
||||
|
@ -251,7 +291,7 @@ class XGBoostCostModel(CostModel):
|
|||
need_extract = [x for x in indexes if x not in fea_cache]
|
||||
|
||||
if need_extract:
|
||||
pool = self.pool if self.upper_model is None else self.upper_model.pool
|
||||
pool = self._get_pool()
|
||||
feas = pool.map(self.feature_extract_func, need_extract)
|
||||
for i, fea in zip(need_extract, feas):
|
||||
fea_cache[i] = fea
|
||||
|
@ -261,6 +301,9 @@ class XGBoostCostModel(CostModel):
|
|||
ret[i, :] = fea_cache[ii]
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
self._close_pool()
|
||||
|
||||
|
||||
_extract_space = None
|
||||
_extract_target = None
|
||||
|
|
|
@ -20,8 +20,12 @@ class XGBTuner(ModelBasedTuner):
|
|||
If is 'curve', use sampled curve feature (relation feature).
|
||||
|
||||
Note on choosing feature type:
|
||||
For single task tuning, 'itervar' and 'knob' is good.
|
||||
For single task tuning, 'itervar' and 'knob' are good.
|
||||
'itervar' is more accurate but 'knob' is much faster.
|
||||
There are some constraints on 'itervar', if you meet
|
||||
problems with feature extraction when using 'itervar',
|
||||
you can swith to 'knob'.
|
||||
|
||||
For cross-shape tuning (e.g. many convolutions with different shapes),
|
||||
'itervar' and 'curve' has better transferability,
|
||||
'knob' is faster.
|
||||
|
@ -32,8 +36,7 @@ class XGBTuner(ModelBasedTuner):
|
|||
If is 'rank', use pairwise rank loss to train cost model.
|
||||
The cost model predicts relative rank score.
|
||||
num_threads: int, optional
|
||||
The number of threads.
|
||||
optimizer: str or ModelOptimizer, optional
|
||||
The number of threads. optimizer: str or ModelOptimizer, optional
|
||||
If is 'sa', use a default simulated annealing optimizer.
|
||||
Otherwise it should be a ModelOptimizer object.
|
||||
diversity_filter_ratio: int or float, optional
|
||||
|
@ -45,7 +48,7 @@ class XGBTuner(ModelBasedTuner):
|
|||
If is 0, output nothing.
|
||||
Otherwise, output debug information every `verbose` iterations.
|
||||
"""
|
||||
def __init__(self, task, plan_size=32,
|
||||
def __init__(self, task, plan_size=64,
|
||||
feature_type='itervar', loss_type='rank', num_threads=None,
|
||||
optimizer='sa', diversity_filter_ratio=None, log_interval=50):
|
||||
cost_model = XGBoostCostModel(task,
|
||||
|
@ -62,3 +65,9 @@ class XGBTuner(ModelBasedTuner):
|
|||
|
||||
super(XGBTuner, self).__init__(task, cost_model, optimizer,
|
||||
plan_size, diversity_filter_ratio)
|
||||
|
||||
def tune(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
super(XGBTuner, self).tune(*args, **kwargs)
|
||||
|
||||
# manually close pool to avoid multiprocessing issues
|
||||
self.cost_model._close_pool()
|
||||
|
|
|
@ -8,8 +8,8 @@ from ..autotvm.tophub import list_packages, download_package
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--download", type=str, nargs='+',
|
||||
help="Target to download. Use 'all' to download for all targets")
|
||||
parser.add_argument("-d", "--download", type=str, nargs='+',
|
||||
help="The targets to download. Use 'all' to download for all targets")
|
||||
parser.add_argument("-l", "--list", action='store_true', help="List available packages")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -21,8 +21,7 @@ if __name__ == '__main__':
|
|||
print("-" * 41)
|
||||
for target, info in info:
|
||||
print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000)))
|
||||
|
||||
if args.download:
|
||||
elif args.download:
|
||||
info = list_packages()
|
||||
all_targets = [x[0] for x in info]
|
||||
if 'all' in args.download:
|
||||
|
@ -34,3 +33,5 @@ if __name__ == '__main__':
|
|||
if t not in all_targets:
|
||||
print("Warning : cannot find tuned parameters of " + t + ". (ignored)")
|
||||
download_package(t)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
|
|
@ -263,6 +263,7 @@ def override_native_generic_func(func_name):
|
|||
"Keyword arguments cannot be used when invoking generic_func %s" % func_name)
|
||||
return generic_func_node(*args)
|
||||
fresult = decorate(fdefault, dispatch_func)
|
||||
fresult.fdefault = fdefault
|
||||
fresult.register = register
|
||||
return fresult
|
||||
return fdecorate
|
||||
|
|
|
@ -3,34 +3,48 @@ The dispatcher can choose which template to use according
|
|||
to the parameters of workload"""
|
||||
|
||||
from collections import namedtuple
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm.task import dispatcher, DispatchContext
|
||||
|
||||
SimpleWorkload = namedtuple("SimpleWorkload", ["key"])
|
||||
SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
|
||||
SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))
|
||||
|
||||
def test_dispatch():
|
||||
@dispatcher
|
||||
def my_dispatcher(a, b):
|
||||
return SimpleWorkload(key=a + b)
|
||||
|
||||
@my_dispatcher.register("spatial_pack")
|
||||
def _sp_pack_add(cfg, a, b):
|
||||
return b + 100
|
||||
return (a, b)
|
||||
|
||||
@my_dispatcher.register("im2col")
|
||||
def _im2col_add(cfg, a, b):
|
||||
return a + 1
|
||||
def _im2col(cfg, a, b):
|
||||
return a
|
||||
|
||||
@my_dispatcher.register("spatial_pack")
|
||||
def _spatial_pack(cfg, a, b):
|
||||
return b
|
||||
|
||||
class SimpleDispatcher(DispatchContext):
|
||||
def query(self, target, workload):
|
||||
tkey = "spatial_pack" if workload.key > 2 else "im2col"
|
||||
return SimpleConfig(tkey)
|
||||
a, b = workload
|
||||
tkey = "spatial_pack" if a + b > 2 else "im2col"
|
||||
cfg = SimpleConfig(tkey, False)
|
||||
return cfg
|
||||
|
||||
with SimpleDispatcher():
|
||||
# im2col
|
||||
assert my_dispatcher(1, 0) == 2
|
||||
# spack
|
||||
assert my_dispatcher(1, 100) == 200
|
||||
# this will call im2col
|
||||
assert my_dispatcher(1, 0) == 1
|
||||
|
||||
# this will call spatial pack
|
||||
assert my_dispatcher(1, 100) == 100
|
||||
|
||||
def test_fallback():
|
||||
|
||||
@autotvm.template
|
||||
def simple_template(a, b):
|
||||
cfg = autotvm.get_config()
|
||||
assert cfg.is_fallback
|
||||
|
||||
simple_template(2, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dispatch()
|
||||
test_fallback()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Test space definition primitives"""
|
||||
|
||||
import tvm
|
||||
from tvm.autotvm.task.space import ConfigSpace
|
||||
from tvm.autotvm.task.space import ConfigSpace, FallbackConfigEntity
|
||||
|
||||
def gemm_func(cfg, N):
|
||||
A = tvm.placeholder((N, N), name='A')
|
||||
|
@ -26,5 +26,18 @@ def test_split():
|
|||
assert len(cfg) == 64
|
||||
assert len(cfg.space_map['tile_y']) == 8
|
||||
|
||||
# test fallback
|
||||
cfg = FallbackConfigEntity()
|
||||
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
|
||||
cfg.fallback_split('tile_n', [-1, 8, 4])
|
||||
|
||||
assert cfg['tile_n'].size == [4, 8, 4]
|
||||
|
||||
cfg = FallbackConfigEntity()
|
||||
cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
|
||||
cfg.fallback_split('tile_n', [-1, 8, 4])
|
||||
|
||||
assert cfg['tile_n'].size == [7, 7, 1]
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_split()
|
||||
|
|
|
@ -12,7 +12,7 @@ from test_autotvm_common import get_sample_task, get_sample_records
|
|||
|
||||
def test_fit():
|
||||
task, target = get_sample_task()
|
||||
records = get_sample_records(n=100)
|
||||
records = get_sample_records(n=500)
|
||||
|
||||
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||
base_model.fit_log(records, plan_size=32)
|
||||
|
@ -20,8 +20,8 @@ def test_fit():
|
|||
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||
upper_model.load_basemodel(base_model)
|
||||
|
||||
xs = np.arange(100)
|
||||
ys = np.arange(100)
|
||||
xs = np.arange(10)
|
||||
ys = np.arange(10)
|
||||
|
||||
upper_model.fit(xs, ys, plan_size=32)
|
||||
|
||||
|
|
|
@ -27,7 +27,14 @@ def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
|
|||
@autotvm.task.dispatcher
|
||||
def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype):
|
||||
"""TOPI compute callback. Mark this function as a dispatcher, so
|
||||
this template can assign config according to workload"""
|
||||
this template can assign config according to workload
|
||||
|
||||
Returns
|
||||
-------
|
||||
workload: Tuple
|
||||
Dispatcher will use this workload to query corresponding config.
|
||||
Then use cfg.template_key to call a registered template.
|
||||
"""
|
||||
return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
|
||||
|
||||
@conv2d_arm_cpu.register(['direct'])
|
||||
|
@ -70,8 +77,10 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
|
|||
|
||||
def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
|
||||
assert layout == "NCHW", "Only support NCHW"
|
||||
out_dtype = out_dtype or data.dtype
|
||||
# create workload according to raw arguments
|
||||
wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
|
||||
|
||||
out_dtype = out_dtype or data.dtype
|
||||
N, CI, IH, IW = get_const_tuple(data.shape)
|
||||
if len(kernel.shape) == 4:
|
||||
pre_packed = False
|
||||
|
@ -113,6 +122,18 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
|
|||
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
|
||||
# ====================================================================
|
||||
|
||||
if cfg.is_fallback:
|
||||
if num_tile == 2:
|
||||
cfg.fallback_split('tile_co', [-1, 8])
|
||||
cfg.fallback_split('tile_oh', [-1, 2])
|
||||
cfg.fallback_split('tile_ow', [-1, 8])
|
||||
else:
|
||||
cfg.fallback_split('tile_co', [-1, 16, 4])
|
||||
cfg.fallback_split('tile_oh', [-1, 1, 1])
|
||||
cfg.fallback_split('tile_ow', [-1, 1, 4])
|
||||
cfg['ann_reduce'].anns = ['unroll', 'unroll']
|
||||
cfg['ann_spatial'].anns = ['none', 'unroll', 'vec']
|
||||
|
||||
VC = cfg["tile_co"].size[-1]
|
||||
VH = cfg["tile_oh"].size[-1]
|
||||
VW = cfg["tile_ow"].size[-1]
|
||||
|
@ -145,8 +166,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
|
|||
output = tvm.compute(oshape, lambda n, co, h, w:
|
||||
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
|
||||
name='output_unpack', tag='spatial_conv2d_output',
|
||||
attrs={'workload': _conv_arg_to_workload(data, kernel, strides, padding,
|
||||
layout, out_dtype)})
|
||||
attrs={'workload': wkl})
|
||||
return output
|
||||
|
||||
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
|
||||
|
@ -212,6 +232,10 @@ def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
|
|||
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
|
||||
|
||||
def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
|
||||
# create workload according to raw arguments
|
||||
wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout,
|
||||
out_dtype, tile_size)
|
||||
|
||||
N, CI, IH, IW = get_const_tuple(data.shape)
|
||||
if len(kernel.shape) == 4:
|
||||
pre_computed = False
|
||||
|
@ -333,10 +357,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
|
|||
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
|
||||
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
|
||||
name='output', tag='winograd_conv2d_output',
|
||||
attrs={'workload': _winograd_conv_arg_to_workload(
|
||||
data, kernel, strides, padding, layout, out_dtype, tile_size)})
|
||||
attrs={'workload': wkl})
|
||||
|
||||
# we have to manually assign effective GFLOP for winogard
|
||||
# we have to manually assign effective GFLOP for winograd
|
||||
cfg.add_flop(2 * N * K * H * W * KH * KW * C)
|
||||
return output
|
||||
|
||||
|
@ -358,30 +381,29 @@ def _schedule_winograd(cfg, s, output, last):
|
|||
kernel, G = U.op.input_tensors
|
||||
s[G].compute_inline()
|
||||
eps, nu, k, c, kk, = s[U].op.axis
|
||||
r_kh, r_kw = s[U].op.reduce_axis
|
||||
s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
|
||||
s[U].unroll(eps)
|
||||
s[U].unroll(nu)
|
||||
s[U].unroll(r_kh)
|
||||
s[U].unroll(r_kw)
|
||||
s[U].vectorize(kk)
|
||||
if autotvm.GLOBAL_SCOPE.in_tuning:
|
||||
# kernel transformation will be pre-computed during compilation, so we skip
|
||||
# this part to make tuning records correct
|
||||
s[U].pragma(k, 'debug_skip_region')
|
||||
s[U].pragma(eps, 'debug_skip_region')
|
||||
else:
|
||||
r_kh, r_kw = s[U].op.reduce_axis
|
||||
s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
|
||||
for axis in [eps, nu, r_kh, r_kw]:
|
||||
s[U].unroll(axis)
|
||||
s[U].vectorize(kk)
|
||||
s[U].parallel(k)
|
||||
|
||||
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
|
||||
s[kernel].compute_inline()
|
||||
|
||||
# transform image
|
||||
DD = s.cache_read(d, 'global', [V])
|
||||
s[B].compute_inline()
|
||||
eps, nu, b, c, bb = s[V].op.axis
|
||||
r_eps, r_nu = s[V].op.reduce_axis
|
||||
s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb)
|
||||
s[V].unroll(eps)
|
||||
s[V].unroll(nu)
|
||||
s[V].unroll(r_eps)
|
||||
s[V].unroll(r_nu)
|
||||
for axis in [eps, nu, r_eps, r_nu]:
|
||||
s[V].unroll(axis)
|
||||
s[DD].compute_at(s[V], c)
|
||||
s[V].vectorize(bb)
|
||||
s[V].parallel(b)
|
||||
|
@ -405,10 +427,8 @@ def _schedule_winograd(cfg, s, output, last):
|
|||
s[A].compute_inline()
|
||||
k, b, vh, vw = s[Y].op.axis
|
||||
r_eps, r_nu = s[Y].op.reduce_axis
|
||||
s[Y].unroll(vh)
|
||||
s[Y].unroll(vw)
|
||||
s[Y].unroll(r_eps)
|
||||
s[Y].unroll(r_nu)
|
||||
for axis in [vh, vw, r_eps, r_nu]:
|
||||
s[Y].unroll(axis)
|
||||
|
||||
# output
|
||||
n, co, h, w = s[last].op.axis
|
||||
|
@ -444,6 +464,7 @@ def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_d
|
|||
[data, raw_kernel, strides, padding, layout, out_dtype])
|
||||
|
||||
|
||||
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
|
||||
@conv2d_winograd_without_weight_transform.register(['arm_cpu'])
|
||||
@autotvm.task.dispatcher
|
||||
def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
|
||||
|
@ -472,6 +493,7 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
|
|||
return s
|
||||
|
||||
|
||||
##### REGISTER ALTER OP LAYOUT #####
|
||||
@conv2d_alter_layout.register(["arm_cpu", "mali"])
|
||||
def _alter_conv2d_layout(attrs, inputs, tinfos):
|
||||
"""Alter op layout for pre-computing kernel transformation"""
|
||||
|
@ -493,18 +515,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
|
|||
# query config of this workload
|
||||
workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding,
|
||||
layout, out_dtype)
|
||||
cfg = autotvm.task.DispatchContext.current.query(tvm.target.current_target(), workload)
|
||||
cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
|
||||
|
||||
if cfg.is_fallback: # if is fallback, clear query cache and return None
|
||||
context = autotvm.DispatchContext.current
|
||||
while not isinstance(context, autotvm.FallbackContext):
|
||||
context = context._old_ctx
|
||||
context.clear_cache(tvm.target.current_target(), workload)
|
||||
return None
|
||||
|
||||
if cfg.template_key == 'direct': # packing weight tensor
|
||||
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
|
||||
return sym.conv2d(*copy_inputs, **new_attrs)
|
||||
else: # pre-compute weight transformation in winograd
|
||||
tile_size = 4
|
||||
if "-device=arm_cpu" in tvm.target.current_target().options:
|
||||
tile_size = 4
|
||||
VC = cfg['tile_k'].size[-1]
|
||||
else:
|
||||
from ..mali.conv2d import _pick_tile_size
|
||||
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
|
||||
VC = cfg['tile_bna'].val
|
||||
|
||||
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
|
||||
tile_size=tile_size)
|
||||
CO, CI, KH, KW = get_const_tuple(tinfos[1].shape)
|
||||
VC = cfg['tile_k'].size[-1]
|
||||
weight = sym.reshape(weight,
|
||||
shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
|
||||
weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
|
||||
|
|
|
@ -14,16 +14,21 @@ autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
|
|||
|
||||
# register customized schedule for arm cpu.
|
||||
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct')
|
||||
def schedule_depthwise_conv2d_nchw_(cfg, outs):
|
||||
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
|
||||
"""Schedule depthwise conv2d
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg: ConfigEntity
|
||||
The configuration of this tempalte
|
||||
The configuration of this template
|
||||
outs: Array of Tensor
|
||||
The computation graph description of depthwise convolution2d
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
The computation schedule for depthwise_conv2d nchw.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
|
@ -38,6 +43,11 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
|
|||
cfg.define_split('tile_h', h, num_outputs=2)
|
||||
cfg.define_split('tile_w', w, num_outputs=2)
|
||||
|
||||
if cfg.is_fallback:
|
||||
cfg.fallback_split('tile_c', [-1, 8])
|
||||
cfg.fallback_split('tile_h', [-1, 2])
|
||||
cfg.fallback_split('tile_w', [-1, 8])
|
||||
|
||||
# park data to vector form [n, c, h, w] -> [n, C, h, w, VC]
|
||||
A0 = s.cache_read(data_pad, "global", C)
|
||||
_, c, h, w = s[A0].op.axis
|
||||
|
|
|
@ -29,7 +29,7 @@ def schedule_injective(outs):
|
|||
elif len(s[x].op.axis) >= 3:
|
||||
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
|
||||
s[x].parallel(fused)
|
||||
else:
|
||||
elif len(s[x].op.axis) >= 1:
|
||||
s[x].parallel(s[x].op.axis[0])
|
||||
return s
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
"""Common utility for topi test"""
|
||||
|
||||
def get_all_backend():
|
||||
"""return all supported target
|
||||
|
||||
Returns
|
||||
-------
|
||||
targets: list
|
||||
A list of all supported targets
|
||||
"""
|
||||
return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
|
||||
'llvm -device=arm_cpu']
|
|
@ -1,11 +1,8 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
import topi.testing
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.util import get_const_tuple
|
||||
from tvm.contrib import util
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
|
||||
def generate_quantized_np(shape, bits, out_dtype):
|
||||
|
@ -16,23 +13,23 @@ def generate_quantized_np(shape, bits, out_dtype):
|
|||
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
|
||||
activation_bits, weight_bits, dorefa):
|
||||
in_height = in_width = in_size
|
||||
input_type='uint32'
|
||||
out_dtype='int32'
|
||||
input_type = 'uint32'
|
||||
out_dtype = 'int32'
|
||||
|
||||
with tvm.target.create('llvm'):
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
|
||||
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
|
||||
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
|
||||
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
|
||||
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
|
||||
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
dtype = A.dtype
|
||||
|
||||
@memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
|
||||
def get_ref_data():
|
||||
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
|
||||
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
|
||||
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
|
||||
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
|
||||
if dorefa:
|
||||
w_ = np.copy(w_np).astype(out_dtype)
|
||||
for x in np.nditer(w_, op_flags=['readwrite']):
|
||||
|
@ -61,16 +58,16 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
|
|||
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
|
||||
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
|
||||
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
|
||||
layout="NHWC", dorefa=dorefa)
|
||||
layout="NHWC", dorefa=dorefa)
|
||||
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
dtype = A.dtype
|
||||
|
||||
@memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
|
||||
def get_ref_data():
|
||||
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
|
||||
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
|
||||
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
|
||||
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
|
||||
if dorefa:
|
||||
w_ = np.copy(w_np).astype(out_dtype)
|
||||
for x in np.nditer(w_, op_flags=['readwrite']):
|
||||
|
@ -109,4 +106,4 @@ def test_bitserial_conv2d():
|
|||
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bitserial_conv2d()
|
||||
test_bitserial_conv2d()
|
||||
|
|
|
@ -4,10 +4,6 @@ import numpy as np
|
|||
import tvm
|
||||
import topi
|
||||
import topi.testing
|
||||
from topi.util import get_const_tuple
|
||||
from tvm.contrib import util
|
||||
|
||||
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
|
||||
|
||||
def generate_quantized_np(shape, bits, out_dtype):
|
||||
np.random.seed(0)
|
||||
|
@ -17,20 +13,19 @@ def generate_quantized_np(shape, bits, out_dtype):
|
|||
|
||||
# Verify that certain special instructions from the tensorize pass exist
|
||||
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
|
||||
activation_bits, weight_bits, dorefa):
|
||||
activation_bits, weight_bits, dorefa):
|
||||
in_height = in_width = in_size
|
||||
input_type='uint32'
|
||||
out_dtype='int32'
|
||||
input_type = 'uint32'
|
||||
out_dtype = 'int32'
|
||||
|
||||
with tvm.target.arm_cpu('rasp3b'):
|
||||
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
|
||||
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
|
||||
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
|
||||
layout="NHWC", dorefa=dorefa)
|
||||
layout="NHWC", dorefa=dorefa)
|
||||
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
|
||||
|
||||
|
||||
func = tvm.build(s, [A, W, B], target)
|
||||
func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
|
||||
|
||||
assembly = func.get_source('asm')
|
||||
matches = re.findall("vpadal", assembly)
|
||||
|
@ -47,7 +42,6 @@ def test_bitserial_conv2d():
|
|||
stride = 1
|
||||
pad = 1
|
||||
|
||||
|
||||
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
|
||||
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ def verify_binary_dense(batch, in_dim, out_dim):
|
|||
a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype)
|
||||
b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype)
|
||||
c_np = np.dot(a_np, b_np.T)
|
||||
return (a_np, b_np, c_np)
|
||||
return a_np, b_np, c_np
|
||||
|
||||
a_np, b_np, c_np = get_ref_data()
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Test code for broadcasting operators."""
|
||||
import os
|
||||
from common import get_all_backend
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
|
@ -8,6 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
|
|||
# Build the logic and compile the function
|
||||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = fbcast(A, out_shape)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
|
@ -21,16 +22,11 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
|
|||
out_npy = np.broadcast_to(data_npy, out_shape)
|
||||
data_nd = tvm.nd.array(data_npy, ctx)
|
||||
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
|
||||
for _ in range(1):
|
||||
foo(data_nd, out_nd)
|
||||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
check_device("vulkan")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("nvptx")
|
||||
for target in get_all_backend():
|
||||
check_device(target)
|
||||
check_device("sdaccel")
|
||||
|
||||
|
||||
|
@ -45,9 +41,10 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
|
|||
B = (tvm.var("B", dtype=dtype) if rhs_shape is None
|
||||
else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
|
||||
C = ftopi(A, B)
|
||||
if (isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr)):
|
||||
if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
|
||||
assert(isinstance(C, tvm.expr.Expr))
|
||||
return
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
|
@ -82,12 +79,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
|
|||
foo(lhs_nd, rhs_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
|
||||
|
||||
check_device("opencl")
|
||||
check_device("vulkan")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("nvptx")
|
||||
for target in get_all_backend():
|
||||
check_device(target)
|
||||
check_device("sdaccel")
|
||||
|
||||
def test_broadcast_to():
|
||||
|
|
|
@ -5,6 +5,7 @@ import topi
|
|||
from topi.util import get_const_tuple
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
|
||||
from common import get_all_backend
|
||||
|
||||
def verify_clip(N, a_min, a_max, dtype):
|
||||
A = tvm.placeholder((N, N), dtype=dtype, name='A')
|
||||
|
@ -34,7 +35,7 @@ def verify_clip(N, a_min, a_max, dtype):
|
|||
f(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['llvm', 'opencl', 'sdaccel']:
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
def test_clip():
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
"""Example code to do conv2d."""
|
||||
import os
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
import topi
|
||||
import topi.testing
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.util import get_const_tuple
|
||||
|
||||
|
||||
def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding):
|
||||
in_height = in_width = in_size
|
||||
|
||||
with tvm.target.arm_cpu():
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
|
||||
B = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), 'NCHW', 'float32')
|
||||
s = topi.generic.schedule_conv2d_nchw([B])
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
dtype = A.dtype
|
||||
|
||||
@memoize("topi.tests.test_topi_conv2d.verify_conv2d")
|
||||
def get_ref_data():
|
||||
a_np = np.random.uniform(size=a_shape).astype(dtype)
|
||||
w_np = np.random.uniform(size=w_shape).astype(dtype)
|
||||
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
|
||||
return a_np, w_np, b_np
|
||||
|
||||
a_np, w_np, b_np = get_ref_data()
|
||||
|
||||
ctx = tvm.cpu(0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
func = tvm.build(s, [A, W, B], "llvm")
|
||||
func(a, w, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
def test_conv2d():
|
||||
with autotvm.tophub.context(tvm.target.arm_cpu('rasp3b'), allow_fallback=True):
|
||||
verify_conv2d(1, 56, 64, 64, 3, 1, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d()
|
|
@ -43,14 +43,12 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
|
||||
with tvm.build_config(auto_unroll_max_step=128,
|
||||
unroll_explicit=(device != "cuda")):
|
||||
func1 = tvm.build(s1, [A, W, B], device)
|
||||
func2 = tvm.build(s2, [A, W, C], device)
|
||||
func1(a, w, b)
|
||||
func2(a, w, c)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
func1 = tvm.build(s1, [A, W, B], device)
|
||||
func2 = tvm.build(s2, [A, W, C], device)
|
||||
func1(a, w, b)
|
||||
func2(a, w, c)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
|
||||
check_device(device)
|
||||
|
|
|
@ -1,31 +1,41 @@
|
|||
"""Example code to do convolution."""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
import topi
|
||||
import topi.testing
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.util import get_const_tuple
|
||||
|
||||
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
|
||||
from common import get_all_backend
|
||||
|
||||
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
|
||||
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
|
||||
|
||||
in_height = in_width = in_size
|
||||
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
|
||||
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
bias_shape = get_const_tuple(bias.shape)
|
||||
dtype = A.dtype
|
||||
|
||||
@memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
|
||||
def get_ref_data():
|
||||
a_np = np.random.uniform(size=a_shape).astype(dtype)
|
||||
w_np = np.random.uniform(size=w_shape).astype(dtype)
|
||||
b_np = np.random.uniform(size=bias_shape).astype(dtype)
|
||||
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
|
||||
b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
|
||||
c_np = np.maximum(b_np, 0)
|
||||
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
|
||||
if add_bias:
|
||||
b_np = np.random.uniform(size=bias_shape).astype(dtype)
|
||||
c_np += b_np
|
||||
if add_relu:
|
||||
c_np = np.maximum(c_np, 0)
|
||||
return a_np, w_np, b_np, c_np
|
||||
|
||||
a_np, w_np, b_np, c_np = get_ref_data()
|
||||
|
@ -38,66 +48,103 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
|
||||
B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW')
|
||||
C = topi.nn.relu(B)
|
||||
s1 = topi.generic.schedule_conv2d_nchw([B])
|
||||
s2 = topi.generic.schedule_conv2d_nchw([C])
|
||||
C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype)
|
||||
if add_bias:
|
||||
C = topi.add(C, bias)
|
||||
if add_relu:
|
||||
C = topi.nn.relu(C)
|
||||
s = topi.generic.schedule_conv2d_nchw([C])
|
||||
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
|
||||
no_unroll_explicit = device in ["cuda", "nvptx", "rocm"]
|
||||
with tvm.build_config(auto_unroll_max_step=1400,
|
||||
unroll_explicit=not no_unroll_explicit):
|
||||
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
|
||||
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
|
||||
func1(a, w, b)
|
||||
func2(a, w, c)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
if add_bias:
|
||||
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
|
||||
func(a, w, b, c)
|
||||
else:
|
||||
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
|
||||
func(a, w, c)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
|
||||
def test_conv2d_nchw():
|
||||
autotvm.DispatchContext.current.silent = True
|
||||
|
||||
# ResNet18 workloads
|
||||
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
|
||||
# ResNet50 workloads
|
||||
verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0)
|
||||
# Vgg16 workloads
|
||||
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
|
||||
# Super resolution workloads
|
||||
verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
|
||||
verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
|
||||
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
|
||||
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
|
||||
|
||||
# bias, relu
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
|
||||
|
||||
# dilation = 2
|
||||
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2)
|
||||
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2)
|
||||
|
||||
# weird workloads
|
||||
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1)
|
||||
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2)
|
||||
|
||||
# inception v3 workloads
|
||||
verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0)
|
||||
verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0)
|
||||
verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0)
|
||||
verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2)
|
||||
verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0)
|
||||
# verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0)
|
||||
# verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3)
|
||||
# verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3)
|
||||
# verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3)
|
||||
# verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3)
|
||||
# verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0)
|
||||
# verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3)
|
||||
# verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0)
|
||||
# verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0)
|
||||
verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1)
|
||||
verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
|
||||
verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_nchw()
|
||||
|
|
|
@ -6,14 +6,13 @@ import topi.testing
|
|||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.util import get_const_tuple
|
||||
|
||||
from common import get_all_backend
|
||||
|
||||
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
|
||||
in_height = in_width = in_size
|
||||
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
|
||||
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding, A.dtype)
|
||||
C = topi.nn.relu(B)
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
|
@ -36,22 +35,23 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
|
|||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
|
||||
C = topi.nn.relu(B)
|
||||
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
|
||||
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
|
||||
with tvm.build_config(auto_unroll_max_step=128,
|
||||
unroll_explicit=(device != "cuda")):
|
||||
func1 = tvm.build(s1, [A, W, B], device)
|
||||
func2 = tvm.build(s2, [A, W, C], device)
|
||||
func1(a, w, b)
|
||||
func2(a, w, c)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
|
||||
func1 = tvm.build(s1, [A, W, B], device)
|
||||
func2 = tvm.build(s2, [A, W, C], device)
|
||||
func1(a, w, b)
|
||||
func2(a, w, c)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -6,13 +6,12 @@ import topi.testing
|
|||
from topi.util import get_const_tuple
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
|
||||
from common import get_all_backend
|
||||
|
||||
def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
||||
A = tvm.placeholder((batch, in_dim), name='A')
|
||||
B = tvm.placeholder((out_dim, in_dim), name='B')
|
||||
C = tvm.placeholder((out_dim,), name='C')
|
||||
D = topi.nn.dense(A, B, C if use_bias else None)
|
||||
D = topi.nn.relu(D)
|
||||
dtype = A.dtype
|
||||
|
||||
# use memoize to pickle the test data for next time use
|
||||
|
@ -36,6 +35,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
D = topi.nn.dense(A, B, C if use_bias else None)
|
||||
D = topi.nn.relu(D)
|
||||
s = topi.generic.schedule_dense(D)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
|
@ -45,13 +46,15 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
f(a, b, c, d)
|
||||
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
def test_dense():
|
||||
verify_dense(1, 1024, 1000, use_bias=True)
|
||||
verify_dense(1, 1024, 1000, use_bias=False)
|
||||
|
||||
verify_dense(2, 1024, 1000, use_bias=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dense()
|
||||
|
|
|
@ -2,11 +2,10 @@ import tvm
|
|||
import topi
|
||||
import topi.testing
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from topi.util import get_const_tuple
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
|
||||
|
||||
from common import get_all_backend
|
||||
|
||||
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
|
||||
in_width = in_height
|
||||
|
@ -18,10 +17,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
|
||||
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
|
||||
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
|
@ -30,6 +25,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
|
||||
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
|
||||
|
@ -88,12 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
|
||||
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
|
||||
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
check_device("nvptx")
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
|
||||
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
|
||||
|
@ -107,11 +102,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
|
||||
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
|
||||
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
|
@ -121,6 +111,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
print("Running on target: %s" % device)
|
||||
|
||||
with tvm.target.create(device):
|
||||
# declare
|
||||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
|
||||
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
|
||||
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
|
||||
|
@ -180,12 +175,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
|
||||
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
|
||||
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
check_device("nvptx")
|
||||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
|
||||
def test_depthwise_conv2d():
|
||||
print("testing nchw")
|
||||
|
|
|
@ -312,7 +312,9 @@ def tune_and_evaluate():
|
|||
|
||||
# upload module to device
|
||||
print("Upload...")
|
||||
remote = autotvm.measure.request_remote(device_key, timeout=10000)
|
||||
remote = autotvm.measure.request_remote(device_key,
|
||||
tracker_addr=('localhost', 9190),
|
||||
timeout=10000)
|
||||
remote.upload(tmp.relpath(filename))
|
||||
rlib = remote.load_module(filename)
|
||||
|
||||
|
@ -333,7 +335,6 @@ def tune_and_evaluate():
|
|||
|
||||
# We do not run the tuning in our webpage server since it takes too long.
|
||||
# Uncomment the following line to run by yourself.
|
||||
|
||||
# tune_and_evaluate()
|
||||
|
||||
######################################################################
|
||||
|
|
Загрузка…
Ссылка в новой задаче