[AUTOTVM] Decouple build and run in measurement (#1661)
This commit is contained in:
Родитель
38203a860b
Коммит
12839e6d2b
|
@ -16,6 +16,11 @@ tvm.autotvm.measure
|
|||
|
||||
.. autofunction:: tvm.autotvm.measure.create_measure_batch
|
||||
|
||||
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
|
||||
|
||||
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
|
||||
|
||||
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
|
||||
|
||||
tvm.autotvm.tuner
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -22,7 +22,8 @@ from . import env
|
|||
from . import tophub
|
||||
|
||||
# some shortcuts
|
||||
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
|
||||
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
|
||||
LocalBuilder, LocalRunner, RPCRunner
|
||||
from .tuner import callback
|
||||
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
|
||||
register_topi_compute, register_topi_schedule, \
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Distributed executor infrastructure to scale up the tuning"""
|
||||
|
||||
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
|
||||
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
|
||||
|
||||
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
|
||||
create_measure_batch
|
||||
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
|
||||
from .executor import Executor
|
||||
from .local_executor import LocalExecutor
|
||||
from .executor import Future, Executor
|
||||
|
|
|
@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
|
|||
res = exc
|
||||
queue.put(res)
|
||||
|
||||
def timeout_monitor(queue, timeout, func, args, kwargs):
|
||||
|
||||
def call_with_timeout(queue, timeout, func, args, kwargs):
|
||||
"""A wrapper to support timeout of a function call"""
|
||||
|
||||
# start a new process for timeout (cannot use thread because we have c function)
|
||||
|
@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
|
|||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
alive = p.is_alive()
|
||||
queue.put(executor.TimeoutError())
|
||||
|
||||
kill_child_processes(p.pid)
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
if alive:
|
||||
queue.put(executor.TimeoutError())
|
||||
else:
|
||||
if queue.empty():
|
||||
queue.put(executor.ExecutionError("Fatal error in local executor"))
|
||||
|
||||
|
||||
class LocalFuture(executor.Future):
|
||||
"""Local wrapper for the future
|
||||
|
@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
|
|||
return LocalFutureNoFork(func(*args, **kwargs))
|
||||
|
||||
queue = Queue(2)
|
||||
process = Process(target=timeout_monitor,
|
||||
process = Process(target=call_with_timeout,
|
||||
args=(queue, self.timeout, func, args, kwargs))
|
||||
process.start()
|
||||
return LocalFuture(process, queue)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
|
||||
"""User facing API for specifying how to measure the generated code"""
|
||||
import multiprocessing
|
||||
from collections import namedtuple
|
||||
|
||||
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
|
||||
|
@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
|
|||
Specific configuration.
|
||||
"""
|
||||
|
||||
|
||||
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
|
||||
"""
|
||||
Stores all the results of a measurement
|
||||
|
@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
|
|||
Parameters
|
||||
----------
|
||||
costs: Array of float or Array of Exception
|
||||
If no error occurs for this measurement, it is an array of measured running times.
|
||||
If some error occurs during the measurement, it is an array of the exception objections.
|
||||
If no error occurs during measurement, it is an array of measured running times.
|
||||
If an error occurs during measurement, it is an array of the exception objections.
|
||||
error_no: int
|
||||
Denote error type, defined by MeasureErrorNo
|
||||
all_cost: float
|
||||
|
@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
|
|||
class MeasureErrorNo(object):
|
||||
"""Error type for MeasureResult"""
|
||||
NO_ERROR = 0 # no error
|
||||
INSTANTIATION_ERROR = 1 # error when calling template function
|
||||
INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
|
||||
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
|
||||
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
|
||||
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
|
||||
RUNTIME_DEVICE = 4 # error when run program on device
|
||||
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
|
||||
FLEET_ERROR = 6 # error of measure infrastructure
|
||||
BUILD_TIMEOUT = 6 # timeout during compilation
|
||||
RUN_TIMEOUT = 7 # timeout during run
|
||||
UNKNOWN_ERROR = 8 # unknown error
|
||||
|
||||
|
||||
def measure_option(measure_func,
|
||||
number=1,
|
||||
repeat=1,
|
||||
timeout=60,
|
||||
n_parallel=1,
|
||||
do_fork=True,
|
||||
build_func='default',
|
||||
check_correctness=False,
|
||||
replay_db=None):
|
||||
"""Configure how to do measurement
|
||||
class Builder(object):
|
||||
"""Builder that builds programs in tuning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measure_func: str or callable
|
||||
'local': use the local device for measurement. The tuner will start a tracker
|
||||
and a RPC server silently for the user.
|
||||
|
||||
callable: It is a callable function for measurement.
|
||||
See the return value of measure/measure_methods.py::rpc for example.
|
||||
number : int, optional
|
||||
Number of times to do the measurement for average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
each of which is the average of `number` test run.
|
||||
timeout: int, optional
|
||||
Timeout for a whole batch. TimeoutError will be returned as the result if a
|
||||
task timeouts.
|
||||
timeout: float, optional
|
||||
The timeout of a build task
|
||||
n_parallel: int, optional
|
||||
The number of measurement task that can run in parallel.
|
||||
Set this according to the number of cpu cores (for compilation) and
|
||||
the number of devices you have (for measuring generate code).
|
||||
do_fork: bool, optional
|
||||
Whether use multiprocessing (based on fork) for running measure jobs in parallel.
|
||||
Set this to False if you want to debug (see trackback) or using fork is not suitable.
|
||||
NOTE: If this is False, parallel and timeout do not work.
|
||||
build_func: str or callable, optional
|
||||
'default': call default builder. This works for normal target (llvm, cuda)
|
||||
The number of tasks submitted in parallel
|
||||
By default it will use all cpu cores
|
||||
"""
|
||||
def __init__(self, timeout=10, n_parallel=None):
|
||||
self.timeout = timeout
|
||||
self.n_parallel = n_parallel or multiprocessing.cpu_count()
|
||||
self.build_kwargs = {}
|
||||
self.task = None
|
||||
|
||||
'ndk': use Android NDK to create shared library. Use this for android target.
|
||||
def set_task(self, task, build_kwargs=None):
|
||||
"""
|
||||
Initialize for a new tuning task
|
||||
|
||||
callable: customized build function for other backends (e.g. VTA).
|
||||
See measure/measure_methods.py::default_build_func for example.
|
||||
check_correctness: bool, optional
|
||||
Whether check correctness after measurement. This will use llvm cpu target to generate
|
||||
reference output.
|
||||
replay_db : Database, optional
|
||||
The database that we retrieve saved MeasureResult from.
|
||||
Parameters
|
||||
----------
|
||||
task: Task
|
||||
The tuning task
|
||||
build_kwargs: dict, optional
|
||||
The additional kwargs for build function
|
||||
"""
|
||||
self.task = task
|
||||
self.build_kwargs = build_kwargs
|
||||
|
||||
def build(self, measure_inputs):
|
||||
"""Build programs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measure_inputs: List of MeasureInput
|
||||
The measure input
|
||||
|
||||
Returns
|
||||
-------
|
||||
build_results: List of BuildResult
|
||||
The build result.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Runner(object):
|
||||
"""Runner that runs and measures the time cost of a generated program in tuning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: float, optional
|
||||
The timeout of a build task
|
||||
n_parallel: int, optional
|
||||
The number of tasks submitted in parallel
|
||||
By default it will use all cpu cores
|
||||
"""
|
||||
def __init__(self, timeout=5, n_parallel=None):
|
||||
self.timeout = timeout
|
||||
self.n_parallel = n_parallel or multiprocessing.cpu_count()
|
||||
self.task = None
|
||||
|
||||
def set_task(self, task):
|
||||
"""
|
||||
Initialize for a new tuning task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: Task
|
||||
The tuning task
|
||||
"""
|
||||
self.task = task
|
||||
|
||||
def get_build_kwargs(self):
|
||||
"""
|
||||
Get device specific build arguments (e.g. maximum shared memory size)
|
||||
|
||||
Returns
|
||||
----------
|
||||
kwargs: dict
|
||||
The additional keyword arguments
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def run(self, measure_inputs, build_results):
|
||||
"""Run amd measure built programs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measure_inputs: List of MeasureInput
|
||||
The raw measure input
|
||||
build_results: List of BuildResults
|
||||
The build results
|
||||
|
||||
Returns
|
||||
-------
|
||||
measure_results: List of MeasureResult
|
||||
The final results of measurement
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def measure_option(builder, runner):
|
||||
"""
|
||||
Set options for measure. To measure a config, we will build it and run it.
|
||||
So we have to set options for these two steps.
|
||||
They have their own options on timeout, parallel, etc.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
builder: Builder
|
||||
Specify how to build programs
|
||||
runner: Runner
|
||||
Specify how to run programs
|
||||
"""
|
||||
from .measure_methods import LocalBuilder, LocalRunner
|
||||
|
||||
if isinstance(builder, str):
|
||||
if builder == 'local':
|
||||
builder = LocalBuilder()
|
||||
else:
|
||||
raise ValueError("Invalid builder: " + builder)
|
||||
|
||||
if isinstance(runner, str):
|
||||
if runner == 'local':
|
||||
runner = LocalRunner()
|
||||
else:
|
||||
raise ValueError("Invalid runner: " + runner)
|
||||
|
||||
opt = {
|
||||
'builder': builder,
|
||||
'runner': runner,
|
||||
}
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
def create_measure_batch(task, option):
|
||||
"""Get a standard measure_batch function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: tvm.autotvm.task.Task
|
||||
The tuning task
|
||||
option: dict
|
||||
The option for measuring generated code.
|
||||
You should use the return value of function :any:`measure_option` for this argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
options: dict
|
||||
A dict to store all options
|
||||
|
||||
Note
|
||||
----
|
||||
To support customized measure, you can pass callable `measure_func` or
|
||||
`build_func` in. The `measure_func` will call `build_func` to build binary library
|
||||
and handle the logic of measurement.
|
||||
|
||||
Signature:
|
||||
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
|
||||
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
|
||||
return measure_results
|
||||
|
||||
* build_func (see measure/measure_methods.py::default_build_func for example)
|
||||
def build_func(inp, tmp_dir, **kwargs):
|
||||
return func, args, filename
|
||||
measure_batch: callable
|
||||
a callback function to measure a batch of configs
|
||||
"""
|
||||
return {
|
||||
'measure_func': measure_func,
|
||||
'number': number,
|
||||
'repeat': repeat,
|
||||
'timeout': timeout,
|
||||
'n_parallel': n_parallel,
|
||||
'do_fork': do_fork,
|
||||
'build_func': build_func,
|
||||
'check_correctness': check_correctness,
|
||||
'replay_db': replay_db,
|
||||
}
|
||||
builder = option['builder']
|
||||
runner = option['runner']
|
||||
|
||||
attach_objects = runner.set_task(task)
|
||||
|
||||
# feed device related information from runner to builder
|
||||
# (e.g. max shared memory for validity checking)
|
||||
build_kwargs = runner.get_build_kwargs()
|
||||
builder.set_task(task, build_kwargs)
|
||||
|
||||
def measure_batch(measure_inputs):
|
||||
build_results = builder.build(measure_inputs)
|
||||
results = runner.run(measure_inputs, build_results)
|
||||
return results
|
||||
|
||||
measure_batch.n_parallel = builder.n_parallel
|
||||
measure_batch.attach_objects = attach_objects
|
||||
return measure_batch
|
||||
|
|
|
@ -1,129 +1,339 @@
|
|||
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
|
||||
# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks
|
||||
"""
|
||||
Functions that run on executor for measurement.
|
||||
These functions are responsible for building tvm module, uploading it to
|
||||
remote devices, recording the running time costs and checking the correctness of output
|
||||
|
||||
These functions are responsible for building the tvm module, uploading it to
|
||||
remote devices, recording the running time costs, and checking the correctness of the output.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from random import getrandbits
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \
|
||||
target as _target, rpc as _rpc
|
||||
from ...contrib import nvcc, util, ndk
|
||||
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
|
||||
rpc as _rpc, target as _target
|
||||
from ...contrib import nvcc, ndk
|
||||
|
||||
from ..util import get_const_tuple
|
||||
from ..env import AutotvmGlobalScope
|
||||
from ..task.space import InstantiationError
|
||||
|
||||
from .measure import MeasureResult, MeasureErrorNo
|
||||
from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
|
||||
from .local_executor import LocalExecutor
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class HashMismatchError(ValueError):
|
||||
"""Raised when the code hash of a submitted config doesn't match that on the
|
||||
measure side """
|
||||
pass
|
||||
|
||||
|
||||
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
|
||||
"""request a remote session
|
||||
class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
|
||||
"""
|
||||
Stores all the necessary inputs for a measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device_key: string
|
||||
device key of registered device in tracker
|
||||
tracker_addr: Tuple(string, int), optional
|
||||
The address of rpc tracker in (host, port) format.
|
||||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
and "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this session (units: seconds)
|
||||
|
||||
Returns
|
||||
------
|
||||
session: RPCSession
|
||||
filename : str
|
||||
The filename of generated library
|
||||
arg_info : Tuple
|
||||
The shape and dtype information of tvm tensor arguments
|
||||
error : Exception
|
||||
The error happens during compilation.
|
||||
time_cost : float
|
||||
The time cost of building
|
||||
"""
|
||||
# connect to the tracker
|
||||
if tracker_addr:
|
||||
host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST']
|
||||
port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT'])
|
||||
else:
|
||||
host = os.environ['TVM_TRACKER_HOST']
|
||||
port = int(os.environ['TVM_TRACKER_PORT'])
|
||||
|
||||
tracker = _rpc.connect_tracker(host, port)
|
||||
remote = tracker.request(device_key, priority=priority,
|
||||
session_timeout=timeout)
|
||||
return remote
|
||||
|
||||
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
|
||||
"""
|
||||
Check the availability of a remote device
|
||||
class LocalBuilder(Builder):
|
||||
"""Run compilation on local machine
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The wanted compilation target
|
||||
device_key: string
|
||||
device key of registered device in tracker
|
||||
tracker_addr: Tuple(string, int), optional
|
||||
The address of rpc tracker in (host, port) format.
|
||||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
and "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this check (units: seconds).
|
||||
If time is out, a RuntimeError will be raised.
|
||||
timeout: float
|
||||
The timeout of a compilation
|
||||
n_parallel: int
|
||||
The number of tasks run in parallel. "None" will use all cpu cores
|
||||
build_func: callable or str
|
||||
If is 'default', use default build function
|
||||
If is 'ndk', use function for android ndk
|
||||
If is callable, use it as custom build function
|
||||
"""
|
||||
def _check():
|
||||
remote = request_remote(device_key, tracker_addr, priority)
|
||||
remote.context(str(target))
|
||||
t = threading.Thread(target=_check,)
|
||||
t.start()
|
||||
t.join(timeout)
|
||||
return not t.is_alive()
|
||||
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
|
||||
super(LocalBuilder, self).__init__(timeout, n_parallel)
|
||||
|
||||
def create_measure_batch(task, option):
|
||||
"""Get a standard measure_batch function.
|
||||
if isinstance(build_func, str):
|
||||
if build_func == 'default':
|
||||
build_func = default_build_func
|
||||
elif build_func == 'ndk':
|
||||
build_func = android_ndk_build_func
|
||||
else:
|
||||
raise ValueError("Invalid build_func" + build_func)
|
||||
|
||||
self.build_func = build_func
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.executor = LocalExecutor(timeout=timeout)
|
||||
|
||||
def build(self, measure_inputs):
|
||||
results = []
|
||||
|
||||
for i in range(0, len(measure_inputs), self.n_parallel):
|
||||
futures = []
|
||||
for inp in measure_inputs[i:i + self.n_parallel]:
|
||||
ret = self.executor.submit(self.build_func,
|
||||
inp,
|
||||
self.tmp_dir,
|
||||
**self.build_kwargs)
|
||||
futures.append(ret)
|
||||
|
||||
for future in futures:
|
||||
res = future.get()
|
||||
|
||||
if isinstance(res, Exception):
|
||||
# timeout or fleet error, return MeasureResult directly
|
||||
results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
|
||||
self.timeout, time.time()))
|
||||
elif res.error is not None:
|
||||
# instantiation errorD
|
||||
if isinstance(res.error, InstantiationError):
|
||||
results.append(MeasureResult((res.error,),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
res.time_cost, time.time()))
|
||||
else:
|
||||
if "InstantiationError" in str(res.error):
|
||||
msg = str(res.error)
|
||||
try:
|
||||
msg = msg.split('\n')[-2].split(": ")[1]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
results.append(MeasureResult((InstantiationError(msg),),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
res.time_cost, time.time()))
|
||||
else: # tvm error
|
||||
results.append(MeasureResult((res.error,),
|
||||
MeasureErrorNo.COMPILE_HOST,
|
||||
res.time_cost, time.time()))
|
||||
else:
|
||||
# return BuildResult
|
||||
results.append(res)
|
||||
|
||||
return results
|
||||
|
||||
def __del__(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
|
||||
class RPCRunner(Runner):
|
||||
"""Run generated code on remove devices.
|
||||
This function will ask a RPC Tracker to get device for measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: tvm.autotvm.task.Task
|
||||
The tuning task
|
||||
option: dict
|
||||
The option for measuring generated code.
|
||||
You should use the return value of function :any:`measure_option` for this argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
measure_batch: callable
|
||||
a callback function to measure a batch of configs
|
||||
timeout: float
|
||||
The timeout of a compilation
|
||||
n_parallel: int
|
||||
The number of tasks run in parallel. "None" will use all cpu cores
|
||||
key: str
|
||||
The key of the device registered in the tracker
|
||||
host: str
|
||||
The host address of RPC Tracker
|
||||
port: int
|
||||
The port of RPC Tracker
|
||||
number : int, optional
|
||||
Number of times to do measurement for tasking average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
min_repeat_ms : float, optional
|
||||
Minimum duration of a timer measurement in milliseconds.
|
||||
When the run time of a measurement trial falls below this time, the
|
||||
`number` parameter will be automatically increased.
|
||||
Set this to improve the accuracy of perf measurement, e.g., when timers
|
||||
are not precise enough to capture short-running tasks. This parameter is
|
||||
also critical when devices need a certain minimum running time to "warm
|
||||
up," such as GPUs that need time to reach a performance power state.
|
||||
cooldown_interval: float, optional
|
||||
The cool down interval between two measurements.
|
||||
check_correctness: bool, optional
|
||||
Whether check correctness after measurement. This will use llvm cpu target to
|
||||
call your template and get the reference output.
|
||||
This can work for TOPI templates, but may not work for your custom template.
|
||||
"""
|
||||
from ..database import filter_inputs
|
||||
def __init__(self,
|
||||
key, host, port, priority=1,
|
||||
timeout=10, n_parallel=None,
|
||||
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
|
||||
check_correctness=False):
|
||||
super(RPCRunner, self).__init__(timeout, n_parallel)
|
||||
|
||||
measure_func = option['measure_func']
|
||||
number, repeat = option['number'], option['repeat']
|
||||
timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork']
|
||||
build_func = option['build_func']
|
||||
check_correctness = option['check_correctness']
|
||||
replay_db = option['replay_db']
|
||||
self.key = key
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.priority = priority
|
||||
self.timeout = timeout
|
||||
|
||||
executor = LocalExecutor(timeout=timeout, do_fork=do_fork)
|
||||
self.number = number
|
||||
self.repeat = repeat
|
||||
self.min_repeat_ms = min_repeat_ms
|
||||
self.cur_number = number
|
||||
|
||||
self.ref_input = None
|
||||
self.ref_output = None
|
||||
self.check_correctness = check_correctness
|
||||
self.cooldown_interval = cooldown_interval
|
||||
|
||||
self.executor = LocalExecutor()
|
||||
|
||||
def set_task(self, task):
|
||||
self.task = task
|
||||
self.cur_number = self.number
|
||||
|
||||
if check_remote(task.target, self.key, self.host, self.port):
|
||||
logger.info("Get devices for measurement successfully!")
|
||||
else:
|
||||
raise RuntimeError("Cannot get remote devices from the tracker. "
|
||||
"Please check the status of tracker by "
|
||||
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
|
||||
"and make sure you have free devices on the queue status.")
|
||||
|
||||
if self.check_correctness:
|
||||
# use llvm cpu to generate a reference input/output
|
||||
# this option works for tuning topi, but might not work for you custom op
|
||||
with _target.create("llvm"):
|
||||
s, arg_bufs = task.instantiate(task.config_space.get(0))
|
||||
self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
|
||||
for x in arg_bufs]
|
||||
func = build(s, arg_bufs, "llvm")
|
||||
tvm_buf = [nd.array(x) for x in self.ref_input]
|
||||
func(*tvm_buf)
|
||||
self.ref_output = [x.asnumpy() for x in tvm_buf]
|
||||
|
||||
def get_build_kwargs(self):
|
||||
kwargs = {}
|
||||
if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys:
|
||||
remote = request_remote(self.key, self.host, self.port)
|
||||
ctx = remote.context(str(self.task.target), 0)
|
||||
max_dims = ctx.max_thread_dimensions
|
||||
kwargs['check_gpu'] = {
|
||||
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
|
||||
'max_threads_per_block': ctx.max_threads_per_block,
|
||||
'max_thread_x': max_dims[0],
|
||||
'max_thread_y': max_dims[1],
|
||||
'max_thread_z': max_dims[2],
|
||||
}
|
||||
|
||||
if 'cuda' in self.task.target.keys:
|
||||
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
|
||||
|
||||
return kwargs
|
||||
|
||||
def run(self, measure_inputs, build_results):
|
||||
results = []
|
||||
remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
|
||||
|
||||
for i in range(0, len(measure_inputs), self.n_parallel):
|
||||
futures = []
|
||||
for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
|
||||
build_results[i:i+self.n_parallel]):
|
||||
ret = self.executor.submit(run_through_rpc,
|
||||
measure_inp,
|
||||
build_res,
|
||||
self.cur_number,
|
||||
self.repeat,
|
||||
self.cooldown_interval,
|
||||
remote_args,
|
||||
self.ref_input,
|
||||
self.ref_output)
|
||||
futures.append(ret)
|
||||
|
||||
for future in futures:
|
||||
res = future.get()
|
||||
if isinstance(res, Exception): # executor error or timeout
|
||||
results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
|
||||
self.timeout, time.time()))
|
||||
else:
|
||||
results.append(res)
|
||||
|
||||
# If some runs were too fast, do remeasure for them
|
||||
# to meet the requirement of `min_repeat_ms`
|
||||
remeasure = np.zeros((len(measure_inputs),), dtype=np.bool)
|
||||
pre_number = next_number = self.cur_number
|
||||
min_repeat_duration = self.min_repeat_ms / 1000.0
|
||||
for i, res in enumerate(results):
|
||||
if res.error_no == MeasureErrorNo.NO_ERROR:
|
||||
if np.mean(res.costs) * pre_number <= min_repeat_duration:
|
||||
next_number = max(next_number,
|
||||
int(np.ceil(min_repeat_duration / np.mean(res.costs))))
|
||||
remeasure[i] = True
|
||||
|
||||
if pre_number != next_number:
|
||||
self.cur_number = next_number
|
||||
msg = "increasing number to %d" % self.cur_number
|
||||
logger.info(msg)
|
||||
|
||||
re_measure_inputs = [x for i, x in enumerate(measure_inputs) if remeasure[i]]
|
||||
re_build_results = [x for i, x in enumerate(build_results) if remeasure[i]]
|
||||
re_res = self.run(re_measure_inputs, re_build_results)
|
||||
ct = 0
|
||||
for i, rerun in enumerate(remeasure):
|
||||
if rerun:
|
||||
results[i] = re_res[ct]
|
||||
ct += 1
|
||||
|
||||
return results
|
||||
|
||||
class LocalRunner(RPCRunner):
|
||||
"""Run generated code on local devices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: float
|
||||
The timeout of a compilation
|
||||
number : int, optional
|
||||
Number of times to do measurement for tasking average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
each of which is the average of `number` test run.
|
||||
min_repeat_ms : float, optional
|
||||
Minimum duration of a timer measurement in milliseconds.
|
||||
When the run time of a measurement trial falls below this time, the
|
||||
`number` parameter will be automatically increased.
|
||||
Set this to improve the accuracy of perf measurement, e.g., when timers
|
||||
are not precise enough to capture short-running tasks. This parameter is
|
||||
also critical when devices need a certain minimum running time to "warm
|
||||
up," such as GPUs that need time to reach a performance power state.
|
||||
cooldown_interval: float, optional
|
||||
The cool down interval between two measurements.
|
||||
check_correctness: bool, optional
|
||||
Whether check correctness after measurement. This will use llvm cpu target to
|
||||
call your template and get the reference output.
|
||||
This can work for TOPI templates, but may not work for your custom template.
|
||||
|
||||
Note
|
||||
----
|
||||
This is a "fake" local mode. We start a silent rpc tracker and rpc server
|
||||
for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
|
||||
"""
|
||||
def __init__(self,
|
||||
timeout=10,
|
||||
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
|
||||
check_correctness=False):
|
||||
super(LocalRunner, self).__init__('', None, None, 0,
|
||||
timeout=timeout, n_parallel=1,
|
||||
number=number, repeat=repeat,
|
||||
min_repeat_ms=min_repeat_ms,
|
||||
cooldown_interval=cooldown_interval,
|
||||
check_correctness=check_correctness)
|
||||
self.tracker = None
|
||||
self.server = None
|
||||
|
||||
def set_task(self, task):
|
||||
self.task = task
|
||||
|
||||
# convert convenient string to function object
|
||||
attach_objects = None
|
||||
if measure_func == 'local':
|
||||
# start temporary rpc tracker and rpc server for the user
|
||||
from ...rpc.tracker import Tracker
|
||||
from ...rpc.server import Server
|
||||
|
||||
|
@ -133,360 +343,215 @@ def create_measure_batch(task, option):
|
|||
key=device_key,
|
||||
use_popen=True, silent=True,
|
||||
tracker_addr=(tracker.host, tracker.port))
|
||||
self.key = device_key
|
||||
self.host = tracker.host
|
||||
self.port = tracker.port
|
||||
|
||||
measure_func = rpc(device_key, tracker.host, tracker.port)
|
||||
attach_objects = (server, tracker)
|
||||
|
||||
build_kwargs = {}
|
||||
if build_func == 'default':
|
||||
build_func = default_build_func
|
||||
if build_func == 'ndk':
|
||||
build_func = default_build_func
|
||||
build_kwargs['use_ndk'] = True
|
||||
|
||||
# check the availability of remote devices
|
||||
if hasattr(measure_func, 'rpc_info'):
|
||||
rpc_info = measure_func.rpc_info
|
||||
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
|
||||
logger.info("Get devices for measurement successfully!")
|
||||
else:
|
||||
raise RuntimeError("Cannot get remote devices from the tracker. "
|
||||
"Please check the status of tracker by "
|
||||
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
|
||||
"and make sure you have free devices on the queue status.")
|
||||
|
||||
# add device info of cuda and opencl target
|
||||
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
|
||||
and hasattr(measure_func, 'rpc_info'):
|
||||
rpc_info = measure_func.rpc_info
|
||||
add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]),
|
||||
build_kwargs)
|
||||
|
||||
if check_correctness:
|
||||
# use llvm cpu to generate a reference input/output
|
||||
# this option works for tuning topi, but might not work for you custom op
|
||||
with _target.create("llvm"):
|
||||
s, arg_bufs = task.instantiate(task.config_space.get(0))
|
||||
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
|
||||
for x in arg_bufs]
|
||||
func = build(s, arg_bufs, "llvm")
|
||||
tvm_buf = [nd.array(x) for x in ref_input]
|
||||
func(*tvm_buf)
|
||||
ref_output = [x.asnumpy() for x in tvm_buf]
|
||||
else:
|
||||
ref_input = ref_output = None
|
||||
|
||||
def measure_batch(measure_inputs):
|
||||
"""measure the time cost for a batch of configs in real machines"""
|
||||
if replay_db is not None:
|
||||
partial_results, measure_inputs = \
|
||||
filter_inputs(replay_db, measure_inputs, retry=False)
|
||||
|
||||
# launch measure jobs in parallel
|
||||
pack_size = getattr(measure_func, "pack_size", 1) # measure `pack_size` inputs in one job
|
||||
futures = []
|
||||
for i in range(0, len(measure_inputs), pack_size):
|
||||
input_pack = measure_inputs[i:i + pack_size]
|
||||
ret = executor.submit(
|
||||
measure_func,
|
||||
input_pack,
|
||||
build_func,
|
||||
build_kwargs,
|
||||
number,
|
||||
repeat,
|
||||
ref_input,
|
||||
ref_output)
|
||||
futures.append(ret)
|
||||
|
||||
# transform results
|
||||
results = []
|
||||
for future in futures:
|
||||
result = future.get()
|
||||
if isinstance(result, Exception):
|
||||
tstamp = time.time()
|
||||
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
|
||||
timeout, tstamp)] * pack_size)
|
||||
else:
|
||||
results.extend(result)
|
||||
|
||||
if replay_db is not None:
|
||||
result_idx = 0
|
||||
for i in range(len(partial_results)):
|
||||
if partial_results[i] is None:
|
||||
partial_results[i] = results[result_idx]
|
||||
result_idx += 1
|
||||
return partial_results
|
||||
return results
|
||||
|
||||
measure_batch.n_parallel = n_parallel
|
||||
# attach server and tracker object to avoid them of being garbage-collected
|
||||
measure_batch.attach_objects = attach_objects
|
||||
return measure_batch
|
||||
super(LocalRunner, self).set_task(task)
|
||||
return server, tracker
|
||||
|
||||
|
||||
def rpc(key,
|
||||
host=None,
|
||||
port=None,
|
||||
priority=1,
|
||||
session_timeout=60,
|
||||
pack_size=1):
|
||||
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
|
||||
"""Common part for building a configuration"""
|
||||
target, task, config = measure_input
|
||||
|
||||
with target:
|
||||
s, args = task.instantiate(config)
|
||||
|
||||
# check invalidity of template and code hash consistency
|
||||
if not config.valid():
|
||||
raise InstantiationError(config.errors)
|
||||
|
||||
opts = build_option or {}
|
||||
if check_gpu: # Add verify pass to filter out invalid configs in advance.
|
||||
opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
|
||||
if cuda_arch:
|
||||
set_cuda_target_arch(cuda_arch)
|
||||
|
||||
with build_config(**opts):
|
||||
func = build(s, args, target_host=task.target_host)
|
||||
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
|
||||
|
||||
|
||||
def default_build_func(measure_input, tmp_dir, **kwargs):
|
||||
"""
|
||||
Create a standard measure_func which uses RPC Tracker for measurement.
|
||||
This measure_func will request a device from the RPC Tracker and
|
||||
upload the built binary library to that device for measurement.
|
||||
Default build func. This can work for cuda, opencl, llvm backend
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The registered key of the device in tracker. The tuner will request devices for
|
||||
measurement by this key.
|
||||
host: str, optional
|
||||
The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST"
|
||||
port: int, optional
|
||||
The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
Priority of this task, used by scheduler in tracker
|
||||
session_timeout: int, optional
|
||||
Timeout of rpc session
|
||||
pack_size: int, optional
|
||||
The number of configs measure in one RPC session.
|
||||
Usually this can be set to 1. If your device has high overhead to establish a
|
||||
rpc connection, set this higher.
|
||||
measure_input: MeasureInput
|
||||
The input of measurement
|
||||
tmp_dir: str
|
||||
The path of temporary directory to export generated library
|
||||
"""
|
||||
def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
|
||||
"""Do measurement for a list of inputs inside a same RPC session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_pack: List of MeasureInput
|
||||
The inputs of measurement
|
||||
build_func: callable
|
||||
Function for building the code. see :any:`default_build_func` for example
|
||||
build_kwargs: dict
|
||||
Extra arguments for build_func
|
||||
number : int, optional
|
||||
Number of times to do the measurement for average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
each of which is the average of `number` test run.
|
||||
ref_input: List of numpy array
|
||||
Reference input for correctness check
|
||||
ref_output: List of numpy array
|
||||
Reference output for correctness check
|
||||
|
||||
Returns
|
||||
-------
|
||||
results: List of MeasureResult
|
||||
The results for input_pack
|
||||
"""
|
||||
remote_args = (key, (host, port), priority, session_timeout)
|
||||
|
||||
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
||||
ref_input, ref_output,
|
||||
remote_args)
|
||||
return res
|
||||
|
||||
fmeasure.pack_size = pack_size
|
||||
fmeasure.rpc_info = {"key": key, "host": host, "port": port}
|
||||
return fmeasure
|
||||
tic = time.time()
|
||||
try:
|
||||
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
|
||||
func, arg_info = _build_func_common(measure_input, **kwargs)
|
||||
func.export_library(filename)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return BuildResult(None, None, e, time.time() - tic)
|
||||
return BuildResult(filename, arg_info, None, time.time() - tic)
|
||||
|
||||
|
||||
def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
||||
ref_input=None, ref_output=None, remote_args=None):
|
||||
"""Measure the time cost for a pack of inputs.
|
||||
|
||||
(Note: A pack is a list of inputs which will be measured inside a same RPC session)
|
||||
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
|
||||
"""
|
||||
Build function for android device using ndk.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_pack : list of MeasureInput
|
||||
The inputs we need to evaluate
|
||||
build_func : function takes MeasureInput returns tuple of (time_func, ctx, args)
|
||||
The build function used to build each input.
|
||||
build_kwargs: Dict
|
||||
The extra keyword arguments to build_func
|
||||
measure_input: MeasureInput
|
||||
The input of measurement
|
||||
tmp_dir: str
|
||||
The path of temporary directory to export generated library
|
||||
"""
|
||||
tic = time.time()
|
||||
try:
|
||||
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
|
||||
func, arg_info = _build_func_common(measure_input, **kwargs)
|
||||
func.export_library(filename, ndk.create_shared)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return BuildResult(None, None, e, time.time() - tic)
|
||||
return BuildResult(filename, arg_info, None, time.time() - tic)
|
||||
|
||||
|
||||
def run_through_rpc(measure_input, build_result,
|
||||
number, repeat, cooldown_interval,
|
||||
remote_args, ref_input=None, ref_output=None):
|
||||
"""Run a generated library through rpc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measure_input: MeasureInput
|
||||
The raw measure input
|
||||
build_result: BuildResult
|
||||
The result returned from Builder. This contains the path to the generated library.
|
||||
number : int, optional
|
||||
Number of times to do the measurement for average
|
||||
Number of times to do measurement for tasking average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
each of which is the average of `number` test run.
|
||||
ref_input: Array of np.ndarray, optional
|
||||
Reference input for checking correctness
|
||||
ref_output: Array of np.ndarray, optional
|
||||
Reference output for checking correctness
|
||||
remote_args: Tuple, optional
|
||||
The arguments to request_remote. If is not None, will use remote rpc devices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res_pack : Array of MeasureResult
|
||||
The list of results of measurement.
|
||||
cooldown_interval: float
|
||||
The cool down interval between two measurements
|
||||
remote_args: Tuple
|
||||
The argument for request_remote
|
||||
ref_input: List of np.ndarray
|
||||
The reference input used for checking correctness
|
||||
ref_output: List of np.ndarray
|
||||
The reference output used for checking correctness
|
||||
"""
|
||||
res_pack = []
|
||||
tmp_dir = util.tempdir() if remote_args else None
|
||||
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
|
||||
if isinstance(build_result, MeasureResult):
|
||||
return build_result
|
||||
|
||||
for inp in input_pack:
|
||||
tic = time.time()
|
||||
tic = time.time()
|
||||
errno = MeasureErrorNo.NO_ERROR
|
||||
try:
|
||||
# upload built module
|
||||
remote = request_remote(*remote_args)
|
||||
remote.upload(build_result.filename)
|
||||
func = remote.load_module(os.path.split(build_result.filename)[1])
|
||||
ctx = remote.context(str(measure_input.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
|
||||
# build function
|
||||
try:
|
||||
func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs)
|
||||
except TVMError as exc:
|
||||
tstamp = time.time()
|
||||
msg = str(exc)
|
||||
if "Stack trace returned" in msg:
|
||||
msg = msg[:msg.index("Stack trace returned")]
|
||||
if "InstantiationError" in msg:
|
||||
try:
|
||||
msg = msg.split('\n')[-2].split(": ")[1]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
res_pack.append(MeasureResult((InstantiationError(msg),),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
tstamp - tic, tstamp))
|
||||
else:
|
||||
res_pack.append(MeasureResult((RuntimeError(msg),),
|
||||
MeasureErrorNo.COMPILE_HOST,
|
||||
tstamp - tic, tstamp))
|
||||
continue
|
||||
except InstantiationError as e:
|
||||
tstamp = time.time()
|
||||
res_pack.append(MeasureResult((InstantiationError(str(e)),),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
tstamp - tic, tstamp))
|
||||
continue
|
||||
# set input
|
||||
if ref_input:
|
||||
args = [nd.array(x, ctx=ctx) for x in ref_input]
|
||||
else:
|
||||
args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
|
||||
|
||||
# measure time
|
||||
errno = MeasureErrorNo.NO_ERROR
|
||||
try:
|
||||
# upload built module
|
||||
if remote_args:
|
||||
remote = request_remote(*remote_args)
|
||||
remote.upload(tmp_dir.relpath(filename))
|
||||
func = remote.load_module(filename)
|
||||
ctx = remote.context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
else:
|
||||
ctx = context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
costs = time_f(*args).results
|
||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||
costs = list(costs)
|
||||
costs.sort()
|
||||
costs = tuple(costs[1:-1])
|
||||
|
||||
# set input
|
||||
if ref_input:
|
||||
args = [nd.array(x, ctx=ctx) for x in ref_input]
|
||||
else:
|
||||
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
|
||||
for x in arg_bufs]
|
||||
|
||||
costs = time_f(*args).results
|
||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||
costs = list(costs)
|
||||
costs.sort()
|
||||
costs = tuple(costs[1:-1])
|
||||
|
||||
# check correctness of output
|
||||
if ref_output:
|
||||
for expected, real in zip(ref_output, args):
|
||||
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
|
||||
logger.warning("Wrong Answer!")
|
||||
errno = MeasureErrorNo.WRONG_ANSWER
|
||||
except TVMError as exc:
|
||||
msg = str(exc)
|
||||
if "Stack trace returned" in msg:
|
||||
msg = msg[:msg.index("Stack trace returned")]
|
||||
if "CUDA Source" in msg:
|
||||
msg = msg[:msg.index("CUDA Source")]
|
||||
costs = (RuntimeError(msg),)
|
||||
errno = MeasureErrorNo.RUNTIME_DEVICE
|
||||
tstamp = time.time()
|
||||
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
|
||||
return res_pack
|
||||
# check correctness of output
|
||||
if ref_output:
|
||||
for expected, real in zip(ref_output, args):
|
||||
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
|
||||
logger.warning("Wrong Answer!")
|
||||
errno = MeasureErrorNo.WRONG_ANSWER
|
||||
except TVMError as exc:
|
||||
msg = str(exc)
|
||||
if "Stack trace returned" in msg:
|
||||
msg = msg[:msg.index("Stack trace returned")]
|
||||
if "CUDA Source" in msg:
|
||||
msg = msg[:msg.index("CUDA Source")]
|
||||
costs = (RuntimeError(msg[:1024]),)
|
||||
errno = MeasureErrorNo.RUNTIME_DEVICE
|
||||
tstamp = time.time()
|
||||
time.sleep(cooldown_interval)
|
||||
return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
|
||||
|
||||
|
||||
def default_build_func(inp, tmp_dir=None, **kwargs):
|
||||
"""Build function module. Exception will be raised when any error occurs
|
||||
def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
|
||||
"""Request a remote session
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inp: MeasureInput
|
||||
The input of this measurement
|
||||
tmp_dir: tvm.contrib.util.TempDirectory, optional
|
||||
The temporary directory for exporting built binary library.
|
||||
If is not None (in RPC mode), the library in this directory will be uploaded to
|
||||
remote devices.
|
||||
kwargs: Dict, optional
|
||||
Other extra arguments
|
||||
device_key: string
|
||||
The device key of registered device in tracker
|
||||
host: host, optional
|
||||
The host address of rpc tracker.
|
||||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
port: int, optional
|
||||
The port of rpc tracker.
|
||||
If is none, will use environment variable "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this session (units: second)
|
||||
|
||||
Returns
|
||||
------
|
||||
session: RPCSession
|
||||
"""
|
||||
# connect to the tracker
|
||||
host = host or os.environ['TVM_TRACKER_HOST']
|
||||
port = port or int(os.environ['TVM_TRACKER_PORT'])
|
||||
|
||||
tracker = _rpc.connect_tracker(host, port)
|
||||
remote = tracker.request(device_key, priority=priority,
|
||||
session_timeout=timeout)
|
||||
return remote
|
||||
|
||||
|
||||
def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10):
|
||||
"""
|
||||
Check the availability of a remote device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The wanted compilation target
|
||||
device_key: string
|
||||
device key of registered device in tracker
|
||||
host: host, optional
|
||||
The host address of rpc tracker.
|
||||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
port: int, optional
|
||||
The port address of rpc tracker.
|
||||
If is none, will use environment variable "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this check (units: seconds).
|
||||
|
||||
Returns
|
||||
-------
|
||||
func: Function
|
||||
TVM built function. Typically this is the return value of tvm.build.
|
||||
args: Array of Buffer or Tensor
|
||||
The argument list for the function. Typically this is the second argument of tvm.build.
|
||||
filename: str
|
||||
The filename of the output build library
|
||||
available: bool
|
||||
True if can find available device
|
||||
"""
|
||||
# build function
|
||||
with inp.target:
|
||||
s, args = inp.task.instantiate(inp.config)
|
||||
|
||||
# check invalidity of template and code hash consistency
|
||||
if not inp.config.valid():
|
||||
raise InstantiationError(inp.config.errors)
|
||||
code_hash = getattr(s, 'code_hash', None)
|
||||
if inp.config.code_hash != code_hash:
|
||||
raise HashMismatchError('got {0:s}, expected {1:s}'
|
||||
.format(str(inp.config.code_hash), str(code_hash)))
|
||||
|
||||
opts = {}
|
||||
if "check_gpu" in kwargs: # Add verify pass to filter out invalid configs in advance.
|
||||
opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))]
|
||||
if 'cuda_arch' in kwargs:
|
||||
set_cuda_target_arch(kwargs['cuda_arch'])
|
||||
|
||||
with build_config(**opts):
|
||||
func = build(s, args, target_host=inp.task.target_host)
|
||||
|
||||
# export library to temp directory
|
||||
if tmp_dir:
|
||||
if kwargs.get('use_ndk', False): # for Android NDK
|
||||
filename = "tmp_func_%0x.so" % getrandbits(64)
|
||||
func.export_library(tmp_dir.relpath(filename), ndk.create_shared)
|
||||
else:
|
||||
filename = "tmp_func_%0x.tar" % getrandbits(64)
|
||||
func.export_library(tmp_dir.relpath(filename))
|
||||
else:
|
||||
filename = None
|
||||
|
||||
return func, args, filename
|
||||
|
||||
|
||||
def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs):
|
||||
"""Add device info for gpu target.
|
||||
The info will be used to check the validity of generated code."""
|
||||
remote = request_remote(device_key, rpc_tracker_addr)
|
||||
ctx = remote.context(str(target), 0)
|
||||
max_dims = ctx.max_thread_dimensions
|
||||
kwargs['check_gpu'] = {
|
||||
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
|
||||
'max_threads_per_block': ctx.max_threads_per_block,
|
||||
'max_thread_x': max_dims[0],
|
||||
'max_thread_y': max_dims[1],
|
||||
'max_thread_z': max_dims[2],
|
||||
}
|
||||
|
||||
if 'cuda' in target.keys:
|
||||
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
|
||||
|
||||
def set_cuda_target_arch(arch):
|
||||
"""set target architecture of nvcc compiler"""
|
||||
AutotvmGlobalScope.current.cuda_target_arch = arch
|
||||
def _check():
|
||||
remote = request_remote(device_key, host, port, priority)
|
||||
remote.context(str(target))
|
||||
t = threading.Thread(target=_check,)
|
||||
t.start()
|
||||
t.join(timeout)
|
||||
return not t.is_alive()
|
||||
|
||||
|
||||
@register_func
|
||||
|
@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
|
|||
return ptx
|
||||
|
||||
|
||||
def set_cuda_target_arch(arch):
|
||||
"""set target architecture of nvcc compiler
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arch: str
|
||||
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
|
||||
"""
|
||||
AutotvmGlobalScope.current.cuda_target_arch = arch
|
||||
|
||||
|
||||
def gpu_verify_pass(**kwargs):
|
||||
"""Verify the validity of a gpu kernel.
|
||||
This pass will check memory usage and number of threads per block.
|
||||
|
|
|
@ -22,7 +22,7 @@ class GATuner(Tuner):
|
|||
mutation_prob: float
|
||||
probability of mutation of a knob in a gene
|
||||
"""
|
||||
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
|
||||
def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
|
||||
super(GATuner, self).__init__(task)
|
||||
|
||||
# algorithm configurations
|
||||
|
|
|
@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
|||
|
||||
new_scores = model.predict(new_points)
|
||||
|
||||
ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
|
||||
ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
|
||||
ac_index = np.random.random(len(ac_prob)) < ac_prob
|
||||
|
||||
points[ac_index] = new_points[ac_index]
|
||||
|
|
|
@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
|
|||
target=target, target_host=target_host)
|
||||
return task, target
|
||||
|
||||
|
||||
def test_task_tuner_without_measurement():
|
||||
"""test task and tuner without measurement"""
|
||||
task, target = get_sample_task()
|
||||
|
||||
def custom_measure(input_pack, build_func, build_args, number, repeat,
|
||||
ref_input, ref_output):
|
||||
from tvm.autotvm import MeasureResult
|
||||
|
||||
results = []
|
||||
for inp in input_pack:
|
||||
tic = time.time()
|
||||
# do nothing
|
||||
time.sleep(0.001)
|
||||
results.append(MeasureResult([time.time() - tic], 0,
|
||||
time.time() - tic, time.time()))
|
||||
return results
|
||||
measure_option = autotvm.measure_option(custom_measure)
|
||||
|
||||
logging.info("%s", task.config_space)
|
||||
|
||||
# new tuner and recorder
|
||||
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
|
||||
tuner = tuner_class(task)
|
||||
tuner.tune(n_trial=10, measure_option=measure_option)
|
||||
assert tuner.best_flops > 1
|
||||
|
||||
def test_tuning_with_measure():
|
||||
def test_tuning():
|
||||
def check(target, target_host):
|
||||
ctx = tvm.context(target, 0)
|
||||
if not ctx.exist:
|
||||
|
@ -141,12 +114,12 @@ def test_tuning_with_measure():
|
|||
task, target = get_sample_task(target, target_host)
|
||||
logging.info("%s", task.config_space)
|
||||
|
||||
measure_option = autotvm.measure_option('local',
|
||||
timeout=4,
|
||||
number=2)
|
||||
measure_option = autotvm.measure_option(
|
||||
autotvm.LocalBuilder(),
|
||||
autotvm.LocalRunner())
|
||||
|
||||
tuner = RandomTuner(task)
|
||||
tuner.tune(n_trial=10, measure_option=measure_option)
|
||||
tuner.tune(n_trial=20, measure_option=measure_option)
|
||||
|
||||
check("cuda", None)
|
||||
check("opencl", None)
|
||||
|
@ -155,6 +128,4 @@ if __name__ == "__main__":
|
|||
# only print log when invoked from main
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
test_task_tuner_without_measurement()
|
||||
test_tuning_with_measure()
|
||||
|
||||
test_tuning()
|
||||
|
|
|
@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
|
|||
|
||||
return s, [A, B, C]
|
||||
|
||||
@autotvm.template
|
||||
def bad_matmul(N, L, M, dtype):
|
||||
if 'bad_device' in tvm.target.current_target().keys:
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L-1), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_y", y, num_outputs=2)
|
||||
cfg.define_split("tile_x", x, num_outputs=2)
|
||||
return s, [A, B, C]
|
||||
|
||||
return matmul(N, L, M, dtype)
|
||||
|
||||
def get_sample_task(n=128):
|
||||
"""return a sample task for testing"""
|
||||
target = tvm.target.create("llvm")
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
"""Test database"""
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm import database
|
||||
from tvm.autotvm.measure.measure_methods import HashMismatchError
|
||||
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
|
||||
from tvm.autotvm.record import encode, MeasureResult
|
||||
|
||||
from test_autotvm_common import get_sample_task, get_sample_records
|
||||
from test_autotvm_common import get_sample_records
|
||||
|
||||
def test_save_load():
|
||||
logging.info("test basic db load/save ...")
|
||||
|
@ -35,66 +29,6 @@ def test_save_load():
|
|||
|
||||
TRIAL_LIMIT = 2
|
||||
|
||||
def test_db_filter():
|
||||
logging.info("test db filter ...")
|
||||
|
||||
# Pick a GPU target because there are more likely to be failures/invalid configs
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
batch_size = 2
|
||||
|
||||
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
ct = 0
|
||||
all_inputs = list()
|
||||
all_results = list()
|
||||
batches = list()
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
while ct < TRIAL_LIMIT:
|
||||
inputs = list()
|
||||
for i in range(batch_size):
|
||||
cfg = tuner.next_batch(1)[0]
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
all_inputs.append(inputs[-1])
|
||||
batches.append(inputs)
|
||||
results = measure_batch(inputs)
|
||||
all_results += results
|
||||
ct += 1
|
||||
|
||||
del measure_batch
|
||||
|
||||
db = database.DummyDatabase()
|
||||
db.flush()
|
||||
|
||||
# First setting, memoize one input at a time, check that each is saved and replayed
|
||||
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
for i in range(len(all_inputs)+1):
|
||||
db.flush()
|
||||
for j in range(i):
|
||||
db.save(all_inputs[j], all_results[j])
|
||||
|
||||
for k in range(len(batches)):
|
||||
batch = batches[k]
|
||||
batch_result = measure_batch(batch)
|
||||
for l in range(batch_size):
|
||||
all_idx = k*batch_size + l
|
||||
assert batch_result[l] is not None
|
||||
if all_idx < i:
|
||||
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
|
||||
"(no retry) EXPECTED MATCH, GOT MISMATCH"
|
||||
else:
|
||||
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
|
||||
"(no retry) EXPECTED MISMATCH, GOT MATCH"
|
||||
|
||||
del measure_batch
|
||||
|
||||
def test_db_hash():
|
||||
logging.info("test db hash check ...")
|
||||
inp1, res1 = get_sample_records(1)[0]
|
||||
|
@ -149,89 +83,8 @@ def test_db_latest_all():
|
|||
assert encode(inp1, load4[1]) == encode(inp1, res2)
|
||||
assert encode(inp1, load4[2]) == encode(inp1, res3)
|
||||
|
||||
def test_db_save_replay():
|
||||
logging.info("test db save (from measure_batch) and replay ...")
|
||||
_db = database.DummyDatabase()
|
||||
_db.flush()
|
||||
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
measure_option = autotvm.measure_option('local',
|
||||
do_fork=False,
|
||||
timeout=2,
|
||||
replay_db=_db)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
batch_size = 2
|
||||
|
||||
ct = 0
|
||||
all_inputs = list()
|
||||
all_results = list()
|
||||
batches = list()
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
while ct < TRIAL_LIMIT:
|
||||
inputs = list()
|
||||
for i in range(batch_size):
|
||||
cfg = tuner.next_batch(1)[0]
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
all_inputs.append(inputs[-1])
|
||||
batches.append(inputs)
|
||||
results = measure_batch(inputs)
|
||||
all_results += results
|
||||
ct += 1
|
||||
callback = autotvm.callback.log_to_database(_db)
|
||||
callback(None, all_inputs, all_results)
|
||||
|
||||
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
|
||||
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
|
||||
|
||||
all_results_2 = measure_batch(all_inputs)
|
||||
all_results_3 = measure_batch(all_inputs)
|
||||
|
||||
for i in range(len(all_results)):
|
||||
encr1 = encode(all_inputs[i], all_results[i])
|
||||
encr2 = encode(all_inputs[i], all_results_2[i])
|
||||
encr3 = encode(all_inputs[i], all_results_3[i])
|
||||
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
|
||||
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
|
||||
|
||||
del measure_batch
|
||||
|
||||
def test_check_hashmismatch():
|
||||
logging.info("test hash mismatch check")
|
||||
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
measure_option = autotvm.measure_option('local', do_fork=False)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
inputs = list()
|
||||
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
|
||||
# notvalidh is not a valid CRC32 hash (not hex)
|
||||
cfg.code_hash = 'notvalidh'
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
|
||||
try:
|
||||
results = measure_batch(inputs)
|
||||
assert False, "HashMismatchError should be raised"
|
||||
except HashMismatchError:
|
||||
pass
|
||||
|
||||
del measure_batch
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test_save_load()
|
||||
test_db_filter()
|
||||
test_db_hash()
|
||||
test_db_latest_all()
|
||||
test_db_save_replay()
|
||||
test_check_hashmismatch()
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
"""Test builder and runner"""
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
from test_autotvm_common import get_sample_task, bad_matmul
|
||||
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
|
||||
|
||||
def test_task_tuner_without_measurement():
|
||||
"""test task and tuner without measurement"""
|
||||
task, target = get_sample_task()
|
||||
|
||||
class DummyRunner(Runner):
|
||||
def __init__(self):
|
||||
super(DummyRunner, self).__init__(1, 1)
|
||||
|
||||
def run(self, measure_inputs, build_results):
|
||||
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
|
||||
for _ in range(len(measure_inputs))]
|
||||
|
||||
def get_build_kwargs(self):
|
||||
return {}
|
||||
|
||||
measure_option = autotvm.measure_option(
|
||||
builder=autotvm.LocalBuilder(),
|
||||
runner=DummyRunner()
|
||||
)
|
||||
|
||||
logging.info("%s", task.config_space)
|
||||
|
||||
for tuner_class in [autotvm.tuner.RandomTuner,
|
||||
autotvm.tuner.GridSearchTuner,
|
||||
autotvm.tuner.GATuner,
|
||||
autotvm.tuner.XGBTuner]:
|
||||
tuner = tuner_class(task)
|
||||
tuner.tune(n_trial=10, measure_option=measure_option)
|
||||
assert tuner.best_flops > 1
|
||||
|
||||
def test_check_correctness():
|
||||
task, target = get_sample_task()
|
||||
|
||||
measure_option = autotvm.measure_option(
|
||||
builder=autotvm.LocalBuilder(),
|
||||
runner=autotvm.LocalRunner(check_correctness=True)
|
||||
)
|
||||
|
||||
def _callback_correct(tuner, measure_inputs, measure_results):
|
||||
for inp, res in zip(measure_inputs, measure_results):
|
||||
assert res.error_no == 0
|
||||
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
tuner.tune(n_trial=2, measure_option=measure_option,
|
||||
callbacks=[_callback_correct])
|
||||
|
||||
# a bad template
|
||||
n = 128
|
||||
target = tvm.target.create("llvm -device=bad_device")
|
||||
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
|
||||
|
||||
def _callback_wrong(tuner, measure_inputs, measure_results):
|
||||
for inp, res in zip(measure_inputs, measure_results):
|
||||
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
|
||||
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
tuner.tune(n_trial=2, measure_option=measure_option,
|
||||
callbacks=[_callback_wrong])
|
||||
|
||||
|
||||
def test_min_repeat_ms():
|
||||
task, target = get_sample_task()
|
||||
|
||||
measure_option = autotvm.measure_option(
|
||||
builder=autotvm.LocalBuilder(),
|
||||
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
|
||||
)
|
||||
|
||||
def _callback(tuner, measure_inputs, measure_results):
|
||||
for inp, res in zip(measure_inputs, measure_results):
|
||||
if res.error_no != 0:
|
||||
continue
|
||||
|
||||
assert 1000 * np.mean(res.costs) * \
|
||||
measure_option['runner'].cur_number >= 100
|
||||
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
tuner.tune(n_trial=5, measure_option=measure_option,
|
||||
callbacks=[_callback])
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
test_task_tuner_without_measurement()
|
||||
test_check_correctness()
|
||||
test_min_repeat_ms()
|
|
@ -137,12 +137,15 @@ if __name__ == '__main__':
|
|||
print(task.config_space)
|
||||
|
||||
measure_option = autotvm.measure_option(
|
||||
measure_func='local', number=10, n_parallel=8, timeout=20)
|
||||
builder=autotvm.LocalBuilder(),
|
||||
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
|
||||
)
|
||||
|
||||
log_name = 'gemm_int8.log'
|
||||
if DO_TUNING:
|
||||
tuner = autotvm.tuner.XGBTuner(task)
|
||||
tuner.tune(n_trial=1000, measure_option=measure_option,
|
||||
callbacks=[autotvm.callback.log_to_file(log_name)])
|
||||
callbacks=[autotvm.callback.log_to_file(log_name)])
|
||||
|
||||
dispatch_context = autotvm.apply_history_best(log_name)
|
||||
best_config = dispatch_context.query(task.target, task.workload)
|
||||
|
|
|
@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
|
|||
target='cuda')
|
||||
print(task.config_space)
|
||||
|
||||
# use local gpu, measure 5 times for every config to reduce variance
|
||||
# run 8 parallel threads for compilation
|
||||
measure_option = autotvm.measure_option('local',
|
||||
number=5,
|
||||
n_parallel=8,
|
||||
timeout=20)
|
||||
# use local gpu, measure 10 times for every config to reduce variance
|
||||
# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
|
||||
measure_option = autotvm.measure_option(
|
||||
builder=autotvm.LocalBuilder(),
|
||||
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
|
||||
)
|
||||
|
||||
# begin tuning, log records to file `conv2d.log`
|
||||
tuner = autotvm.tuner.XGBTuner(task)
|
||||
|
|
|
@ -65,15 +65,20 @@ def get_network(name, batch_size):
|
|||
input_shape = (batch_size, 3, 224, 224)
|
||||
output_shape = (batch_size, 1000)
|
||||
|
||||
if name =='resnet-18':
|
||||
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
|
||||
elif name =='mobilenet':
|
||||
if "resnet" in name:
|
||||
n_layer = int(name.split('-')[1])
|
||||
net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
|
||||
elif "vgg" in name:
|
||||
n_layer = int(name.split('-')[1])
|
||||
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
|
||||
elif name == 'mobilenet':
|
||||
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
|
||||
elif name =='squeezenet v1.1':
|
||||
elif name == 'squeezenet_v1.1':
|
||||
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
|
||||
elif name =='vgg-16':
|
||||
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size)
|
||||
elif name =='custom':
|
||||
elif name == 'inception_v3':
|
||||
input_shape = (1, 3, 299, 299)
|
||||
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
|
||||
elif name == 'custom':
|
||||
# an example for custom network
|
||||
from nnvm.testing import utils
|
||||
net = nnvm.sym.Variable('data')
|
||||
|
@ -92,6 +97,7 @@ def get_network(name, batch_size):
|
|||
|
||||
return net, params, input_shape, output_shape
|
||||
|
||||
|
||||
#################################################################
|
||||
# Start RPC Tracker
|
||||
# -----------------
|
||||
|
@ -158,6 +164,8 @@ def get_network(name, batch_size):
|
|||
# rk3399 2 2 0
|
||||
# rpi3b 11 11 0
|
||||
# ----------------------------------
|
||||
#
|
||||
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
|
||||
|
||||
###########################################
|
||||
# Set Tuning Options
|
||||
|
@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network)
|
|||
dtype = 'float32'
|
||||
|
||||
tuning_option = {
|
||||
'log_filename': log_file,
|
||||
'log_filename': log_file,
|
||||
|
||||
'tuner': 'xgb',
|
||||
'n_trial': 1000,
|
||||
'early_stopping': 250,
|
||||
'tuner': 'xgb',
|
||||
'n_trial': 1000,
|
||||
'early_stopping': 400,
|
||||
|
||||
'measure_option': autotvm.measure_option(
|
||||
autotvm.measure.rpc(device_key, host='localhost', port=9190),
|
||||
number=4,
|
||||
n_parallel=1,
|
||||
timeout=10,
|
||||
build_func='ndk' if use_android else 'default',
|
||||
),
|
||||
'measure_option': autotvm.measure_option(
|
||||
builder=autotvm.LocalBuilder(
|
||||
build_func='ndk' if use_android else 'default'),
|
||||
runner=autotvm.RPCRunner(
|
||||
device_key, host='localhost', port=9190,
|
||||
number=5,
|
||||
timeout=4,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
####################################################################
|
||||
#
|
||||
# .. note:: How to set tuning options
|
||||
#
|
||||
# In general, the default value provided here works well. It is the same
|
||||
# value that we used to generate pre-tuned parameters.
|
||||
# If you have multiple devices, you can set :code:`n_parallel` to
|
||||
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
|
||||
# boards to the tracker).
|
||||
# In general, the default value provided here works well.
|
||||
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
|
||||
# which makes the tuning run longer.
|
||||
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
|
||||
# consider setting timeout larger.
|
||||
#
|
||||
|
||||
###################################################################
|
||||
|
@ -219,7 +223,7 @@ tuning_option = {
|
|||
# ------------
|
||||
# Now we can extract tuning tasks from the network and begin tuning.
|
||||
# Here we provide a simple utility function to tune a list of tasks.
|
||||
# This function is just an initial implementation which tune them in sequential order.
|
||||
# This function is just an initial implementation which tunes them in sequential order.
|
||||
# Later we will bring more sophisticated tuner scheduler.
|
||||
|
||||
# You can skip the implementation of this function for this tutorial.
|
||||
|
@ -236,7 +240,9 @@ def tune_tasks(tasks,
|
|||
try: # try winograd template
|
||||
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
|
||||
tasks[i].target, tasks[i].target_host, 'winograd')
|
||||
tasks.append(tsk)
|
||||
input_channel = tsk.workload[1][1]
|
||||
if input_channel >= 64:
|
||||
tasks[i] = tsk
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -245,8 +251,8 @@ def tune_tasks(tasks,
|
|||
if os.path.exists(tmp_log_file):
|
||||
os.remove(tmp_log_file)
|
||||
|
||||
for i, tsk in enumerate(tasks):
|
||||
prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
|
||||
for i, tsk in enumerate(reversed(tasks)):
|
||||
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
|
||||
|
||||
# create tuner
|
||||
if tuner == 'xgb' or tuner == 'xgb-rank':
|
||||
|
@ -280,7 +286,7 @@ def tune_tasks(tasks,
|
|||
########################################################################
|
||||
# Finally we launch tuning jobs and evaluate the end-to-end performance.
|
||||
|
||||
def tune_and_evaluate():
|
||||
def tune_and_evaluate(tuning_opt):
|
||||
# extract workloads from nnvm graph
|
||||
print("Extract tasks...")
|
||||
net, params, input_shape, out_shape = get_network(network, batch_size=1)
|
||||
|
@ -290,19 +296,18 @@ def tune_and_evaluate():
|
|||
|
||||
# run tuning tasks
|
||||
print("Tuning...")
|
||||
tune_tasks(tasks, **tuning_option)
|
||||
tune_tasks(tasks, **tuning_opt)
|
||||
|
||||
# compile kernels with history best records
|
||||
with autotvm.apply_history_best(log_file):
|
||||
print("Compile...")
|
||||
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
|
||||
graph, lib, params = nnvm.compiler.build(
|
||||
net, target=target,
|
||||
shape={'data': input_shape}, params=params, dtype=dtype)
|
||||
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
|
||||
|
||||
# export library
|
||||
tmp = tempdir()
|
||||
if tuning_option['measure_option']['build_func'] == 'ndk': # for android
|
||||
if use_android:
|
||||
from tvm.contrib import ndk
|
||||
filename = "net.so"
|
||||
lib.export_library(tmp.relpath(filename), ndk.create_shared)
|
||||
|
@ -312,8 +317,7 @@ def tune_and_evaluate():
|
|||
|
||||
# upload module to device
|
||||
print("Upload...")
|
||||
remote = autotvm.measure.request_remote(device_key,
|
||||
tracker_addr=('localhost', 9190),
|
||||
remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
|
||||
timeout=10000)
|
||||
remote.upload(tmp.relpath(filename))
|
||||
rlib = remote.load_module(filename)
|
||||
|
@ -328,47 +332,44 @@ def tune_and_evaluate():
|
|||
|
||||
# evaluate
|
||||
print("Evaluate inference time cost...")
|
||||
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
|
||||
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
|
||||
ftimer = module.module.time_evaluator("run", ctx, number=8, repeat=3)
|
||||
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
|
||||
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
|
||||
(np.mean(prof_res), np.std(prof_res)))
|
||||
|
||||
# We do not run the tuning in our webpage server since it takes too long.
|
||||
# Uncomment the following line to run by yourself.
|
||||
# tune_and_evaluate()
|
||||
|
||||
# tune_and_evaluate(tuning_option)
|
||||
|
||||
######################################################################
|
||||
# Sample Output
|
||||
# -------------
|
||||
# The tuning needs to train xgboost models and use them for prediction.
|
||||
# The tuning needs to compile many programs and extract feature from them.
|
||||
# So a high performance CPU is recommended.
|
||||
# It takes about 2 hours on a 32T AMD Ryzen CPU.
|
||||
# One sample output is
|
||||
# One sample output is listed below.
|
||||
# It takes about 2 hours on a 32T AMD Ryzen Threadripper.
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# Extract tasks...
|
||||
# Tuning...
|
||||
# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done.
|
||||
# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done.
|
||||
# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done.
|
||||
# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done.
|
||||
# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done.
|
||||
# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done.
|
||||
# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done.
|
||||
# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done.
|
||||
# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done.
|
||||
# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done.
|
||||
# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done.
|
||||
# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done.
|
||||
# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done.
|
||||
# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done.
|
||||
# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done.
|
||||
# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done.
|
||||
# [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
|
||||
# [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
|
||||
# [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
|
||||
# [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
|
||||
# [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
|
||||
# [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
|
||||
# [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
|
||||
# [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
|
||||
# [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
|
||||
# [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
|
||||
# [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
|
||||
# [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
|
||||
# Compile...
|
||||
# Upload...
|
||||
# Evaluate inference time cost...
|
||||
# Mean inference time (std dev): 157.29 ms (1.74 ms)
|
||||
# Mean inference time (std dev): 162.59 ms (0.06 ms)
|
||||
|
||||
######################################################################
|
||||
#
|
||||
|
|
|
@ -271,9 +271,12 @@ print(task.config_space)
|
|||
logging.getLogger('autotvm').setLevel(logging.DEBUG)
|
||||
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
|
||||
|
||||
# use local cpu, measure 5 times for every config to reduce variance
|
||||
measure_option = autotvm.measure_option('local',
|
||||
number=5)
|
||||
# There are two steps for measuring a config: build and run.
|
||||
# By default, we use all cpu cores to compile program. Then measure them sequentially.
|
||||
# We measure 5 times and take average to reduce variance.
|
||||
measure_option = autotvm.measure_option(
|
||||
builder='local',
|
||||
runner=autotvm.LocalRunner(number=5))
|
||||
|
||||
# begin tuning, log records to file `matmul.log`
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
|
|
Загрузка…
Ссылка в новой задаче