diff --git a/docs/api/python/autotvm.rst b/docs/api/python/autotvm.rst index f03406db..93d69050 100644 --- a/docs/api/python/autotvm.rst +++ b/docs/api/python/autotvm.rst @@ -16,6 +16,11 @@ tvm.autotvm.measure .. autofunction:: tvm.autotvm.measure.create_measure_batch +.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder + +.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner + +.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner tvm.autotvm.tuner ~~~~~~~~~~~~~~~~~ diff --git a/python/tvm/autotvm/__init__.py b/python/tvm/autotvm/__init__.py index 625b50c1..7170dbdd 100644 --- a/python/tvm/autotvm/__init__.py +++ b/python/tvm/autotvm/__init__.py @@ -22,7 +22,8 @@ from . import env from . import tophub # some shortcuts -from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo +from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \ + LocalBuilder, LocalRunner, RPCRunner from .tuner import callback from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ register_topi_compute, register_topi_schedule, \ diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py index 880dfd1f..8a612664 100644 --- a/python/tvm/autotvm/measure/__init__.py +++ b/python/tvm/autotvm/measure/__init__.py @@ -1,7 +1,7 @@ """Distributed executor infrastructure to scale up the tuning""" -from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option -from .measure_methods import request_remote, check_remote, create_measure_batch, rpc - +from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \ + create_measure_batch +from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote +from .executor import Executor from .local_executor import LocalExecutor -from .executor import Future, Executor diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index 55f1dc75..63d995c3 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs): res = exc queue.put(res) -def timeout_monitor(queue, timeout, func, args, kwargs): + +def call_with_timeout(queue, timeout, func, args, kwargs): """A wrapper to support timeout of a function call""" # start a new process for timeout (cannot use thread because we have c function) @@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs): p.start() p.join(timeout=timeout) - alive = p.is_alive() + queue.put(executor.TimeoutError()) + kill_child_processes(p.pid) p.terminate() p.join() - if alive: - queue.put(executor.TimeoutError()) - else: - if queue.empty(): - queue.put(executor.ExecutionError("Fatal error in local executor")) - class LocalFuture(executor.Future): """Local wrapper for the future @@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor): return LocalFutureNoFork(func(*args, **kwargs)) queue = Queue(2) - process = Process(target=timeout_monitor, + process = Process(target=call_with_timeout, args=(queue, self.timeout, func, args, kwargs)) process.start() return LocalFuture(process, queue) diff --git a/python/tvm/autotvm/measure/measure.py b/python/tvm/autotvm/measure/measure.py index 2d780eea..38b5f99e 100644 --- a/python/tvm/autotvm/measure/measure.py +++ b/python/tvm/autotvm/measure/measure.py @@ -1,5 +1,6 @@ # pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name """User facing API for specifying how to measure the generated code""" +import multiprocessing from collections import namedtuple class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): @@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): Specific configuration. """ + class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])): """ Stores all the results of a measurement @@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" Parameters ---------- costs: Array of float or Array of Exception - If no error occurs for this measurement, it is an array of measured running times. - If some error occurs during the measurement, it is an array of the exception objections. + If no error occurs during measurement, it is an array of measured running times. + If an error occurs during measurement, it is an array of the exception objections. error_no: int Denote error type, defined by MeasureErrorNo all_cost: float @@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" class MeasureErrorNo(object): """Error type for MeasureResult""" NO_ERROR = 0 # no error - INSTANTIATION_ERROR = 1 # error when calling template function + INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build) - COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device) + COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device) RUNTIME_DEVICE = 4 # error when run program on device WRONG_ANSWER = 5 # answer is wrong when compared to a golden output - FLEET_ERROR = 6 # error of measure infrastructure + BUILD_TIMEOUT = 6 # timeout during compilation + RUN_TIMEOUT = 7 # timeout during run + UNKNOWN_ERROR = 8 # unknown error -def measure_option(measure_func, - number=1, - repeat=1, - timeout=60, - n_parallel=1, - do_fork=True, - build_func='default', - check_correctness=False, - replay_db=None): - """Configure how to do measurement +class Builder(object): + """Builder that builds programs in tuning Parameters ---------- - measure_func: str or callable - 'local': use the local device for measurement. The tuner will start a tracker - and a RPC server silently for the user. - - callable: It is a callable function for measurement. - See the return value of measure/measure_methods.py::rpc for example. - number : int, optional - Number of times to do the measurement for average - repeat : int, optional - Number of times to repeat the measurement. - In total, the generated code will be run (1 + number x repeat) times, - where the first one is warm up. The returned result contains `repeat` costs, - each of which is the average of `number` test run. - timeout: int, optional - Timeout for a whole batch. TimeoutError will be returned as the result if a - task timeouts. + timeout: float, optional + The timeout of a build task n_parallel: int, optional - The number of measurement task that can run in parallel. - Set this according to the number of cpu cores (for compilation) and - the number of devices you have (for measuring generate code). - do_fork: bool, optional - Whether use multiprocessing (based on fork) for running measure jobs in parallel. - Set this to False if you want to debug (see trackback) or using fork is not suitable. - NOTE: If this is False, parallel and timeout do not work. - build_func: str or callable, optional - 'default': call default builder. This works for normal target (llvm, cuda) + The number of tasks submitted in parallel + By default it will use all cpu cores + """ + def __init__(self, timeout=10, n_parallel=None): + self.timeout = timeout + self.n_parallel = n_parallel or multiprocessing.cpu_count() + self.build_kwargs = {} + self.task = None - 'ndk': use Android NDK to create shared library. Use this for android target. + def set_task(self, task, build_kwargs=None): + """ + Initialize for a new tuning task - callable: customized build function for other backends (e.g. VTA). - See measure/measure_methods.py::default_build_func for example. - check_correctness: bool, optional - Whether check correctness after measurement. This will use llvm cpu target to generate - reference output. - replay_db : Database, optional - The database that we retrieve saved MeasureResult from. + Parameters + ---------- + task: Task + The tuning task + build_kwargs: dict, optional + The additional kwargs for build function + """ + self.task = task + self.build_kwargs = build_kwargs + + def build(self, measure_inputs): + """Build programs + + Parameters + ---------- + measure_inputs: List of MeasureInput + The measure input + + Returns + ------- + build_results: List of BuildResult + The build result. + """ + raise NotImplementedError() + + +class Runner(object): + """Runner that runs and measures the time cost of a generated program in tuning + + Parameters + ---------- + timeout: float, optional + The timeout of a build task + n_parallel: int, optional + The number of tasks submitted in parallel + By default it will use all cpu cores + """ + def __init__(self, timeout=5, n_parallel=None): + self.timeout = timeout + self.n_parallel = n_parallel or multiprocessing.cpu_count() + self.task = None + + def set_task(self, task): + """ + Initialize for a new tuning task + + Parameters + ---------- + task: Task + The tuning task + """ + self.task = task + + def get_build_kwargs(self): + """ + Get device specific build arguments (e.g. maximum shared memory size) + + Returns + ---------- + kwargs: dict + The additional keyword arguments + """ + raise NotImplementedError() + + def run(self, measure_inputs, build_results): + """Run amd measure built programs + + Parameters + ---------- + measure_inputs: List of MeasureInput + The raw measure input + build_results: List of BuildResults + The build results + + Returns + ------- + measure_results: List of MeasureResult + The final results of measurement + """ + raise NotImplementedError() + + +def measure_option(builder, runner): + """ + Set options for measure. To measure a config, we will build it and run it. + So we have to set options for these two steps. + They have their own options on timeout, parallel, etc. + + Parameters + ---------- + builder: Builder + Specify how to build programs + runner: Runner + Specify how to run programs + """ + from .measure_methods import LocalBuilder, LocalRunner + + if isinstance(builder, str): + if builder == 'local': + builder = LocalBuilder() + else: + raise ValueError("Invalid builder: " + builder) + + if isinstance(runner, str): + if runner == 'local': + runner = LocalRunner() + else: + raise ValueError("Invalid runner: " + runner) + + opt = { + 'builder': builder, + 'runner': runner, + } + + return opt + + +def create_measure_batch(task, option): + """Get a standard measure_batch function. + + Parameters + ---------- + task: tvm.autotvm.task.Task + The tuning task + option: dict + The option for measuring generated code. + You should use the return value of function :any:`measure_option` for this argument. Returns ------- - options: dict - A dict to store all options - - Note - ---- - To support customized measure, you can pass callable `measure_func` or - `build_func` in. The `measure_func` will call `build_func` to build binary library - and handle the logic of measurement. - - Signature: - * measure_func (see the return value of measure/measure_methods.py::rpc for example) - def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output): - return measure_results - - * build_func (see measure/measure_methods.py::default_build_func for example) - def build_func(inp, tmp_dir, **kwargs): - return func, args, filename + measure_batch: callable + a callback function to measure a batch of configs """ - return { - 'measure_func': measure_func, - 'number': number, - 'repeat': repeat, - 'timeout': timeout, - 'n_parallel': n_parallel, - 'do_fork': do_fork, - 'build_func': build_func, - 'check_correctness': check_correctness, - 'replay_db': replay_db, - } + builder = option['builder'] + runner = option['runner'] + + attach_objects = runner.set_task(task) + + # feed device related information from runner to builder + # (e.g. max shared memory for validity checking) + build_kwargs = runner.get_build_kwargs() + builder.set_task(task, build_kwargs) + + def measure_batch(measure_inputs): + build_results = builder.build(measure_inputs) + results = runner.run(measure_inputs, build_results) + return results + + measure_batch.n_parallel = builder.n_parallel + measure_batch.attach_objects = attach_objects + return measure_batch diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 2d740b94..6a3cd028 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -1,129 +1,339 @@ -# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args +# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks """ Functions that run on executor for measurement. -These functions are responsible for building tvm module, uploading it to -remote devices, recording the running time costs and checking the correctness of output + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. """ import logging +import shutil import os +import threading import time from random import getrandbits -import threading +from collections import namedtuple +import tempfile import numpy as np -from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \ - target as _target, rpc as _rpc -from ...contrib import nvcc, util, ndk +from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ + rpc as _rpc, target as _target +from ...contrib import nvcc, ndk from ..util import get_const_tuple from ..env import AutotvmGlobalScope from ..task.space import InstantiationError -from .measure import MeasureResult, MeasureErrorNo +from .measure import MeasureResult, MeasureErrorNo, Builder, Runner from .local_executor import LocalExecutor logger = logging.getLogger('autotvm') -class HashMismatchError(ValueError): - """Raised when the code hash of a submitted config doesn't match that on the - measure side """ - pass - - -def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): - """request a remote session +class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))): + """ + Stores all the necessary inputs for a measurement. Parameters ---------- - device_key: string - device key of registered device in tracker - tracker_addr: Tuple(string, int), optional - The address of rpc tracker in (host, port) format. - If is none, will use environment variable "TVM_TRACKER_HOST" - and "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this session (units: seconds) - - Returns - ------ - session: RPCSession + filename : str + The filename of generated library + arg_info : Tuple + The shape and dtype information of tvm tensor arguments + error : Exception + The error happens during compilation. + time_cost : float + The time cost of building """ - # connect to the tracker - if tracker_addr: - host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST'] - port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT']) - else: - host = os.environ['TVM_TRACKER_HOST'] - port = int(os.environ['TVM_TRACKER_PORT']) - tracker = _rpc.connect_tracker(host, port) - remote = tracker.request(device_key, priority=priority, - session_timeout=timeout) - return remote - -def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10): - """ - Check the availability of a remote device +class LocalBuilder(Builder): + """Run compilation on local machine Parameters ---------- - target: Target - The wanted compilation target - device_key: string - device key of registered device in tracker - tracker_addr: Tuple(string, int), optional - The address of rpc tracker in (host, port) format. - If is none, will use environment variable "TVM_TRACKER_HOST" - and "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this check (units: seconds). - If time is out, a RuntimeError will be raised. + timeout: float + The timeout of a compilation + n_parallel: int + The number of tasks run in parallel. "None" will use all cpu cores + build_func: callable or str + If is 'default', use default build function + If is 'ndk', use function for android ndk + If is callable, use it as custom build function """ - def _check(): - remote = request_remote(device_key, tracker_addr, priority) - remote.context(str(target)) - t = threading.Thread(target=_check,) - t.start() - t.join(timeout) - return not t.is_alive() + def __init__(self, timeout=10, n_parallel=None, build_func='default'): + super(LocalBuilder, self).__init__(timeout, n_parallel) -def create_measure_batch(task, option): - """Get a standard measure_batch function. + if isinstance(build_func, str): + if build_func == 'default': + build_func = default_build_func + elif build_func == 'ndk': + build_func = android_ndk_build_func + else: + raise ValueError("Invalid build_func" + build_func) + + self.build_func = build_func + self.tmp_dir = tempfile.mkdtemp() + self.executor = LocalExecutor(timeout=timeout) + + def build(self, measure_inputs): + results = [] + + for i in range(0, len(measure_inputs), self.n_parallel): + futures = [] + for inp in measure_inputs[i:i + self.n_parallel]: + ret = self.executor.submit(self.build_func, + inp, + self.tmp_dir, + **self.build_kwargs) + futures.append(ret) + + for future in futures: + res = future.get() + + if isinstance(res, Exception): + # timeout or fleet error, return MeasureResult directly + results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT, + self.timeout, time.time())) + elif res.error is not None: + # instantiation errorD + if isinstance(res.error, InstantiationError): + results.append(MeasureResult((res.error,), + MeasureErrorNo.INSTANTIATION_ERROR, + res.time_cost, time.time())) + else: + if "InstantiationError" in str(res.error): + msg = str(res.error) + try: + msg = msg.split('\n')[-2].split(": ")[1] + except Exception: # pylint: disable=broad-except + pass + results.append(MeasureResult((InstantiationError(msg),), + MeasureErrorNo.INSTANTIATION_ERROR, + res.time_cost, time.time())) + else: # tvm error + results.append(MeasureResult((res.error,), + MeasureErrorNo.COMPILE_HOST, + res.time_cost, time.time())) + else: + # return BuildResult + results.append(res) + + return results + + def __del__(self): + shutil.rmtree(self.tmp_dir) + + +class RPCRunner(Runner): + """Run generated code on remove devices. + This function will ask a RPC Tracker to get device for measurement. Parameters ---------- - task: tvm.autotvm.task.Task - The tuning task - option: dict - The option for measuring generated code. - You should use the return value of function :any:`measure_option` for this argument. - - Returns - ------- - measure_batch: callable - a callback function to measure a batch of configs + timeout: float + The timeout of a compilation + n_parallel: int + The number of tasks run in parallel. "None" will use all cpu cores + key: str + The key of the device registered in the tracker + host: str + The host address of RPC Tracker + port: int + The port of RPC Tracker + number : int, optional + Number of times to do measurement for tasking average + repeat : int, optional + Number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first one is warm up. The returned result contains `repeat` costs, + min_repeat_ms : float, optional + Minimum duration of a timer measurement in milliseconds. + When the run time of a measurement trial falls below this time, the + `number` parameter will be automatically increased. + Set this to improve the accuracy of perf measurement, e.g., when timers + are not precise enough to capture short-running tasks. This parameter is + also critical when devices need a certain minimum running time to "warm + up," such as GPUs that need time to reach a performance power state. + cooldown_interval: float, optional + The cool down interval between two measurements. + check_correctness: bool, optional + Whether check correctness after measurement. This will use llvm cpu target to + call your template and get the reference output. + This can work for TOPI templates, but may not work for your custom template. """ - from ..database import filter_inputs + def __init__(self, + key, host, port, priority=1, + timeout=10, n_parallel=None, + number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, + check_correctness=False): + super(RPCRunner, self).__init__(timeout, n_parallel) - measure_func = option['measure_func'] - number, repeat = option['number'], option['repeat'] - timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork'] - build_func = option['build_func'] - check_correctness = option['check_correctness'] - replay_db = option['replay_db'] + self.key = key + self.host = host + self.port = port + self.priority = priority + self.timeout = timeout - executor = LocalExecutor(timeout=timeout, do_fork=do_fork) + self.number = number + self.repeat = repeat + self.min_repeat_ms = min_repeat_ms + self.cur_number = number + + self.ref_input = None + self.ref_output = None + self.check_correctness = check_correctness + self.cooldown_interval = cooldown_interval + + self.executor = LocalExecutor() + + def set_task(self, task): + self.task = task + self.cur_number = self.number + + if check_remote(task.target, self.key, self.host, self.port): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + if self.check_correctness: + # use llvm cpu to generate a reference input/output + # this option works for tuning topi, but might not work for you custom op + with _target.create("llvm"): + s, arg_bufs = task.instantiate(task.config_space.get(0)) + self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) + for x in arg_bufs] + func = build(s, arg_bufs, "llvm") + tvm_buf = [nd.array(x) for x in self.ref_input] + func(*tvm_buf) + self.ref_output = [x.asnumpy() for x in tvm_buf] + + def get_build_kwargs(self): + kwargs = {} + if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys: + remote = request_remote(self.key, self.host, self.port) + ctx = remote.context(str(self.task.target), 0) + max_dims = ctx.max_thread_dimensions + kwargs['check_gpu'] = { + 'max_shared_memory_per_block': ctx.max_shared_memory_per_block, + 'max_threads_per_block': ctx.max_threads_per_block, + 'max_thread_x': max_dims[0], + 'max_thread_y': max_dims[1], + 'max_thread_z': max_dims[2], + } + + if 'cuda' in self.task.target.keys: + kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) + + return kwargs + + def run(self, measure_inputs, build_results): + results = [] + remote_args = (self.key, self.host, self.port, self.priority, self.timeout) + + for i in range(0, len(measure_inputs), self.n_parallel): + futures = [] + for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel], + build_results[i:i+self.n_parallel]): + ret = self.executor.submit(run_through_rpc, + measure_inp, + build_res, + self.cur_number, + self.repeat, + self.cooldown_interval, + remote_args, + self.ref_input, + self.ref_output) + futures.append(ret) + + for future in futures: + res = future.get() + if isinstance(res, Exception): # executor error or timeout + results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT, + self.timeout, time.time())) + else: + results.append(res) + + # If some runs were too fast, do remeasure for them + # to meet the requirement of `min_repeat_ms` + remeasure = np.zeros((len(measure_inputs),), dtype=np.bool) + pre_number = next_number = self.cur_number + min_repeat_duration = self.min_repeat_ms / 1000.0 + for i, res in enumerate(results): + if res.error_no == MeasureErrorNo.NO_ERROR: + if np.mean(res.costs) * pre_number <= min_repeat_duration: + next_number = max(next_number, + int(np.ceil(min_repeat_duration / np.mean(res.costs)))) + remeasure[i] = True + + if pre_number != next_number: + self.cur_number = next_number + msg = "increasing number to %d" % self.cur_number + logger.info(msg) + + re_measure_inputs = [x for i, x in enumerate(measure_inputs) if remeasure[i]] + re_build_results = [x for i, x in enumerate(build_results) if remeasure[i]] + re_res = self.run(re_measure_inputs, re_build_results) + ct = 0 + for i, rerun in enumerate(remeasure): + if rerun: + results[i] = re_res[ct] + ct += 1 + + return results + +class LocalRunner(RPCRunner): + """Run generated code on local devices. + + Parameters + ---------- + timeout: float + The timeout of a compilation + number : int, optional + Number of times to do measurement for tasking average + repeat : int, optional + Number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first one is warm up. The returned result contains `repeat` costs, + each of which is the average of `number` test run. + min_repeat_ms : float, optional + Minimum duration of a timer measurement in milliseconds. + When the run time of a measurement trial falls below this time, the + `number` parameter will be automatically increased. + Set this to improve the accuracy of perf measurement, e.g., when timers + are not precise enough to capture short-running tasks. This parameter is + also critical when devices need a certain minimum running time to "warm + up," such as GPUs that need time to reach a performance power state. + cooldown_interval: float, optional + The cool down interval between two measurements. + check_correctness: bool, optional + Whether check correctness after measurement. This will use llvm cpu target to + call your template and get the reference output. + This can work for TOPI templates, but may not work for your custom template. + + Note + ---- + This is a "fake" local mode. We start a silent rpc tracker and rpc server + for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure. + """ + def __init__(self, + timeout=10, + number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, + check_correctness=False): + super(LocalRunner, self).__init__('', None, None, 0, + timeout=timeout, n_parallel=1, + number=number, repeat=repeat, + min_repeat_ms=min_repeat_ms, + cooldown_interval=cooldown_interval, + check_correctness=check_correctness) + self.tracker = None + self.server = None + + def set_task(self, task): + self.task = task - # convert convenient string to function object - attach_objects = None - if measure_func == 'local': - # start temporary rpc tracker and rpc server for the user from ...rpc.tracker import Tracker from ...rpc.server import Server @@ -133,360 +343,215 @@ def create_measure_batch(task, option): key=device_key, use_popen=True, silent=True, tracker_addr=(tracker.host, tracker.port)) + self.key = device_key + self.host = tracker.host + self.port = tracker.port - measure_func = rpc(device_key, tracker.host, tracker.port) - attach_objects = (server, tracker) - - build_kwargs = {} - if build_func == 'default': - build_func = default_build_func - if build_func == 'ndk': - build_func = default_build_func - build_kwargs['use_ndk'] = True - - # check the availability of remote devices - if hasattr(measure_func, 'rpc_info'): - rpc_info = measure_func.rpc_info - if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])): - logger.info("Get devices for measurement successfully!") - else: - raise RuntimeError("Cannot get remote devices from the tracker. " - "Please check the status of tracker by " - "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " - "and make sure you have free devices on the queue status.") - - # add device info of cuda and opencl target - if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \ - and hasattr(measure_func, 'rpc_info'): - rpc_info = measure_func.rpc_info - add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]), - build_kwargs) - - if check_correctness: - # use llvm cpu to generate a reference input/output - # this option works for tuning topi, but might not work for you custom op - with _target.create("llvm"): - s, arg_bufs = task.instantiate(task.config_space.get(0)) - ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) - for x in arg_bufs] - func = build(s, arg_bufs, "llvm") - tvm_buf = [nd.array(x) for x in ref_input] - func(*tvm_buf) - ref_output = [x.asnumpy() for x in tvm_buf] - else: - ref_input = ref_output = None - - def measure_batch(measure_inputs): - """measure the time cost for a batch of configs in real machines""" - if replay_db is not None: - partial_results, measure_inputs = \ - filter_inputs(replay_db, measure_inputs, retry=False) - - # launch measure jobs in parallel - pack_size = getattr(measure_func, "pack_size", 1) # measure `pack_size` inputs in one job - futures = [] - for i in range(0, len(measure_inputs), pack_size): - input_pack = measure_inputs[i:i + pack_size] - ret = executor.submit( - measure_func, - input_pack, - build_func, - build_kwargs, - number, - repeat, - ref_input, - ref_output) - futures.append(ret) - - # transform results - results = [] - for future in futures: - result = future.get() - if isinstance(result, Exception): - tstamp = time.time() - results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR, - timeout, tstamp)] * pack_size) - else: - results.extend(result) - - if replay_db is not None: - result_idx = 0 - for i in range(len(partial_results)): - if partial_results[i] is None: - partial_results[i] = results[result_idx] - result_idx += 1 - return partial_results - return results - - measure_batch.n_parallel = n_parallel - # attach server and tracker object to avoid them of being garbage-collected - measure_batch.attach_objects = attach_objects - return measure_batch + super(LocalRunner, self).set_task(task) + return server, tracker -def rpc(key, - host=None, - port=None, - priority=1, - session_timeout=60, - pack_size=1): +def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): + """Common part for building a configuration""" + target, task, config = measure_input + + with target: + s, args = task.instantiate(config) + + # check invalidity of template and code hash consistency + if not config.valid(): + raise InstantiationError(config.errors) + + opts = build_option or {} + if check_gpu: # Add verify pass to filter out invalid configs in advance. + opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] + if cuda_arch: + set_cuda_target_arch(cuda_arch) + + with build_config(**opts): + func = build(s, args, target_host=task.target_host) + return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) + + +def default_build_func(measure_input, tmp_dir, **kwargs): """ - Create a standard measure_func which uses RPC Tracker for measurement. - This measure_func will request a device from the RPC Tracker and - upload the built binary library to that device for measurement. + Default build func. This can work for cuda, opencl, llvm backend Parameters ---------- - key: str - The registered key of the device in tracker. The tuner will request devices for - measurement by this key. - host: str, optional - The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - Priority of this task, used by scheduler in tracker - session_timeout: int, optional - Timeout of rpc session - pack_size: int, optional - The number of configs measure in one RPC session. - Usually this can be set to 1. If your device has high overhead to establish a - rpc connection, set this higher. + measure_input: MeasureInput + The input of measurement + tmp_dir: str + The path of temporary directory to export generated library """ - def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output): - """Do measurement for a list of inputs inside a same RPC session. - - Parameters - ---------- - input_pack: List of MeasureInput - The inputs of measurement - build_func: callable - Function for building the code. see :any:`default_build_func` for example - build_kwargs: dict - Extra arguments for build_func - number : int, optional - Number of times to do the measurement for average - repeat : int, optional - Number of times to repeat the measurement. - In total, the generated code will be run (1 + number x repeat) times, - where the first one is warm up. The returned result contains `repeat` costs, - each of which is the average of `number` test run. - ref_input: List of numpy array - Reference input for correctness check - ref_output: List of numpy array - Reference output for correctness check - - Returns - ------- - results: List of MeasureResult - The results for input_pack - """ - remote_args = (key, (host, port), priority, session_timeout) - - res = _measure_common(input_pack, build_func, build_kwargs, number, repeat, - ref_input, ref_output, - remote_args) - return res - - fmeasure.pack_size = pack_size - fmeasure.rpc_info = {"key": key, "host": host, "port": port} - return fmeasure + tic = time.time() + try: + filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) + func, arg_info = _build_func_common(measure_input, **kwargs) + func.export_library(filename) + except Exception as e: # pylint: disable=broad-except + return BuildResult(None, None, e, time.time() - tic) + return BuildResult(filename, arg_info, None, time.time() - tic) -def _measure_common(input_pack, build_func, build_kwargs, number, repeat, - ref_input=None, ref_output=None, remote_args=None): - """Measure the time cost for a pack of inputs. - - (Note: A pack is a list of inputs which will be measured inside a same RPC session) +def android_ndk_build_func(measure_input, tmp_dir, **kwargs): + """ + Build function for android device using ndk. Parameters ---------- - input_pack : list of MeasureInput - The inputs we need to evaluate - build_func : function takes MeasureInput returns tuple of (time_func, ctx, args) - The build function used to build each input. - build_kwargs: Dict - The extra keyword arguments to build_func + measure_input: MeasureInput + The input of measurement + tmp_dir: str + The path of temporary directory to export generated library + """ + tic = time.time() + try: + filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64)) + func, arg_info = _build_func_common(measure_input, **kwargs) + func.export_library(filename, ndk.create_shared) + except Exception as e: # pylint: disable=broad-except + return BuildResult(None, None, e, time.time() - tic) + return BuildResult(filename, arg_info, None, time.time() - tic) + + +def run_through_rpc(measure_input, build_result, + number, repeat, cooldown_interval, + remote_args, ref_input=None, ref_output=None): + """Run a generated library through rpc + + Parameters + ---------- + measure_input: MeasureInput + The raw measure input + build_result: BuildResult + The result returned from Builder. This contains the path to the generated library. number : int, optional - Number of times to do the measurement for average + Number of times to do measurement for tasking average repeat : int, optional Number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first one is warm up. The returned result contains `repeat` costs, each of which is the average of `number` test run. - ref_input: Array of np.ndarray, optional - Reference input for checking correctness - ref_output: Array of np.ndarray, optional - Reference output for checking correctness - remote_args: Tuple, optional - The arguments to request_remote. If is not None, will use remote rpc devices. - - Returns - ------- - res_pack : Array of MeasureResult - The list of results of measurement. + cooldown_interval: float + The cool down interval between two measurements + remote_args: Tuple + The argument for request_remote + ref_input: List of np.ndarray + The reference input used for checking correctness + ref_output: List of np.ndarray + The reference output used for checking correctness """ - res_pack = [] - tmp_dir = util.tempdir() if remote_args else None - assert len(input_pack) == 1, "Only supports input_pack == 1 for now" + if isinstance(build_result, MeasureResult): + return build_result - for inp in input_pack: - tic = time.time() + tic = time.time() + errno = MeasureErrorNo.NO_ERROR + try: + # upload built module + remote = request_remote(*remote_args) + remote.upload(build_result.filename) + func = remote.load_module(os.path.split(build_result.filename)[1]) + ctx = remote.context(str(measure_input.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat) - # build function - try: - func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs) - except TVMError as exc: - tstamp = time.time() - msg = str(exc) - if "Stack trace returned" in msg: - msg = msg[:msg.index("Stack trace returned")] - if "InstantiationError" in msg: - try: - msg = msg.split('\n')[-2].split(": ")[1] - except Exception: # pylint: disable=broad-except - pass - res_pack.append(MeasureResult((InstantiationError(msg),), - MeasureErrorNo.INSTANTIATION_ERROR, - tstamp - tic, tstamp)) - else: - res_pack.append(MeasureResult((RuntimeError(msg),), - MeasureErrorNo.COMPILE_HOST, - tstamp - tic, tstamp)) - continue - except InstantiationError as e: - tstamp = time.time() - res_pack.append(MeasureResult((InstantiationError(str(e)),), - MeasureErrorNo.INSTANTIATION_ERROR, - tstamp - tic, tstamp)) - continue + # set input + if ref_input: + args = [nd.array(x, ctx=ctx) for x in ref_input] + else: + args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info] - # measure time - errno = MeasureErrorNo.NO_ERROR - try: - # upload built module - if remote_args: - remote = request_remote(*remote_args) - remote.upload(tmp_dir.relpath(filename)) - func = remote.load_module(filename) - ctx = remote.context(str(inp.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat) - else: - ctx = context(str(inp.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat) + costs = time_f(*args).results + if len(costs) > 2: # remove largest and smallest value to reduce variance + costs = list(costs) + costs.sort() + costs = tuple(costs[1:-1]) - # set input - if ref_input: - args = [nd.array(x, ctx=ctx) for x in ref_input] - else: - args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx) - for x in arg_bufs] - - costs = time_f(*args).results - if len(costs) > 2: # remove largest and smallest value to reduce variance - costs = list(costs) - costs.sort() - costs = tuple(costs[1:-1]) - - # check correctness of output - if ref_output: - for expected, real in zip(ref_output, args): - if not np.allclose(expected, real.asnumpy(), rtol=1e-4): - logger.warning("Wrong Answer!") - errno = MeasureErrorNo.WRONG_ANSWER - except TVMError as exc: - msg = str(exc) - if "Stack trace returned" in msg: - msg = msg[:msg.index("Stack trace returned")] - if "CUDA Source" in msg: - msg = msg[:msg.index("CUDA Source")] - costs = (RuntimeError(msg),) - errno = MeasureErrorNo.RUNTIME_DEVICE - tstamp = time.time() - res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp)) - return res_pack + # check correctness of output + if ref_output: + for expected, real in zip(ref_output, args): + if not np.allclose(expected, real.asnumpy(), rtol=1e-4): + logger.warning("Wrong Answer!") + errno = MeasureErrorNo.WRONG_ANSWER + except TVMError as exc: + msg = str(exc) + if "Stack trace returned" in msg: + msg = msg[:msg.index("Stack trace returned")] + if "CUDA Source" in msg: + msg = msg[:msg.index("CUDA Source")] + costs = (RuntimeError(msg[:1024]),) + errno = MeasureErrorNo.RUNTIME_DEVICE + tstamp = time.time() + time.sleep(cooldown_interval) + return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp) -def default_build_func(inp, tmp_dir=None, **kwargs): - """Build function module. Exception will be raised when any error occurs +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """Request a remote session Parameters ---------- - inp: MeasureInput - The input of this measurement - tmp_dir: tvm.contrib.util.TempDirectory, optional - The temporary directory for exporting built binary library. - If is not None (in RPC mode), the library in this directory will be uploaded to - remote devices. - kwargs: Dict, optional - Other extra arguments + device_key: string + The device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this session (units: second) + + Returns + ------ + session: RPCSession + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = _rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + target: Target + The wanted compilation target + device_key: string + device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). Returns ------- - func: Function - TVM built function. Typically this is the return value of tvm.build. - args: Array of Buffer or Tensor - The argument list for the function. Typically this is the second argument of tvm.build. - filename: str - The filename of the output build library + available: bool + True if can find available device """ - # build function - with inp.target: - s, args = inp.task.instantiate(inp.config) - - # check invalidity of template and code hash consistency - if not inp.config.valid(): - raise InstantiationError(inp.config.errors) - code_hash = getattr(s, 'code_hash', None) - if inp.config.code_hash != code_hash: - raise HashMismatchError('got {0:s}, expected {1:s}' - .format(str(inp.config.code_hash), str(code_hash))) - - opts = {} - if "check_gpu" in kwargs: # Add verify pass to filter out invalid configs in advance. - opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))] - if 'cuda_arch' in kwargs: - set_cuda_target_arch(kwargs['cuda_arch']) - - with build_config(**opts): - func = build(s, args, target_host=inp.task.target_host) - - # export library to temp directory - if tmp_dir: - if kwargs.get('use_ndk', False): # for Android NDK - filename = "tmp_func_%0x.so" % getrandbits(64) - func.export_library(tmp_dir.relpath(filename), ndk.create_shared) - else: - filename = "tmp_func_%0x.tar" % getrandbits(64) - func.export_library(tmp_dir.relpath(filename)) - else: - filename = None - - return func, args, filename - - -def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs): - """Add device info for gpu target. - The info will be used to check the validity of generated code.""" - remote = request_remote(device_key, rpc_tracker_addr) - ctx = remote.context(str(target), 0) - max_dims = ctx.max_thread_dimensions - kwargs['check_gpu'] = { - 'max_shared_memory_per_block': ctx.max_shared_memory_per_block, - 'max_threads_per_block': ctx.max_threads_per_block, - 'max_thread_x': max_dims[0], - 'max_thread_y': max_dims[1], - 'max_thread_z': max_dims[2], - } - - if 'cuda' in target.keys: - kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) - -def set_cuda_target_arch(arch): - """set target architecture of nvcc compiler""" - AutotvmGlobalScope.current.cuda_target_arch = arch + def _check(): + remote = request_remote(device_key, host, port, priority) + remote.context(str(target)) + t = threading.Thread(target=_check,) + t.start() + t.join(timeout) + return not t.is_alive() @register_func @@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code): return ptx +def set_cuda_target_arch(arch): + """set target architecture of nvcc compiler + + Parameters + ---------- + arch: str + The argument of nvcc -arch. (e.g. "sm_51", "sm_62") + """ + AutotvmGlobalScope.current.cuda_target_arch = arch + + def gpu_verify_pass(**kwargs): """Verify the validity of a gpu kernel. This pass will check memory usage and number of threads per block. diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py index b9d900e4..1afaca73 100644 --- a/python/tvm/autotvm/tuner/ga_tuner.py +++ b/python/tvm/autotvm/tuner/ga_tuner.py @@ -22,7 +22,7 @@ class GATuner(Tuner): mutation_prob: float probability of mutation of a knob in a gene """ - def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1): + def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1): super(GATuner, self).__init__(task) # algorithm configurations diff --git a/python/tvm/autotvm/tuner/sa_model_optimizer.py b/python/tvm/autotvm/tuner/sa_model_optimizer.py index 1947c6dd..77c7e919 100644 --- a/python/tvm/autotvm/tuner/sa_model_optimizer.py +++ b/python/tvm/autotvm/tuner/sa_model_optimizer.py @@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): new_scores = model.predict(new_points) - ac_prob = np.exp((new_scores - scores) / (t + 1e-2)) + ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1)) ac_index = np.random.random(len(ac_prob)) < ac_prob points[ac_index] = new_points[ac_index] diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 87da86a4..8e1b458a 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): target=target, target_host=target_host) return task, target - -def test_task_tuner_without_measurement(): - """test task and tuner without measurement""" - task, target = get_sample_task() - - def custom_measure(input_pack, build_func, build_args, number, repeat, - ref_input, ref_output): - from tvm.autotvm import MeasureResult - - results = [] - for inp in input_pack: - tic = time.time() - # do nothing - time.sleep(0.001) - results.append(MeasureResult([time.time() - tic], 0, - time.time() - tic, time.time())) - return results - measure_option = autotvm.measure_option(custom_measure) - - logging.info("%s", task.config_space) - - # new tuner and recorder - for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]: - tuner = tuner_class(task) - tuner.tune(n_trial=10, measure_option=measure_option) - assert tuner.best_flops > 1 - -def test_tuning_with_measure(): +def test_tuning(): def check(target, target_host): ctx = tvm.context(target, 0) if not ctx.exist: @@ -141,12 +114,12 @@ def test_tuning_with_measure(): task, target = get_sample_task(target, target_host) logging.info("%s", task.config_space) - measure_option = autotvm.measure_option('local', - timeout=4, - number=2) + measure_option = autotvm.measure_option( + autotvm.LocalBuilder(), + autotvm.LocalRunner()) tuner = RandomTuner(task) - tuner.tune(n_trial=10, measure_option=measure_option) + tuner.tune(n_trial=20, measure_option=measure_option) check("cuda", None) check("opencl", None) @@ -155,6 +128,4 @@ if __name__ == "__main__": # only print log when invoked from main logging.basicConfig(level=logging.DEBUG) - test_task_tuner_without_measurement() - test_tuning_with_measure() - + test_tuning() diff --git a/tests/python/unittest/test_autotvm_common.py b/tests/python/unittest/test_autotvm_common.py index 3a6883f6..ed39c384 100644 --- a/tests/python/unittest/test_autotvm_common.py +++ b/tests/python/unittest/test_autotvm_common.py @@ -32,6 +32,25 @@ def matmul(N, L, M, dtype): return s, [A, B, C] +@autotvm.template +def bad_matmul(N, L, M, dtype): + if 'bad_device' in tvm.target.current_target().keys: + A = tvm.placeholder((N, L), name='A', dtype=dtype) + B = tvm.placeholder((L, M), name='B', dtype=dtype) + + k = tvm.reduce_axis((0, L-1), name='k') + C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C') + s = tvm.create_schedule(C.op) + + # schedule + y, x = s[C].op.axis + cfg = autotvm.get_config() + cfg.define_split("tile_y", y, num_outputs=2) + cfg.define_split("tile_x", x, num_outputs=2) + return s, [A, B, C] + + return matmul(N, L, M, dtype) + def get_sample_task(n=128): """return a sample task for testing""" target = tvm.target.create("llvm") diff --git a/tests/python/unittest/test_autotvm_database.py b/tests/python/unittest/test_autotvm_database.py index af4704d9..aa956f61 100644 --- a/tests/python/unittest/test_autotvm_database.py +++ b/tests/python/unittest/test_autotvm_database.py @@ -1,17 +1,11 @@ """Test database""" import copy import logging -import time -import numpy as np -import tvm - -from tvm import autotvm from tvm.autotvm import database -from tvm.autotvm.measure.measure_methods import HashMismatchError -from tvm.autotvm.record import encode, MeasureInput, MeasureResult +from tvm.autotvm.record import encode, MeasureResult -from test_autotvm_common import get_sample_task, get_sample_records +from test_autotvm_common import get_sample_records def test_save_load(): logging.info("test basic db load/save ...") @@ -35,66 +29,6 @@ def test_save_load(): TRIAL_LIMIT = 2 -def test_db_filter(): - logging.info("test db filter ...") - - # Pick a GPU target because there are more likely to be failures/invalid configs - task, target = get_sample_task() - - ctx = tvm.context(str(target)) - if not ctx.exist: - logging.warning("Skip this test because there is no supported device for test") - - batch_size = 2 - - measure_option = autotvm.measure_option('local', do_fork=False, timeout=2) - measure_batch = autotvm.measure.create_measure_batch(task, measure_option) - - ct = 0 - all_inputs = list() - all_results = list() - batches = list() - tuner = autotvm.tuner.RandomTuner(task) - while ct < TRIAL_LIMIT: - inputs = list() - for i in range(batch_size): - cfg = tuner.next_batch(1)[0] - inputs.append((MeasureInput(target, task, cfg))) - all_inputs.append(inputs[-1]) - batches.append(inputs) - results = measure_batch(inputs) - all_results += results - ct += 1 - - del measure_batch - - db = database.DummyDatabase() - db.flush() - - # First setting, memoize one input at a time, check that each is saved and replayed - measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db) - measure_batch = autotvm.measure.create_measure_batch(task, measure_option) - - for i in range(len(all_inputs)+1): - db.flush() - for j in range(i): - db.save(all_inputs[j], all_results[j]) - - for k in range(len(batches)): - batch = batches[k] - batch_result = measure_batch(batch) - for l in range(batch_size): - all_idx = k*batch_size + l - assert batch_result[l] is not None - if all_idx < i: - assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \ - "(no retry) EXPECTED MATCH, GOT MISMATCH" - else: - assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \ - "(no retry) EXPECTED MISMATCH, GOT MATCH" - - del measure_batch - def test_db_hash(): logging.info("test db hash check ...") inp1, res1 = get_sample_records(1)[0] @@ -149,89 +83,8 @@ def test_db_latest_all(): assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[2]) == encode(inp1, res3) -def test_db_save_replay(): - logging.info("test db save (from measure_batch) and replay ...") - _db = database.DummyDatabase() - _db.flush() - - task, target = get_sample_task() - - ctx = tvm.context(str(target)) - if not ctx.exist: - logging.warning("Skip this test because there is no supported device for test") - - measure_option = autotvm.measure_option('local', - do_fork=False, - timeout=2, - replay_db=_db) - measure_batch = autotvm.measure.create_measure_batch(task, measure_option) - - batch_size = 2 - - ct = 0 - all_inputs = list() - all_results = list() - batches = list() - tuner = autotvm.tuner.RandomTuner(task) - while ct < TRIAL_LIMIT: - inputs = list() - for i in range(batch_size): - cfg = tuner.next_batch(1)[0] - inputs.append((MeasureInput(target, task, cfg))) - all_inputs.append(inputs[-1]) - batches.append(inputs) - results = measure_batch(inputs) - all_results += results - ct += 1 - callback = autotvm.callback.log_to_database(_db) - callback(None, all_inputs, all_results) - - assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \ - "%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT) - - all_results_2 = measure_batch(all_inputs) - all_results_3 = measure_batch(all_inputs) - - for i in range(len(all_results)): - encr1 = encode(all_inputs[i], all_results[i]) - encr2 = encode(all_inputs[i], all_results_2[i]) - encr3 = encode(all_inputs[i], all_results_3[i]) - assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH" - assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH" - - del measure_batch - -def test_check_hashmismatch(): - logging.info("test hash mismatch check") - - task, target = get_sample_task() - - ctx = tvm.context(str(target)) - if not ctx.exist: - logging.warning("Skip this test because there is no supported device for test") - - measure_option = autotvm.measure_option('local', do_fork=False) - measure_batch = autotvm.measure.create_measure_batch(task, measure_option) - - inputs = list() - cfg = task.config_space.get(np.random.randint(len(task.config_space))) - # notvalidh is not a valid CRC32 hash (not hex) - cfg.code_hash = 'notvalidh' - inputs.append((MeasureInput(target, task, cfg))) - - try: - results = measure_batch(inputs) - assert False, "HashMismatchError should be raised" - except HashMismatchError: - pass - - del measure_batch - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) test_save_load() - test_db_filter() test_db_hash() test_db_latest_all() - test_db_save_replay() - test_check_hashmismatch() diff --git a/tests/python/unittest/test_autotvm_measure.py b/tests/python/unittest/test_autotvm_measure.py new file mode 100644 index 00000000..e29cc2c5 --- /dev/null +++ b/tests/python/unittest/test_autotvm_measure.py @@ -0,0 +1,97 @@ +"""Test builder and runner""" +import logging +import time + +import numpy as np + +import tvm +from tvm import autotvm +from test_autotvm_common import get_sample_task, bad_matmul +from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo + +def test_task_tuner_without_measurement(): + """test task and tuner without measurement""" + task, target = get_sample_task() + + class DummyRunner(Runner): + def __init__(self): + super(DummyRunner, self).__init__(1, 1) + + def run(self, measure_inputs, build_results): + return [MeasureResult((np.random.random(),), 0, 0.2, time.time()) + for _ in range(len(measure_inputs))] + + def get_build_kwargs(self): + return {} + + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=DummyRunner() + ) + + logging.info("%s", task.config_space) + + for tuner_class in [autotvm.tuner.RandomTuner, + autotvm.tuner.GridSearchTuner, + autotvm.tuner.GATuner, + autotvm.tuner.XGBTuner]: + tuner = tuner_class(task) + tuner.tune(n_trial=10, measure_option=measure_option) + assert tuner.best_flops > 1 + +def test_check_correctness(): + task, target = get_sample_task() + + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.LocalRunner(check_correctness=True) + ) + + def _callback_correct(tuner, measure_inputs, measure_results): + for inp, res in zip(measure_inputs, measure_results): + assert res.error_no == 0 + + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune(n_trial=2, measure_option=measure_option, + callbacks=[_callback_correct]) + + # a bad template + n = 128 + target = tvm.target.create("llvm -device=bad_device") + task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target) + + def _callback_wrong(tuner, measure_inputs, measure_results): + for inp, res in zip(measure_inputs, measure_results): + assert res.error_no == MeasureErrorNo.WRONG_ANSWER + + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune(n_trial=2, measure_option=measure_option, + callbacks=[_callback_wrong]) + + +def test_min_repeat_ms(): + task, target = get_sample_task() + + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.LocalRunner(number=1, min_repeat_ms=100) + ) + + def _callback(tuner, measure_inputs, measure_results): + for inp, res in zip(measure_inputs, measure_results): + if res.error_no != 0: + continue + + assert 1000 * np.mean(res.costs) * \ + measure_option['runner'].cur_number >= 100 + + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune(n_trial=5, measure_option=measure_option, + callbacks=[_callback]) + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + test_task_tuner_without_measurement() + test_check_correctness() + test_min_repeat_ms() diff --git a/topi/recipe/gemm/gemm_int8.py b/topi/recipe/gemm/gemm_int8.py index 61ef97d0..4cce2735 100644 --- a/topi/recipe/gemm/gemm_int8.py +++ b/topi/recipe/gemm/gemm_int8.py @@ -137,12 +137,15 @@ if __name__ == '__main__': print(task.config_space) measure_option = autotvm.measure_option( - measure_func='local', number=10, n_parallel=8, timeout=20) + builder=autotvm.LocalBuilder(), + runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4) + ) + log_name = 'gemm_int8.log' if DO_TUNING: tuner = autotvm.tuner.XGBTuner(task) tuner.tune(n_trial=1000, measure_option=measure_option, - callbacks=[autotvm.callback.log_to_file(log_name)]) + callbacks=[autotvm.callback.log_to_file(log_name)]) dispatch_context = autotvm.apply_history_best(log_name) best_config = dispatch_context.query(task.target, task.workload) diff --git a/tutorials/autotvm/tune_conv2d_cuda.py b/tutorials/autotvm/tune_conv2d_cuda.py index 3ff26a05..3cd63d03 100644 --- a/tutorials/autotvm/tune_conv2d_cuda.py +++ b/tutorials/autotvm/tune_conv2d_cuda.py @@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching, target='cuda') print(task.config_space) -# use local gpu, measure 5 times for every config to reduce variance -# run 8 parallel threads for compilation -measure_option = autotvm.measure_option('local', - number=5, - n_parallel=8, - timeout=20) +# use local gpu, measure 10 times for every config to reduce variance +# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds +measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4) +) # begin tuning, log records to file `conv2d.log` tuner = autotvm.tuner.XGBTuner(task) diff --git a/tutorials/autotvm/tune_nnvm_arm.py b/tutorials/autotvm/tune_nnvm_arm.py index a080681f..8ab7bb2f 100644 --- a/tutorials/autotvm/tune_nnvm_arm.py +++ b/tutorials/autotvm/tune_nnvm_arm.py @@ -65,15 +65,20 @@ def get_network(name, batch_size): input_shape = (batch_size, 3, 224, 224) output_shape = (batch_size, 1000) - if name =='resnet-18': - net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) - elif name =='mobilenet': + if "resnet" in name: + n_layer = int(name.split('-')[1]) + net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size) + elif "vgg" in name: + n_layer = int(name.split('-')[1]) + net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size) + elif name == 'mobilenet': net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) - elif name =='squeezenet v1.1': + elif name == 'squeezenet_v1.1': net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') - elif name =='vgg-16': - net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size) - elif name =='custom': + elif name == 'inception_v3': + input_shape = (1, 3, 299, 299) + net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size) + elif name == 'custom': # an example for custom network from nnvm.testing import utils net = nnvm.sym.Variable('data') @@ -92,6 +97,7 @@ def get_network(name, batch_size): return net, params, input_shape, output_shape + ################################################################# # Start RPC Tracker # ----------------- @@ -158,6 +164,8 @@ def get_network(name, batch_size): # rk3399 2 2 0 # rpi3b 11 11 0 # ---------------------------------- +# +# You can register multiple devices to the tracker to accelerate the measurement in tuning. ########################################### # Set Tuning Options @@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network) dtype = 'float32' tuning_option = { - 'log_filename': log_file, + 'log_filename': log_file, - 'tuner': 'xgb', - 'n_trial': 1000, - 'early_stopping': 250, + 'tuner': 'xgb', + 'n_trial': 1000, + 'early_stopping': 400, - 'measure_option': autotvm.measure_option( - autotvm.measure.rpc(device_key, host='localhost', port=9190), - number=4, - n_parallel=1, - timeout=10, - build_func='ndk' if use_android else 'default', - ), + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder( + build_func='ndk' if use_android else 'default'), + runner=autotvm.RPCRunner( + device_key, host='localhost', port=9190, + number=5, + timeout=4, + ), + ), } #################################################################### # # .. note:: How to set tuning options # -# In general, the default value provided here works well. It is the same -# value that we used to generate pre-tuned parameters. -# If you have multiple devices, you can set :code:`n_parallel` to -# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399 -# boards to the tracker). +# In general, the default value provided here works well. # If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger, # which makes the tuning run longer. -# If your device is very slow or a single conv2d operator in your network has large FLOPs, -# consider setting timeout larger. # ################################################################### @@ -219,7 +223,7 @@ tuning_option = { # ------------ # Now we can extract tuning tasks from the network and begin tuning. # Here we provide a simple utility function to tune a list of tasks. -# This function is just an initial implementation which tune them in sequential order. +# This function is just an initial implementation which tunes them in sequential order. # Later we will bring more sophisticated tuner scheduler. # You can skip the implementation of this function for this tutorial. @@ -236,7 +240,9 @@ def tune_tasks(tasks, try: # try winograd template tsk = autotvm.task.create(tasks[i].name, tasks[i].args, tasks[i].target, tasks[i].target_host, 'winograd') - tasks.append(tsk) + input_channel = tsk.workload[1][1] + if input_channel >= 64: + tasks[i] = tsk except Exception: pass @@ -245,8 +251,8 @@ def tune_tasks(tasks, if os.path.exists(tmp_log_file): os.remove(tmp_log_file) - for i, tsk in enumerate(tasks): - prefix = "[Task %2d/%2d] " %(i+1, len(tasks)) + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) # create tuner if tuner == 'xgb' or tuner == 'xgb-rank': @@ -280,7 +286,7 @@ def tune_tasks(tasks, ######################################################################## # Finally we launch tuning jobs and evaluate the end-to-end performance. -def tune_and_evaluate(): +def tune_and_evaluate(tuning_opt): # extract workloads from nnvm graph print("Extract tasks...") net, params, input_shape, out_shape = get_network(network, batch_size=1) @@ -290,19 +296,18 @@ def tune_and_evaluate(): # run tuning tasks print("Tuning...") - tune_tasks(tasks, **tuning_option) + tune_tasks(tasks, **tuning_opt) # compile kernels with history best records with autotvm.apply_history_best(log_file): print("Compile...") with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): graph, lib, params = nnvm.compiler.build( - net, target=target, - shape={'data': input_shape}, params=params, dtype=dtype) + net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) # export library tmp = tempdir() - if tuning_option['measure_option']['build_func'] == 'ndk': # for android + if use_android: from tvm.contrib import ndk filename = "net.so" lib.export_library(tmp.relpath(filename), ndk.create_shared) @@ -312,8 +317,7 @@ def tune_and_evaluate(): # upload module to device print("Upload...") - remote = autotvm.measure.request_remote(device_key, - tracker_addr=('localhost', 9190), + remote = autotvm.measure.request_remote(device_key, 'localhost', 9190, timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) @@ -328,47 +332,44 @@ def tune_and_evaluate(): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + ftimer = module.module.time_evaluator("run", ctx, number=8, repeat=3) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run by yourself. -# tune_and_evaluate() + +# tune_and_evaluate(tuning_option) ###################################################################### # Sample Output # ------------- -# The tuning needs to train xgboost models and use them for prediction. +# The tuning needs to compile many programs and extract feature from them. # So a high performance CPU is recommended. -# It takes about 2 hours on a 32T AMD Ryzen CPU. -# One sample output is +# One sample output is listed below. +# It takes about 2 hours on a 32T AMD Ryzen Threadripper. # # .. code-block:: bash # # Extract tasks... # Tuning... -# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done. -# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done. -# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done. -# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done. -# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done. -# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done. -# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done. -# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done. -# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done. -# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done. -# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done. -# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done. -# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done. -# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done. -# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done. -# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done. +# [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done. +# [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done. +# [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done. +# [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done. +# [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done. +# [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done. +# [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done. +# [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done. +# [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done. +# [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done. +# [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done. +# [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done. # Compile... # Upload... # Evaluate inference time cost... -# Mean inference time (std dev): 157.29 ms (1.74 ms) +# Mean inference time (std dev): 162.59 ms (0.06 ms) ###################################################################### # diff --git a/tutorials/autotvm/tune_simple_template.py b/tutorials/autotvm/tune_simple_template.py index 8d4aab0b..5b3ddaaf 100644 --- a/tutorials/autotvm/tune_simple_template.py +++ b/tutorials/autotvm/tune_simple_template.py @@ -271,9 +271,12 @@ print(task.config_space) logging.getLogger('autotvm').setLevel(logging.DEBUG) logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) -# use local cpu, measure 5 times for every config to reduce variance -measure_option = autotvm.measure_option('local', - number=5) +# There are two steps for measuring a config: build and run. +# By default, we use all cpu cores to compile program. Then measure them sequentially. +# We measure 5 times and take average to reduce variance. +measure_option = autotvm.measure_option( + builder='local', + runner=autotvm.LocalRunner(number=5)) # begin tuning, log records to file `matmul.log` tuner = autotvm.tuner.RandomTuner(task)