[AUTOTVM] Core part of auto-tuning module (#1312)
This commit is contained in:
Родитель
7e7154f170
Коммит
6ea74d4119
|
@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
|
|||
file(GLOB COMPILER_SRCS
|
||||
src/api/*.cc
|
||||
src/arithmetic/*.cc
|
||||
src/autotvm/*.cc
|
||||
src/codegen/*.cc
|
||||
src/codegen/stack_vm/*.cc
|
||||
src/lang/*.cc
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
tvm.autotvm
|
||||
-----------
|
||||
.. automodule:: tvm.autotvm
|
||||
|
||||
tvm.autotvm.measure
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.autotvm.measure.measure
|
||||
|
||||
.. autoclass:: tvm.autotvm.measure.MeasureInput
|
||||
:members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.measure.MeasureResult
|
||||
:members:
|
||||
|
||||
.. autofunction:: tvm.autotvm.measure.measure_option
|
||||
|
||||
.. autofunction:: tvm.autotvm.measure.create_measure_batch
|
||||
|
||||
|
||||
tvm.autotvm.tuner
|
||||
~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.autotvm.tuner
|
||||
:members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.tuner.Tuner
|
||||
:members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.tuner.RandomTuner
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.tuner.GridSearchTuner
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.tuner.GATuner
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: tvm.autotvm.tuner.XGBTuner
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. automodule:: tvm.autotvm.tuner.callback
|
||||
:members:
|
||||
|
||||
tvm.autotvm.task
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.autotvm.task
|
||||
:members:
|
||||
|
||||
.. automodule:: tvm.autotvm.task.task
|
||||
:members:
|
||||
|
||||
.. automodule:: tvm.autotvm.task.space
|
||||
:members:
|
||||
|
||||
tvm.autotvm.record
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.autotvm.record
|
||||
:members:
|
|
@ -14,6 +14,7 @@ Python API
|
|||
ndarray
|
||||
container
|
||||
function
|
||||
autotvm
|
||||
graph_runtime
|
||||
rpc
|
||||
bridge
|
||||
|
|
|
@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"]
|
|||
subsection_order = ExplicitOrder(
|
||||
['../tutorials/language',
|
||||
'../tutorials/optimize',
|
||||
'../tutorials/autotvm',
|
||||
'../tutorials/vta',
|
||||
'../tutorials/topi',
|
||||
'../tutorials/deployment',
|
||||
|
|
|
@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type);
|
|||
*
|
||||
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
|
||||
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
|
||||
* "max_thread_per_block": Maximum number of threads per block.
|
||||
* "max_threads_per_block": Maximum number of threads per block.
|
||||
* "max_thread_x": Maximum length of threadIdx.x.
|
||||
* "max_thread_y": Maximum length of threadIdx.y.
|
||||
* "max_thread_z": Maximum length of threadIdx.z.
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
"""The auto-tuning module of tvm
|
||||
|
||||
This module includes:
|
||||
|
||||
* Tuning space definition API
|
||||
|
||||
* Efficient auto-tuners
|
||||
|
||||
* Tuning result and database support
|
||||
|
||||
* Distributed measurement to scale up tuning
|
||||
"""
|
||||
|
||||
from . import database
|
||||
from . import feature
|
||||
from . import measure
|
||||
from . import record
|
||||
from . import task
|
||||
from . import tuner
|
||||
from . import util
|
||||
|
||||
# some shortcuts
|
||||
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
|
||||
from .tuner import callback
|
||||
from .task import template, get_config, create, ConfigSpace, ConfigEntity
|
||||
from .record import ApplyHistoryBest as apply_history_best
|
|
@ -0,0 +1,181 @@
|
|||
# pylint: disable=consider-using-enumerate,invalid-name
|
||||
"""
|
||||
Database of MeasureInput/MeasureResult pair.
|
||||
This can be used for replaying measurement.
|
||||
"""
|
||||
import os
|
||||
|
||||
from .record import encode, decode, measure_str_key
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""
|
||||
Base class for a record database object.
|
||||
"""
|
||||
def load(self, inp, get_all=False):
|
||||
"""
|
||||
Load a result based on an input's string key
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inp: MeasureInput
|
||||
to be translated into key for RedisDB
|
||||
get_all: bool, optional
|
||||
Whether the latest result (or all matching results) should be returned
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec: MeasureResult if previously saved, otherwise None
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, inp, res, extend=False):
|
||||
"""
|
||||
Save a result based on an input's string key
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inp: MeasureInput
|
||||
to be translated into key for RedisDB
|
||||
res: MeasureResult
|
||||
to associate with key
|
||||
extend:
|
||||
Whether to extend existing MeasureResults if they exist
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def filter_inputs(db, measure_inputs, retry=False):
|
||||
"""
|
||||
Filter a measure_inputs batch based on saved db results
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db: Database
|
||||
database object
|
||||
measure_inputs: Array of MeasureInput
|
||||
measure_inputs as expected in measure_batch
|
||||
retry: bool
|
||||
whether to retry if the saved result is a failure
|
||||
|
||||
Returns
|
||||
-------
|
||||
partial_results: Array of MeasureResult
|
||||
a full list of result, where None denotes no corresponding saved result
|
||||
unsaved: Array of MeasureInput
|
||||
a list that only contains unsaved inputs
|
||||
"""
|
||||
partial_results = list()
|
||||
unsaved = list()
|
||||
for inp in measure_inputs:
|
||||
res = db.load(inp)
|
||||
if res is None or (retry and res.error_no != 0):
|
||||
unsaved.append(inp)
|
||||
partial_results.append(None)
|
||||
else:
|
||||
partial_results.append(res)
|
||||
return partial_results, unsaved
|
||||
|
||||
class RedisDatabase(Database):
|
||||
"""
|
||||
Redis version of record database
|
||||
"""
|
||||
REDIS_PROD = 15
|
||||
REDIS_LOCA = 14
|
||||
REDIS_TEST = 13 # for unit test
|
||||
REDIS_NIGHT_TEMP = 12 # for nightly report (will be flushed after every workload)
|
||||
|
||||
MAGIC_SPLIT = "$"
|
||||
|
||||
def __init__(self, db_index=REDIS_PROD):
|
||||
import redis
|
||||
|
||||
if db_index == RedisDatabase.REDIS_TEST:
|
||||
host = 'localhost'
|
||||
else:
|
||||
host = os.environ.get('TVM_FLEET_HOST')
|
||||
self.db = redis.StrictRedis(host=host, port=6379, db=db_index)
|
||||
self.db_index = db_index
|
||||
|
||||
def set(self, key, value):
|
||||
self.db.set(key, value)
|
||||
|
||||
def get(self, key):
|
||||
return self.db.get(key)
|
||||
|
||||
def load(self, inp, get_all=False):
|
||||
current = self.get(measure_str_key(inp))
|
||||
if current is not None:
|
||||
current = str(current)
|
||||
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
|
||||
results = [rec[1] for rec in records]
|
||||
if get_all:
|
||||
return results
|
||||
return max(results, key=lambda result: result.timestamp)
|
||||
return current
|
||||
|
||||
def save(self, inp, res, extend=False):
|
||||
current = self.get(measure_str_key(inp))
|
||||
if not extend or current is None:
|
||||
self.set(measure_str_key(inp),
|
||||
RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
|
||||
else:
|
||||
current = current.split(RedisDatabase.MAGIC_SPLIT)
|
||||
self.set(measure_str_key(inp),
|
||||
RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)]))
|
||||
|
||||
def filter(self, func):
|
||||
"""
|
||||
Dump all of the records for a particular target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: callable
|
||||
The signature of the function is bool (MeasureInput, Array of MeasureResult)
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of records (inp, result) matching the target
|
||||
|
||||
Examples
|
||||
--------
|
||||
get records for a target
|
||||
>>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys)
|
||||
"""
|
||||
matched_records = list()
|
||||
# may consider filtering in iterator in the future
|
||||
for key in self.db:
|
||||
current = self.get(key)
|
||||
try:
|
||||
records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)]
|
||||
except TypeError: # got a badly formatted/old format record
|
||||
continue
|
||||
|
||||
inps, results = zip(*records)
|
||||
inp = inps[0]
|
||||
if not func(inp, results):
|
||||
continue
|
||||
result = max(results, key=lambda res: res.timestamp)
|
||||
matched_records.append((inp, result))
|
||||
return matched_records
|
||||
|
||||
def flush(self):
|
||||
self.db.flushdb()
|
||||
|
||||
class DummyDatabase(RedisDatabase):
|
||||
"""
|
||||
A database based on python dictionary for testing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# pylint: disable=super-init-not-called
|
||||
self.db = {}
|
||||
|
||||
def set(self, key, value):
|
||||
self.db[key] = value
|
||||
|
||||
def get(self, key):
|
||||
return self.db.get(key)
|
||||
|
||||
def flush(self):
|
||||
self.db = {}
|
|
@ -0,0 +1,12 @@
|
|||
"""Global configuration/variable scope for autotvm"""
|
||||
|
||||
class AutotvmGlobalScope(object):
|
||||
current = None
|
||||
|
||||
def __init__(self):
|
||||
self._old = AutotvmGlobalScope.current
|
||||
AutotvmGlobalScope.current = self
|
||||
|
||||
self.cuda_target_arch = None
|
||||
|
||||
GLOBAL_SCOPE = AutotvmGlobalScope()
|
|
@ -0,0 +1,181 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Extract feature of iter vars
|
||||
|
||||
There are two types of feature
|
||||
1) Itervar feature
|
||||
This feature is extracted based on loop variables.
|
||||
Different loop structures will result in different shapes of feature
|
||||
2) Curve sample feature (relation feature)
|
||||
This feature is extracted by sampling relation curve.
|
||||
This feature is invariant of loop structure.
|
||||
"""
|
||||
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
|
||||
|
||||
def ana_lower(sch, args,
|
||||
binds=None,
|
||||
simple_mode=True):
|
||||
"""Do lower while keeping all axes in IR
|
||||
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
|
||||
"""
|
||||
binds, _ = build_module.get_binds(args, binds)
|
||||
sch = sch.normalize()
|
||||
# Phase 0
|
||||
bounds = schedule.InferBound(sch)
|
||||
stmt = schedule.ScheduleOps(sch, bounds, True)
|
||||
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
|
||||
stmt = ir_pass.CanonicalSimplify(stmt)
|
||||
assert simple_mode
|
||||
return stmt
|
||||
|
||||
try:
|
||||
_get_buffer_curve_sample_flatten = get_global_func(
|
||||
"autotvm.feature.GetCurveSampleFeatureFlatten")
|
||||
_get_itervar_feature = get_global_func("autotvm.feature.GetItervarFeature")
|
||||
_get_itervar_feature_flatten = get_global_func("autotvm.feature.GetItervarFeatureFlatten")
|
||||
except ValueError as e:
|
||||
def raise_error(*args, **kwargs): # pylint: disable=unused-argument
|
||||
raise RuntimeError("Cannot load autotvm c++ API")
|
||||
_get_buffer_curve_sample_flatten = _get_itervar_feature = _get_itervar_feature_flatten = \
|
||||
raise_error
|
||||
|
||||
def get_itervar_feature(sch, args, take_log=False):
|
||||
"""get features of iter vars
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
args: Array of tvm.tensor.Tensor
|
||||
the buffer args for lower
|
||||
take_log: bool
|
||||
whether take log of numerical statics
|
||||
|
||||
Returns
|
||||
-------
|
||||
features of every axis in the IR, see doc/features.md for detail
|
||||
"""
|
||||
stmt = ana_lower(sch, args, simple_mode=True)
|
||||
feas = _get_itervar_feature(stmt, take_log)
|
||||
|
||||
# convert tvm node to python type
|
||||
ret = []
|
||||
for row in feas:
|
||||
tmp = []
|
||||
tmp.append([row[0][0].value, row[0][1]])
|
||||
for item in row[1:]:
|
||||
tmp.append([item[0].value] + [x.value for x in item[1:]])
|
||||
ret.append(tmp)
|
||||
return ret
|
||||
|
||||
def flatten_itervar_feature(fea):
|
||||
"""flatten features into one-dimensional feature vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea: list
|
||||
return value of get_itervar_feature
|
||||
|
||||
Returns
|
||||
-------
|
||||
flatten_feature: np.ndarray
|
||||
one-dimensional vector
|
||||
"""
|
||||
flatten = []
|
||||
for axis in fea:
|
||||
for pair in axis[1:]:
|
||||
flatten.append(pair[1:])
|
||||
return np.concatenate(flatten)
|
||||
|
||||
def get_itervar_feature_flatten(sch, args, take_log=True):
|
||||
"""get flatten features of iter vars
|
||||
this is equivalent to get_itervar_feature + flatten_itervar_feature, but much faster.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
args: Array of tvm.tensor.Tensor
|
||||
the buffer args for lower
|
||||
take_log: bool
|
||||
whether take log of numerical statics
|
||||
|
||||
Returns
|
||||
-------
|
||||
flatten_feature: np.ndarray
|
||||
one-dimensional vector
|
||||
"""
|
||||
stmt = ana_lower(sch, args, simple_mode=True)
|
||||
feas = _get_itervar_feature_flatten(stmt, take_log)
|
||||
feas = struct.unpack('%df' % (len(feas)//4), feas)
|
||||
return feas
|
||||
|
||||
def get_flatten_name(fea):
|
||||
""" Get names of feature after flatten.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea: list or str
|
||||
return value of get_itervar_feature or a line of logfile
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_names: Array of str
|
||||
"""
|
||||
|
||||
feature_name = {
|
||||
"_attr_": ["length", "nest_level", "topdown", "bottomup"] +
|
||||
["ann_%d" % i for i in range(20)],
|
||||
"_arith_": ["add", "mul", "div"],
|
||||
"buf_touch": ["stride", "mod", "count", "reuse", "T_count", "T_reuse"],
|
||||
}
|
||||
|
||||
if isinstance(fea, str):
|
||||
from .record import decode
|
||||
# flatten line to feature
|
||||
line = fea
|
||||
inp, _ = decode(line)
|
||||
target = _target.create(inp.target)
|
||||
with target:
|
||||
s, args = inp.template.instantiate(inp.config)
|
||||
fea = get_itervar_feature(s, args)
|
||||
|
||||
names = []
|
||||
ct = 0
|
||||
for row in fea:
|
||||
var_name = str(row[0][1])
|
||||
for pair in row[1:]:
|
||||
key = pair[0]
|
||||
if key in feature_name:
|
||||
name_list = feature_name[key]
|
||||
else:
|
||||
name_list = feature_name["buf_touch"]
|
||||
|
||||
for i in range(len((pair[1:]))):
|
||||
names.append(".".join(["f%d" % ct, var_name, key, name_list[i]]))
|
||||
ct += 1
|
||||
return names
|
||||
|
||||
|
||||
def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
|
||||
"""
|
||||
Get flatten curve sample feature (relation feature)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
args: Array of tvm.tensor.Tensor
|
||||
the buffer args for lower
|
||||
sample_n: int
|
||||
number of sample points along one dimension
|
||||
|
||||
Returns
|
||||
-------
|
||||
flatten_feature: np.ndarray
|
||||
one-dimensional vector
|
||||
"""
|
||||
stmt = ana_lower(sch, args, simple_mode=True)
|
||||
feas = _get_buffer_curve_sample_flatten(stmt, sample_n, False)
|
||||
feas = struct.unpack('%df' % (len(feas)//4), feas)
|
||||
return feas
|
|
@ -0,0 +1,8 @@
|
|||
"""Distributed executor infrastructure to scale up the tuning"""
|
||||
|
||||
from .measure import MeasureInput, MeasureResult, MeasureErrorNo
|
||||
from .measure import create_measure_batch, measure_option
|
||||
from .measure_methods import request_remote
|
||||
|
||||
from .local_executor import LocalExecutor
|
||||
from .executor import Future, Executor
|
|
@ -0,0 +1,83 @@
|
|||
""" Abstraction for asynchronous job execution """
|
||||
|
||||
class Executor(object):
|
||||
"""
|
||||
Base abstract executor interface for asynchronous job submission.
|
||||
Allows submit asynchronous jobs and returns the Future object.
|
||||
"""
|
||||
# timeout for jobs that may hang
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
def submit(self, func, *args, **kwargs):
|
||||
"""
|
||||
Pass task (function, arguments) to the Executor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
function to be run by a worker
|
||||
args : list or tuple, optional
|
||||
arguments passed to the function
|
||||
kwargs : dict, optional
|
||||
The keyword arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
future : Future
|
||||
Future object wrapping the task which can be used to
|
||||
collect the task's result.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Future(object):
|
||||
"""
|
||||
Base class of the future object.
|
||||
The implementations can return object of subclass of this.
|
||||
This objects encapsulates the asynchronous execution of task
|
||||
submitted to another thread, or another worker for execution.
|
||||
|
||||
Future objects store the state of tasks--can be polled for
|
||||
result or a blocking call to retrieve the result can be used.
|
||||
"""
|
||||
def done(self):
|
||||
"""
|
||||
Return True if job was successfully cancelled or finished running.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(self, timeout=None):
|
||||
"""
|
||||
Get the result. This will block until the result is available.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Maximum number of seconds to wait before it timeouts.
|
||||
If not specified, it means we block until the result is available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : Any
|
||||
The result returned by the submitted function.
|
||||
|
||||
Raises
|
||||
------
|
||||
TimeoutError : if the result call timeouts.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
class FutureError(RuntimeError):
|
||||
"""Base error class of all future events"""
|
||||
pass
|
||||
|
||||
# pylint:disable=redefined-builtin
|
||||
class TimeoutError(FutureError):
|
||||
"""Error raised when a task is timeout."""
|
||||
pass
|
||||
|
||||
class ExecutionError(FutureError):
|
||||
"""
|
||||
Error raised when future execution crashes or failed.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,131 @@
|
|||
"""Local based implementation of the executor using multiprocessing"""
|
||||
|
||||
import signal
|
||||
|
||||
from multiprocessing import Process, Queue
|
||||
try:
|
||||
from queue import Empty
|
||||
except ImportError:
|
||||
from Queue import Empty
|
||||
|
||||
import psutil
|
||||
|
||||
from . import executor
|
||||
|
||||
|
||||
def kill_child_processes(parent_pid, sig=signal.SIGTERM):
|
||||
"""kill all child processes recursively"""
|
||||
try:
|
||||
parent = psutil.Process(parent_pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
children = parent.children(recursive=True)
|
||||
for process in children:
|
||||
try:
|
||||
process.send_signal(sig)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
def _execute_func(func, queue, args, kwargs):
|
||||
"""execute function and return the result or exception to a queue"""
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
res = exc
|
||||
queue.put(res)
|
||||
|
||||
def timeout_monitor(queue, timeout, func, args, kwargs):
|
||||
"""A wrapper to support timeout of a function call"""
|
||||
|
||||
# start a new process for timeout (cannot use thread because we have c function)
|
||||
p = Process(target=_execute_func, args=(func, queue, args, kwargs))
|
||||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
alive = p.is_alive()
|
||||
kill_child_processes(p.pid)
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
if alive:
|
||||
queue.put(executor.TimeoutError())
|
||||
else:
|
||||
if queue.empty():
|
||||
queue.put(executor.ExecutionError("Fatal error in local executor"))
|
||||
|
||||
|
||||
class LocalFuture(executor.Future):
|
||||
"""Local wrapper for the future
|
||||
|
||||
Parameters
|
||||
----------
|
||||
process: multiprocessing.Process
|
||||
process for running this task
|
||||
queue: multiprocessing.Queue
|
||||
queue for receiving the result of this task
|
||||
"""
|
||||
def __init__(self, process, queue):
|
||||
self._done = False
|
||||
self._process = process
|
||||
self._queue = queue
|
||||
|
||||
def done(self):
|
||||
self._done = self._done or not self._queue.empty()
|
||||
return self._done
|
||||
|
||||
def get(self, timeout=None):
|
||||
try:
|
||||
res = self._queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
raise executor.TimeoutError()
|
||||
if self._process.is_alive():
|
||||
kill_child_processes(self._process.pid)
|
||||
self._process.terminate()
|
||||
self._process.join()
|
||||
self._queue.close()
|
||||
self._queue.join_thread()
|
||||
self._done = True
|
||||
del self._queue
|
||||
del self._process
|
||||
return res
|
||||
|
||||
|
||||
class LocalFutureNoFork(executor.Future):
|
||||
"""Local wrapper for the future.
|
||||
This is a none-fork version of LocalFuture.
|
||||
Use this for the runtime that does not support fork (like cudnn)
|
||||
"""
|
||||
def __init__(self, result):
|
||||
self._result = result
|
||||
|
||||
def done(self):
|
||||
return True
|
||||
|
||||
def get(self, timeout=None):
|
||||
return self._result
|
||||
|
||||
|
||||
class LocalExecutor(executor.Executor):
|
||||
"""Local executor that runs workers on the same machine with multiprocessing."""
|
||||
def __init__(self, timeout=None):
|
||||
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
|
||||
|
||||
def submit(self, func, *args, **kwargs):
|
||||
"""
|
||||
|
||||
Note
|
||||
----------
|
||||
By default, the executor will fork a new process for a new job
|
||||
But some runtime does not support fork (e.g. cuda runtime, cudnn).
|
||||
In this circumstance, you should set 'fork_new_process' to False in kwargs
|
||||
"""
|
||||
fork_new_process = kwargs.pop('fork_new_process', True)
|
||||
|
||||
if not fork_new_process:
|
||||
return LocalFutureNoFork(func(*args, **kwargs))
|
||||
|
||||
queue = Queue(1)
|
||||
process = Process(target=timeout_monitor,
|
||||
args=(queue, self.timeout, func, args, kwargs))
|
||||
process.start()
|
||||
return LocalFuture(process, queue)
|
|
@ -0,0 +1,338 @@
|
|||
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
|
||||
"""User facing API for specifying how to measure the generated code"""
|
||||
import time
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import build, nd, target as _target
|
||||
from ...contrib.util import tempdir
|
||||
from ...rpc.tracker import Tracker
|
||||
from ...rpc.server import Server
|
||||
|
||||
from ..util import get_const_tuple
|
||||
from .local_executor import LocalExecutor
|
||||
|
||||
|
||||
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
|
||||
"""
|
||||
Stores all the necessary inputs for a measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : tvm.target.Target
|
||||
The target device
|
||||
task : task.Task
|
||||
Task function
|
||||
config : ConfigEntity
|
||||
Specific configuration.
|
||||
"""
|
||||
|
||||
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
|
||||
"""
|
||||
Stores all the results of a measurement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
costs: Array of float or Array of Exception
|
||||
If no error occurs for this measurement, it is an array of measured running times.
|
||||
If some error occurs during the measurement, it is an array of the exception objections.
|
||||
error_no: int
|
||||
Denote error type, defined by MeasureErrorNo
|
||||
all_cost: float
|
||||
All cost of this measure, including rpc, compilation, test runs
|
||||
timestamp: float
|
||||
The absolute time stamp when we finish measurement.
|
||||
"""
|
||||
|
||||
class MeasureErrorNo(object):
|
||||
"""Error type for MeasureResult"""
|
||||
NO_ERROR = 0 # no error
|
||||
INSTANTIATION_ERROR = 1 # error when calling template function
|
||||
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
|
||||
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
|
||||
RUNTIME_DEVICE = 4 # error when run program on device
|
||||
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
|
||||
FLEET_ERROR = 6 # error of measure infrastructure
|
||||
|
||||
|
||||
def measure_option(mode,
|
||||
number=1,
|
||||
repeat=1,
|
||||
timeout=60,
|
||||
parallel_num=1,
|
||||
pack_size=1,
|
||||
check_correctness=False,
|
||||
build_option=None,
|
||||
replay_db=None,
|
||||
save_to_replay_db=True,
|
||||
rpc_device_key=None,
|
||||
rpc_priority=1,
|
||||
rpc_timeout=60,
|
||||
rpc_tracker_addr=None,
|
||||
use_ndk=False,
|
||||
custom_measure_batch=None):
|
||||
"""Configure how to do measurement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mode: str
|
||||
'local': use the local device for measurement. In this mode,
|
||||
the tuner starts a tracker and a RPC server silently for the user.
|
||||
|
||||
'rpc': request devices for measurement from rpc tracker. In this mode,
|
||||
you should start a rpc tracker in a separate processing.
|
||||
|
||||
'custom': use custom measure function
|
||||
|
||||
'local-nofork': use local device for measure but does not use multiprocessing.
|
||||
This mode is suitable for debug, but does not support timeout and parallel.
|
||||
|
||||
number : int, optional
|
||||
Number of times to do the measurement for average
|
||||
repeat : int, optional
|
||||
Number of times to repeat the measurement.
|
||||
In total, the generated code will be run (1 + number x repeat) times,
|
||||
where the first one is warm up. The returned result contains `repeat` costs,
|
||||
each of which is the average of `number` test run.
|
||||
timeout: int, optional
|
||||
Timeout for a whole batch. TimeoutError will be returned as the result if a
|
||||
task timeouts.
|
||||
parallel_num: int, optional
|
||||
The number of measurement task that can run in parallel.
|
||||
Set this according to the number of cpu cores (for compilation) and
|
||||
the number of devices you have (for measuring generate code).
|
||||
pack_size : int, optional
|
||||
Number of configs to measure in one RPC call.
|
||||
Usually this can be set to 1. If your device has high cost to establish a rpc connection,
|
||||
set this higher.
|
||||
check_correctness: bool
|
||||
Whether check correctness after measurement.
|
||||
build_option: Dict, optional
|
||||
Build options for tvm.build_config
|
||||
|
||||
replay_db : Database, optional
|
||||
The database that we retrieve saved MeasureResults from
|
||||
save_to_replay_db: bool, optional
|
||||
Whether save measure result to database. This is useless when replay_db is None
|
||||
|
||||
rpc_priority: int, optional
|
||||
Priority of this task, used by scheduler in tracker
|
||||
rpc_device_key: str, optional
|
||||
The device key of registered devices in tracker
|
||||
rpc_timeout: int, optional
|
||||
Timeout of rpc session
|
||||
rpc_tracker_addr: Tuple(str, int), optional
|
||||
The address of rpc tracker in Tuple(host, port) format.
|
||||
If is set, will use this address.
|
||||
If is not set, will use environment variable "TVM_TRACKER_HOST" and "TVM_TRACKER_PORT"
|
||||
|
||||
use_ndk: bool, option
|
||||
Whether export requires ndk
|
||||
custom_measure_batch: callable, optional
|
||||
custom measure function
|
||||
|
||||
Returns
|
||||
-------
|
||||
options: dict
|
||||
A dict to store all options
|
||||
"""
|
||||
return {
|
||||
'mode': mode,
|
||||
'number': number,
|
||||
'repeat': repeat,
|
||||
'timeout': timeout,
|
||||
'parallel_num': parallel_num,
|
||||
'pack_size': pack_size,
|
||||
'check_correctness': check_correctness,
|
||||
'build_option': build_option,
|
||||
|
||||
'replay_db': replay_db,
|
||||
'save_to_replay_db': save_to_replay_db,
|
||||
|
||||
'rpc_device_key': rpc_device_key,
|
||||
'rpc_priority': rpc_priority,
|
||||
'rpc_timeout': rpc_timeout,
|
||||
'rpc_tracker_addr': rpc_tracker_addr,
|
||||
|
||||
'use_ndk': use_ndk,
|
||||
'custom_measure_batch': custom_measure_batch
|
||||
}
|
||||
|
||||
|
||||
def create_measure_batch(task, options):
|
||||
"""Get a standard measure_batch function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: tvm.autotvm.task.Task
|
||||
The tuning task
|
||||
options: dict
|
||||
The option for measuring generated code.
|
||||
You should use the return value of :any:`autotvm.measure_option` for this argument
|
||||
|
||||
Returns
|
||||
-------
|
||||
measure_batch: callable
|
||||
a callback function to measure a batch of configs
|
||||
"""
|
||||
from . import measure_methods
|
||||
from ..database import filter_inputs
|
||||
|
||||
mode = options['mode']
|
||||
number, repeat = options['number'], options['repeat']
|
||||
timeout, parallel_num = options['timeout'], options['parallel_num']
|
||||
pack_size = options['pack_size']
|
||||
check_correctness = options['check_correctness']
|
||||
build_option = options['build_option']
|
||||
replay_db = options['replay_db']
|
||||
save_to_replay_db = options['save_to_replay_db']
|
||||
rpc_device_key = options['rpc_device_key']
|
||||
rpc_priority, rpc_timeout = options['rpc_priority'], options['rpc_timeout']
|
||||
use_ndk = options['use_ndk']
|
||||
custom_measure_batch = options['custom_measure_batch']
|
||||
|
||||
kwargs = {}
|
||||
executor = LocalExecutor(timeout=timeout)
|
||||
|
||||
if mode == 'local':
|
||||
# start temporary rpc tracker and rpc server for the user
|
||||
tracker = Tracker('localhost', port=9000, port_end=10000,
|
||||
silent=True)
|
||||
rpc_device_key = '$local$device$%d' % tracker.port
|
||||
server = Server('localhost', port=9000, port_end=10000,
|
||||
key=rpc_device_key,
|
||||
use_popen=True, silent=True,
|
||||
tracker_addr=(tracker.host, tracker.port))
|
||||
|
||||
fmeasure = measure_methods.measure_rpc
|
||||
kwargs['rpc_device_key'] = rpc_device_key
|
||||
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
|
||||
kwargs['rpc_timeout'] = timeout
|
||||
kwargs['tmp_dir'] = tempdir()
|
||||
elif mode == 'rpc':
|
||||
fmeasure = measure_methods.measure_rpc
|
||||
kwargs['rpc_device_key'] = rpc_device_key
|
||||
kwargs['rpc_priority'] = rpc_priority
|
||||
kwargs['rpc_timeout'] = rpc_timeout
|
||||
kwargs['use_ndk'] = use_ndk
|
||||
kwargs['tmp_dir'] = tempdir()
|
||||
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
|
||||
elif mode == "custom":
|
||||
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
|
||||
"must be a callable object"
|
||||
elif mode == 'local-nofork':
|
||||
fmeasure = measure_methods.measure_local
|
||||
kwargs['fork_new_process'] = False
|
||||
else:
|
||||
raise RuntimeError("Invalid mode: " + mode)
|
||||
|
||||
if 'cuda' in task.target.keys and 'rpc_device_key' in kwargs: # query cuda device info
|
||||
add_cuda_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
|
||||
if 'opencl' in task.target.keys and 'rpc_device_key' in kwargs:
|
||||
add_opencl_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
|
||||
|
||||
if check_correctness:
|
||||
# use llvm to generate a reference input/output
|
||||
# this option works for tuning topi, but might not work for you custom op
|
||||
with _target.create("llvm"):
|
||||
s, arg_bufs = task.instantiate(task.config_space.get(0))
|
||||
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
|
||||
for x in arg_bufs]
|
||||
func = build(s, arg_bufs, "llvm")
|
||||
tvm_buf = [nd.array(x) for x in ref_input]
|
||||
func(*tvm_buf)
|
||||
ref_output = [x.asnumpy() for x in tvm_buf]
|
||||
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
|
||||
|
||||
def measure_batch(measure_inputs):
|
||||
"""measure the time cost for a batch of configs in real machines"""
|
||||
if replay_db is not None:
|
||||
partial_results, measure_inputs =\
|
||||
filter_inputs(replay_db, measure_inputs, retry=False)
|
||||
|
||||
# pack configs
|
||||
input_packs = []
|
||||
for i in range(0, len(measure_inputs), pack_size):
|
||||
input_packs.append(measure_inputs[i:i + pack_size])
|
||||
|
||||
# send to measure
|
||||
futures = []
|
||||
for input_pack in input_packs:
|
||||
future = executor.submit(
|
||||
fmeasure, input_pack,
|
||||
number=number,
|
||||
repeat=repeat,
|
||||
build_option=build_option,
|
||||
**kwargs
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
# transform results
|
||||
results = []
|
||||
for future in futures:
|
||||
result = future.get()
|
||||
if isinstance(result, Exception):
|
||||
if mode == 'local-nofork':
|
||||
# debug usage, raise exception
|
||||
raise result
|
||||
tstamp = time.time()
|
||||
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
|
||||
timeout, tstamp)] * pack_size)
|
||||
else:
|
||||
results.extend(result)
|
||||
|
||||
if replay_db is not None:
|
||||
if save_to_replay_db: # save result to database
|
||||
for measure_input, result in zip(measure_inputs, results):
|
||||
replay_db.save(measure_input, result)
|
||||
|
||||
result_idx = 0
|
||||
for i in range(len(partial_results)):
|
||||
if partial_results[i] is None:
|
||||
partial_results[i] = results[result_idx]
|
||||
result_idx += 1
|
||||
return partial_results
|
||||
return results
|
||||
|
||||
if mode == 'custom':
|
||||
measure_batch = custom_measure_batch
|
||||
|
||||
measure_batch.parallel_num = parallel_num
|
||||
if mode == 'local':
|
||||
measure_batch.aux_objects = {"server": server, "tracker": tracker}
|
||||
return measure_batch
|
||||
|
||||
|
||||
def add_cuda_device_info(device_key, rpc_tracker_addr, kwargs):
|
||||
"""Query cuda device info. This is used to set the flags for nvcc compiler
|
||||
and check the validity of a generated code."""
|
||||
from .measure_methods import request_remote
|
||||
|
||||
remote = request_remote(device_key, rpc_tracker_addr)
|
||||
ctx = remote.context('cuda', 0)
|
||||
max_dims = ctx.max_thread_dimensions
|
||||
kwargs['check_gpu'] = {
|
||||
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
|
||||
'max_threads_per_block': ctx.max_threads_per_block,
|
||||
'max_thread_x': max_dims[0],
|
||||
'max_thread_y': max_dims[1],
|
||||
'max_thread_z': max_dims[2],
|
||||
}
|
||||
|
||||
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
|
||||
|
||||
def add_opencl_device_info(device_key, rpc_tracker_addr, kwargs):
|
||||
"""Query opencl device info. This is used to check the validity of a generated code."""
|
||||
from .measure_methods import request_remote
|
||||
|
||||
remote = request_remote(device_key, rpc_tracker_addr)
|
||||
ctx = remote.context('opencl', 0)
|
||||
max_dims = ctx.max_thread_dimensions
|
||||
kwargs['check_gpu'] = {
|
||||
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
|
||||
'max_threads_per_block': ctx.max_threads_per_block,
|
||||
'max_thread_x': max_dims[0],
|
||||
'max_thread_y': max_dims[1],
|
||||
'max_thread_z': max_dims[2],
|
||||
}
|
|
@ -0,0 +1,296 @@
|
|||
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
|
||||
"""
|
||||
Functions that run on executor for measurement.
|
||||
These functions are responsible for building tvm module, uploading it to
|
||||
remote devices, recording the running time costs and checking the correctness of output
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from random import getrandbits
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...contrib import ndk, nvcc
|
||||
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
|
||||
|
||||
from ..util import get_const_tuple
|
||||
from ..env import AutotvmGlobalScope
|
||||
from .measure import MeasureResult, MeasureErrorNo
|
||||
from ..task.space import InstantiationError
|
||||
|
||||
|
||||
class HashMismatchError(ValueError):
|
||||
"""Raised when the code hash of a submitted config doesn't match that on the
|
||||
measure side """
|
||||
pass
|
||||
|
||||
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
|
||||
"""request a remote session
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device_key: string
|
||||
device key of registered device in tracker
|
||||
tracker_addr: Tuple(string, int), optional
|
||||
The address of rpc tracker in (host, port) format
|
||||
priority: int, optional
|
||||
priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
timeout of this session (units: seconds)
|
||||
|
||||
Returns
|
||||
------
|
||||
session: RPCSession
|
||||
"""
|
||||
# connect to the tracker
|
||||
if tracker_addr:
|
||||
host = tracker_addr[0]
|
||||
port = tracker_addr[1]
|
||||
else:
|
||||
host = os.environ['TVM_TRACKER_HOST']
|
||||
port = int(os.environ['TVM_TRACKER_PORT'])
|
||||
|
||||
tracker = rpc.connect_tracker(host, port)
|
||||
remote = tracker.request(device_key, priority=priority,
|
||||
session_timeout=timeout)
|
||||
return remote
|
||||
|
||||
|
||||
def _measure_generic(fbuild, input_pack, ref_input, ref_output):
|
||||
"""Generic measurement function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fbuild : function takes MeasureInput returns tuple of (time_func, ctx)
|
||||
The build function used to build each input.
|
||||
input_pack : list of MeasureInput
|
||||
The inputs we need to evaluate
|
||||
ref_input: Array of np.ndarray
|
||||
Reference input for checking correctness
|
||||
ref_output: Array of np.ndarray
|
||||
Reference output for checking correctness
|
||||
|
||||
Returns
|
||||
-------
|
||||
res_pack : array of MeasureResult
|
||||
The list of execution result of measurement.
|
||||
"""
|
||||
res_pack = []
|
||||
for inp in input_pack:
|
||||
tic = time.time()
|
||||
try:
|
||||
time_f, ctx, arg_bufs = fbuild(inp)
|
||||
except TVMError as exc:
|
||||
tstamp = time.time()
|
||||
msg = str(exc)
|
||||
if "Stack trace returned" in msg:
|
||||
msg = msg[:msg.index("Stack trace returned")]
|
||||
if "InstantiationError" in msg:
|
||||
try:
|
||||
msg = msg.split('\n')[-2].split(": ")[1]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
res_pack.append(MeasureResult((InstantiationError(msg),),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
tstamp - tic, tstamp))
|
||||
else:
|
||||
res_pack.append(MeasureResult((RuntimeError(msg),),
|
||||
MeasureErrorNo.COMPILE_HOST,
|
||||
tstamp - tic, tstamp))
|
||||
continue
|
||||
except InstantiationError as e:
|
||||
tstamp = time.time()
|
||||
res_pack.append(MeasureResult((e,),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
tstamp - tic, tstamp))
|
||||
continue
|
||||
|
||||
# measure time
|
||||
errno = MeasureErrorNo.NO_ERROR
|
||||
try:
|
||||
if ref_input:
|
||||
args = [nd.array(x, ctx) for x in ref_input]
|
||||
else:
|
||||
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
|
||||
ctx) for x in arg_bufs]
|
||||
costs = time_f(*args).results
|
||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||
costs = list(costs)
|
||||
costs.sort()
|
||||
costs = tuple(costs[1:-1])
|
||||
if ref_output:
|
||||
for expected, real in zip(ref_output, args):
|
||||
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
|
||||
logging.warning("Wrong Answer!")
|
||||
errno = MeasureErrorNo.WRONG_ANSWER
|
||||
except TVMError as exc:
|
||||
msg = str(exc)
|
||||
if "Stack trace returned" in msg:
|
||||
msg = msg[:msg.index("Stack trace returned")]
|
||||
costs = (RuntimeError(msg),)
|
||||
errno = MeasureErrorNo.RUNTIME_DEVICE
|
||||
tstamp = time.time()
|
||||
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
|
||||
return res_pack
|
||||
|
||||
def _build_func(inp, build_option, kwargs):
|
||||
"""Build function module. Exception will be raised when error occurs"""
|
||||
with inp.target:
|
||||
s, args = inp.task.instantiate(inp.config)
|
||||
if not inp.config.valid():
|
||||
raise InstantiationError(inp.config.errors)
|
||||
code_hash = getattr(s, 'code_hash', None)
|
||||
if inp.config.code_hash != code_hash:
|
||||
raise HashMismatchError('got {0:s}, expected {1:s}'
|
||||
.format(str(inp.config.code_hash), str(code_hash)))
|
||||
|
||||
opts = build_option or {}
|
||||
if "check_gpu" in kwargs:
|
||||
values = kwargs['check_gpu']
|
||||
# Add gpu verify pass to filter out invalid configs in advance.
|
||||
# This can accelerate the tuning process
|
||||
check_keys = ['max_shared_memory_per_block', 'max_threads_per_block',
|
||||
'max_thread_x', 'max_thread_y', 'max_thread_z']
|
||||
opts["add_lower_pass"] = [
|
||||
(2, gpu_verify_pass(**{key: values[key] for key in check_keys}))]
|
||||
|
||||
if 'cuda_arch' in kwargs:
|
||||
set_cuda_target_arch(kwargs['cuda_arch'])
|
||||
|
||||
with build_config(**opts):
|
||||
func = build(s, args, target_host=inp.task.target_host)
|
||||
|
||||
return func, args
|
||||
|
||||
|
||||
def measure_rpc(input_pack,
|
||||
rpc_device_key,
|
||||
number,
|
||||
repeat=1,
|
||||
build_option=None,
|
||||
rpc_tracker_addr=None,
|
||||
rpc_priority=1,
|
||||
rpc_timeout=60,
|
||||
tmp_dir=None,
|
||||
**kwargs):
|
||||
"""Measure the time cost on a device by rpc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_pack : list of MeasureInput
|
||||
The inputs we need to evaluate
|
||||
rpc_device_key: str
|
||||
The device key of registered devices in tracker
|
||||
number : int
|
||||
Number of times to get the running measurement
|
||||
repeat : int, optional
|
||||
How many times we want to repeat the measurement.
|
||||
build_option: Dict
|
||||
build options for tvm.build_config
|
||||
|
||||
rpc_tracker_addr: Tuple(string, int), optional
|
||||
The address of rpc tracker in (host, port) format
|
||||
If is none, will use environment variable
|
||||
rpc_priority: int, optional
|
||||
priority of this task, used by scheduler in tracker
|
||||
rpc_timeout: int, optional
|
||||
timeout of the rpc session
|
||||
|
||||
tmp_dir: tvm.contrib.util.TempDirectory, optional
|
||||
directory to store temp file
|
||||
|
||||
kwargs: dict, optional
|
||||
Additional key word arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
res_pack : Array of MeasureResult
|
||||
The list of execution results of measurement.
|
||||
"""
|
||||
def _fbuild(inp):
|
||||
""" Local build function."""
|
||||
func, args = _build_func(inp, build_option, kwargs)
|
||||
|
||||
if not kwargs.get('use_ndk', False):
|
||||
file_name = "tmp_func_%0x.tar" % getrandbits(64)
|
||||
path = tmp_dir.relpath(file_name)
|
||||
func.export_library(path)
|
||||
else:
|
||||
file_name = "tmp_func_%0x.so" % getrandbits(64)
|
||||
path = tmp_dir.relpath(file_name)
|
||||
func.export_library(path, ndk.create_shared)
|
||||
remote = request_remote(rpc_device_key, rpc_tracker_addr, rpc_priority, rpc_timeout)
|
||||
remote.upload(path)
|
||||
func = remote.load_module(file_name)
|
||||
ctx = remote.context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
return time_f, ctx, args
|
||||
|
||||
ret = _measure_generic(_fbuild, input_pack,
|
||||
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
|
||||
return ret
|
||||
|
||||
|
||||
def measure_local(input_pack,
|
||||
number,
|
||||
repeat=1,
|
||||
build_option=None,
|
||||
**kwargs):
|
||||
"""Measure the time cost on a local machine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_pack : list of MeasureInput
|
||||
The inputs we need to evaluate
|
||||
number : int
|
||||
Number of times to get the running measurement
|
||||
repeat : int, optional
|
||||
How many times we want to repeat the measurement.
|
||||
build_option: dict, optional
|
||||
Build options for tvm.build_config
|
||||
kwargs: dict, optional
|
||||
Additional key word arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
res_pack : Array of MeasureResult
|
||||
The list of execution results of measurement.
|
||||
"""
|
||||
|
||||
def _fbuild(inp):
|
||||
""" Local build function """
|
||||
func, args = _build_func(inp, build_option, kwargs)
|
||||
ctx = context(str(inp.target), 0)
|
||||
time_f = func.time_evaluator(
|
||||
func.entry_name, ctx, number=number, repeat=repeat)
|
||||
return time_f, ctx, args
|
||||
|
||||
ret = _measure_generic(_fbuild, input_pack,
|
||||
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
|
||||
return ret
|
||||
|
||||
|
||||
def gpu_verify_pass(**kwargs):
|
||||
"""Verify the validity of a gpu kernel
|
||||
This pass will check shared memory size and number of threads per block.
|
||||
"""
|
||||
def verify_pass(stmt):
|
||||
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
|
||||
if not valid:
|
||||
raise InstantiationError("Skipped because of invalid gpu kernel")
|
||||
return stmt
|
||||
return verify_pass
|
||||
|
||||
|
||||
@register_func
|
||||
def tvm_callback_cuda_compile(code):
|
||||
"""use nvcc to generate ptx code for better optimization"""
|
||||
ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
|
||||
return ptx
|
||||
|
||||
def set_cuda_target_arch(arch):
|
||||
"""set target architecture of nvcc compiler"""
|
||||
AutotvmGlobalScope.current.cuda_target_arch = arch
|
|
@ -0,0 +1,332 @@
|
|||
# pylint: disable=superfluous-parens, redefined-outer-name, redefined-outer-name,pointless-string-statement
|
||||
# pylint: disable=consider-using-enumerate,invalid-name
|
||||
"""Tuning record and serialization format"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import json
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import target, build, lower
|
||||
|
||||
from . import task
|
||||
from .task import DispatchContext, ConfigEntity
|
||||
from .measure import MeasureInput, MeasureResult
|
||||
|
||||
AUTOTVM_LOG_VERSION = 0.1
|
||||
|
||||
try: # convert unicode to str for python2
|
||||
_unicode = unicode
|
||||
except NameError:
|
||||
_unicode = ()
|
||||
|
||||
|
||||
def measure_str_key(inp, include_config=True):
|
||||
""" get unique str key for MeasureInput
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inp: MeasureInput
|
||||
input for the measure
|
||||
include_config: bool, optional
|
||||
whether includes config in the str key
|
||||
|
||||
Returns
|
||||
-------
|
||||
key: str
|
||||
The str representation of key
|
||||
"""
|
||||
config_str = str(inp.config) if include_config else ""
|
||||
return "".join([str(inp.target), inp.task.name, str(inp.task.args),
|
||||
str(inp.task.kwargs), config_str])
|
||||
|
||||
|
||||
def encode(inp, result, protocol='json'):
|
||||
"""encode (MeasureInput, MeasureResult) pair to a string
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inp: autotvm.tuner.MeasureInput
|
||||
result: autotvm.tuner.MeasureResult
|
||||
pair of input/result
|
||||
protocol: str
|
||||
log protocol, json or pickle
|
||||
|
||||
Returns
|
||||
-------
|
||||
row: str
|
||||
a row in the logger file
|
||||
"""
|
||||
|
||||
if protocol == 'json':
|
||||
json_dict = {
|
||||
"i": (str(inp.target),
|
||||
inp.task.name, inp.task.args, inp.task.kwargs,
|
||||
inp.task.workload,
|
||||
inp.config.to_json_dict()),
|
||||
|
||||
"r": (result.costs if result.error_no == 0 else (1e9,),
|
||||
result.error_no,
|
||||
result.all_cost,
|
||||
result.timestamp),
|
||||
|
||||
"v": AUTOTVM_LOG_VERSION
|
||||
}
|
||||
return json.dumps(json_dict)
|
||||
elif protocol == 'pickle':
|
||||
row = (str(inp.target),
|
||||
str(base64.b64encode(pickle.dumps([inp.task.name,
|
||||
inp.task.args,
|
||||
inp.task.kwargs,
|
||||
inp.task.workload])).decode()),
|
||||
str(base64.b64encode(pickle.dumps(inp.config)).decode()),
|
||||
str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
|
||||
return '\t'.join(row)
|
||||
else:
|
||||
raise RuntimeError("Invalid log protocol: " + protocol)
|
||||
|
||||
|
||||
def decode(row, protocol='json'):
|
||||
"""Decode encoded record string to python object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row: str
|
||||
a row in the logger file
|
||||
protocol: str
|
||||
log protocol, json or pickle
|
||||
|
||||
Returns
|
||||
-------
|
||||
input: autotvm.tuner.MeasureInput
|
||||
result: autotvm.tuner.MeasureResult
|
||||
"""
|
||||
# pylint: disable=unused-variable
|
||||
if protocol == 'json':
|
||||
row = json.loads(row)
|
||||
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
|
||||
tgt = target.create(str(tgt))
|
||||
|
||||
def clean_json_to_python(x):
|
||||
"""1. convert all list in x to tuple (hashable)
|
||||
2. convert unicode to str for python2
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
return tuple([clean_json_to_python(a) for a in x])
|
||||
if isinstance(x, _unicode):
|
||||
return str(x)
|
||||
return x
|
||||
|
||||
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
|
||||
tsk.workload = clean_json_to_python(workload)
|
||||
config = ConfigEntity.from_json_dict(config)
|
||||
inp = MeasureInput(tgt, tsk, config)
|
||||
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
|
||||
|
||||
return inp, result
|
||||
elif protocol == 'pickle':
|
||||
items = row.split("\t")
|
||||
tgt = target.create(items[0])
|
||||
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
|
||||
config = pickle.loads(base64.b64decode(items[2].encode()))
|
||||
result = pickle.loads(base64.b64decode(items[3].encode()))
|
||||
|
||||
tsk = task.Task(task_tuple[0], task_tuple[1])
|
||||
tsk.workload = task_tuple[3]
|
||||
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
|
||||
else:
|
||||
raise RuntimeError("Invalid log protocol: " + protocol)
|
||||
|
||||
def load_from_file(filename):
|
||||
"""Generator: load records from file.
|
||||
This is a generator that yields the records.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename: str
|
||||
|
||||
Yields
|
||||
------
|
||||
input: autotvm.tuner.MeasureInput
|
||||
result: autotvm.tuner.MeasureResult
|
||||
"""
|
||||
for row in open(filename):
|
||||
yield decode(row)
|
||||
|
||||
|
||||
class ApplyHistoryBest(DispatchContext):
|
||||
"""
|
||||
Apply the history best config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
records : str or iterator of (MeasureInput, MeasureResult)
|
||||
Collection of tuning records.
|
||||
if is str, then it should be the filename of a records log file.
|
||||
Each row of this file is an encoded record pair.
|
||||
otherwise, it is an iterator
|
||||
default: ConfigEntity, optional
|
||||
default config to return when no history records
|
||||
"""
|
||||
def __init__(self, records, default=None):
|
||||
super(ApplyHistoryBest, self).__init__()
|
||||
|
||||
if isinstance(records, str):
|
||||
records = load_from_file(records)
|
||||
|
||||
counter = 0
|
||||
best_map = {}
|
||||
for inp, res in records:
|
||||
counter += 1
|
||||
if res.error_no != 0:
|
||||
continue
|
||||
for k in inp.target.keys:
|
||||
key = (k, inp.task.workload)
|
||||
if key not in best_map:
|
||||
best_map[key] = (inp, res)
|
||||
else:
|
||||
_, other_res = best_map[key]
|
||||
if np.mean(other_res.costs) > np.mean(res.costs):
|
||||
best_map[key] = (inp, res)
|
||||
logging.info(
|
||||
"Finish load %d records, %d entries selected", counter, len(best_map))
|
||||
self._best_map = best_map
|
||||
self._default = default
|
||||
|
||||
def query(self, target, workload):
|
||||
if target is None:
|
||||
raise RuntimeError("Need a target context to find the history best. "
|
||||
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
|
||||
" above the dispatcher call. So does other target. ")
|
||||
|
||||
for k in target.keys:
|
||||
key = (k, workload)
|
||||
if key in self._best_map:
|
||||
return self._best_map[key][0].config
|
||||
|
||||
if self._default:
|
||||
return self._default
|
||||
raise RuntimeError(
|
||||
"Cannot find config for target=%s, workload=%s" % (target, workload))
|
||||
|
||||
def dump_best(self, out_file):
|
||||
"""Dump the best records for each workload to a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_file: str
|
||||
filename
|
||||
"""
|
||||
fout = open(out_file, 'a')
|
||||
for val in self._best_map.values():
|
||||
inp, res = val
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
|
||||
|
||||
def split_workload(in_file, clean=True):
|
||||
"""Split a log file into separate files, each of which contains only a single workload
|
||||
This function can also delete duplicated records in log file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_file: str
|
||||
input filename
|
||||
clean: bool
|
||||
whether delete duplicated items
|
||||
"""
|
||||
tic = time.time()
|
||||
lines = list(open(in_file).readlines())
|
||||
|
||||
logging.info("start convert...")
|
||||
pool = multiprocessing.Pool()
|
||||
lines = pool.map(decode, lines)
|
||||
logging.info("map done %.2f", time.time() - tic)
|
||||
|
||||
wkl_dict = OrderedDict()
|
||||
for inp, res in lines:
|
||||
wkl = measure_str_key(inp, False)
|
||||
if wkl not in wkl_dict:
|
||||
wkl_dict[wkl] = []
|
||||
wkl_dict[wkl].append([inp, res])
|
||||
|
||||
if clean:
|
||||
for i, (k, v) in enumerate(wkl_dict.items()):
|
||||
# clean duplicated items
|
||||
added = set()
|
||||
cleaned = []
|
||||
for inp, res in v:
|
||||
str_key = measure_str_key(inp)
|
||||
if str_key in added:
|
||||
continue
|
||||
added.add(str_key)
|
||||
cleaned.append([inp, res])
|
||||
|
||||
# write to file
|
||||
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
|
||||
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
|
||||
for inp, res in cleaned:
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
else:
|
||||
for i, (k, v) in enumerate(wkl_dict.items()):
|
||||
logging.info("Key: %s\tNum: %d", k, len(v))
|
||||
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
|
||||
for inp, res in v:
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
|
||||
|
||||
"""
|
||||
Usage:
|
||||
This record executable module has three modes.
|
||||
|
||||
* Print log file in readable format
|
||||
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
|
||||
|
||||
* Extract history best from a large log file
|
||||
e.g. python -m autotvm.record --mode best --i collect.tsv
|
||||
|
||||
* Split a log file into separate files, each of which contains only a single wkl
|
||||
e.g. python -m autotvm.record --mode split --i collect.tsv
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
|
||||
parser.add_argument("--i", type=str, help="input file")
|
||||
parser.add_argument("--o", type=str, default=None, help='output file')
|
||||
parser.add_argument("--begin", type=int, default=0)
|
||||
parser.add_argument("--end", type=int, default=5)
|
||||
parser.add_argument("--ir", action='store_true')
|
||||
parser.add_argument("--code", action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if args.mode == 'best':
|
||||
args.o = args.o or args.i + ".best"
|
||||
hist_best = ApplyHistoryBest(load_from_file(args.i))
|
||||
hist_best.dump_best(args.o)
|
||||
elif args.mode == 'read':
|
||||
for i, (inp, result) in enumerate(load_from_file(args.i)):
|
||||
if args.begin <= i < args.end:
|
||||
with inp.target:
|
||||
s, arg_bufs = inp.task.instantiate(inp.config)
|
||||
|
||||
print("")
|
||||
print(inp.target, inp.task, inp.config)
|
||||
print(result)
|
||||
|
||||
if args.ir:
|
||||
with inp.target:
|
||||
print(lower(s, arg_bufs, simple_mode=True))
|
||||
|
||||
if args.code:
|
||||
with inp.target:
|
||||
func = build(s, arg_bufs)
|
||||
print(func.imported_modules[0].get_source())
|
||||
elif args.mode == 'split':
|
||||
split_workload(args.i)
|
|
@ -0,0 +1,12 @@
|
|||
"""Task is a tunable composition of template functions.
|
||||
|
||||
Tuner takes a tunable task and optimizes the joint configuration
|
||||
space of all the template functions in the task.
|
||||
This module defines the task data structure, as well as a collection(zoo)
|
||||
of typical tasks of interest.
|
||||
"""
|
||||
|
||||
from .task import Task, create, register, template, get_config
|
||||
from .space import ConfigSpace, ConfigEntity
|
||||
from .code_hash import attach_code_hash, attach_code_hash_to_arg
|
||||
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
|
|
@ -0,0 +1,43 @@
|
|||
"""
|
||||
Decorator functions for hashing schedule code
|
||||
|
||||
code hashing is used to check the consistence of schedule code and the parameters loaded from log
|
||||
"""
|
||||
import inspect
|
||||
import zlib
|
||||
|
||||
from tvm import schedule
|
||||
|
||||
def attach_code_hash(s):
|
||||
"""Decorator for attaching a code hash to a schedule
|
||||
|
||||
Parameters
|
||||
----------
|
||||
s: Schedule
|
||||
tvm.schedule.Schedule to attach the hash to
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
|
||||
s.code_hash = hex(raw_hash)[2:]
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def attach_code_hash_to_arg(arg_idx=1):
|
||||
"""Decorator for attaching a code hash to a schedule
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arg_idx: int
|
||||
index of the argument (expected to be a Schedule) to attach the code
|
||||
hash to
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
assert isinstance(args[arg_idx], schedule.Schedule)
|
||||
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
|
||||
args[arg_idx].code_hash = hex(raw_hash)[2:]
|
||||
return wrapper
|
||||
return decorator
|
|
@ -0,0 +1,139 @@
|
|||
"""
|
||||
Template dispatcher module.
|
||||
|
||||
A dispatcher is a function that can contains multiple behaviors.
|
||||
Its specific behavior is can be controlled by DispatchContext.
|
||||
|
||||
DispatchContext is used in two ways, usually via different implementation
|
||||
of the DispatchContext base class.
|
||||
|
||||
- During search, we can use it to pass the current proposal from tuner.
|
||||
- During evaluation, we can use it to set pick the best policy.
|
||||
"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from decorator import decorate
|
||||
|
||||
from tvm import target as _target
|
||||
|
||||
class DispatchContext(object):
|
||||
"""
|
||||
Base class of dispatch context.
|
||||
|
||||
DispatchContext enables the target and workload
|
||||
specific dispatch mechanism for templates.
|
||||
"""
|
||||
current = None
|
||||
|
||||
def query(self, target, workload):
|
||||
"""
|
||||
Query the context to get the specific implementation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The current target
|
||||
workload : Workload
|
||||
The current workload.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cfg : ConfigSpace
|
||||
The specific configuration.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __enter__(self):
|
||||
self._old_ctx = DispatchContext.current
|
||||
DispatchContext.current = self
|
||||
return self
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
DispatchContext.current = self._old_ctx
|
||||
|
||||
|
||||
class ApplyConfig(DispatchContext):
|
||||
"""Apply a specific config entity during query.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ConfigSpace or ConfigEntity
|
||||
The specific configuration we care about.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(ApplyConfig, self).__init__()
|
||||
self._config = config
|
||||
self.workload = None
|
||||
|
||||
def query(self, target, workload):
|
||||
"""Override query"""
|
||||
self.workload = workload
|
||||
return self._config
|
||||
|
||||
|
||||
def dispatcher(fworkload):
|
||||
"""Wrap a workload dispatcher function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fworkload : function
|
||||
The workload extraction function from arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fdispatcher : function
|
||||
A wrapped dispatcher function, which will
|
||||
dispatch based on DispatchContext and
|
||||
the current workload.
|
||||
"""
|
||||
dispatch_dict = {}
|
||||
func_name = fworkload.__name__
|
||||
|
||||
def register(key, func=None, override=False):
|
||||
"""Register template function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or List of str
|
||||
The template key to identify the template
|
||||
under this dispatcher.
|
||||
func : function
|
||||
The function to be registered.
|
||||
The first argument of the function is always
|
||||
cfg returned by DispatchContext,
|
||||
the rest arguments are the same as the fworkload.
|
||||
override : bool
|
||||
Whether override existing registration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The register function if necessary.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
key = [key]
|
||||
|
||||
def _do_reg(myf):
|
||||
for x in key:
|
||||
if x in dispatch_dict and not override:
|
||||
raise ValueError(
|
||||
"Key %s is already registered for %s" % (x, func_name))
|
||||
dispatch_dict[x] = myf
|
||||
return myf
|
||||
|
||||
if func:
|
||||
return _do_reg(func)
|
||||
return _do_reg
|
||||
|
||||
def dispatch_func(func, *args, **kwargs):
|
||||
"""The wrapped dispatch function"""
|
||||
tgt = _target.current_target()
|
||||
context = DispatchContext.current
|
||||
if context is None:
|
||||
raise RuntimeError("DispatchContext is not initialized")
|
||||
workload = func(*args, **kwargs)
|
||||
cfg = context.query(tgt, workload)
|
||||
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
|
||||
|
||||
fdecorate = decorate(fworkload, dispatch_func)
|
||||
fdecorate.register = register
|
||||
return fdecorate
|
|
@ -0,0 +1,886 @@
|
|||
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
|
||||
# pylint: disable=consider-using-enumerate
|
||||
"""
|
||||
Template configuration space.
|
||||
|
||||
Each template function can be parametrized by a ConfigSpace.
|
||||
The space is declared when we invoke the template function with ConfigSpace.
|
||||
During evaluation, we pass in a ConfigEntity, which contains a specific
|
||||
entity in the space. This entity contains deterministic parameters.
|
||||
"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
import itertools
|
||||
import functools
|
||||
import math
|
||||
from collections import namedtuple, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
from tvm import schedule, thread_axis
|
||||
from tvm.autotvm.util import get_const_int
|
||||
|
||||
Axis = namedtuple('Axis', ['space', 'index'])
|
||||
|
||||
|
||||
class InstantiationError(ValueError):
|
||||
"""Actively detected error in instantiating a template with a config,
|
||||
raised by cfg.raise_error
|
||||
e.g. too many unrolling, too many threads in a block
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TransformSpace(object):
|
||||
"""Base class for transform space
|
||||
TransformSpace is the node in the computation graph of axes
|
||||
|
||||
Note
|
||||
----
|
||||
We can regard our schedule code as a transformation graph of axes.
|
||||
Starting from raw axes in the definition of tvm.compute, we can transform these axes
|
||||
by some operators. The operator includes 'split', 'reorder' and 'annotate'.
|
||||
Each operator has some tunable parameters (e.g. the split factor).
|
||||
Then the tuning process is just to find good parameters of these op.
|
||||
|
||||
So the all the combinations of the parameters of these op forms our search space.
|
||||
|
||||
Naming convention:
|
||||
We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
|
||||
We call a specific entity in a space as XXXEntity.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.ins = []
|
||||
self.num_output = 0
|
||||
self.entities = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.entities)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get an entity of the space by index
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index: int
|
||||
|
||||
Returns
|
||||
-------
|
||||
transform entity
|
||||
"""
|
||||
return self.entities[index]
|
||||
|
||||
@staticmethod
|
||||
def get_num_output():
|
||||
"""get number of output axes after this transform
|
||||
|
||||
Returns
|
||||
-------
|
||||
n: int
|
||||
number of output axes
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class VirtualAxis(TransformSpace):
|
||||
"""Axis placeholder in template
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var: int or tvm.schedule.IterVar
|
||||
If is int, return a virtual axis whose length is the provided argument.
|
||||
If is IterVar, return a virtual axis whose length is extracted from
|
||||
the IterVar's extent domain.
|
||||
name: str
|
||||
"""
|
||||
name_ct = 0
|
||||
|
||||
def __init__(self, var, name=None):
|
||||
super(VirtualAxis, self).__init__()
|
||||
self.num_output = 1
|
||||
|
||||
if name is None:
|
||||
name = 'axis_%d' % VirtualAxis.name_ct
|
||||
VirtualAxis.name_ct += 1
|
||||
|
||||
self.name = name
|
||||
if isinstance(var, int):
|
||||
self.length = var
|
||||
elif isinstance(var, schedule.IterVar):
|
||||
self.name = var.var.name
|
||||
if var.dom is None:
|
||||
self.length = -1
|
||||
else:
|
||||
self.length = get_const_int(var.dom.extent)
|
||||
elif isinstance(var, VirtualAxis):
|
||||
self.length = var.length
|
||||
else:
|
||||
raise RuntimeError("Invalid type of axis")
|
||||
|
||||
@staticmethod
|
||||
def get_num_output(var, name=None):
|
||||
return 1
|
||||
|
||||
def __repr__(self):
|
||||
return "vaxis(%s)" % self.name
|
||||
|
||||
|
||||
def get_factors(n):
|
||||
"""return all factors of an integer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: int
|
||||
integer to factorize
|
||||
|
||||
Returns
|
||||
-------
|
||||
factors: list
|
||||
List of all factors
|
||||
"""
|
||||
step = 2 if n % 2 else 1
|
||||
ret = list(set(
|
||||
functools.reduce(
|
||||
list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step)
|
||||
if n % i == 0))))
|
||||
ret.sort()
|
||||
return ret
|
||||
|
||||
|
||||
class SplitSpace(TransformSpace):
|
||||
"""Split an axis for several times"""
|
||||
def __init__(self, axes, policy, **kwargs):
|
||||
super(SplitSpace, self).__init__()
|
||||
axis = axes[0]
|
||||
|
||||
self.policy = policy
|
||||
self.entities = []
|
||||
|
||||
if policy == 'all':
|
||||
num_outputs = kwargs["num_outputs"]
|
||||
max_factor = kwargs.get("max_factor", 1 << 31)
|
||||
fil = kwargs.get("filter", lambda x: True)
|
||||
|
||||
length = axis.length
|
||||
factors = get_factors(length)
|
||||
factors = [x for x in factors if x <= max_factor]
|
||||
# copy factors for every level
|
||||
self.product = length
|
||||
self.num_outputs = num_outputs
|
||||
self.factors = [factors] * (num_outputs-1)
|
||||
self._generate_space(0, [None] * (num_outputs - 1))
|
||||
self.entities = list(filter(fil, self.entities))
|
||||
self.num_output = num_outputs
|
||||
elif policy == 'candidate':
|
||||
self.product = axis.length
|
||||
self.num_outputs = kwargs["num_outputs"]
|
||||
for size in kwargs["candidate"]:
|
||||
assert len(size) == self.num_outputs
|
||||
# assert np.prod(size) == self.product
|
||||
self.entities.append(SplitEntity(size))
|
||||
self.num_output = self.num_outputs
|
||||
else:
|
||||
raise RuntimeError("Invalid policy: " + policy)
|
||||
|
||||
def _generate_space(self, now, tmp_stack):
|
||||
"""Generate space by DFS"""
|
||||
if now == self.num_outputs - 1:
|
||||
if self.product % np.prod(tmp_stack) == 0:
|
||||
first = int(self.product // int(np.prod(tmp_stack)))
|
||||
self.entities.append(SplitEntity([first] + tmp_stack[::-1]))
|
||||
else:
|
||||
for factor in self.factors[now]:
|
||||
tmp_stack[now] = factor
|
||||
self._generate_space(now + 1, tmp_stack)
|
||||
|
||||
@staticmethod
|
||||
def get_num_output(axes, policy, **kwargs):
|
||||
return kwargs["num_outputs"]
|
||||
|
||||
def __repr__(self):
|
||||
return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
|
||||
(self.policy, self.product, self.num_outputs, len(self)))
|
||||
|
||||
|
||||
class SplitEntity(object):
|
||||
"""
|
||||
A split operation with detailed parameters
|
||||
that can apply to an axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: Array of int
|
||||
the size of every axis after split
|
||||
e.g. an axis of extent 128, we split it into 3 axes, a possible
|
||||
size is [4, 4, 8] (4x4x8 = 128)
|
||||
"""
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def apply(self, sch, op, axis):
|
||||
"""Apply split to an axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
The tvm schedule
|
||||
op: tvm.tensor.Operation
|
||||
The stage to be applied
|
||||
axis: tvm.schedule.IterVar
|
||||
axis to split
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes : list of Axis
|
||||
The transformed axes.
|
||||
"""
|
||||
ret = []
|
||||
for i in range(1, len(self.size)):
|
||||
ax0, ax1 = sch[op].split(axis, int(np.prod(self.size[i:])))
|
||||
ret.append(ax0)
|
||||
axis = ax1
|
||||
return ret + [axis]
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.size)
|
||||
|
||||
|
||||
class ReorderSpace(TransformSpace):
|
||||
"""The parameter space for ordering an array of axes"""
|
||||
def __init__(self, axes, policy, **kwargs):
|
||||
super(ReorderSpace, self).__init__()
|
||||
self.ins = axes
|
||||
self.policy = policy
|
||||
self.num_output = len(axes)
|
||||
|
||||
if policy == 'identity':
|
||||
self.entities = [ReorderEntity(range(len(axes)))]
|
||||
elif policy == 'all':
|
||||
self.entities = [
|
||||
ReorderEntity(x) for x in itertools.permutations(range(len(axes)))]
|
||||
elif policy == 'interval_all':
|
||||
begin, end = kwargs['interval']
|
||||
sub_space = list(itertools.permutations(range(begin, end)))
|
||||
prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes)))
|
||||
self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space]
|
||||
elif policy == 'candidate':
|
||||
candidate = kwargs["candidate"]
|
||||
for can in candidate:
|
||||
perm = [axes.index(x) for x in can]
|
||||
self.entities.append(ReorderEntity(perm))
|
||||
elif policy == 'interleave':
|
||||
spatial, reduce = kwargs['spatial'], kwargs['reduce']
|
||||
|
||||
spatial = [[axes.index(x) for x in ch] for ch in spatial]
|
||||
reduce = [[axes.index(x) for x in ch] for ch in reduce]
|
||||
|
||||
outer_merged = self._merge_chain([x[:-1] for x in spatial])
|
||||
inner_merged = self._merge_chain([x[-1:] for x in spatial] + reduce)
|
||||
|
||||
for o in outer_merged:
|
||||
for i in inner_merged:
|
||||
self.entities.append(ReorderEntity(o + i))
|
||||
elif policy == 'interleave_cuda':
|
||||
spatial, reduce = kwargs['spatial'], kwargs['reduce']
|
||||
|
||||
spatial = [[axes.index(x) for x in ch] for ch in spatial]
|
||||
reduce = [[axes.index(x) for x in ch] for ch in reduce]
|
||||
|
||||
outer_merged = self._merge_chain([x[:-1] for x in spatial])
|
||||
reduce_merged = self._merge_chain(reduce)
|
||||
inner_merged = [x[-1] for x in spatial]
|
||||
|
||||
for o in outer_merged:
|
||||
for r in reduce_merged:
|
||||
self.entities.append(ReorderEntity(o + r + inner_merged))
|
||||
else:
|
||||
raise RuntimeError("Invalid policy: " + policy)
|
||||
|
||||
@staticmethod
|
||||
def get_num_output(axes, policy, **kwargs):
|
||||
return len(axes)
|
||||
|
||||
def __repr__(self):
|
||||
return "Reorder(policy=%s) len=%d" % (self.policy, len(self))
|
||||
|
||||
def _merge_chain(self, chains):
|
||||
"""generate all combinations of merge some chains"""
|
||||
merged = []
|
||||
tmp_pt = [0] * len(chains)
|
||||
tmp_stack = []
|
||||
|
||||
size = np.sum([len(x) for x in chains])
|
||||
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
|
||||
return merged
|
||||
|
||||
def _merge_dfs(self, chains, size, tmp_pt, tmp_stack, merged):
|
||||
if np.sum(tmp_pt) == size:
|
||||
merged.append(list(tmp_stack))
|
||||
return
|
||||
else:
|
||||
for i in range(len(chains)):
|
||||
# use i == np.argmax(....) here to take spatial order into consideration
|
||||
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
|
||||
if (tmp_pt[i] < len(chains[i]) and
|
||||
(i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
|
||||
tmp_stack.append(chains[i][tmp_pt[i]])
|
||||
tmp_pt[i] += 1
|
||||
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
|
||||
tmp_pt[i] -= 1
|
||||
tmp_stack.pop()
|
||||
|
||||
|
||||
class ReorderEntity(object):
|
||||
"""A reorder operation with detailed parameters that can apply to axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perm: Array of int
|
||||
define the permutation
|
||||
"""
|
||||
def __init__(self, perm):
|
||||
self.perm = perm
|
||||
|
||||
def apply(self, sch, op, axes):
|
||||
"""Apply reorder to an array of axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
The tvm schedule
|
||||
op: tvm.tensor.Operation
|
||||
The stage to be applied
|
||||
axis: tvm.schedule.IterVar
|
||||
axis to split
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes : list of Axis
|
||||
The transformed axes.
|
||||
"""
|
||||
if len(axes) == len(self.perm):
|
||||
new_order = [axes[i] for i in self.perm]
|
||||
else:
|
||||
new_order = [axes[i] for i in self.perm if i < len(axes)]
|
||||
sch[op].reorder(*new_order)
|
||||
return new_order
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.perm)
|
||||
|
||||
|
||||
class AnnotateSpace(TransformSpace):
|
||||
"""The parameter space for annotating an array of axes"""
|
||||
def __init__(self, axes, policy, **kwargs):
|
||||
super(AnnotateSpace, self).__init__()
|
||||
|
||||
self.ins = axes
|
||||
self.policy = policy
|
||||
self.num_output = len(axes)
|
||||
|
||||
if policy == 'bind_gpu':
|
||||
self.num_axis = len(axes)
|
||||
if self.num_axis >= 6:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 6) +
|
||||
['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
|
||||
'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
|
||||
elif self.num_axis >= 4:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 4) +
|
||||
['blockIdx.y', 'blockIdx.x',
|
||||
'threadIdx.y', 'threadIdx.x']))
|
||||
elif self.num_axis >= 2:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 2) +
|
||||
['blockIdx.x', 'threadIdx.x']))
|
||||
else:
|
||||
raise RuntimeError("Unhandled case in bind_gpu")
|
||||
elif policy == 'bind_gpu_virtual':
|
||||
self.num_axis = len(axes)
|
||||
if self.num_axis >= 9:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 9) +
|
||||
['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
|
||||
'vthread', 'vthread', 'vthread',
|
||||
'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
|
||||
elif self.num_axis >= 6:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 6) +
|
||||
['blockIdx.y', 'blockIdx.x',
|
||||
'vthread', 'vthread',
|
||||
'threadIdx.y', 'threadIdx.x']))
|
||||
elif self.num_axis >= 3:
|
||||
self.entities.append(AnnotateEntity(
|
||||
['fuse'] * (self.num_axis - 3) +
|
||||
['blockIdx.x', 'vthread', 'threadIdx.x']))
|
||||
else:
|
||||
raise RuntimeError("Unhandled case in bind_gpu")
|
||||
elif policy == 'locate_cache':
|
||||
self.num_axis = len(axes)
|
||||
num_anchor = kwargs["num_anchor"]
|
||||
self.anns = list(itertools.combinations(np.arange(self.num_axis), num_anchor))
|
||||
self.entities = [AnnotateEntity(x) for x in self.anns]
|
||||
else: # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ...
|
||||
anns = policy.replace('try', 'none').split('_')
|
||||
|
||||
for ann in anns:
|
||||
if ann not in ['none', 'unroll', 'vec']:
|
||||
raise RuntimeError("Invalid policy: " + policy)
|
||||
|
||||
self.num_axis = len(axes)
|
||||
self.anns = [anns] * self.num_axis
|
||||
self._generate_space(0, [""] * self.num_axis)
|
||||
|
||||
def _generate_space(self, now, tmp_stack):
|
||||
"""Generate space by DFS"""
|
||||
if now == self.num_axis:
|
||||
# only vectorize inner most dimension
|
||||
vec_ct = tmp_stack.count('vec')
|
||||
if vec_ct == 0 or vec_ct == 1:
|
||||
self.entities.append(AnnotateEntity(list(tmp_stack)))
|
||||
else:
|
||||
for ann in self.anns[now]:
|
||||
tmp_stack[now] = ann
|
||||
self._generate_space(now + 1, tmp_stack)
|
||||
|
||||
@staticmethod
|
||||
def get_num_output(axes, policy, **kwargs):
|
||||
return len(axes)
|
||||
|
||||
def __repr__(self):
|
||||
return "Annotate(policy=%s) len=%d" % (self.policy, len(self))
|
||||
|
||||
|
||||
class AnnotateEntity(object):
|
||||
"""An annotation operation with detailed parameters that can apply to axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
anns: Array of string
|
||||
The annotations of axes
|
||||
"""
|
||||
def __init__(self, anns):
|
||||
self.anns = anns
|
||||
|
||||
def apply(self, sch, op, axes, axis_lens=None,
|
||||
max_unroll=None, vec_size=None, cfg=None, source=None):
|
||||
"""Apply annotation to an array of axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
The tvm schedule
|
||||
op: tvm.tensor.Operation
|
||||
The stage to be applied
|
||||
axes: Array of tvm.schedule.IterVar
|
||||
axis to split
|
||||
axis_lens: Array of int, optional
|
||||
the length of axes
|
||||
max_unroll: int, optional
|
||||
maximum unroll step
|
||||
vec_size: Array of int, optional
|
||||
valid vector lanes for vectorization
|
||||
cfg: ConfigEntity, optional
|
||||
cfg for recording error
|
||||
source: Array of Array tensor, optional
|
||||
source tensor for attaching cache
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes : list of tvm.schedule.IterVar
|
||||
The transformed axes
|
||||
"""
|
||||
if source is not None: # special case : attach cache_read/cache_write
|
||||
for src, to in zip(source, self.anns):
|
||||
for t in src:
|
||||
sch[t].compute_at(sch[op], axes[to])
|
||||
else: # other cases
|
||||
for i, ann in enumerate(self.anns):
|
||||
if ann == 'none':
|
||||
pass
|
||||
elif ann == 'unroll':
|
||||
if max_unroll and axis_lens[i] > max_unroll:
|
||||
cfg.raise_error("Too large factor for unrolling")
|
||||
sch[op].unroll(axes[i])
|
||||
elif ann == 'vec':
|
||||
if vec_size and axis_lens[i] not in vec_size:
|
||||
cfg.raise_error("Wrong size of lanes in vectorization")
|
||||
sch[op].vectorize(axes[i])
|
||||
elif ann == 'blockIdx.x':
|
||||
sch[op].bind(axes[i], thread_axis('blockIdx.x'))
|
||||
elif ann == 'blockIdx.y':
|
||||
sch[op].bind(axes[i], thread_axis('blockIdx.y'))
|
||||
elif ann == 'blockIdx.z':
|
||||
sch[op].bind(axes[i], thread_axis('blockIdx.z'))
|
||||
elif ann == 'threadIdx.x':
|
||||
sch[op].bind(axes[i], thread_axis('threadIdx.x'))
|
||||
elif ann == 'threadIdx.y':
|
||||
sch[op].bind(axes[i], thread_axis('threadIdx.y'))
|
||||
elif ann == 'threadIdx.z':
|
||||
sch[op].bind(axes[i], thread_axis('threadIdx.z'))
|
||||
elif ann == 'vthread':
|
||||
sch[op].bind(axes[i], thread_axis("vthread"))
|
||||
elif ann == 'fuse':
|
||||
assert i < len(axes) - 1
|
||||
axes[i+1] = sch[op].fuse(axes[i], axes[i+1])
|
||||
else:
|
||||
raise RuntimeError("Invalid annotation " + ann)
|
||||
return axes
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.anns)
|
||||
|
||||
|
||||
class OtherOptionSpace(TransformSpace):
|
||||
"""The parameter space for general option"""
|
||||
def __init__(self, axes, policy, **kwargs):
|
||||
super(OtherOptionSpace, self).__init__()
|
||||
|
||||
candidate = kwargs["candidate"]
|
||||
self.entities = [OtherOptionEntity(x) for x in candidate]
|
||||
|
||||
@staticmethod
|
||||
def get_num_output(axes, policy, **kwargs):
|
||||
return 0
|
||||
|
||||
def __repr__(self):
|
||||
return "OtherOption(%s) len=%d" % (self.entities, len(self))
|
||||
|
||||
|
||||
class OtherOptionEntity(object):
|
||||
"""The parameter entity for general option, with a detailed value"""
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.val)
|
||||
|
||||
|
||||
class ConfigSpace(object):
|
||||
"""The configuration space of a schedule. Pass it as config in template to
|
||||
collect transformation space and build transform graph of axes
|
||||
"""
|
||||
def __init__(self):
|
||||
# private dict to provide sugar
|
||||
self.space_map = OrderedDict() # name -> space
|
||||
self._collect = True
|
||||
self._length = None
|
||||
self._entity_map = OrderedDict()
|
||||
self._constraints = []
|
||||
self.errors = []
|
||||
self.template_key = None
|
||||
self.code_hash = None
|
||||
self.flop = 0
|
||||
|
||||
@staticmethod
|
||||
def axis(var):
|
||||
"""get a virtual axis (axis placeholder)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var: int or tvm.schedule.IterVar
|
||||
If is int, return an axis whose length is the provided argument.
|
||||
If is IterVar, return an axis whose length is extracted from the
|
||||
IterVar's extent domain.
|
||||
"""
|
||||
return VirtualAxis(var)
|
||||
|
||||
reduce_axis = axis
|
||||
|
||||
def define_split(self, name, axis, policy='all', **kwargs):
|
||||
"""Define a new tunable knob which splits an axis into a list of axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name to index the entity of this space
|
||||
axis: tvm.schedule.IterVar
|
||||
axis to split
|
||||
policy: str
|
||||
name of policy.
|
||||
If is 'all', the tuner will try all divisible factors.
|
||||
If is 'candidate', try listed candidate.
|
||||
kwargs: dict
|
||||
extra arguments for policy
|
||||
"""
|
||||
axes = [axis]
|
||||
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
|
||||
|
||||
def define_reorder(self, name, axes, policy, **kwargs):
|
||||
"""Define a new tunable knob which reorders a list of axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name to index the entity of this space
|
||||
axes: Array of tvm.schedule.IterVar
|
||||
axes to reorder
|
||||
policy: str
|
||||
name of policy
|
||||
If is 'identity', do an identity permutation.
|
||||
If is 'all', try all permutations.
|
||||
If is 'interval_all', try all permutations of an interval of axes.
|
||||
If is 'candidate', try listed candidate.
|
||||
If is 'interleave', interleave chains of spatial axes and chains of reduction axes.
|
||||
kwargs: dict
|
||||
extra arguments for policy
|
||||
"""
|
||||
return self._add_new_transform(ReorderSpace, name, axes, policy, **kwargs)
|
||||
|
||||
def define_annotate(self, name, axes, policy, **kwargs):
|
||||
"""Define a new tunable knob which annotates a list of axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name to index the entity of this space
|
||||
axes: Array of tvm.schedule.IterVar
|
||||
axes to annotate
|
||||
policy: str
|
||||
name of policy
|
||||
If is 'unroll', unroll the axes.
|
||||
If is 'try_unroll', try to unroll the axes.
|
||||
If is 'try_unroll_vec', try to unroll or vectorize the axes.
|
||||
If is 'bind_gpu', bind the first few axes to gpu threads.
|
||||
If is 'locate_cache', choose n axes to attach shared/local cache.
|
||||
kwargs: dict
|
||||
extra arguments for policy
|
||||
"""
|
||||
return self._add_new_transform(AnnotateSpace, name, axes, policy, **kwargs)
|
||||
|
||||
def define_knob(self, name, candidate):
|
||||
"""Define a tunable knob with a list of candidates
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name key of that option
|
||||
candidate: list
|
||||
list of candidates
|
||||
"""
|
||||
return self._add_new_transform(OtherOptionSpace, name, [], None, candidate=candidate)
|
||||
|
||||
def add_flop(self, flop):
|
||||
"""Add float operation statistics for this tuning task
|
||||
|
||||
Parameters
|
||||
---------
|
||||
flop: int or float
|
||||
number of float operations
|
||||
"""
|
||||
self.flop += flop
|
||||
|
||||
def raise_error(self, msg):
|
||||
"""register error in config
|
||||
Using this to actively detect error when scheudling.
|
||||
Otherwise these error will occur during runtime, which
|
||||
will cost more time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg: str
|
||||
"""
|
||||
self.errors.append(msg)
|
||||
|
||||
def valid(self):
|
||||
"""Check whether the config meets all the constraints
|
||||
Note: This check should be called after instantiation of task,
|
||||
because the ConfigEntity/ConfigSpace collects errors during instantiation
|
||||
|
||||
Returns
|
||||
-------
|
||||
valid: bool
|
||||
whether the config meets all the constraints
|
||||
"""
|
||||
return not bool(self.errors)
|
||||
|
||||
def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
|
||||
"""Add a new transform space in template"""
|
||||
if self._collect:
|
||||
# convert schedule axis to space definition axis
|
||||
axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes]
|
||||
|
||||
# add subspace (knob)
|
||||
space = space_class(axes, policy, **kwargs)
|
||||
self.space_map[name] = space
|
||||
self._entity_map[name] = space[0]
|
||||
return [Axis(space, i) for i in range(space.num_output)]
|
||||
return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]
|
||||
|
||||
def __len__(self):
|
||||
if self._length is None:
|
||||
self._length = int(np.prod([len(x) for x in self.space_map.values()]))
|
||||
return self._length
|
||||
|
||||
def get(self, index):
|
||||
"""Get a config entity with detailed parameters from this space
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index: int
|
||||
index in the space
|
||||
"""
|
||||
entities = OrderedDict()
|
||||
t = index
|
||||
for name, space in self.space_map.items():
|
||||
entities[name] = space[t % len(space)]
|
||||
t //= len(space)
|
||||
ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints)
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
return self._entity_map.__iter__()
|
||||
|
||||
def __getitem__(self, name):
|
||||
"""get the transform entity(knob) of this entity by name
|
||||
do not use this to get a ConfigEntity of this space (should use ConfigSpace.get instead)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
name of the transform
|
||||
"""
|
||||
return self._entity_map[name]
|
||||
|
||||
def __repr__(self):
|
||||
res = "ConfigSpace (len=%d, space_map=\n" % len(self)
|
||||
for i, (name, space) in enumerate(self.space_map.items()):
|
||||
res += " %2d %s: %s\n" % (i, name, space)
|
||||
return res + ")"
|
||||
|
||||
|
||||
_ann_to_number = {
|
||||
'none': 0, 'vec': 1, 'unroll': 2,
|
||||
'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5,
|
||||
'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8,
|
||||
'vthread': 9, 'fuse': 10
|
||||
}
|
||||
|
||||
class ConfigEntity(ConfigSpace):
|
||||
"""A configuration with detailed parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index: int
|
||||
index of this config in space
|
||||
code_hash: str
|
||||
hash of schedule code
|
||||
template_key : str
|
||||
The specific template key
|
||||
entity_map: dict
|
||||
map name to transform entity
|
||||
constraints : list
|
||||
List of constraints
|
||||
"""
|
||||
def __init__(self, index, code_hash, template_key, entity_map, constraints):
|
||||
super(ConfigEntity, self).__init__()
|
||||
self.index = index
|
||||
self.template_key = template_key
|
||||
self._collect = False
|
||||
self._entity_map = entity_map
|
||||
self._space_map = None
|
||||
self._constraints = constraints
|
||||
self.code_hash = code_hash
|
||||
|
||||
def get_flatten_feature(self):
|
||||
""" flatten entities to a numerical one-dimensional feature vector
|
||||
|
||||
Returns
|
||||
-------
|
||||
fea: np.array
|
||||
one dimensional float32 array
|
||||
"""
|
||||
fea = []
|
||||
for _, v in self._entity_map.items():
|
||||
if isinstance(v, SplitEntity):
|
||||
fea.extend(v.size)
|
||||
elif isinstance(v, ReorderEntity):
|
||||
# use a naive way: directly copy the permutation
|
||||
fea.extend(v.perm)
|
||||
elif isinstance(v, AnnotateEntity):
|
||||
# one-hot encoding
|
||||
for ann in v.anns:
|
||||
tmp = [0] * len(_ann_to_number)
|
||||
tmp[_ann_to_number[ann]] = 1
|
||||
fea.extend(tmp)
|
||||
elif isinstance(v, OtherOptionEntity):
|
||||
fea.append(v.val)
|
||||
return np.array(fea, dtype=np.float32)
|
||||
|
||||
def get_other_option(self):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
other_option: dict
|
||||
other tunable parameters (tunable parameters defined by `cfg.define_knob`)
|
||||
"""
|
||||
return {x: x.val for x in self._entity_map.values() if isinstance(x, OtherOptionEntity)}
|
||||
|
||||
def to_json_dict(self):
|
||||
"""convert to a json serializable dictionary
|
||||
|
||||
Return
|
||||
------
|
||||
json_dict: dict
|
||||
a json serializable dictionary
|
||||
"""
|
||||
ret = {}
|
||||
ret['i'] = int(self.index)
|
||||
ret['t'] = self.template_key
|
||||
ret['c'] = self.code_hash
|
||||
entity_map = []
|
||||
for k, v in self._entity_map.items():
|
||||
if isinstance(v, SplitEntity):
|
||||
entity_map.append((k, 'sp', v.size))
|
||||
elif isinstance(v, ReorderEntity):
|
||||
entity_map.append((k, 're', v.perm))
|
||||
elif isinstance(v, AnnotateEntity):
|
||||
entity_map.append((k, 'an', v.anns))
|
||||
elif isinstance(v, OtherOptionEntity):
|
||||
entity_map.append((k, 'ot', v.val))
|
||||
else:
|
||||
raise RuntimeError("Invalid entity instance: " + v)
|
||||
ret['e'] = entity_map
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_json_dict(json_dict):
|
||||
"""Build a ConfigEntity from json serializable dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
json_dict: dict
|
||||
Json serializable dictionary. This should be the return value
|
||||
of :any:`to_json_dict`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
config: ConfigEntity
|
||||
The corresponding config object
|
||||
|
||||
"""
|
||||
index = json_dict["i"]
|
||||
code_hash = json_dict["c"]
|
||||
template_key = json_dict["t"]
|
||||
constraints = []
|
||||
entity_map = OrderedDict()
|
||||
|
||||
for item in json_dict["e"]:
|
||||
key, knob_type, knob_args = item
|
||||
if knob_type == 'sp':
|
||||
entity = SplitEntity(knob_args)
|
||||
elif knob_type == 're':
|
||||
entity = ReorderEntity(knob_args)
|
||||
elif knob_type == 'an':
|
||||
entity = AnnotateEntity(knob_args)
|
||||
elif knob_type == 'ot':
|
||||
entity = OtherOptionEntity(knob_args)
|
||||
else:
|
||||
raise RuntimeError("Invalid config knob type: " + knob_type)
|
||||
entity_map[str(key)] = entity
|
||||
|
||||
return ConfigEntity(index, code_hash, template_key, entity_map, constraints)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
|
||||
self.code_hash, self.index)
|
|
@ -0,0 +1,359 @@
|
|||
# pylint: disable=unused-variable
|
||||
"""Definition of task function.
|
||||
|
||||
Task can be constructed from tuple of func, args, and kwargs.
|
||||
func is a state-less function, or a string that
|
||||
registers the standard task.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import tensor, expr, container, target as _target
|
||||
|
||||
from ..util import get_const_int, get_const_tuple, get_func_name
|
||||
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
|
||||
from .space import ConfigSpace
|
||||
|
||||
def _raise_error(*args, **kwargs): # pylint: disable=unused-argument
|
||||
raise RuntimeError("The function of this task is not found. Possibly the function "
|
||||
"of this task is registered in another python file "
|
||||
"which is not imported in this run")
|
||||
|
||||
class Task(object):
|
||||
"""A Tunable Task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the task.
|
||||
args: Tuple
|
||||
Positional argument of func
|
||||
"""
|
||||
def __init__(self, name, args):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.kwargs = {} # currently unused
|
||||
|
||||
# init null config space
|
||||
self.config_space = None
|
||||
self.func = TASK_TABLE.get(name, _raise_error)
|
||||
|
||||
# auxiliary info, available after `init_space` is called
|
||||
self.workload = None
|
||||
self.flop = None
|
||||
self.target = None
|
||||
self.target_host = None
|
||||
|
||||
def instantiate(self, config):
|
||||
"""Instantiate this task function (template) with a config.
|
||||
Returns corresponding schedule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config: template.ConfigEntity
|
||||
parameter config for this template
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: tvm.schedule.Schedule
|
||||
The tvm schedule
|
||||
arg_bufs: Array of tvm.tensor.Tensor
|
||||
The input/output buffers
|
||||
"""
|
||||
config.flop = 0
|
||||
with ApplyConfig(config):
|
||||
sch, arg_bufs = self.func(*self.args, **self.kwargs)
|
||||
if not self.flop:
|
||||
config.flop = config.flop or compute_flop(sch)
|
||||
self.flop = config.flop
|
||||
return sch, arg_bufs
|
||||
|
||||
def __repr__(self):
|
||||
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
|
||||
self.name, self.args, self.kwargs, self.workload
|
||||
)
|
||||
|
||||
TASK_TABLE = {
|
||||
}
|
||||
|
||||
def register(name, func=None, override=False):
|
||||
"""Register a task function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name to identify the task.
|
||||
func : callable
|
||||
The function to be registered.
|
||||
override : bool
|
||||
Whether override existing registration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
func: callable
|
||||
The registered function
|
||||
"""
|
||||
def _do_reg(myf):
|
||||
if name in TASK_TABLE and not override:
|
||||
raise ValueError(
|
||||
"Key %s is already registered" % name)
|
||||
TASK_TABLE[name] = myf
|
||||
return myf
|
||||
if func:
|
||||
return _do_reg(func)
|
||||
return _do_reg
|
||||
|
||||
def create(func_name, args, target, target_host=None, template_key=None):
|
||||
"""Create a tuning task and initialize its search space
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func_name : str or callable
|
||||
The task function
|
||||
args : List
|
||||
Positional arguments
|
||||
target : Target
|
||||
The compilation target
|
||||
target_host: Target, optional
|
||||
The compilation target for host side
|
||||
|
||||
Returns
|
||||
-------
|
||||
tsk: Task
|
||||
a task object
|
||||
"""
|
||||
if callable(func_name):
|
||||
# register this function if it is not registered before
|
||||
func = func_name
|
||||
func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
|
||||
if func_name in TASK_TABLE:
|
||||
assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
|
||||
"Consider to choose another name for this task"
|
||||
else:
|
||||
register(func_name, func=func)
|
||||
|
||||
func = TASK_TABLE[func_name]
|
||||
ret = Task(func_name, args)
|
||||
|
||||
if isinstance(target, str):
|
||||
target = _target.create(target)
|
||||
|
||||
# init config space
|
||||
ret.config_space = ConfigSpace()
|
||||
ret.config_space.template_key = template_key or ""
|
||||
|
||||
ctx = ApplyConfig(ret.config_space)
|
||||
with ctx:
|
||||
with target:
|
||||
sch, _ = func(*args)
|
||||
ret.config_space.code_hash = getattr(sch, 'code_hash', None)
|
||||
|
||||
ret.workload = ctx.workload
|
||||
ret.flop = ret.config_space.flop or compute_flop(sch)
|
||||
ret.target = target
|
||||
ret.target_host = target_host
|
||||
|
||||
return ret
|
||||
|
||||
def args_to_workload(x):
|
||||
"""Convert argument list to hashable workload tuple.
|
||||
This function will convert list to tuple, tvm node to python value and
|
||||
flatten tvm.tensor.Tensor to a tuple
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: primitive hashable types or tensor.Tensor
|
||||
The original value
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret: hashable
|
||||
The hashable value
|
||||
"""
|
||||
if isinstance(x, tensor.Tensor):
|
||||
return get_const_tuple(x.shape) + (x.dtype, )
|
||||
elif isinstance(x, (tuple, list, container.Array)):
|
||||
return tuple([args_to_workload(a) for a in x])
|
||||
elif isinstance(x, (str, int, float, np.int, np.float)):
|
||||
return x
|
||||
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
|
||||
return x.value
|
||||
elif x is None:
|
||||
return None
|
||||
else:
|
||||
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
|
||||
'primitive types only' % type(x))
|
||||
|
||||
def template(func):
|
||||
"""
|
||||
Decorate a function as a tunable schedule template
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: callable
|
||||
A callable template function.
|
||||
Its argument should be hashable values.
|
||||
Its return value should be a Tuple(Schedule, Array of Tensor)
|
||||
|
||||
Returns
|
||||
-------
|
||||
func: callable
|
||||
The decorated function
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following code is a tunable template for a blocked matrix multiplication
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@autotvm.template
|
||||
def matmul(N, L, M, dtype):
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
k = s[C].op.reduce_axis[0]
|
||||
|
||||
##### define space begin #####
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_y", y, num_outputs=2)
|
||||
cfg.define_split("tile_x", x, num_outputs=2)
|
||||
##### define space end #####
|
||||
|
||||
# schedule according to config
|
||||
yo, yi = cfg["tile_y"].apply(s, C, y)
|
||||
xo, xi = cfg["tile_x"].apply(s, C, x)
|
||||
|
||||
s[C].reorder(yo, xo, k, yi, xi)
|
||||
|
||||
return s, [A, B, C]
|
||||
"""
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
fname = get_func_name(func)
|
||||
|
||||
@register(fname)
|
||||
@dispatcher
|
||||
def config_dispatcher(*args, **kwargs):
|
||||
assert not kwargs, "Do not support kwargs in template function call"
|
||||
return (fname, ) + args_to_workload(args)
|
||||
|
||||
@config_dispatcher.register("")
|
||||
def template_call(cfg, *args, **kwargs):
|
||||
assert not kwargs, "Do not support kwargs in template function call"
|
||||
with ApplyConfig(cfg):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
config_dispatcher.func_name = fname
|
||||
return config_dispatcher
|
||||
|
||||
def get_config():
|
||||
"""Get current config object
|
||||
|
||||
Returns
|
||||
-------
|
||||
cfg: ConfigSpace or ConfigEntity
|
||||
The current config
|
||||
"""
|
||||
return DispatchContext.current.query(None, None)
|
||||
|
||||
class FlopCalculationError(RuntimeError):
|
||||
"""Error happens when estimating FLOP for a compute op"""
|
||||
pass
|
||||
|
||||
def compute_flop(sch):
|
||||
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sch: tvm.schedule.Schedule
|
||||
schedule
|
||||
|
||||
Returns
|
||||
-------
|
||||
flop: int
|
||||
number of FLOP in this schedule
|
||||
"""
|
||||
def _prod_length(axes):
|
||||
"""compute product of the lengths of a list of axes"""
|
||||
try:
|
||||
num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
|
||||
except ValueError:
|
||||
raise FlopCalculationError("The length of axis is not constant. ")
|
||||
return num_iter
|
||||
|
||||
def _count_flop(exp):
|
||||
"""compute flop for a single expression"""
|
||||
if isinstance(exp, expr.Reduce):
|
||||
num_iter = _prod_length(exp.axis)
|
||||
combiner = exp.combiner.result
|
||||
source = exp.source
|
||||
if len(combiner) != 1:
|
||||
raise FlopCalculationError("Found multiple output in the combiner of reduce op")
|
||||
if len(source) != 1:
|
||||
raise FlopCalculationError("Found multiple output in the source of reduce op")
|
||||
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
|
||||
elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
|
||||
return 0
|
||||
elif isinstance(exp, expr.Cast):
|
||||
return _count_flop(exp.value)
|
||||
elif isinstance(exp, expr.Var):
|
||||
return 0
|
||||
elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
|
||||
expr.Max, expr.Min,
|
||||
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
|
||||
expr.And, expr.Or, expr.Not)):
|
||||
base = 1 if "float" in exp.a.dtype else 0
|
||||
|
||||
if isinstance(exp, expr.Not): # unary
|
||||
return base + _count_flop(exp.a)
|
||||
|
||||
return base + _count_flop(exp.a) + _count_flop(exp.b)
|
||||
elif isinstance(exp, expr.Select):
|
||||
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
|
||||
_count_flop(exp.false_value))
|
||||
elif isinstance(exp, expr.Call):
|
||||
return sum([_count_flop(x) for x in exp.args])
|
||||
else:
|
||||
raise FlopCalculationError("Found unsupported operator in the compute expr")
|
||||
|
||||
def traverse(ops):
|
||||
"""accumulate flops"""
|
||||
ret = 0
|
||||
for op in ops:
|
||||
if isinstance(op, tensor.ComputeOp):
|
||||
num_element = _prod_length(op.axis)
|
||||
|
||||
body = op.body
|
||||
if len(body) != 1:
|
||||
raise FlopCalculationError("Found multiple output in the compute")
|
||||
exp = body[0]
|
||||
|
||||
ret += num_element * _count_flop(exp)
|
||||
ret += traverse([sch[t].op for t in op.input_tensors])
|
||||
|
||||
elif isinstance(op, tensor.PlaceholderOp):
|
||||
pass
|
||||
else:
|
||||
raise FlopCalculationError("Only support tvm.compute currently. "
|
||||
"Other ops like tvm.scan is not supported")
|
||||
return ret
|
||||
|
||||
try:
|
||||
ret = traverse(sch.outputs)
|
||||
except FlopCalculationError as exc:
|
||||
raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
|
||||
+ str(exc) + ". Please use `cfg.add_flop` to manually set "
|
||||
"FLOP for this operator")
|
||||
|
||||
if ret == 0:
|
||||
raise RuntimeError("Cannot find float number operation in this operator. "
|
||||
"Please use `cfg.add_flop` to manually set "
|
||||
"FLOP for this operator")
|
||||
|
||||
return ret
|
|
@ -0,0 +1,14 @@
|
|||
"""
|
||||
A tuner takes a task as input. It proposes some promising :any:`ConfigEntity`
|
||||
in the :any:`ConfigSpace` and measure them on the real hardware. Then it
|
||||
proposed the next batch of :any:`ConfigEntity` according to the measure results.
|
||||
This tuning loop is repeated.
|
||||
"""
|
||||
|
||||
from . import callback
|
||||
|
||||
from .tuner import Tuner
|
||||
|
||||
from .gridsearch_tuner import GridSearchTuner, RandomTuner
|
||||
from .ga_tuner import GATuner
|
||||
from .xgboost_tuner import XGBTuner
|
|
@ -0,0 +1,112 @@
|
|||
# pylint: disable=consider-using-enumerate,invalid-name
|
||||
"""Namespace of callback utilities of AutoTVM"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import record
|
||||
|
||||
def log_to_file(file_out, protocol='json'):
|
||||
"""Log the tuning records into file.
|
||||
The rows of the log are stored in the format of autotvm.record.encode.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_out : File or str
|
||||
The file to log to.
|
||||
protocol: str, optional
|
||||
The log protocol. Can be 'json' or 'pickle'
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : callable
|
||||
Callback function to do the logging.
|
||||
"""
|
||||
|
||||
def _callback(_, inputs, results):
|
||||
"""Callback implementation"""
|
||||
if isinstance(file_out, str):
|
||||
with open(file_out, "a") as f:
|
||||
for inp, result in zip(inputs, results):
|
||||
f.write(record.encode(inp, result, protocol) + "\n")
|
||||
else:
|
||||
for inp, result in zip(inputs, results):
|
||||
file_out.write(record.encode(inp, result, protocol) + "\n")
|
||||
return _callback
|
||||
|
||||
|
||||
def save_tuner_state(prefix, save_every_sample=100):
|
||||
"""Save the state of tuner
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prefix : srt
|
||||
prefix of the filename to store state
|
||||
save_every_sample: int
|
||||
save the state every x samples
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : function
|
||||
Callback function to do the auto saving.
|
||||
"""
|
||||
def _callback(tuner, inputs, results):
|
||||
for _, __ in zip(inputs, results):
|
||||
try:
|
||||
ct = len(tuner.visited)
|
||||
except AttributeError:
|
||||
ct = 0
|
||||
if ct % save_every_sample == 0:
|
||||
tuner.save_state(prefix + "_%d.state" % ct)
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
def log_to_redis(host="localhost", port=6379, dbn=11):
|
||||
"""Record the tuning record to a redis DB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host: str, optional
|
||||
Host address of redis db
|
||||
port: int, optional
|
||||
Port of redis db
|
||||
dbn: int, optional
|
||||
which redis db to use, default 11
|
||||
"""
|
||||
# import here so only depend on redis when necessary
|
||||
import redis
|
||||
red = redis.StrictRedis(host=host, port=port, db=dbn)
|
||||
|
||||
def _callback(_, inputs, results):
|
||||
"""Callback implementation"""
|
||||
for inp, result in zip(inputs, results):
|
||||
red.set(inp, result)
|
||||
return _callback
|
||||
|
||||
class Monitor(object):
|
||||
"""A monitor to collect statistic during tuning"""
|
||||
def __init__(self):
|
||||
self.scores = []
|
||||
self.timestamps = []
|
||||
|
||||
def __call__(self, tuner, inputs, results):
|
||||
for inp, res in zip(inputs, results):
|
||||
if res.error_no == 0:
|
||||
flops = inp.task.flop / np.mean(res.costs)
|
||||
self.scores.append(flops)
|
||||
else:
|
||||
self.scores.append(0)
|
||||
|
||||
self.timestamps.append(res.timestamp)
|
||||
|
||||
def reset(self):
|
||||
self.scores = []
|
||||
self.timestamps = []
|
||||
|
||||
def trial_scores(self):
|
||||
"""get scores (currently is flops) of all trials"""
|
||||
return np.array(self.scores)
|
||||
|
||||
def trial_timestamps(self):
|
||||
"""get wall clock time stamp of all trials"""
|
||||
return np.array(self.timestamps)
|
|
@ -0,0 +1,119 @@
|
|||
# pylint: disable=consider-using-enumerate,invalid-name,abstract-method
|
||||
|
||||
"""Tuner with genetic algorithm"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .tuner import Tuner
|
||||
from .model_based_tuner import knob2point, point2knob
|
||||
|
||||
|
||||
class GATuner(Tuner):
|
||||
"""Tuner with genetic algorithm.
|
||||
This tuner does not have a cost model so it always run measurement on real machines.
|
||||
This tuner expands the :code:`ConfigEntity` as gene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pop_size: int
|
||||
number of genes in one generation
|
||||
elite_num: int
|
||||
number of elite to keep
|
||||
mutation_prob: float
|
||||
probability of mutation of a knob in a gene
|
||||
"""
|
||||
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
|
||||
super(GATuner, self).__init__(task)
|
||||
|
||||
# algorithm configurations
|
||||
self.pop_size = pop_size
|
||||
self.elite_num = elite_num
|
||||
self.mutation_prob = mutation_prob
|
||||
|
||||
assert elite_num <= pop_size, "The number of elites must be less than population size"
|
||||
|
||||
# space info
|
||||
self.space = task.config_space
|
||||
self.dims = [len(x) for x in self.space.space_map.values()]
|
||||
|
||||
self.visited = set([])
|
||||
|
||||
# current generation
|
||||
self.genes = []
|
||||
self.scores = []
|
||||
self.elites = []
|
||||
self.elite_scores = []
|
||||
self.trial_pt = 0
|
||||
|
||||
# random initialization
|
||||
self.pop_size = min(self.pop_size, len(self.space))
|
||||
for _ in range(self.pop_size):
|
||||
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
|
||||
while knob2point(tmp_gene, self.dims) in self.visited:
|
||||
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
|
||||
|
||||
self.genes.append(tmp_gene)
|
||||
self.visited.add(knob2point(tmp_gene, self.dims))
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
ret = []
|
||||
for _ in range(batch_size):
|
||||
gene = self.genes[self.trial_pt % self.pop_size]
|
||||
self.trial_pt += 1
|
||||
ret.append(self.space.get(knob2point(gene, self.dims)))
|
||||
|
||||
return ret
|
||||
|
||||
def update(self, inputs, results):
|
||||
for inp, res in zip(inputs, results):
|
||||
if res.error_no == 0:
|
||||
y = inp.task.flop / np.mean(res.costs)
|
||||
self.scores.append(y)
|
||||
else:
|
||||
self.scores.append(0)
|
||||
|
||||
if len(self.scores) >= len(self.genes):
|
||||
genes = self.genes + self.elites
|
||||
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
|
||||
|
||||
# reserve elite
|
||||
self.elites, self.elite_scores = [], []
|
||||
elite_indexes = np.argpartition(scores, -self.elite_num)[-self.elite_num:]
|
||||
for ind in elite_indexes:
|
||||
self.elites.append(genes[ind])
|
||||
self.elite_scores.append(scores[ind])
|
||||
|
||||
# cross over
|
||||
indices = np.arange(len(genes))
|
||||
scores /= np.max(scores)
|
||||
probs = scores / np.sum(scores)
|
||||
tmp_genes = []
|
||||
for _ in range(self.pop_size):
|
||||
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
|
||||
p1, p2 = genes[p1], genes[p2]
|
||||
point = np.random.randint(len(self.dims))
|
||||
tmp_gene = p1[:point] + p2[point:]
|
||||
tmp_genes.append(tmp_gene)
|
||||
|
||||
# mutation
|
||||
next_genes = []
|
||||
for tmp_gene in tmp_genes:
|
||||
for j, dim in enumerate(self.dims):
|
||||
if np.random.random() < self.mutation_prob:
|
||||
tmp_gene[j] = np.random.randint(dim)
|
||||
|
||||
if len(self.visited) < len(self.space):
|
||||
while knob2point(tmp_gene, self.dims) in self.visited:
|
||||
j = np.random.randint(len(self.dims))
|
||||
tmp_gene[j] = np.random.randint(self.dims[j])
|
||||
next_genes.append(tmp_gene)
|
||||
self.visited.add(knob2point(tmp_gene, self.dims))
|
||||
else:
|
||||
break
|
||||
|
||||
self.genes = next_genes
|
||||
self.trial_pt = 0
|
||||
self.scores = []
|
||||
|
||||
def has_next(self):
|
||||
return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
|
|
@ -0,0 +1,63 @@
|
|||
# pylint: disable=abstract-method
|
||||
"""Grid search tuner and random tuner"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .tuner import Tuner
|
||||
|
||||
|
||||
class GridSearchTuner(Tuner):
|
||||
"""Enumerate the search space in a grid search order"""
|
||||
def __init__(self, task):
|
||||
super(GridSearchTuner, self).__init__(task)
|
||||
self.counter = 0
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
ret = []
|
||||
for _ in range(batch_size):
|
||||
if self.counter >= len(self.task.config_space):
|
||||
continue
|
||||
index = self.counter
|
||||
ret.append(self.task.config_space.get(index))
|
||||
self.counter = self.counter + 1
|
||||
return ret
|
||||
|
||||
def has_next(self):
|
||||
return self.counter < len(self.task.config_space)
|
||||
|
||||
def __getstate__(self):
|
||||
return {"counter": self.counter}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.counter = state['counter']
|
||||
|
||||
|
||||
class RandomTuner(Tuner):
|
||||
"""Enumerate the search space in a random order"""
|
||||
def __init__(self, task):
|
||||
super(RandomTuner, self).__init__(task)
|
||||
self.visited = set()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
ret = []
|
||||
counter = 0
|
||||
while counter < batch_size:
|
||||
if len(self.visited) >= len(self.task.config_space):
|
||||
break
|
||||
index = np.random.randint(len(self.task.config_space))
|
||||
while index in self.visited:
|
||||
index = np.random.randint(len(self.task.config_space))
|
||||
|
||||
ret.append(self.task.config_space.get(index))
|
||||
self.visited.add(index)
|
||||
counter += 1
|
||||
return ret
|
||||
|
||||
def has_next(self):
|
||||
return len(self.visited) < len(self.task.config_space)
|
||||
|
||||
def __getstate__(self):
|
||||
return {"visited": self.counter}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.counter = state['visited']
|
|
@ -0,0 +1,106 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Metrics for evaluating tuning process"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import get_rank
|
||||
|
||||
def max_curve(trial_scores):
|
||||
""" f(n) = max([s[i] fo i < n])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_scores: Array of float
|
||||
the score of i th trial
|
||||
|
||||
Returns
|
||||
-------
|
||||
curve: Array of float
|
||||
function values
|
||||
"""
|
||||
ret = np.empty(len(trial_scores))
|
||||
keep = -1e9
|
||||
for i, score in enumerate(trial_scores):
|
||||
keep = max(keep, score)
|
||||
ret[i] = keep
|
||||
return ret
|
||||
|
||||
def mean_curve(trial_scores):
|
||||
""" f(n) = mean([s[i] fo i < n])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_scores: Array of float
|
||||
the score of i th trial
|
||||
|
||||
Returns
|
||||
-------
|
||||
curve: Array of float
|
||||
function values
|
||||
"""
|
||||
ret = np.empty(len(trial_scores))
|
||||
keep = 0
|
||||
for i, score in enumerate(trial_scores):
|
||||
keep += score
|
||||
ret[i] = keep / (i+1)
|
||||
return ret
|
||||
|
||||
def recall_curve(trial_ranks, top=None):
|
||||
"""
|
||||
if top is None, f(n) = sum([I(rank[i] < n) for i < n]) / n
|
||||
if top is K, f(n) = sum([I(rank[i] < K) for i < n]) / K
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_ranks: Array of int
|
||||
the rank of i th trial in labels
|
||||
top: int or None
|
||||
top-n recall
|
||||
|
||||
Returns
|
||||
-------
|
||||
curve: Array of float
|
||||
function values
|
||||
"""
|
||||
if not isinstance(trial_ranks, np.ndarray):
|
||||
trial_ranks = np.array(trial_ranks)
|
||||
|
||||
ret = np.zeros(len(trial_ranks))
|
||||
if top is None:
|
||||
for i in range(len(trial_ranks)):
|
||||
ret[i] = np.sum(trial_ranks[:i] <= i) / (i+1)
|
||||
else:
|
||||
for i in range(len(trial_ranks)):
|
||||
ret[i] = 1.0 * np.sum(trial_ranks[:i] < top) / top
|
||||
return ret
|
||||
|
||||
def cover_curve(trial_ranks):
|
||||
"""
|
||||
f(n) = max k s.t. {1,2,...,k} is a subset of {ranks[i] for i < n}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_ranks: Array of int
|
||||
the rank of i th trial in labels
|
||||
|
||||
Returns
|
||||
-------
|
||||
curve: Array of float
|
||||
function values
|
||||
"""
|
||||
ret = np.empty(len(trial_ranks))
|
||||
keep = -1
|
||||
cover = set()
|
||||
for i, rank in enumerate(trial_ranks):
|
||||
cover.add(rank)
|
||||
while keep+1 in cover:
|
||||
keep += 1
|
||||
ret[i] = keep + 1
|
||||
return ret / len(trial_ranks)
|
||||
|
||||
def average_recall(preds, labels, N):
|
||||
"""evaluate average recall-n for predictions and labels"""
|
||||
trials = np.argsort(preds)[::-1]
|
||||
ranks = get_rank(labels[trials])
|
||||
curve = recall_curve(ranks)
|
||||
return np.sum(curve[:N]) / N
|
|
@ -0,0 +1,343 @@
|
|||
# pylint: disable=no-else-return,invalid-name,consider-using-enumerate,abstract-method
|
||||
"""Base class for model-based tuner
|
||||
This type of tuner will fit a cost model and use some optimization methods to
|
||||
find optimums points of cost model in space.
|
||||
"""
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .tuner import Tuner
|
||||
|
||||
|
||||
class FeatureCache(object):
|
||||
"""Feature cache manager for cache sharing between different cost models"""
|
||||
def __init__(self):
|
||||
self.feature_cache = {}
|
||||
|
||||
def get(self, key):
|
||||
""" Get feature cache dictionary for a key
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The key of a feature type
|
||||
|
||||
Returns
|
||||
-------
|
||||
fea_cache: dict
|
||||
cache dictionary
|
||||
"""
|
||||
if key not in self.feature_cache:
|
||||
self.feature_cache[key] = {}
|
||||
|
||||
return self.feature_cache[key]
|
||||
|
||||
def size(self, key):
|
||||
"""" Get the size of a feature cache dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The key of a feature type
|
||||
|
||||
Returns
|
||||
-------
|
||||
n: int
|
||||
"""
|
||||
return len(self.feature_cache.get(key, tuple()))
|
||||
|
||||
def clear(self, key):
|
||||
"""Clear feature cache for a key
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The key of a feature type
|
||||
"""
|
||||
del self.feature_cache[key]
|
||||
self.feature_cache[key] = {}
|
||||
gc.collect()
|
||||
|
||||
|
||||
class CostModel(object):
|
||||
"""Cost model to predict the speed of a config"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def fit(self, xs, ys, plan_size):
|
||||
"""Fit to training data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xs: Array of int
|
||||
indexes of configs in the config space
|
||||
ys: Array of float
|
||||
The speed (flop, float number operations per second)
|
||||
plan_size: int
|
||||
The plan size of tuner
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def fit_log(self, records, plan_size):
|
||||
"""Fit training data from log.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
records: Array of Tuple(MeasureInput, MeasureResult)
|
||||
The tuning records
|
||||
plan_size: int
|
||||
The plan size of tuner
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self, xs, output_margin=False):
|
||||
"""Predict the speed of configs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xs: Array of int
|
||||
The indexes of configs to predict
|
||||
output_margin: bool, optional
|
||||
Whether output the untransformed margin.
|
||||
When a model is used as base model, it should output untransformed margin
|
||||
|
||||
Returns
|
||||
-------
|
||||
preds: Array of float
|
||||
The prediction
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_basemodel(self, base_model):
|
||||
"""Load base model for transfer learning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_model: CostModel
|
||||
base model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def clone_new(self):
|
||||
"""Clone a new model with the same parameters.
|
||||
This function will only copy hyperparameters of the tuner, not all the trained model
|
||||
|
||||
This is used for deriving a base model conveniently
|
||||
|
||||
Returns
|
||||
-------
|
||||
model: CostModel
|
||||
A model with the same hyperparameter (argument)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ModelOptimizer(object):
|
||||
"""Optimizer used to find optimal points of cost model"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def find_maximums(self, model, num, exclusive):
|
||||
"""Find maximum of a cost model
|
||||
|
||||
Note we use cost model to predict GFLOPS, so we should find the maximum
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: CostModel
|
||||
Cost model
|
||||
num: int
|
||||
The number of returned maximum points
|
||||
exclusive: set, optional
|
||||
The excluded set of this optimizer. Return results won't include any
|
||||
elements in this set.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ModelBasedTuner(Tuner):
|
||||
"""Base class for model based tuner
|
||||
This type of tuner will fit a cost model and use an optimizer to
|
||||
find the maximums of the cost model as next trials
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: autotvm.task.Task
|
||||
The tuning task
|
||||
cost_model: CostModel
|
||||
The cost model that predicts the speed of a config (IR)
|
||||
model_optimizer:
|
||||
The optimizer to find local optimum points of cost model in tuning search space
|
||||
plan_size: int
|
||||
Tuner will re-fit model per `plan_size` new measure samples
|
||||
diversity_filter_ratio: int or float, optional
|
||||
If is not None, the tuner will first select
|
||||
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
|
||||
and then pick plan_size of them according to the diversity metric.
|
||||
"""
|
||||
|
||||
def __init__(self, task, cost_model, model_optimizer, plan_size, diversity_filter_ratio=None):
|
||||
super(ModelBasedTuner, self).__init__(task)
|
||||
|
||||
# space
|
||||
self.task = task
|
||||
self.target = task.target
|
||||
self.plan_size = plan_size
|
||||
self.space = task.config_space
|
||||
self.space_len = len(task.config_space)
|
||||
self.dims = [len(x) for x in self.space.space_map.values()]
|
||||
|
||||
self.cost_model = cost_model
|
||||
self.model_optimizer = model_optimizer
|
||||
self.diversity_filter_ratio = diversity_filter_ratio
|
||||
|
||||
if self.diversity_filter_ratio:
|
||||
assert self.diversity_filter_ratio >= 1, "Diversity filter ratio " \
|
||||
"must be larger than one"
|
||||
|
||||
# trial plan
|
||||
self.trials = []
|
||||
self.trial_pt = 0
|
||||
self.visited = set()
|
||||
|
||||
# observed samples
|
||||
self.xs = []
|
||||
self.ys = []
|
||||
self.flops_max = 0.0
|
||||
self.train_ct = 0
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
ret = []
|
||||
|
||||
counter = 0
|
||||
while counter < batch_size:
|
||||
if len(self.visited) >= len(self.space):
|
||||
break
|
||||
|
||||
while self.trial_pt < len(self.trials):
|
||||
index = self.trials[self.trial_pt]
|
||||
if index not in self.visited:
|
||||
break
|
||||
self.trial_pt += 1
|
||||
|
||||
if self.trial_pt >= len(self.trials): # trial list is empty, choose randomly
|
||||
index = np.random.randint(len(self.space))
|
||||
while index in self.visited:
|
||||
index = np.random.randint(len(self.space))
|
||||
|
||||
ret.append(self.space.get(index))
|
||||
self.visited.add(index)
|
||||
|
||||
counter += 1
|
||||
return ret
|
||||
|
||||
def update(self, inputs, results):
|
||||
for inp, res in zip(inputs, results):
|
||||
index = inp.config.index
|
||||
if res.error_no == 0:
|
||||
self.xs.append(index)
|
||||
flops = inp.task.flop / np.mean(res.costs)
|
||||
self.flops_max = max(self.flops_max, flops)
|
||||
self.ys.append(flops)
|
||||
else:
|
||||
self.xs.append(index)
|
||||
self.ys.append(0)
|
||||
|
||||
# if we have enough new training samples
|
||||
if len(self.xs) >= self.plan_size * (self.train_ct + 1) \
|
||||
and self.flops_max > 1e-6:
|
||||
self.cost_model.fit(self.xs, self.ys, self.plan_size)
|
||||
if self.diversity_filter_ratio:
|
||||
candidate = self.model_optimizer.find_maximums(
|
||||
self.cost_model, self.plan_size * self.diversity_filter_ratio, self.visited)
|
||||
scores = self.cost_model.predict(candidate)
|
||||
knobs = [point2knob(x, self.dims) for x in candidate]
|
||||
pick_index = submodular_pick(0 * scores, knobs, self.plan_size, knob_weight=1)
|
||||
maximums = np.array(candidate)[pick_index]
|
||||
else:
|
||||
maximums = self.model_optimizer.find_maximums(
|
||||
self.cost_model, self.plan_size, self.visited)
|
||||
|
||||
self.trials = maximums
|
||||
self.trial_pt = 0
|
||||
self.train_ct += 1
|
||||
|
||||
def load_history(self, data_set):
|
||||
base_model = self.cost_model.clone_new()
|
||||
base_model.fit_log(data_set, self.plan_size)
|
||||
|
||||
if not self.trials:
|
||||
# no plan yet, use base model to select initial trials
|
||||
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
|
||||
self.trials = maximums
|
||||
self.trial_pt = 0
|
||||
|
||||
self.cost_model.load_basemodel(base_model)
|
||||
|
||||
def has_next(self):
|
||||
return len(self.visited) < len(self.space)
|
||||
|
||||
|
||||
def point2knob(p, dims):
|
||||
"""convert point form (single integer) to knob form (vector)"""
|
||||
knob = []
|
||||
for dim in dims:
|
||||
knob.append(p % dim)
|
||||
p //= dim
|
||||
return knob
|
||||
|
||||
|
||||
def knob2point(knob, dims):
|
||||
"""convert knob form (vector) to point form (single integer)"""
|
||||
p = 0
|
||||
for j, k in enumerate(knob):
|
||||
p += int(np.prod(dims[:j])) * k
|
||||
return p
|
||||
|
||||
|
||||
def submodular_pick(scores, knobs, n_pick, knob_weight=1.0):
|
||||
"""Run greedy optimization to pick points with regard to both score and diversity.
|
||||
DiversityScore = knob_weight * number of unique knobs in the selected set
|
||||
Obj = sum(scores[i] for i in pick) + DiversityScore
|
||||
Note that this objective function is a monotone submodular function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scores: Array of float
|
||||
score of every points
|
||||
knobs: Array of Array of int
|
||||
feature vector (tunable knobs) of every points
|
||||
n_pick: int
|
||||
number of points to pick
|
||||
knob_weight: float
|
||||
weight of an unique knob feature
|
||||
"""
|
||||
n = len(scores)
|
||||
assert n == len(knobs)
|
||||
n_knobs = len(knobs[0])
|
||||
|
||||
knobs_set = [set() for _ in range(n_knobs)]
|
||||
|
||||
ret = []
|
||||
remain = list(range(len(scores)))
|
||||
|
||||
for _ in range(n_pick):
|
||||
max_x = -1
|
||||
max_delta = -1e9
|
||||
|
||||
for x in remain:
|
||||
tmp_delta = scores[x]
|
||||
for i in range(n_knobs):
|
||||
if knobs[x][i] not in knobs_set[i]:
|
||||
tmp_delta += knob_weight
|
||||
|
||||
if tmp_delta > max_delta:
|
||||
max_delta, max_x = tmp_delta, x
|
||||
|
||||
ret.append(max_x)
|
||||
remain.remove(max_x)
|
||||
for i in range(n_knobs):
|
||||
knobs_set[i].add(knobs[max_x][i])
|
||||
|
||||
return ret
|
|
@ -0,0 +1,148 @@
|
|||
# pylint: disable=consider-using-enumerate
|
||||
"""
|
||||
Cost model optimizer based on simulated annealing
|
||||
"""
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import sample_ints
|
||||
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
|
||||
|
||||
class SimulatedAnnealingOptimizer(ModelOptimizer):
|
||||
"""parallel simulated annealing optimization algorithm
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: Task
|
||||
The tuning task
|
||||
n_iter: int
|
||||
The number of iterations of simulated annealing
|
||||
temp: float or Array of float
|
||||
If is a single float, then use a constant temperature.
|
||||
If is an Array, then perform linear cooling from temp[0] to temp[1]
|
||||
early_stop: int, optional
|
||||
Stop iteration if the optimal set do not change in `early_stop` rounds
|
||||
verbose: int, optional
|
||||
Print log every `verbose` iterations
|
||||
"""
|
||||
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
|
||||
early_stop=30, verbose=50):
|
||||
super(SimulatedAnnealingOptimizer, self).__init__()
|
||||
|
||||
self.task = task
|
||||
self.dims = [len(x) for x in self.task.config_space.space_map.values()]
|
||||
|
||||
self.n_iter = n_iter
|
||||
self.temp = temp
|
||||
self.persistent = persistent
|
||||
self.parallel_size = parallel_size
|
||||
self.early_stop = early_stop
|
||||
self.verbose = verbose
|
||||
self.points = None
|
||||
|
||||
def find_maximums(self, model, num, exclusive):
|
||||
tic = time.time()
|
||||
temp, n_iter, early_stop, verbose = self.temp, self.n_iter, self.early_stop, self.verbose
|
||||
|
||||
if self.persistent and self.points is not None:
|
||||
points = self.points
|
||||
else:
|
||||
points = np.array(sample_ints(0, len(self.task.config_space), self.parallel_size))
|
||||
|
||||
scores = model.predict(points)
|
||||
|
||||
# build heap and insert initial points
|
||||
heap_items = [(float('-inf'), -i) for i in range(num)]
|
||||
heapq.heapify(heap_items)
|
||||
in_heap = set(exclusive)
|
||||
in_heap.update([-i for i in range(num)])
|
||||
|
||||
for s, p in zip(scores, points):
|
||||
if s > heap_items[0][0] and p not in in_heap:
|
||||
pop = heapq.heapreplace(heap_items, (s, p))
|
||||
in_heap.remove(pop[1])
|
||||
in_heap.add(p)
|
||||
|
||||
k = 0
|
||||
k_last_modify = 0
|
||||
|
||||
if isinstance(temp, (tuple, list, np.ndarray)):
|
||||
t = temp[0]
|
||||
cool = 1.0 * (temp[0] - temp[1]) / (n_iter + 1)
|
||||
else:
|
||||
t = temp
|
||||
cool = 0
|
||||
|
||||
while k < n_iter and k < k_last_modify + early_stop:
|
||||
new_points = np.empty_like(points)
|
||||
for i, p in enumerate(points):
|
||||
new_points[i] = random_walk(p, self.dims)
|
||||
|
||||
new_scores = model.predict(new_points)
|
||||
|
||||
ac_prob = np.exp((new_scores - scores) / t)
|
||||
ac_index = np.random.random(len(ac_prob)) < ac_prob
|
||||
|
||||
points[ac_index] = new_points[ac_index]
|
||||
scores[ac_index] = new_scores[ac_index]
|
||||
|
||||
for s, p in zip(new_scores, new_points):
|
||||
if s > heap_items[0][0] and p not in in_heap:
|
||||
pop = heapq.heapreplace(heap_items, (s, p))
|
||||
in_heap.remove(pop[1])
|
||||
in_heap.add(p)
|
||||
k_last_modify = k
|
||||
|
||||
k += 1
|
||||
t -= cool
|
||||
|
||||
if verbose >= 1 and k % verbose == 0:
|
||||
t_str = "%.2f" % t
|
||||
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
|
||||
"elapsed: %.2f",
|
||||
k, k_last_modify, heap_items[0][0],
|
||||
np.max([v for v, _ in heap_items]), t_str,
|
||||
time.time() - tic)
|
||||
|
||||
heap_items.sort(key=lambda item: -item[0])
|
||||
if verbose:
|
||||
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
|
||||
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
|
||||
logging.info("SA Maximums: %s", heap_items)
|
||||
|
||||
if self.persistent:
|
||||
self.points = points
|
||||
|
||||
return [x[1] for x in heap_items]
|
||||
|
||||
def random_walk(p, dims):
|
||||
"""random walk as local transition
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p: int
|
||||
index of the ConfigEntity
|
||||
dims: Array of int
|
||||
sizes of each dimension
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_p: int
|
||||
new neighborhood index
|
||||
"""
|
||||
# transform to knob form
|
||||
old = point2knob(p, dims)
|
||||
new = list(old)
|
||||
|
||||
# mutate
|
||||
while new == old:
|
||||
from_i = np.random.randint(len(old))
|
||||
to_v = np.random.randint(dims[from_i])
|
||||
new[from_i] = to_v
|
||||
|
||||
# transform to index form
|
||||
return knob2point(new, dims)
|
|
@ -0,0 +1,138 @@
|
|||
# pylint: disable=unused-argument, no-self-use, invalid-name
|
||||
"""Base class of tuner"""
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..measure import MeasureInput
|
||||
from ..measure import create_measure_batch
|
||||
|
||||
|
||||
class Tuner(object):
|
||||
"""Base class for tuners
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: autotvm.task.Task
|
||||
Tuning Task
|
||||
"""
|
||||
|
||||
def __init__(self, task, **kwargs):
|
||||
self.param = kwargs
|
||||
self.recorder = None
|
||||
|
||||
self.task = task
|
||||
|
||||
# keep the current best
|
||||
self.best_config = None
|
||||
self.best_flops = 0
|
||||
self.best_measure_pair = None
|
||||
|
||||
def has_next(self):
|
||||
"""Whether has next untried config in the space
|
||||
|
||||
Returns
|
||||
-------
|
||||
has_next: bool
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""get the next batch of configs to be measure on real hardware
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
The size of the batch
|
||||
|
||||
Returns
|
||||
-------
|
||||
a batch of configs
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, inputs, results):
|
||||
"""Update parameters of the tuner according to measurement results
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: Array of autotvm.measure.MeasureInput
|
||||
The input for measurement
|
||||
results: Array of autotvm.measure.MeasureResult
|
||||
result for measurement
|
||||
"""
|
||||
pass
|
||||
|
||||
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
|
||||
"""Begin tuning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_trial: int
|
||||
Maximum number of configs to try (measure on real hardware)
|
||||
measure_option: dict
|
||||
The options for how to measure generated code.
|
||||
You should use the return value ot autotvm.measure_option for this argument.
|
||||
verbose: int
|
||||
0: silent mode, no output
|
||||
1: print every measurement result
|
||||
callbacks: List of callable
|
||||
A list of callback functions. The signature of callback function is
|
||||
(Tuner, List of MeasureInput, List of MeasureResult)
|
||||
with no return value. These callback functions will be called on
|
||||
every measurement pair. See autotvm/tuner/callback.py for some examples.
|
||||
"""
|
||||
measure_batch = create_measure_batch(self.task, measure_option)
|
||||
parallel_num = getattr(measure_batch, 'parallel_num', 1)
|
||||
|
||||
i = 0
|
||||
while i < n_trial:
|
||||
if not self.has_next():
|
||||
break
|
||||
|
||||
configs = self.next_batch(min(parallel_num, n_trial - i))
|
||||
|
||||
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
|
||||
results = measure_batch(inputs)
|
||||
|
||||
# print info
|
||||
if verbose >= 1:
|
||||
for k, (inp, res) in enumerate(zip(inputs, results)):
|
||||
config = inp.config
|
||||
if res.error_no == 0:
|
||||
flops = inp.task.flop / np.mean(res.costs)
|
||||
else:
|
||||
flops = 0
|
||||
if flops > self.best_flops:
|
||||
self.best_flops = flops
|
||||
self.best_config = config
|
||||
self.best_measure_pair = (inp, res)
|
||||
|
||||
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
||||
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
||||
res, config)
|
||||
|
||||
i += len(results)
|
||||
|
||||
self.update(inputs, results)
|
||||
|
||||
for callback in callbacks:
|
||||
callback(self, inputs, results)
|
||||
|
||||
del measure_batch
|
||||
|
||||
def reset(self):
|
||||
"""reset the status of tuner"""
|
||||
self.best_config = None
|
||||
self.best_flops = 0
|
||||
self.best_measure_pair = None
|
||||
|
||||
def load_history(self, data_set):
|
||||
"""load history data for transfer learning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_set: Array of (MeasureInput, MeasureResult) pair
|
||||
Previous tuning records
|
||||
"""
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,482 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""XGBoost as cost model"""
|
||||
|
||||
import multiprocessing
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
try:
|
||||
import xgboost as xgb
|
||||
except ImportError:
|
||||
xgb = None
|
||||
|
||||
from .. import feature
|
||||
from ..util import get_rank
|
||||
from .metric import max_curve, recall_curve, cover_curve
|
||||
from .model_based_tuner import CostModel, FeatureCache
|
||||
|
||||
class XGBoostCostModel(CostModel):
|
||||
"""XGBoost as cost model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: Task
|
||||
The tuning task
|
||||
feature_type: str, optional
|
||||
If is 'itervar', use features extracted from IterVar (loop variable).
|
||||
If is 'knob', use flatten ConfigEntity directly.
|
||||
If is 'curve', use sampled curve feature (relation feature).
|
||||
|
||||
Note on choosing feature type:
|
||||
For single task tuning, 'itervar' and 'knob' is good.
|
||||
'itervar' is more accurate but 'knob' is much faster.
|
||||
For cross-shape tuning (e.g. many convolutions with different shapes),
|
||||
'itervar' and 'curve' has better transferability,
|
||||
'knob' is faster.
|
||||
For cross-device or cross-operator tuning, you can use 'curve' only.
|
||||
loss_type: str
|
||||
If is 'reg', use regression loss to train cost model.
|
||||
The cost model predicts the normalized flops.
|
||||
If is 'rank', use pairwise rank loss to train cost model.
|
||||
The cost model predicts relative rank score.
|
||||
num_threads: int, optional
|
||||
The number of threads.
|
||||
verbose: int, optional
|
||||
If is not none, the cost model will print training log every `verbose` iterations.
|
||||
"""
|
||||
def __init__(self, task, feature_type, loss_type, num_threads=None, verbose=20):
|
||||
super(XGBoostCostModel, self).__init__()
|
||||
|
||||
if xgb is None:
|
||||
raise RuntimeError("XGBoost is required for XGBoostCostModel. "
|
||||
"Please install its python package first. "
|
||||
"Help: (https://xgboost.readthedocs.io/en/latest/) ")
|
||||
|
||||
self.task = task
|
||||
self.target = task.target
|
||||
self.space = task.config_space
|
||||
|
||||
self.fea_type = feature_type
|
||||
self.loss_type = loss_type
|
||||
self.num_threads = num_threads
|
||||
self.verbose = verbose
|
||||
|
||||
if loss_type == 'reg':
|
||||
self.xgb_params = {
|
||||
'max_depth': 3,
|
||||
'gamma': 0.0001,
|
||||
'min_child_weight': 1,
|
||||
|
||||
'subsample': 1.0,
|
||||
|
||||
'eta': 0.3,
|
||||
'lambda': 1.00,
|
||||
'alpha': 0,
|
||||
|
||||
'objective': 'reg:linear',
|
||||
}
|
||||
elif loss_type == 'rank':
|
||||
self.xgb_params = {
|
||||
'max_depth': 3,
|
||||
'gamma': 0.0001,
|
||||
'min_child_weight': 1,
|
||||
|
||||
'subsample': 1.0,
|
||||
|
||||
'eta': 0.3,
|
||||
'lambda': 1.00,
|
||||
'alpha': 0,
|
||||
|
||||
'objective': 'rank:pairwise',
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("Invalid loss type: " + loss_type)
|
||||
|
||||
self.xgb_params['silent'] = 1
|
||||
if num_threads:
|
||||
self.xgb_params['nthread'] = num_threads
|
||||
self.bst = None
|
||||
|
||||
if feature_type == 'itervar':
|
||||
self.feature_extract_func = _extract_itervar_feature_index
|
||||
elif feature_type == 'knob':
|
||||
self.feature_extract_func = _extract_knob_feature_index
|
||||
elif feature_type == 'curve':
|
||||
self.feature_extract_func = _extract_curve_feature_index
|
||||
else:
|
||||
raise RuntimeError("Invalid feature type " + feature_type)
|
||||
|
||||
self.feature_cache = FeatureCache()
|
||||
self.feature_extra_ct = 0
|
||||
self.pool = None
|
||||
self.base_model = None
|
||||
|
||||
self._reset_pool()
|
||||
|
||||
def _reset_pool(self):
|
||||
# reset processing pool for feature extraction
|
||||
if self.pool:
|
||||
self.pool.terminate()
|
||||
self.pool.join()
|
||||
del self.pool
|
||||
# use global variable to pass common arguments
|
||||
global _extract_space, _extract_target, _extract_task
|
||||
_extract_space = self.space
|
||||
_extract_target = self.target
|
||||
_extract_task = self.task
|
||||
self.pool = multiprocessing.Pool(self.num_threads)
|
||||
|
||||
def fit(self, xs, ys, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
|
||||
x_train = self._get_feature(xs)
|
||||
y_train = np.array(ys)
|
||||
y_train /= np.max(y_train)
|
||||
|
||||
valid_index = y_train > 1e-6
|
||||
index = np.random.permutation(len(x_train))
|
||||
dtrain = xgb.DMatrix(x_train[index], y_train[index])
|
||||
|
||||
if self.base_model:
|
||||
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
||||
|
||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||
num_boost_round=8000,
|
||||
callbacks=[custom_callback(
|
||||
stopping_rounds=20,
|
||||
metric='tr-a-recall@%d' % plan_size,
|
||||
evals=[(dtrain, 'tr')],
|
||||
maximize=True,
|
||||
fevals=[
|
||||
xgb_average_recalln_curve_score(plan_size),
|
||||
],
|
||||
verbose_eval=self.verbose)])
|
||||
|
||||
logging.info("train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
|
||||
time.time() - tic, len(xs),
|
||||
len(xs) - np.sum(valid_index),
|
||||
self.feature_cache.size(self.fea_type))
|
||||
|
||||
def fit_log(self, records, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
|
||||
args = list(records)
|
||||
if self.fea_type == 'itervar':
|
||||
feature_extract_func = _extract_itervar_feature_log
|
||||
elif self.fea_type == 'knob':
|
||||
feature_extract_func = _extract_knob_feature_log
|
||||
elif self.fea_type == 'curve':
|
||||
feature_extract_func = _extract_curve_feature_log
|
||||
else:
|
||||
raise RuntimeError("Invalid feature type: " + self.fea_type)
|
||||
res = self.pool.map(feature_extract_func, args)
|
||||
xs, ys = zip(*res)
|
||||
xs, ys = np.array(xs), np.array(ys)
|
||||
|
||||
x_train = xs
|
||||
y_train = ys
|
||||
y_train /= np.max(y_train)
|
||||
|
||||
index = np.random.permutation(len(x_train))
|
||||
dtrain = xgb.DMatrix(x_train[index], y_train[index])
|
||||
|
||||
plan_size *= 2
|
||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||
num_boost_round=200,
|
||||
callbacks=[custom_callback(
|
||||
stopping_rounds=100,
|
||||
metric='tr-a-recall@%d' % plan_size,
|
||||
evals=[(dtrain, 'tr')],
|
||||
maximize=True,
|
||||
fevals=[
|
||||
xgb_average_recalln_curve_score(plan_size),
|
||||
],
|
||||
verbose_eval=self.verbose)])
|
||||
|
||||
logging.info("train: %.2f\tobs: %d", time.time() - tic, len(xs))
|
||||
|
||||
def predict(self, xs, output_margin=False):
|
||||
feas = self._get_feature(xs)
|
||||
dtest = xgb.DMatrix(feas)
|
||||
|
||||
if self.base_model:
|
||||
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
||||
|
||||
return self.bst.predict(dtest, output_margin=output_margin)
|
||||
|
||||
def load_basemodel(self, base_model):
|
||||
self.base_model = base_model
|
||||
|
||||
def clone_new(self):
|
||||
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
|
||||
self.num_threads, self.verbose)
|
||||
|
||||
def _get_feature(self, indexes):
|
||||
"""get features for indexes, run extraction if we do not have cache for them"""
|
||||
# free feature cache
|
||||
if self.feature_cache.size(self.fea_type) >= 100000:
|
||||
self.feature_cache.clear(self.fea_type)
|
||||
|
||||
fea_cache = self.feature_cache.get(self.fea_type)
|
||||
|
||||
indexes = np.array(indexes)
|
||||
need_extract = [x for x in indexes if x not in fea_cache]
|
||||
|
||||
if need_extract:
|
||||
feas = self.pool.map(self.feature_extract_func, need_extract)
|
||||
for i, fea in zip(need_extract, feas):
|
||||
fea_cache[i] = fea
|
||||
|
||||
ret = np.empty((len(indexes), fea_cache[indexes[0]].shape[-1]), dtype=np.float32)
|
||||
for i, ii in enumerate(indexes):
|
||||
ret[i, :] = fea_cache[ii]
|
||||
return ret
|
||||
|
||||
|
||||
_extract_space = None
|
||||
_extract_target = None
|
||||
_extract_task = None
|
||||
|
||||
def _extract_itervar_feature_index(index):
|
||||
"""extract iteration var feature for an index in extract_space"""
|
||||
config = _extract_space.get(index)
|
||||
with _extract_target:
|
||||
sch, args = _extract_task.instantiate(config)
|
||||
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
|
||||
fea = np.concatenate((fea, list(config.get_other_option().values())))
|
||||
return fea
|
||||
|
||||
def _extract_itervar_feature_log(arg):
|
||||
"""extract iteration var feature for log items"""
|
||||
inp, res = arg
|
||||
config = inp.config
|
||||
with inp.target:
|
||||
sch, args = inp.task.instantiate(config)
|
||||
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
|
||||
x = np.concatenate((fea, list(config.get_other_option().values())))
|
||||
|
||||
if res.error_no == 0:
|
||||
y = inp.task.flop / np.mean(res.costs)
|
||||
else:
|
||||
y = 0
|
||||
return x, y
|
||||
|
||||
def _extract_knob_feature_index(index):
|
||||
"""extract knob feature for an index in extract_space"""
|
||||
config = _extract_space.get(index)
|
||||
return config.get_flatten_feature()
|
||||
|
||||
def _extract_knob_feature_log(arg):
|
||||
"""extract knob feature for log items"""
|
||||
inp, res = arg
|
||||
config = inp.config
|
||||
x = config.get_flatten_feature()
|
||||
|
||||
if res.error_no == 0:
|
||||
with inp.target: # necessary, for calculating flops of this task
|
||||
inp.task.instantiate(config)
|
||||
y = inp.task.flop / np.mean(res.costs)
|
||||
else:
|
||||
y = 0
|
||||
return x, y
|
||||
|
||||
def _extract_curve_feature_index(index):
|
||||
"""extract sampled curve feature for an index in extract_space"""
|
||||
config = _extract_space.get(index)
|
||||
with _extract_target:
|
||||
sch, args = _extract_task.instantiate(config)
|
||||
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
|
||||
fea = np.concatenate((fea, list(config.get_other_option().values())))
|
||||
return np.array(fea)
|
||||
|
||||
def _extract_curve_feature_log(arg):
|
||||
"""extract sampled curve feature for log items"""
|
||||
inp, res = arg
|
||||
config = inp.config
|
||||
with inp.target:
|
||||
sch, args = inp.task.instantiate(config)
|
||||
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
|
||||
x = np.concatenate((fea, list(config.get_other_option().values())))
|
||||
|
||||
if res.error_no == 0:
|
||||
y = inp.task.flop / np.mean(res.costs)
|
||||
else:
|
||||
y = 0
|
||||
return x, y
|
||||
|
||||
|
||||
def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
|
||||
save_file="xgb_checkpoint", save_every=None,
|
||||
maximize=False, verbose_eval=True):
|
||||
"""callback function for xgboost to support multiple custom evaluation functions"""
|
||||
from xgboost.core import EarlyStopException
|
||||
from xgboost.callback import _fmt_metric
|
||||
from xgboost.training import aggcv
|
||||
|
||||
state = {}
|
||||
metric_shortname = metric.split("-")[1]
|
||||
|
||||
def init(env):
|
||||
"""internal function"""
|
||||
bst = env.model
|
||||
|
||||
state['maximize_score'] = maximize
|
||||
state['best_iteration'] = 0
|
||||
if maximize:
|
||||
state['best_score'] = float('-inf')
|
||||
else:
|
||||
state['best_score'] = float('inf')
|
||||
|
||||
if bst is not None:
|
||||
if bst.attr('best_score') is not None:
|
||||
state['best_score'] = float(bst.attr('best_score'))
|
||||
state['best_iteration'] = int(bst.attr('best_iteration'))
|
||||
state['best_msg'] = bst.attr('best_msg')
|
||||
else:
|
||||
bst.set_attr(best_iteration=str(state['best_iteration']))
|
||||
bst.set_attr(best_score=str(state['best_score']))
|
||||
else:
|
||||
assert env.cvfolds is not None
|
||||
|
||||
def callback(env):
|
||||
"""internal function"""
|
||||
if not state:
|
||||
init(env)
|
||||
|
||||
bst = env.model
|
||||
i = env.iteration
|
||||
cvfolds = env.cvfolds
|
||||
|
||||
res_dict = {}
|
||||
|
||||
##### evaluation #####
|
||||
if cvfolds is not None:
|
||||
for feval in fevals:
|
||||
tmp = aggcv([f.eval(i, feval) for f in cvfolds])
|
||||
for k, mean, std in tmp:
|
||||
res_dict[k] = [mean, std]
|
||||
else:
|
||||
for feval in fevals:
|
||||
bst_eval = bst.eval_set(evals, i, feval)
|
||||
res = [x.split(':') for x in bst_eval.split()]
|
||||
for kv in res[1:]:
|
||||
res_dict[kv[0]] = [float(kv[1])]
|
||||
|
||||
eval_res = []
|
||||
keys = list(res_dict.keys())
|
||||
keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
|
||||
for key in keys:
|
||||
v = res_dict[key]
|
||||
eval_res.append([key] + v)
|
||||
|
||||
##### print eval result #####
|
||||
infos = ["XGB iter: %3d" % i]
|
||||
for item in eval_res:
|
||||
if 'null' in item[0]:
|
||||
continue
|
||||
infos.append("%s: %.6f" % (item[0], item[1]))
|
||||
|
||||
if not isinstance(verbose_eval, bool) and i % verbose_eval == 0:
|
||||
logging.info("\t".join(infos))
|
||||
if log_file:
|
||||
with open(log_file, "a") as fout:
|
||||
fout.write("\t".join(infos) + '\n')
|
||||
|
||||
##### save model #####
|
||||
if save_every and i % save_every == 0:
|
||||
filename = save_file + ".%05d.bst" % i
|
||||
logging.info("save model to %s ...", filename)
|
||||
bst.save_model(filename)
|
||||
|
||||
##### choose score and do early stopping #####
|
||||
score = None
|
||||
for item in eval_res:
|
||||
if item[0] == metric:
|
||||
score = item[1]
|
||||
break
|
||||
assert score is not None
|
||||
|
||||
best_score = state['best_score']
|
||||
best_iteration = state['best_iteration']
|
||||
maximize_score = state['maximize_score']
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
msg = '[%d] %s' % (
|
||||
env.iteration,
|
||||
'\t'.join([_fmt_metric(x) for x in eval_res]))
|
||||
state['best_msg'] = msg
|
||||
state['best_score'] = score
|
||||
state['best_iteration'] = env.iteration
|
||||
# save the property to attributes, so they will occur in checkpoint.
|
||||
if env.model is not None:
|
||||
env.model.set_attr(best_score=str(state['best_score']),
|
||||
best_iteration=str(state['best_iteration']),
|
||||
best_msg=state['best_msg'])
|
||||
elif env.iteration - best_iteration >= stopping_rounds:
|
||||
best_msg = state['best_msg']
|
||||
if verbose_eval and env.rank == 0:
|
||||
logging.info("Stopping. Best iteration: %s ", best_msg)
|
||||
raise EarlyStopException(best_iteration)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
# feval wrapper for xgboost
|
||||
def xgb_max_curve_score(N):
|
||||
"""evaluate max curve score for xgb"""
|
||||
def feval(preds, labels):
|
||||
labels = labels.get_label()
|
||||
trials = np.argsort(preds)[::-1]
|
||||
scores = labels[trials]
|
||||
curve = max_curve(scores)
|
||||
return "Smax@%d" % N, curve[N] / np.max(labels)
|
||||
return feval
|
||||
|
||||
def xgb_recalln_curve_score(N):
|
||||
"""evaluate recall-n curve score for xgb"""
|
||||
def feval(preds, labels):
|
||||
labels = labels.get_label()
|
||||
trials = np.argsort(preds)[::-1]
|
||||
ranks = get_rank(labels[trials])
|
||||
curve = recall_curve(ranks)
|
||||
return "recall@%d" % N, curve[N]
|
||||
return feval
|
||||
|
||||
def xgb_average_recalln_curve_score(N):
|
||||
"""evaluate average recall-n curve score for xgb"""
|
||||
def feval(preds, labels):
|
||||
labels = labels.get_label()
|
||||
trials = np.argsort(preds)[::-1]
|
||||
ranks = get_rank(labels[trials])
|
||||
curve = recall_curve(ranks)
|
||||
return "a-recall@%d" % N, np.sum(curve[:N]) / N
|
||||
return feval
|
||||
|
||||
def xgb_recallk_curve_score(N, topk):
|
||||
"""evaluate recall-k curve score for xgb"""
|
||||
def feval(preds, labels):
|
||||
labels = labels.get_label()
|
||||
trials = np.argsort(preds)[::-1]
|
||||
ranks = get_rank(labels[trials])
|
||||
curve = recall_curve(ranks, topk)
|
||||
return "recall@%d" % topk, curve[N]
|
||||
return feval
|
||||
|
||||
def xgb_cover_curve_score(N):
|
||||
"""evaluate cover curve score for xgb"""
|
||||
def feval(preds, labels):
|
||||
labels = labels.get_label()
|
||||
trials = np.argsort(preds)[::-1]
|
||||
ranks = get_rank(labels[trials])
|
||||
curve = cover_curve(ranks)
|
||||
return "cover@%d" % N, curve[N]
|
||||
return feval
|
||||
|
||||
def xgb_null_score(_):
|
||||
"""empty score function for xgb"""
|
||||
def feval(__, ___):
|
||||
return "null", 0
|
||||
return feval
|
|
@ -0,0 +1,59 @@
|
|||
"""Tuner that uses xgboost as cost model"""
|
||||
|
||||
from .model_based_tuner import ModelBasedTuner, ModelOptimizer
|
||||
from .xgboost_cost_model import XGBoostCostModel
|
||||
from .sa_model_optimizer import SimulatedAnnealingOptimizer
|
||||
|
||||
class XGBTuner(ModelBasedTuner):
|
||||
"""Tuner that uses xgboost as cost model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: Task
|
||||
The tuning task
|
||||
plan_size: int
|
||||
The size of a plan. After `plan_size` trials, the tuner will refit a new cost model
|
||||
and do planing for the next `plan_size` trials.
|
||||
feature_type: str, optional
|
||||
If is 'itervar', use features extracted from IterVar (loop variable).
|
||||
If is 'knob', use flatten ConfigEntity directly.
|
||||
If is 'curve', use sampled curve feature (relation feature).
|
||||
|
||||
Note on choosing feature type:
|
||||
For single task tuning, 'itervar' and 'knob' is good.
|
||||
'itervar' is more accurate but 'knob' is much faster.
|
||||
For cross-shape tuning (e.g. many convolutions with different shapes),
|
||||
'itervar' and 'curve' has better transferability,
|
||||
'knob' is faster.
|
||||
For cross-device or cross-operator tuning, you can use 'curve' only.
|
||||
loss_type: str
|
||||
If is 'reg', use regression loss to train cost model.
|
||||
The cost model predicts the normalized flops.
|
||||
If is 'rank', use pairwise rank loss to train cost model.
|
||||
The cost model predicts relative rank score.
|
||||
num_threads: int, optional
|
||||
The number of threads.
|
||||
optimizer: str or ModelOptimizer, optional
|
||||
If is 'sa', use a default simulated annealing optimizer.
|
||||
Otherwise it should be a ModelOptimizer object.
|
||||
diversity_filter_ratio: int or float, optional
|
||||
If is not None, the tuner will first select
|
||||
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
|
||||
and then pick batch_size of them according to the diversity metric.
|
||||
"""
|
||||
def __init__(self, task, plan_size=32,
|
||||
feature_type='itervar', loss_type='rank', num_threads=None,
|
||||
optimizer='sa', diversity_filter_ratio=None):
|
||||
cost_model = XGBoostCostModel(task,
|
||||
feature_type=feature_type,
|
||||
loss_type=loss_type,
|
||||
num_threads=num_threads)
|
||||
if optimizer == 'sa':
|
||||
optimizer = SimulatedAnnealingOptimizer(task)
|
||||
else:
|
||||
assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \
|
||||
"a supported name string" \
|
||||
"or a ModelOptimizer object."
|
||||
|
||||
super(XGBTuner, self).__init__(task, cost_model, optimizer,
|
||||
plan_size, diversity_filter_ratio)
|
|
@ -0,0 +1,149 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Utilities"""
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import expr, ir_pass
|
||||
|
||||
def get_rank(values):
|
||||
"""get rank of items
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values: Array
|
||||
|
||||
Returns
|
||||
-------
|
||||
ranks: Array of int
|
||||
the rank of this item in the input (the largest value ranks first)
|
||||
"""
|
||||
tmp = np.argsort(-values)
|
||||
ranks = np.empty_like(tmp)
|
||||
ranks[tmp] = np.arange(len(tmp))
|
||||
return ranks
|
||||
|
||||
|
||||
def sample_ints(low, high, m):
|
||||
"""
|
||||
Sample m different integer numbers from [low, high) without replacement
|
||||
This function is an alternative of `np.random.choice` when (high - low) > 2 ^ 32, in
|
||||
which case numpy does not work.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
low: int
|
||||
low point of sample range
|
||||
high: int
|
||||
high point of sample range
|
||||
m: int
|
||||
The number of sampled int
|
||||
|
||||
Returns
|
||||
-------
|
||||
ints: an array of size m
|
||||
"""
|
||||
vis = set()
|
||||
assert m <= high - low
|
||||
while len(vis) < m:
|
||||
new = np.random.randint(low, high)
|
||||
while new in vis:
|
||||
new = np.random.randint(low, high)
|
||||
vis.add(new)
|
||||
|
||||
return list(vis)
|
||||
|
||||
|
||||
def pool_map(func, args, batch_size, verbose=False, pool=None):
|
||||
"""A wrapper of multiprocessing.pool.Pool.map to support small-batch mapping
|
||||
for large argument list. This can reduce memory usage
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: Func(arg) -> np.ndarray
|
||||
mapping function
|
||||
args: List
|
||||
list of arguments
|
||||
batch_size: int
|
||||
batch size in mapping
|
||||
verbose: bool, optional
|
||||
whether print progress
|
||||
pool: multiprocessing.Pool, optional
|
||||
pool objection
|
||||
|
||||
Returns
|
||||
-------
|
||||
converted numpy array
|
||||
"""
|
||||
|
||||
ret = None
|
||||
tic = time.time()
|
||||
local_pool = pool or multiprocessing.Pool()
|
||||
if verbose:
|
||||
logging.info("mapping begin")
|
||||
for i in range(0, len(args), batch_size):
|
||||
if verbose:
|
||||
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
|
||||
time.time() - tic)
|
||||
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
|
||||
ret = tmp if ret is None else np.concatenate((ret, tmp))
|
||||
if verbose:
|
||||
logging.info("mapping done")
|
||||
if not pool:
|
||||
local_pool.close()
|
||||
return ret
|
||||
|
||||
def get_func_name(func):
|
||||
"""Get name of a function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: Function
|
||||
The function
|
||||
Returns
|
||||
-------
|
||||
name: str
|
||||
The name
|
||||
"""
|
||||
|
||||
return func.func_name if hasattr(func, 'func_name') else func.__name__
|
||||
|
||||
|
||||
def get_const_int(exp):
|
||||
"""Verifies expr is integer and get the constant value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exp : tvm.Expr or int
|
||||
The input expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out_value : int
|
||||
The output.
|
||||
"""
|
||||
if isinstance(exp, int):
|
||||
return exp
|
||||
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
|
||||
exp = ir_pass.Simplify(expr)
|
||||
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
|
||||
raise ValueError("Expect value to be constant int")
|
||||
return exp.value
|
||||
|
||||
|
||||
def get_const_tuple(in_tuple):
|
||||
"""Verifies input tuple is IntImm, returns tuple of int.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_tuple : tuple of Expr
|
||||
The input.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out_tuple : tuple of int
|
||||
The output.
|
||||
"""
|
||||
return tuple(get_const_int(x) for x in in_tuple)
|
|
@ -229,8 +229,14 @@ class TrackerSession(object):
|
|||
res += "----------------------------\n"
|
||||
res += "key\tfree\tpending\n"
|
||||
res += "----------------------------\n"
|
||||
for k, v in data["queue_info"].items():
|
||||
res += "%s\t%d\t%g\n" % (k, v["free"], v["pending"])
|
||||
queue_info = data['queue_info']
|
||||
keys = list(queue_info.keys())
|
||||
if keys:
|
||||
keys.sort()
|
||||
max_key_len = max([len(k) for k in keys])
|
||||
for k in keys:
|
||||
res += ("%%-%d" % max_key_len + "s\t%d\t%g\n") % \
|
||||
(k, queue_info[k]["free"], queue_info[k]["pending"])
|
||||
res += "----------------------------\n"
|
||||
return res
|
||||
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file feature_visitor.cc
|
||||
* \brief Base class for feature extractor.
|
||||
* These features are used for machine learning cost model
|
||||
*/
|
||||
|
||||
#include "feature_visitor.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace autotvm {
|
||||
|
||||
// for loop
|
||||
void FeatureVisitor::Visit_(const For *op) {
|
||||
const auto *extent = op->extent.as<IntImm>();
|
||||
int64_t loop_extent = -1;
|
||||
if (extent != nullptr)
|
||||
loop_extent = extent->value;
|
||||
AnnotationType ann = kSerial;
|
||||
switch (op->for_type) {
|
||||
case ForType ::Parallel:
|
||||
ann = kParallel;
|
||||
break;
|
||||
case ForType::Unrolled:
|
||||
ann = kUnrolled;
|
||||
break;
|
||||
case ForType::Vectorized:
|
||||
ann = kVectorized;
|
||||
break;
|
||||
case ForType::Serial:
|
||||
ann = kSerial;
|
||||
break;
|
||||
}
|
||||
|
||||
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
|
||||
IRVisitor::Visit_(op);
|
||||
ExitItervar_();
|
||||
}
|
||||
}
|
||||
|
||||
// parallel axis, virtual thread
|
||||
void FeatureVisitor::Visit_(const AttrStmt *op) {
|
||||
if (op->attr_key == attr::thread_extent ||
|
||||
op->attr_key == attr::virtual_thread) {
|
||||
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
|
||||
const auto *extent = op->value.as<IntImm>();
|
||||
CHECK(extent);
|
||||
|
||||
std::string name = var.get()->name_hint;
|
||||
AnnotationType ann = kParallel;
|
||||
if (op->attr_key == attr::thread_extent) {
|
||||
if (name == "blockIdx.x")
|
||||
ann = kBlockX;
|
||||
else if (name == "blockIdx.y")
|
||||
ann = kBlockY;
|
||||
else if (name == "blockIdx.z")
|
||||
ann = kBlockZ;
|
||||
else if (name == "threadIdx.x")
|
||||
ann = kThreadX;
|
||||
else if (name == "threadIdx.y")
|
||||
ann = kThreadY;
|
||||
else if (name == "threadIdx.z")
|
||||
ann = kThreadZ;
|
||||
else
|
||||
LOG(FATAL) << "invalid thread itervar " + name;
|
||||
} else {
|
||||
ann = kVirtualThread;
|
||||
}
|
||||
|
||||
if (EnterItervar_(var, extent->value, ann)) {
|
||||
IRVisitor::Visit_(op);
|
||||
ExitItervar_();
|
||||
}
|
||||
} else {
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
}
|
||||
|
||||
// memory access
|
||||
void FeatureVisitor::Visit_(const Load *op) {
|
||||
EnterMem_(op->buffer_var, op->index);
|
||||
IRVisitor::Visit_(op);
|
||||
ExitMem_();
|
||||
}
|
||||
|
||||
void FeatureVisitor::Visit_(const Store *op) {
|
||||
EnterMem_(op->buffer_var, op->index);
|
||||
IRVisitor::Visit_(op);
|
||||
ExitMem_();
|
||||
}
|
||||
|
||||
} // namespace autotvm
|
||||
} // namespace tvm
|
|
@ -0,0 +1,67 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file feature_visitor.h
|
||||
* \brief Base class for feature extractor.
|
||||
* These features are used for machine learning cost model
|
||||
*/
|
||||
|
||||
#ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_
|
||||
#define TVM_AUTOTVM_FEATURE_VISITOR_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace autotvm {
|
||||
|
||||
using namespace tvm::ir;
|
||||
|
||||
/*!
|
||||
* \brief Type of for loop, used as one-hot encoding in features
|
||||
*/
|
||||
enum AnnotationType {
|
||||
kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ,
|
||||
kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread,
|
||||
kNum,
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A base class for feature extractor, used for processing
|
||||
* for loop and memory access in the IR
|
||||
*/
|
||||
class FeatureVisitor : public IRVisitor {
|
||||
public:
|
||||
// for loop
|
||||
void Visit_(const For *op);
|
||||
void Visit_(const AttrStmt *op);
|
||||
|
||||
// memory access
|
||||
void Visit_(const Load *op);
|
||||
void Visit_(const Store *op);
|
||||
|
||||
protected:
|
||||
/*!
|
||||
* \brief Enter a for loop node
|
||||
* \param var The expression to be printed.
|
||||
* \param length The output stream
|
||||
* \param ann_type The type for the for loop
|
||||
* \return skip Whether skip this node
|
||||
*/
|
||||
virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0;
|
||||
/*! \brief Exit a for loop subtree */
|
||||
virtual void ExitItervar_() = 0;
|
||||
/*!
|
||||
* \brief Enter a memory access node
|
||||
* \param buffer_var The buffer to access.
|
||||
* \param index Index expression
|
||||
*/
|
||||
virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0;
|
||||
/*! \brief Exit a memory access node */
|
||||
virtual void ExitMem_() = 0;
|
||||
};
|
||||
|
||||
} // namespace autotvm
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_AUTOTVM_FEATURE_VISITOR_H_
|
|
@ -0,0 +1,510 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file touch_extractor.cc
|
||||
* \brief Extract feature of touch pattern of axes in lowered IR
|
||||
*/
|
||||
|
||||
#include "touch_extractor.h"
|
||||
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace tvm {
|
||||
namespace autotvm {
|
||||
|
||||
int ParallelLevel(AnnotationType ann) {
|
||||
switch (ann) {
|
||||
case kBlockX: case kBlockY: case kBlockZ:
|
||||
return 2;
|
||||
case kThreadX: case kThreadY: case kThreadZ: case kParallel:
|
||||
return 1;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// get touch pattern from index expression
|
||||
class IndexParser: public IRVisitor {
|
||||
public:
|
||||
void Parse(Expr expr) {
|
||||
pattern_map.clear();
|
||||
this->Visit(expr);
|
||||
}
|
||||
|
||||
void Visit_(const Variable *op) {
|
||||
// TODO(lmzheng): handle more index types (multiple occurrence)
|
||||
if (pattern_map.count(op) == 0) {
|
||||
pattern_map[op] = TouchPattern();
|
||||
pattern_map[op].stride = next_stride_;
|
||||
next_stride_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Visit_(const Mul *op) {
|
||||
if (op->a.as<Variable>()) {
|
||||
if (const auto stride = op->b.as<IntImm>()) {
|
||||
next_stride_ = stride->value;
|
||||
}
|
||||
}
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
std::unordered_map<const Variable*, TouchPattern> pattern_map;
|
||||
|
||||
private:
|
||||
int64_t next_stride_ = 1;
|
||||
};
|
||||
|
||||
// extract iter vars and their touch pattern from ir
|
||||
bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type) {
|
||||
// do not insert duplicated occurrences of virtual thread
|
||||
if (ann_type == kVirtualThread && itervar_map.count(var) != 0) {
|
||||
skip_stack_size_.push_back(itervar_stack_.size());
|
||||
return true;
|
||||
} else {
|
||||
itervar_stack_.push_back(var);
|
||||
topdown_product_ *= length;
|
||||
|
||||
if (itervar_map.count(var) != 0) {
|
||||
// find two duplicated axes
|
||||
// these happens when we create tvm.thread_axis("threadIdx.x") once and
|
||||
// bind it twice. Here we treat them as two axes
|
||||
// so we create a snapshot for the old one and freeze it
|
||||
VarExpr old = VarExpr(var.get()->name_hint);
|
||||
itervar_map.insert({old, itervar_map[var]});
|
||||
itervar_map.erase(var);
|
||||
}
|
||||
|
||||
itervar_map.insert({var, ItervarFeature(var, length,
|
||||
static_cast<int>(itervar_stack_.size()),
|
||||
ann_type,
|
||||
topdown_product_,
|
||||
static_cast<int>(itervar_counter_++))});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void TouchExtractor::ExitItervar_() {
|
||||
if (!skip_stack_size_.empty() && skip_stack_size_.back() == itervar_stack_.size()) {
|
||||
skip_stack_size_.pop_back();
|
||||
return;
|
||||
}
|
||||
VarExpr var = itervar_stack_.back();
|
||||
|
||||
// update count and reuse ratio for upper iter vars (includes self)
|
||||
for (auto kv : itervar_map[var].touch_feature) {
|
||||
if (kv.second.stride != 0) { // multiply count
|
||||
for (auto stack_var : itervar_stack_) {
|
||||
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
|
||||
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
|
||||
touch_pattern->second.count *= itervar_map[var].length;
|
||||
}
|
||||
} else { // multiply reuse ratio
|
||||
for (auto stack_var : itervar_stack_) {
|
||||
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
|
||||
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
|
||||
touch_pattern->second.reuse *= itervar_map[var].length;
|
||||
}
|
||||
}
|
||||
}
|
||||
itervar_stack_.pop_back();
|
||||
|
||||
topdown_product_ /= itervar_map[var].length;
|
||||
int64_t bottomup_product = -1;
|
||||
for (auto kv : itervar_map[var].touch_feature) {
|
||||
bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);
|
||||
}
|
||||
|
||||
itervar_map[var].bottomup_product = bottomup_product;
|
||||
|
||||
// push base to upper parallel axis
|
||||
int para_level = ParallelLevel(itervar_map[var].ann);
|
||||
// if is the separate line of parallel level, push the base to upper parallel level
|
||||
if (!itervar_stack_.empty() &&
|
||||
ParallelLevel(itervar_map[itervar_stack_.back()].ann) == para_level + 1) {
|
||||
for (auto kv : itervar_map[var].touch_feature) {
|
||||
for (auto stack_var : itervar_stack_) {
|
||||
if (ParallelLevel(itervar_map[stack_var].ann) == para_level + 1) {
|
||||
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
|
||||
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
|
||||
touch_pattern->second.thread_reuse = -kv.second.reuse;
|
||||
touch_pattern->second.thread_count = -kv.second.count;
|
||||
// NOTE: use minus as a flag to denote it is a base,
|
||||
// indicating it is not the final value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto kv : itervar_map[var].touch_feature) {
|
||||
if (kv.second.thread_count < 0) {
|
||||
itervar_map[var].touch_feature[kv.first].thread_count =
|
||||
kv.second.count / (-kv.second.thread_count);
|
||||
itervar_map[var].touch_feature[kv.first].thread_reuse =
|
||||
kv.second.reuse / (-kv.second.thread_reuse);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TouchExtractor::EnterMem_(VarExpr buffer_var, Expr index) {
|
||||
std::string name = buffer_var.get()->name_hint;
|
||||
TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++);
|
||||
|
||||
// extract touch pattern from index
|
||||
IndexParser parser;
|
||||
parser.Parse(index);
|
||||
|
||||
// push up mem access info
|
||||
for (auto var : itervar_stack_) {
|
||||
auto x = parser.pattern_map.find(var.get());
|
||||
if (x != parser.pattern_map.end()) {
|
||||
itervar_map[var].touch_feature[buf] = x->second;
|
||||
} else {
|
||||
itervar_map[var].touch_feature[buf] = TouchPattern();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TouchExtractor::ExitMem_() {
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get axis-based feature for all axes
|
||||
* \param stmt The statement to be extracted
|
||||
* \param bool Whether take log for numerical feature
|
||||
* \param ret_feature The buffer where the return value is stored
|
||||
*
|
||||
* \note The format of return value is
|
||||
* ((
|
||||
* ('_itervar_', var),
|
||||
* ('_attr_', length, nest_level, topdown, bottomup, one_hot_annotation),
|
||||
* ('_arith_', add_ct, mul_ct, div_ct),
|
||||
* ('data_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
|
||||
* ('conv_0', stride, mod, count, reuse, thread_count, thread_reuse),
|
||||
* ),
|
||||
* (
|
||||
* ('_itervar_', var2),
|
||||
* ('_attr_', length, nest_level, one_hot_annotation),
|
||||
* ('_arith_', add_ct, mul_ct, div_ct),
|
||||
* ('kernel_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
|
||||
* ('conv_1', stride, mod, count, reuse, thread_count, thread_reuse),
|
||||
* ))
|
||||
*
|
||||
* Itervars are sorted according to their first occurrence position in IR.
|
||||
* Buffers touched by an itervar are sorted by their unique names.
|
||||
*
|
||||
* \note If you want to flatten these features as the input of your model,
|
||||
* You can use the faster one GetItervarFeatureFlatten below.
|
||||
*/
|
||||
void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *ret_feature) {
|
||||
// extract
|
||||
TouchExtractor touch_analyzer;
|
||||
touch_analyzer.Analyze(stmt);
|
||||
|
||||
// sort according to order
|
||||
std::vector<VarExpr> vars;
|
||||
for (auto kv : touch_analyzer.itervar_map) {
|
||||
vars.push_back(kv.first);
|
||||
}
|
||||
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
|
||||
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
|
||||
});
|
||||
|
||||
// whether take log for numerical feature
|
||||
std::function<double(int64_t)> trans;
|
||||
if (take_log) {
|
||||
trans = [](int64_t x) {
|
||||
if (x < 0)
|
||||
return -std::log(-x+1) / std::log(2);
|
||||
x = x + 1;
|
||||
return std::log(x) / std::log(2);
|
||||
};
|
||||
} else {
|
||||
trans = [](int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
|
||||
// serialize for front end
|
||||
for (auto var : vars) {
|
||||
Array<Array<Expr> > feature_row;
|
||||
ItervarFeature &fea = touch_analyzer.itervar_map[var];
|
||||
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
|
||||
|
||||
Array<Expr> attr{std::string("_attr_"),
|
||||
FloatImm::make(Float(32), trans(fea.length)),
|
||||
IntImm::make(Int(32), fea.nest_level),
|
||||
FloatImm::make(Float(32), trans(fea.topdown_product)),
|
||||
FloatImm::make(Float(32), trans(fea.bottomup_product)),
|
||||
};
|
||||
// one hot annotation
|
||||
for (int i = 0; i < kNum; i++) {
|
||||
attr.push_back(i == fea.ann);
|
||||
}
|
||||
feature_row.push_back(attr);
|
||||
|
||||
// arithmetic
|
||||
feature_row.push_back(Array<Expr>{std::string("_arith_"),
|
||||
FloatImm::make(Float(32), trans(fea.add_ct)),
|
||||
FloatImm::make(Float(32), trans(fea.mul_ct)),
|
||||
FloatImm::make(Float(32), trans(fea.div_ct)),
|
||||
});
|
||||
|
||||
// touch map
|
||||
std::vector<TouchedBuffer> bufs;
|
||||
for (auto kv : fea.touch_feature) {
|
||||
bufs.push_back(kv.first);
|
||||
}
|
||||
std::sort(bufs.begin(), bufs.end());
|
||||
for (auto k : bufs) {
|
||||
TouchPattern &v = fea.touch_feature[k];
|
||||
feature_row.push_back(Array<Expr>{k,
|
||||
FloatImm::make(Float(32), trans(v.stride)),
|
||||
FloatImm::make(Float(32), trans(v.mod)),
|
||||
FloatImm::make(Float(32), trans(v.count)),
|
||||
FloatImm::make(Float(32), trans(v.reuse)),
|
||||
FloatImm::make(Float(32), trans(v.thread_count)),
|
||||
FloatImm::make(Float(32), trans(v.thread_reuse)),
|
||||
});
|
||||
}
|
||||
|
||||
ret_feature->push_back(feature_row);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get axis-based feature for all axes and flatten them into a one-dimensional vector.
|
||||
* \param stmt The statement to be extracted
|
||||
* \param bool Whether take log for numerical feature
|
||||
* \param ret_feature The buffer where the return value is stored
|
||||
*
|
||||
* \note See GetItervarFeature for more details about the return value.
|
||||
* This is an optimized version of GetItervarFeature + Flatten. This runs much faster.
|
||||
*/
|
||||
void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float> *ret_feature) {
|
||||
// extract touch feature
|
||||
TouchExtractor touch_analyzer;
|
||||
touch_analyzer.Analyze(stmt);
|
||||
|
||||
// sort according to order
|
||||
std::vector<VarExpr> vars;
|
||||
for (auto kv : touch_analyzer.itervar_map) {
|
||||
vars.push_back(kv.first);
|
||||
}
|
||||
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
|
||||
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
|
||||
});
|
||||
|
||||
// whether take log for numerical feature
|
||||
std::function<float(int64_t)> trans;
|
||||
if (take_log) {
|
||||
trans = [](int64_t x) {
|
||||
if (x < 0)
|
||||
return -std::log(-x+1) / std::log(2);
|
||||
x = x + 1;
|
||||
return std::log(x) / std::log(2);
|
||||
};
|
||||
} else {
|
||||
trans = [](int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
|
||||
// serialize for front end
|
||||
for (auto var : vars) {
|
||||
ItervarFeature &fea = touch_analyzer.itervar_map[var];
|
||||
|
||||
ret_feature->push_back(trans(fea.length));
|
||||
ret_feature->push_back(fea.nest_level);
|
||||
ret_feature->push_back(trans(fea.topdown_product));
|
||||
ret_feature->push_back(trans(fea.bottomup_product));
|
||||
|
||||
// one hot annotation
|
||||
for (int i = 0; i < kNum; i++) {
|
||||
ret_feature->push_back(i == fea.ann);
|
||||
}
|
||||
|
||||
// arithmetic
|
||||
ret_feature->push_back(trans(fea.add_ct));
|
||||
ret_feature->push_back(trans(fea.mul_ct));
|
||||
ret_feature->push_back(trans(fea.div_ct));
|
||||
|
||||
// touch map
|
||||
std::vector<TouchedBuffer> bufs;
|
||||
for (auto kv : fea.touch_feature) {
|
||||
bufs.push_back(kv.first);
|
||||
}
|
||||
std::sort(bufs.begin(), bufs.end());
|
||||
for (auto k : bufs) {
|
||||
TouchPattern &v = fea.touch_feature[k];
|
||||
ret_feature->push_back(trans(v.stride));
|
||||
ret_feature->push_back(trans(v.mod));
|
||||
ret_feature->push_back(trans(v.count));
|
||||
ret_feature->push_back(trans(v.reuse));
|
||||
ret_feature->push_back(trans(v.thread_count));
|
||||
ret_feature->push_back(trans(v.thread_reuse));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector.
|
||||
* \param stmt The statement to be extracted
|
||||
* \param sample_n The number of points used for sampling a curve (along one dimension)
|
||||
* \param ret_feature The buffer where the return value is stored
|
||||
*/
|
||||
void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float> *ret_feature) {
|
||||
// extract touch feature
|
||||
TouchExtractor touch_ext;
|
||||
touch_ext.Analyze(stmt);
|
||||
|
||||
// sort according to order
|
||||
std::vector<VarExpr> vars;
|
||||
for (auto kv : touch_ext.itervar_map) {
|
||||
vars.push_back(kv.first);
|
||||
}
|
||||
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
|
||||
return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
|
||||
});
|
||||
|
||||
int max_depth = 0;
|
||||
std::map<TouchedBuffer, std::vector<double> > reuse_curve;
|
||||
std::map<TouchedBuffer, std::vector<double> > count_curve;
|
||||
std::map<TouchedBuffer, std::vector<double> > topdown_curve;
|
||||
std::map<TouchedBuffer, std::vector<double> > bottomup_curve;
|
||||
std::set<TouchedBuffer> innermost_buffers;
|
||||
std::set<std::string> added;
|
||||
|
||||
// find maximum depth of loop nest
|
||||
for (auto var : vars) {
|
||||
ItervarFeature &fea = touch_ext.itervar_map[var];
|
||||
max_depth = std::max(max_depth, fea.nest_level);
|
||||
}
|
||||
|
||||
// mark inner most buffer
|
||||
for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) {
|
||||
auto var = *iter;
|
||||
ItervarFeature &fea = touch_ext.itervar_map[var];
|
||||
if (fea.nest_level == max_depth) {
|
||||
for (auto kv : fea.touch_feature) {
|
||||
// delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A')
|
||||
std::string raw_name = kv.first.substr(0, kv.first.rfind("_"));
|
||||
|
||||
// delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A')
|
||||
size_t pos = raw_name.find(".");
|
||||
if (pos < kv.first.size())
|
||||
raw_name = raw_name.substr(0, pos);
|
||||
|
||||
// If there are multiple innermost buffers that are derived from a same raw buffer
|
||||
// We only record the last occurrence (note the `iter` is in reverse order)
|
||||
// e.g. `A.local`, `A.shared` are derived from `A`, if they all occurred at the inner most
|
||||
// level, we will only record the last occurrence,
|
||||
if (added.find(raw_name) == added.end()) {
|
||||
innermost_buffers.insert(kv.first);
|
||||
added.insert(raw_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pad the first point (zero) for all curves
|
||||
for (auto buf : innermost_buffers) {
|
||||
reuse_curve[buf].push_back(0);
|
||||
count_curve[buf].push_back(0);
|
||||
topdown_curve[buf].push_back(0);
|
||||
bottomup_curve[buf].push_back(0);
|
||||
}
|
||||
|
||||
// extract curves
|
||||
for (auto var : vars) {
|
||||
ItervarFeature &fea = touch_ext.itervar_map[var];
|
||||
for (auto kv : fea.touch_feature) {
|
||||
if (innermost_buffers.find(kv.first) != innermost_buffers.end()) {
|
||||
reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2));
|
||||
count_curve[kv.first].emplace_back(std::log(kv.second.count) / std::log(2));
|
||||
topdown_curve[kv.first].emplace_back(std::log(fea.topdown_product) / std::log(2));
|
||||
bottomup_curve[kv.first].emplace_back(std::log(fea.bottomup_product) / std::log(2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sample relation in the curve
|
||||
auto sample_curve = [&](const std::vector<double> &x, const std::vector<double> &y,
|
||||
double weight) {
|
||||
for (int i = 0; i < sample_n; i++) {
|
||||
double xx = i * weight;
|
||||
for (int j = static_cast<int>(x.size()) - 1; j >= 0; j--) {
|
||||
if (xx > x[j] - 1e-6) {
|
||||
ret_feature->emplace_back(y[j]);
|
||||
ret_feature->emplace_back(xx - x[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// serialize to frontend
|
||||
for (auto k : innermost_buffers) {
|
||||
std::vector<double> &count = count_curve[k];
|
||||
std::vector<double> &reuse = reuse_curve[k];
|
||||
std::vector<double> &top_down = topdown_curve[k];
|
||||
|
||||
std::sort(count.begin(), count.end());
|
||||
std::sort(reuse.begin(), reuse.end());
|
||||
std::sort(top_down.begin(), top_down.end());
|
||||
|
||||
sample_curve(count, reuse, 1);
|
||||
sample_curve(reuse, count, 1);
|
||||
sample_curve(count, top_down, 1);
|
||||
sample_curve(top_down, count, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// register API for front end
|
||||
TVM_REGISTER_API("autotvm.feature.GetItervarFeature")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Stmt stmt = args[0];
|
||||
bool take_log = args[1];
|
||||
Array<Array<Array<Expr > > > ret_feature;
|
||||
|
||||
GetItervarFeature(stmt, take_log, &ret_feature);
|
||||
|
||||
*ret = ret_feature;
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Stmt stmt = args[0];
|
||||
bool take_log = args[1];
|
||||
std::vector<float> ret_feature;
|
||||
|
||||
GetItervarFeatureFlatten(stmt, take_log, &ret_feature);
|
||||
|
||||
TVMByteArray arr;
|
||||
arr.size = sizeof(float) * ret_feature.size();
|
||||
arr.data = reinterpret_cast<char *>(ret_feature.data());
|
||||
*ret = arr;
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Stmt stmt = args[0];
|
||||
bool take_log = args[1];
|
||||
std::vector<float> ret_feature;
|
||||
|
||||
GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
|
||||
|
||||
TVMByteArray arr;
|
||||
arr.size = sizeof(float) * ret_feature.size();
|
||||
arr.data = reinterpret_cast<char *>(ret_feature.data());
|
||||
*ret = arr;
|
||||
});
|
||||
|
||||
|
||||
} // namespace autotvm
|
||||
} // namespace tvm
|
|
@ -0,0 +1,124 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file touch_extractor.h
|
||||
* \brief Extract feature of touch pattern of axes in lowered IR
|
||||
*/
|
||||
|
||||
#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
|
||||
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/api_registry.h>
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <deque>
|
||||
#include "feature_visitor.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace autotvm {
|
||||
|
||||
using TouchedBuffer = std::string;
|
||||
|
||||
// touch pattern buf[(stride * var) % mod) + other]
|
||||
struct TouchPattern {
|
||||
int64_t stride{0};
|
||||
int64_t mod{-1}; // -1 for +inf
|
||||
|
||||
int64_t count{1};
|
||||
int64_t reuse{1};
|
||||
int64_t thread_count{0}; // count when move thread axis into innermost
|
||||
int64_t thread_reuse{0}; // reuse ratio move thread axis into innermost
|
||||
};
|
||||
|
||||
// all the feature of an iter var
|
||||
struct ItervarFeature {
|
||||
ItervarFeature(VarExpr var,
|
||||
int64_t extent,
|
||||
int nest,
|
||||
AnnotationType ann_type,
|
||||
int64_t topdown,
|
||||
int counter)
|
||||
: length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {}
|
||||
ItervarFeature() {}
|
||||
|
||||
// Axis Attributes
|
||||
int64_t length;
|
||||
int nest_level;
|
||||
AnnotationType ann; // one-hot axis type
|
||||
int64_t topdown_product; // accumulative product of axis length, in top-down order
|
||||
int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
|
||||
// bottomup_product = reuse * count for any touched buffer
|
||||
|
||||
int order; // used for soring axis
|
||||
|
||||
// Arithmetic feature
|
||||
int add_ct{0};
|
||||
int mul_ct{0};
|
||||
int div_ct{0};
|
||||
|
||||
// Memory Touch Feature
|
||||
std::unordered_map<TouchedBuffer, TouchPattern> touch_feature;
|
||||
};
|
||||
|
||||
// extract iter vars and their touch pattern from ir
|
||||
class TouchExtractor : public FeatureVisitor {
|
||||
public:
|
||||
void Analyze(Stmt stmt) {
|
||||
this->Visit(stmt);
|
||||
}
|
||||
|
||||
// arithmetic stats
|
||||
void Visit_(const Add *op) {
|
||||
if (op->type.is_float())
|
||||
itervar_map[itervar_stack_.back()].add_ct++;
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
void Visit_(const Sub *op) {
|
||||
if (op->type.is_float())
|
||||
itervar_map[itervar_stack_.back()].add_ct++;
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
void Visit_(const Mul *op) {
|
||||
if (op->type.is_float())
|
||||
itervar_map[itervar_stack_.back()].mul_ct++;
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
void Visit_(const Div *op) {
|
||||
if (op->type.is_float())
|
||||
itervar_map[itervar_stack_.back()].div_ct++;
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
void Visit_(const Mod *op) {
|
||||
if (op->type.is_float())
|
||||
itervar_map[itervar_stack_.back()].div_ct++;
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
|
||||
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
|
||||
|
||||
private:
|
||||
bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type);
|
||||
void ExitItervar_();
|
||||
void EnterMem_(VarExpr buffer_var, Expr index);
|
||||
void ExitMem_();
|
||||
|
||||
int64_t topdown_product_{1};
|
||||
std::map<std::string, size_t> buffer_counter_;
|
||||
size_t itervar_counter_{0};
|
||||
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
|
||||
std::deque<size_t> skip_stack_size_;
|
||||
|
||||
using IRVisitor::Visit_;
|
||||
};
|
||||
|
||||
} // namespace autotvm
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
|
|
@ -73,7 +73,7 @@ Target CreateTarget(const std::string& target_name,
|
|||
} else {
|
||||
t->device_type = kDLROCM;
|
||||
}
|
||||
t->keys_array.push_back(ir::StringImm::make("rocm"));
|
||||
t->keys_array.push_back(ir::StringImm::make(target_name));
|
||||
t->keys_array.push_back(ir::StringImm::make("gpu"));
|
||||
t->max_num_threads = 256;
|
||||
if (t->device_name == "intel_graphics") {
|
||||
|
@ -195,11 +195,7 @@ Target Target::create(const std::string& target_str) {
|
|||
options.push_back(item);
|
||||
}
|
||||
|
||||
if (device_name == "rasp") {
|
||||
return target::rasp(options);
|
||||
} else {
|
||||
return CreateTarget(target_name, options);
|
||||
}
|
||||
return CreateTarget(target_name, options);
|
||||
}
|
||||
|
||||
/*! \brief Entry to hold the Target context stack. */
|
||||
|
|
|
@ -18,13 +18,13 @@ class GPUCodeVerifier : public IRVisitor {
|
|||
bool Verify(tvm::Stmt stmt,
|
||||
int64_t max_local_memory_per_block,
|
||||
int64_t max_shared_memory_per_block,
|
||||
int64_t max_thread_per_block,
|
||||
int64_t max_threads_per_block,
|
||||
int64_t max_thread_x,
|
||||
int64_t max_thread_y,
|
||||
int64_t max_thread_z) {
|
||||
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
|
||||
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
|
||||
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block);
|
||||
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
|
||||
max_thread_x_ = static_cast<size_t>(max_thread_x);
|
||||
max_thread_y_ = static_cast<size_t>(max_thread_y);
|
||||
max_thread_z_ = static_cast<size_t>(max_thread_z);
|
||||
|
@ -52,7 +52,7 @@ class GPUCodeVerifier : public IRVisitor {
|
|||
|
||||
if (nest_level_ == 0) {
|
||||
// exit a kernel, check the validity
|
||||
valid_ &= thread_per_block_ <= max_thread_per_block_;
|
||||
valid_ &= thread_per_block_ <= max_threads_per_block_;
|
||||
|
||||
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
|
||||
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
|
||||
|
@ -117,7 +117,7 @@ class GPUCodeVerifier : public IRVisitor {
|
|||
|
||||
size_t max_local_memory_per_block_;
|
||||
size_t max_shared_memory_per_block_;
|
||||
size_t max_thread_per_block_;
|
||||
size_t max_threads_per_block_;
|
||||
size_t max_thread_x_, max_thread_y_, max_thread_z_;
|
||||
|
||||
bool valid_{true};
|
||||
|
@ -137,26 +137,34 @@ bool VerifyGPUCode(Stmt stmt,
|
|||
Map<std::string, Expr> constraints) {
|
||||
GPUCodeVerifier verifier;
|
||||
|
||||
auto get_int = [&constraints](std::string key, int64_t def) {
|
||||
auto iter = constraints.find(key);
|
||||
if (iter != constraints.end()) {
|
||||
return ((*iter).second).as<IntImm>()->value;
|
||||
} else {
|
||||
return def;
|
||||
}
|
||||
};
|
||||
int64_t max_local_memory_per_block = INT64_MAX;
|
||||
int64_t max_shared_memory_per_block = INT64_MAX;
|
||||
int64_t max_threads_per_block = INT64_MAX;
|
||||
int64_t max_thread_x = INT64_MAX;
|
||||
int64_t max_thread_y = INT64_MAX;
|
||||
int64_t max_thread_z = INT64_MAX;
|
||||
|
||||
int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX);
|
||||
int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX);
|
||||
int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX);
|
||||
int64_t max_thread_x = get_int("max_thread_x", INT64_MAX);
|
||||
int64_t max_thread_y = get_int("max_thread_y", INT64_MAX);
|
||||
int64_t max_thread_z = get_int("max_thread_z", INT64_MAX);
|
||||
for (auto iter : constraints) {
|
||||
if (iter.first == "max_local_memory_per_block")
|
||||
max_local_memory_per_block = (iter.second).as<IntImm>()->value;
|
||||
else if (iter.first == "max_shared_memory_per_block")
|
||||
max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
|
||||
else if (iter.first == "max_threads_per_block")
|
||||
max_threads_per_block = (iter.second).as<IntImm>()->value;
|
||||
else if (iter.first == "max_thread_x")
|
||||
max_thread_x = (iter.second).as<IntImm>()->value;
|
||||
else if (iter.first == "max_thread_y")
|
||||
max_thread_y = (iter.second).as<IntImm>()->value;
|
||||
else if (iter.first == "max_thread_z")
|
||||
max_thread_z = (iter.second).as<IntImm>()->value;
|
||||
else
|
||||
LOG(FATAL) << "Invalid check item: " << iter.first;
|
||||
}
|
||||
|
||||
return verifier.Verify(stmt,
|
||||
max_local_memory_per_block,
|
||||
max_shared_memory_per_block,
|
||||
max_thread_per_block,
|
||||
max_threads_per_block,
|
||||
max_thread_x,
|
||||
max_thread_y,
|
||||
max_thread_z);
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
"""
|
||||
Test the tuner
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
|
||||
import tvm
|
||||
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm.tuner import RandomTuner
|
||||
|
||||
@autotvm.template
|
||||
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
|
||||
"""An example template for testing"""
|
||||
assert N == 1, "Only consider batch_size = 1 in this template"
|
||||
|
||||
data = tvm.placeholder((N, CI, H, W), name='data')
|
||||
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
|
||||
|
||||
rc = tvm.reduce_axis((0, CI), name='rc')
|
||||
ry = tvm.reduce_axis((0, KH), name='ry')
|
||||
rx = tvm.reduce_axis((0, KW), name='rx')
|
||||
|
||||
conv = tvm.compute(
|
||||
(N, CO, H - KH + 1, W - KW + 1),
|
||||
lambda nn, ff, yy, xx: tvm.sum(
|
||||
data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx],
|
||||
axis=[rc, ry, rx]), tag="conv2d_nchw")
|
||||
|
||||
s = tvm.create_schedule([conv.op])
|
||||
|
||||
output = conv
|
||||
OL = s.cache_write(conv, 'local')
|
||||
|
||||
# create cache stage
|
||||
AA = s.cache_read(data, 'shared', [OL])
|
||||
WW = s.cache_read(kernel, 'shared', [OL])
|
||||
AL = s.cache_read(AA, 'local', [OL])
|
||||
WL = s.cache_read(WW, 'local', [OL])
|
||||
|
||||
# tile and bind spatial axes
|
||||
n, f, y, x = s[output].op.axis
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
|
||||
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
|
||||
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
|
||||
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
|
||||
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
|
||||
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
|
||||
kernel_scope = n # this is the scope to attach global config inside this kernel
|
||||
|
||||
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
|
||||
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
|
||||
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[output].bind(vf, tvm.thread_axis("vthread"))
|
||||
s[output].bind(vy, tvm.thread_axis("vthread"))
|
||||
s[output].bind(vx, tvm.thread_axis("vthread"))
|
||||
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
|
||||
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
|
||||
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
|
||||
s[OL].compute_at(s[output], tx)
|
||||
|
||||
# tile and bind reduction axes
|
||||
n, f, y, x = s[OL].op.axis
|
||||
rc, ry, rx = s[OL].op.reduce_axis
|
||||
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
|
||||
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
|
||||
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
|
||||
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
|
||||
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
|
||||
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
|
||||
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
|
||||
|
||||
s[AA].compute_at(s[OL], rxo)
|
||||
s[WW].compute_at(s[OL], rxo)
|
||||
s[AL].compute_at(s[OL], rxm)
|
||||
s[WL].compute_at(s[OL], rxm)
|
||||
|
||||
# cooperative fetching
|
||||
for load in [AA, WW]:
|
||||
n, f, y, x = s[load].op.axis
|
||||
fused = s[load].fuse(n, f, y, x)
|
||||
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
|
||||
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
|
||||
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
|
||||
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
|
||||
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
|
||||
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
|
||||
# tune unroll
|
||||
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
|
||||
cfg.define_knob("unroll_explicit", [0, 1])
|
||||
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
|
||||
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
|
||||
|
||||
return s, [data, kernel, conv]
|
||||
|
||||
def get_sample_task(target=tvm.target.cuda(), target_host=None):
|
||||
"""return a sample task for testing"""
|
||||
task = autotvm.task.create(conv2d_no_batching,
|
||||
args=(1, 7, 7, 512, 512, 3, 3),
|
||||
target=target, target_host=target_host)
|
||||
return task, target
|
||||
|
||||
|
||||
def test_task_tuner_without_measurement():
|
||||
"""test task and tuner without measurement"""
|
||||
task, target = get_sample_task()
|
||||
|
||||
def measure_batch(inputs):
|
||||
from tvm.autotvm import MeasureResult
|
||||
|
||||
results = []
|
||||
for inp in inputs:
|
||||
tic = time.time()
|
||||
# do nothing
|
||||
time.sleep(0.001)
|
||||
results.append(MeasureResult([time.time() - tic], 0,
|
||||
time.time() - tic, time.time()))
|
||||
return results
|
||||
measure_option = autotvm.measure_option(mode='custom',
|
||||
custom_measure_batch=measure_batch)
|
||||
|
||||
logging.info("%s", task.config_space)
|
||||
|
||||
# new tuner and recorder
|
||||
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
|
||||
tuner = tuner_class(task)
|
||||
tuner.tune(n_trial=10, measure_option=measure_option)
|
||||
|
||||
def test_tuning_with_measure():
|
||||
def check(target, target_host):
|
||||
ctx = tvm.context(target, 0)
|
||||
if not ctx.exist:
|
||||
logging.info("Skip test because %s is not available" % target)
|
||||
return
|
||||
|
||||
# init task
|
||||
task, target = get_sample_task(target, target_host)
|
||||
logging.info("%s", task.config_space)
|
||||
|
||||
measure_option = autotvm.measure_option(mode='local',
|
||||
timeout=4,
|
||||
number=2)
|
||||
|
||||
tuner = RandomTuner(task)
|
||||
tuner.tune(n_trial=10, measure_option=measure_option)
|
||||
|
||||
check("cuda", None)
|
||||
check("opencl", None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# only print log when invoked from main
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
test_task_tuner_without_measurement()
|
||||
test_tuning_with_measure()
|
|
@ -0,0 +1,50 @@
|
|||
"""Common utilities for testing autotvm"""
|
||||
import time
|
||||
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm import MeasureInput, MeasureResult
|
||||
|
||||
@autotvm.template
|
||||
def matmul(N, L, M, dtype):
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
k = s[C].op.reduce_axis[0]
|
||||
|
||||
##### define space begin #####
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_y", y, num_outputs=2)
|
||||
cfg.define_split("tile_x", x, num_outputs=2)
|
||||
##### define space end #####
|
||||
|
||||
# schedule according to config
|
||||
yo, yi = cfg["tile_y"].apply(s, C, y)
|
||||
xo, xi = cfg["tile_x"].apply(s, C, x)
|
||||
|
||||
s[C].reorder(yo, xo, k, yi, xi)
|
||||
|
||||
return s, [A, B, C]
|
||||
|
||||
def get_sample_task(n=128):
|
||||
"""return a sample task for testing"""
|
||||
target = tvm.target.create("llvm")
|
||||
task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
|
||||
return task, target
|
||||
|
||||
def get_sample_records(n):
|
||||
"""get sample records for testing"""
|
||||
tsk, target = get_sample_task()
|
||||
|
||||
inps, ress = [], []
|
||||
for i in range(n):
|
||||
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
|
||||
ress.append(MeasureResult((i+1,), 0, i, time.time()))
|
||||
return list(zip(inps, ress))
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
"""Test database"""
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm import database
|
||||
from tvm.autotvm.measure.measure_methods import HashMismatchError
|
||||
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
|
||||
|
||||
from test_autotvm_common import get_sample_task, get_sample_records
|
||||
|
||||
def test_save_load():
|
||||
logging.info("test basic db load/save ...")
|
||||
records = get_sample_records(3)
|
||||
inp1, res1 = records[0]
|
||||
inp2, res2 = records[1]
|
||||
inp3, _ = records[2]
|
||||
|
||||
_db = database.DummyDatabase()
|
||||
_db.flush()
|
||||
_db.save(inp1, res1)
|
||||
_db.save(inp2, res2)
|
||||
|
||||
load1 = _db.load(inp1)
|
||||
load2 = _db.load(inp2)
|
||||
load3 = _db.load(inp3)
|
||||
assert load1 == res1
|
||||
assert load2 == res2
|
||||
assert load3 is None
|
||||
assert load1 != load2
|
||||
|
||||
TRIAL_LIMIT = 2
|
||||
|
||||
def test_db_filter():
|
||||
logging.info("test db filter ...")
|
||||
|
||||
# Pick a GPU target because there are more likely to be failures/invalid configs
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
batch_size = 2
|
||||
|
||||
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
ct = 0
|
||||
all_inputs = list()
|
||||
all_results = list()
|
||||
batches = list()
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
while ct < TRIAL_LIMIT:
|
||||
inputs = list()
|
||||
for i in range(batch_size):
|
||||
cfg = tuner.next_batch(1)[0]
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
all_inputs.append(inputs[-1])
|
||||
batches.append(inputs)
|
||||
results = measure_batch(inputs)
|
||||
all_results += results
|
||||
ct += 1
|
||||
|
||||
del measure_batch
|
||||
|
||||
db = database.DummyDatabase()
|
||||
db.flush()
|
||||
|
||||
# First setting, memoize one input at a time, check that each is saved and replayed
|
||||
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2, replay_db=db)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
for i in range(len(all_inputs)+1):
|
||||
db.flush()
|
||||
for j in range(i):
|
||||
db.save(all_inputs[j], all_results[j])
|
||||
|
||||
for k in range(len(batches)):
|
||||
batch = batches[k]
|
||||
batch_result = measure_batch(batch)
|
||||
for l in range(batch_size):
|
||||
all_idx = k*batch_size + l
|
||||
assert batch_result[l] is not None
|
||||
if all_idx < i:
|
||||
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
|
||||
"(no retry) EXPECTED MATCH, GOT MISMATCH"
|
||||
else:
|
||||
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
|
||||
"(no retry) EXPECTED MISMATCH, GOT MATCH"
|
||||
|
||||
del measure_batch
|
||||
|
||||
def test_db_hash():
|
||||
logging.info("test db hash check ...")
|
||||
inp1, res1 = get_sample_records(1)[0]
|
||||
inp2 = copy.deepcopy(inp1)
|
||||
inp1.config.code_hash = 'cafecafe'
|
||||
inp2.config.code_hash = 'dbffdbff'
|
||||
res2l = list(tuple(res1))
|
||||
|
||||
# set timestamp
|
||||
res2l[-1] = -1
|
||||
res2 = MeasureResult(*res2l)
|
||||
_db = database.DummyDatabase()
|
||||
_db.flush()
|
||||
_db.save(inp1, res1, extend=True)
|
||||
_db.save(inp2, res2, extend=True)
|
||||
|
||||
load1 = _db.load(inp1)
|
||||
load2 = _db.load(inp2)
|
||||
assert load1 != load2
|
||||
assert load1.timestamp != -1
|
||||
assert load2.timestamp == -1
|
||||
|
||||
def test_db_latest_all():
|
||||
logging.info("test db load w/ multiple results ...")
|
||||
inp1, res1 = get_sample_records(1)[0]
|
||||
lis1 = list(tuple(res1))
|
||||
lis2 = list(tuple(res1))
|
||||
lis3 = list(tuple(res1))
|
||||
|
||||
# set timestamp
|
||||
lis1[-1] = 0.0
|
||||
lis2[-1] = 1.1
|
||||
lis3[-1] = 9999.9999
|
||||
res1 = MeasureResult(*lis1)
|
||||
res2 = MeasureResult(*lis2)
|
||||
res3 = MeasureResult(*lis3)
|
||||
|
||||
_db = database.DummyDatabase()
|
||||
_db.flush()
|
||||
_db.save(inp1, res1, extend=True)
|
||||
load1 = _db.load(inp1)
|
||||
assert load1.timestamp == 0.0
|
||||
_db.save(inp1, res2, extend=True)
|
||||
load2 = _db.load(inp1)
|
||||
assert load2.timestamp == 1.1
|
||||
_db.save(inp1, res3, extend=True)
|
||||
load3 = _db.load(inp1)
|
||||
assert load3.timestamp == 9999.9999
|
||||
|
||||
load4 = _db.load(inp1, get_all=True)
|
||||
assert encode(inp1, load4[0]) == encode(inp1, res1)
|
||||
assert encode(inp1, load4[1]) == encode(inp1, res2)
|
||||
assert encode(inp1, load4[2]) == encode(inp1, res3)
|
||||
|
||||
def test_db_save_replay():
|
||||
logging.info("test db save (from measure_batch) and replay ...")
|
||||
_db = database.DummyDatabase()
|
||||
_db.flush()
|
||||
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
measure_option = autotvm.measure_option(mode='local-nofork',
|
||||
timeout=2,
|
||||
replay_db=_db, save_to_replay_db=True)
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
batch_size = 2
|
||||
|
||||
ct = 0
|
||||
all_inputs = list()
|
||||
all_results = list()
|
||||
batches = list()
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
while ct < TRIAL_LIMIT:
|
||||
inputs = list()
|
||||
for i in range(batch_size):
|
||||
cfg = tuner.next_batch(1)[0]
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
all_inputs.append(inputs[-1])
|
||||
batches.append(inputs)
|
||||
results = measure_batch(inputs)
|
||||
all_results += results
|
||||
ct += 1
|
||||
|
||||
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
|
||||
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
|
||||
|
||||
all_results_2 = measure_batch(all_inputs)
|
||||
all_results_3 = measure_batch(all_inputs)
|
||||
|
||||
for i in range(len(all_results)):
|
||||
encr1 = encode(all_inputs[i], all_results[i])
|
||||
encr2 = encode(all_inputs[i], all_results_2[i])
|
||||
encr3 = encode(all_inputs[i], all_results_3[i])
|
||||
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
|
||||
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
|
||||
|
||||
del measure_batch
|
||||
|
||||
def test_check_hashmismatch():
|
||||
logging.info("test hash mismatch check")
|
||||
|
||||
task, target = get_sample_task()
|
||||
|
||||
ctx = tvm.context(str(target))
|
||||
if not ctx.exist:
|
||||
logging.warning("Skip this test because there is no supported device for test")
|
||||
|
||||
measure_option = autotvm.measure_option(mode='local-nofork')
|
||||
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
|
||||
|
||||
inputs = list()
|
||||
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
|
||||
# notvalidh is not a valid CRC32 hash (not hex)
|
||||
cfg.code_hash = 'notvalidh'
|
||||
inputs.append((MeasureInput(target, task, cfg)))
|
||||
|
||||
try:
|
||||
results = measure_batch(inputs)
|
||||
assert False, "HashMismatchError should be raised"
|
||||
except HashMismatchError:
|
||||
pass
|
||||
|
||||
del measure_batch
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test_save_load()
|
||||
test_db_filter()
|
||||
test_db_hash()
|
||||
test_db_latest_all()
|
||||
test_db_save_replay()
|
||||
test_check_hashmismatch()
|
|
@ -0,0 +1,36 @@
|
|||
"""Test dispatcher.
|
||||
The dispatcher can choose which template to use according
|
||||
to the parameters of workload"""
|
||||
|
||||
from collections import namedtuple
|
||||
from tvm.autotvm.task import dispatcher, DispatchContext
|
||||
|
||||
SimpleWorkload = namedtuple("SimpleWorkload", ["key"])
|
||||
SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
|
||||
|
||||
def test_dispatch():
|
||||
@dispatcher
|
||||
def my_dispatcher(a, b):
|
||||
return SimpleWorkload(key=a + b)
|
||||
|
||||
@my_dispatcher.register("spatial_pack")
|
||||
def _sp_pack_add(cfg, a, b):
|
||||
return b + 100
|
||||
|
||||
@my_dispatcher.register("im2col")
|
||||
def _im2col_add(cfg, a, b):
|
||||
return a + 1
|
||||
|
||||
class SimpleDispatcher(DispatchContext):
|
||||
def query(self, target, workload):
|
||||
tkey = "spatial_pack" if workload.key > 2 else "im2col"
|
||||
return SimpleConfig(tkey)
|
||||
|
||||
with SimpleDispatcher():
|
||||
# im2col
|
||||
assert my_dispatcher(1, 0) == 2
|
||||
# spack
|
||||
assert my_dispatcher(1, 100) == 200
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dispatch()
|
|
@ -0,0 +1,47 @@
|
|||
"""Test local executor"""
|
||||
import time
|
||||
|
||||
from tvm.autotvm.measure import LocalExecutor, executor
|
||||
|
||||
def slow(n):
|
||||
r = 0
|
||||
for i in range(0, n+1):
|
||||
r += i
|
||||
return r
|
||||
|
||||
def fast(n):
|
||||
return n*(n+1)//2
|
||||
|
||||
def test_local_measure_async():
|
||||
ex = LocalExecutor()
|
||||
f1 = ex.submit(slow, 9999999)
|
||||
f2 = ex.submit(fast, 9999999)
|
||||
t1 = 0
|
||||
t2 = 0
|
||||
while True:
|
||||
if t1 == 0 and f1.done():
|
||||
t1 = time.time()
|
||||
if t2 == 0 and f2.done():
|
||||
t2 = time.time()
|
||||
if t1 != 0 and t2 != 0:
|
||||
break
|
||||
assert t2 < t1, "Expected fast async job to finish first!"
|
||||
assert f1.get() == f2.get()
|
||||
|
||||
def timeout_job(n):
|
||||
time.sleep(n * 1.5)
|
||||
|
||||
def test_timeout():
|
||||
timeout = 0.5
|
||||
|
||||
ex = LocalExecutor(timeout=timeout)
|
||||
|
||||
f1 = ex.submit(timeout_job, timeout)
|
||||
while not f1.done():
|
||||
pass
|
||||
res = f1.get()
|
||||
assert isinstance(res, executor.TimeoutError)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_local_measure_async()
|
||||
test_timeout()
|
|
@ -0,0 +1,99 @@
|
|||
"""Test feature extraction"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tvm
|
||||
from tvm.autotvm import feature
|
||||
|
||||
def test_iter_feature_gemm():
|
||||
N = 128
|
||||
|
||||
k = tvm.reduce_axis((0, N), 'k')
|
||||
A = tvm.placeholder((N, N), name='A')
|
||||
B = tvm.placeholder((N, N), name='B')
|
||||
C = tvm.compute(
|
||||
A.shape,
|
||||
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
|
||||
name='C')
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
feas = feature.get_itervar_feature(s, [A, B, C], take_log=False)
|
||||
|
||||
expected = [
|
||||
{
|
||||
'_attr_': [128, 1, 128, 2097152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
|
||||
'A_0': [128, -1, 16384, 128, 0, 0], 'B_0': [0, -1, 16384, 128, 0, 0],
|
||||
'C_0': [128, -1, 16384, 128, 0, 0], 'C_1': [128, -1, 16384, 128, 0, 0],
|
||||
},
|
||||
{
|
||||
'_attr_': [128, 2, 16384, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
|
||||
'A_0': [0, -1, 128, 128, 0, 0], 'B_0': [1, -1, 16384, 1, 0, 0],
|
||||
'C_0': [1, -1, 128, 128, 0, 0], 'C_1': [1, -1, 128, 128, 0, 0],
|
||||
},
|
||||
{
|
||||
'_attr_': [128, 3, 2097152, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
|
||||
'A_0': [1, -1, 128, 1, 0, 0], 'B_0': [128, -1, 128, 1, 0, 0],
|
||||
'C_1': [0, -1, 1, 128, 0, 0], 'C_2': [0, -1, 1, 128, 0, 0],
|
||||
}
|
||||
]
|
||||
|
||||
for ans, row in zip(expected, feas):
|
||||
for pair in row:
|
||||
if pair[0] not in ans:
|
||||
continue
|
||||
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
|
||||
|
||||
|
||||
def test_feature_shape():
|
||||
"""test the dimensions of flatten feature are the same"""
|
||||
|
||||
N = 1024
|
||||
n_sample = 100
|
||||
|
||||
def get_gemm_feature(target):
|
||||
k = tvm.reduce_axis((0, N), 'k')
|
||||
A = tvm.placeholder((N, N), name='A')
|
||||
B = tvm.placeholder((N, N), name='B')
|
||||
C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
|
||||
name='C')
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
y, x = s[C].op.axis
|
||||
axes = list(s[C].tile(y, x, 8, 8)) + [k]
|
||||
perm = np.random.permutation(5)
|
||||
axes = [axes[x] for x in perm]
|
||||
s[C].reorder(*axes)
|
||||
|
||||
if "gpu" in target.keys:
|
||||
pick = []
|
||||
# filter out reduction axis
|
||||
for i in range(len(perm)):
|
||||
if perm[i] != 4:
|
||||
pick.append(axes[i])
|
||||
s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
|
||||
s[C].bind(pick[1], tvm.thread_axis("vthread"))
|
||||
s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))
|
||||
|
||||
with target:
|
||||
feas = feature.get_itervar_feature(s, [A, B, C])
|
||||
feas = feature.flatten_itervar_feature(feas)
|
||||
return feas
|
||||
|
||||
targets = [
|
||||
tvm.target.cuda(),
|
||||
tvm.target.mali(),
|
||||
tvm.target.rasp(),
|
||||
]
|
||||
|
||||
for target in targets:
|
||||
dim = len(get_gemm_feature(target))
|
||||
for i in range(n_sample):
|
||||
assert dim == len(get_gemm_feature(target)), "dimensions of feature do not match" \
|
||||
" for different configurations"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_iter_feature_gemm()
|
||||
test_feature_shape()
|
|
@ -0,0 +1,77 @@
|
|||
"""Test flop calculation"""
|
||||
|
||||
import tvm
|
||||
import numpy as np
|
||||
|
||||
from tvm.autotvm.task.task import compute_flop
|
||||
|
||||
def test_conv():
|
||||
for i in range(5):
|
||||
N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
|
||||
D = tvm.placeholder((N, CI, H, W))
|
||||
K = tvm.placeholder((CO, CI, KH, KW))
|
||||
|
||||
KH = min(H, KH)
|
||||
KW = min(W, KW)
|
||||
|
||||
ci = tvm.reduce_axis((0, CI))
|
||||
kh = tvm.reduce_axis((0, KH))
|
||||
kw = tvm.reduce_axis((0, KW))
|
||||
|
||||
OH = (H - KH) + 1
|
||||
OW = (W - KW) + 1
|
||||
|
||||
C = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:
|
||||
tvm.sum(D[n][ci][h][w] * K[co][ci][h][w], axis=[ci, kh, kw]))
|
||||
|
||||
s = tvm.create_schedule([C.op])
|
||||
|
||||
assert compute_flop(s) == 2 * N * CO * OH * OW * CI * KH * KW
|
||||
|
||||
def test_pack_gemm():
|
||||
for i in range(5):
|
||||
N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)]
|
||||
A = tvm.placeholder((N, L))
|
||||
B = tvm.placeholder((M, L))
|
||||
k = tvm.reduce_axis((0, L))
|
||||
|
||||
bn = 4
|
||||
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
|
||||
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
|
||||
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
|
||||
tvm.sum(A_pack[i, k, ii] * B_pack[j, k, jj], axis=[k]))
|
||||
C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn])
|
||||
|
||||
s = tvm.create_schedule([C.op])
|
||||
assert compute_flop(s) == 2 * N * L * M
|
||||
|
||||
def test_outer_dot():
|
||||
for i in range(5):
|
||||
N, M = [np.random.randint(10, 128) * 4 for _ in range(2)]
|
||||
A = tvm.placeholder((N,))
|
||||
B = tvm.placeholder((M,))
|
||||
|
||||
C = tvm.compute((N, M), lambda i, j: A[i] * B[j])
|
||||
|
||||
s = tvm.create_schedule([C.op])
|
||||
assert compute_flop(s) == N * M
|
||||
|
||||
def test_move():
|
||||
"""No float number operation in simple move. So the estimator should raise an error """
|
||||
N = 1024
|
||||
|
||||
A = tvm.placeholder((N,))
|
||||
C = tvm.compute((N,), lambda i: A[i])
|
||||
s = tvm.create_schedule([C.op])
|
||||
|
||||
try:
|
||||
compute_flop(s)
|
||||
assert False
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv()
|
||||
test_pack_gemm()
|
||||
test_outer_dot()
|
||||
test_move()
|
|
@ -0,0 +1,65 @@
|
|||
"""test the correctness of dump and load of data log"""
|
||||
import time
|
||||
|
||||
import tvm
|
||||
from tvm.contrib import util
|
||||
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm.measure import MeasureInput, MeasureResult, MeasureErrorNo
|
||||
from tvm.autotvm.record import encode, decode, ApplyHistoryBest, measure_str_key
|
||||
|
||||
from test_autotvm_common import get_sample_task
|
||||
|
||||
def test_load_dump():
|
||||
task, target = get_sample_task()
|
||||
|
||||
inp = MeasureInput(target, task, task.config_space.get(0))
|
||||
result = MeasureResult((2.0, 2.23, 0.23, 0.123, 0.234, 0.123), MeasureErrorNo.NO_ERROR,
|
||||
2.3, time.time())
|
||||
|
||||
for protocol in ['json', 'pickle']:
|
||||
row = encode(inp, result, protocol=protocol)
|
||||
inp_2, result_2 = decode(row, protocol=protocol)
|
||||
|
||||
assert measure_str_key(inp) == measure_str_key(inp_2), \
|
||||
"%s vs %s" % (measure_str_key(inp), measure_str_key(inp_2))
|
||||
assert result.costs == result_2.costs
|
||||
assert result.error_no == result_2.error_no
|
||||
assert result.timestamp == result_2.timestamp
|
||||
|
||||
|
||||
def test_file_io():
|
||||
temp = util.tempdir()
|
||||
file_path = temp.relpath("temp.log")
|
||||
|
||||
tsk, target = get_sample_task()
|
||||
inputs = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(0, 10)]
|
||||
results = [MeasureResult((i, ), 0, 0, 0) for i in range(0, 10)]
|
||||
|
||||
with open(file_path, "w") as fo:
|
||||
cb = autotvm.callback.log_to_file(fo)
|
||||
cb(None, inputs, results)
|
||||
|
||||
ref = zip(inputs, results)
|
||||
for x, y in zip(ref, autotvm.record.load_from_file(file_path)):
|
||||
assert x[1] == y[1]
|
||||
|
||||
|
||||
def test_apply_history_best():
|
||||
tsk, target = get_sample_task()
|
||||
|
||||
records = [
|
||||
(MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)),
|
||||
(MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)),
|
||||
(MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)),
|
||||
(MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0))
|
||||
]
|
||||
hist_best = ApplyHistoryBest(records)
|
||||
x = hist_best.query(target, tsk.workload)
|
||||
assert str(x) == str(tsk.config_space.get(2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_dump()
|
||||
test_apply_history_best()
|
||||
test_file_io()
|
|
@ -0,0 +1,30 @@
|
|||
"""Test space definition primitives"""
|
||||
|
||||
import tvm
|
||||
from tvm.autotvm.task.space import ConfigSpace
|
||||
|
||||
def gemm_func(cfg, N):
|
||||
A = tvm.placeholder((N, N), name='A')
|
||||
B = tvm.placeholder((N, N), name='B')
|
||||
|
||||
k = tvm.reduce_axis((0, N), name='k')
|
||||
C = tvm.compute((N, N), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=[k]), name='C')
|
||||
|
||||
s = tvm.create_schedule([C.op])
|
||||
|
||||
y, x = s[C].op.axis
|
||||
|
||||
cfg.define_split('tile_y', cfg.axis(y), num_outputs=2)
|
||||
cfg.define_split('tile_x', cfg.axis(x), num_outputs=2)
|
||||
|
||||
return s, [A, B, C]
|
||||
|
||||
def test_split():
|
||||
cfg = ConfigSpace()
|
||||
|
||||
gemm_func(cfg, 128)
|
||||
assert len(cfg) == 64
|
||||
assert len(cfg.space_map['tile_y']) == 8
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_split()
|
|
@ -31,14 +31,14 @@ def test_shared_memory():
|
|||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=4 * M - 1,
|
||||
max_thread_per_block=M))]}):
|
||||
max_threads_per_block=M))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert not valid[0]
|
||||
|
||||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=4 * M,
|
||||
max_thread_per_block=M))]}):
|
||||
max_threads_per_block=M))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert valid[0]
|
||||
|
||||
|
@ -66,14 +66,14 @@ def test_local_memory():
|
|||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_local_memory_per_block=4 * M - 1,
|
||||
max_thread_per_block=1))]}):
|
||||
max_threads_per_block=1))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert not valid[0]
|
||||
|
||||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_local_memory_per_block=4 * M,
|
||||
max_thread_per_block=1))]}):
|
||||
max_threads_per_block=1))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert valid[0]
|
||||
|
||||
|
@ -101,21 +101,21 @@ def test_num_thread():
|
|||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N - 1))]}):
|
||||
max_threads_per_block=N - 1))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert not valid[0]
|
||||
|
||||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N))]}):
|
||||
max_threads_per_block=N))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert valid[0]
|
||||
|
||||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N,
|
||||
max_threads_per_block=N,
|
||||
max_thread_y=M-1))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert not valid[0]
|
||||
|
@ -123,7 +123,7 @@ def test_num_thread():
|
|||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N,
|
||||
max_threads_per_block=N,
|
||||
max_thread_y=M))]}):
|
||||
tvm.build(s, [A, B], target)
|
||||
assert valid[0]
|
||||
|
@ -151,14 +151,14 @@ def test_multiple_kernels():
|
|||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N - 1))]}):
|
||||
max_threads_per_block=N - 1))]}):
|
||||
tvm.build(s, [A, C], target)
|
||||
assert not valid[0]
|
||||
|
||||
with tvm.build_config(**{"add_lower_pass": [
|
||||
(2, get_verify_pass(valid,
|
||||
max_shared_memory_per_block=0,
|
||||
max_thread_per_block=N))]}):
|
||||
max_threads_per_block=N))]}):
|
||||
tvm.build(s, [A, C], target)
|
||||
assert valid[0]
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ make doc
|
|||
jsdoc web/tvm_runtime.js web/README.md || exit -1
|
||||
mv out docs/_build/html/jsdoc || exit -1
|
||||
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
|
||||
|
||||
cd docs
|
||||
PYTHONPATH=`pwd`/../python make html || exit -1
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
export PYTHONPATH=python:apps/extension/python
|
||||
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
|
||||
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
|
||||
|
||||
# Test TVM
|
||||
make cython || exit -1
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
export PYTHONPATH=python:topi/python
|
||||
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
|
||||
|
||||
TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1
|
||||
TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
Auto tuning
|
||||
-------------
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
|
||||
=========================================================================
|
||||
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
|
||||
|
||||
This is an advanced tutorial for writing high performance tunable template for
|
||||
NVIDIA GPU. By running auto-tuner on this template, we can outperform the
|
||||
vendor provided library CuDNN in many cases.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import tvm
|
||||
import topi
|
||||
|
||||
from tvm import autotvm
|
||||
|
||||
######################################################################
|
||||
# Step 1: Define the search space
|
||||
# ---------------------------------
|
||||
# There are plenty of useful schedule primitives in tvm. You can also find
|
||||
# some tutorials that describe them in more details, such as
|
||||
# (1). :doc:``Optimizing Conv2d on NVIDIA GPU <../optimize/opt_conv_cuda>`
|
||||
# (2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.ai/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example.html>`_
|
||||
#
|
||||
# However, their implementations are manually tuned for some special input
|
||||
# shapes. In this section, we build a large enough space to cover
|
||||
# the techniques used in these tutorials. Then we rely on the efficient auto-tuner
|
||||
# to search through this space and pick some good configurations.
|
||||
#
|
||||
# If you are familiar with writing cuda schedule, you can find the following
|
||||
# template is very general. Actually this template can be easily modified
|
||||
# to tune other operators such as depthwise convolution and gemm.
|
||||
# In order to fully understand this template, you should be familiar with
|
||||
# the schedule primitives and auto tuning API. You can refer to the above
|
||||
# tutorials and :doc:`autotvm tutorial <tune_simple_template>`
|
||||
#
|
||||
# It is worth noting that the search space for a conv2d operator
|
||||
# can be very large (at the level of 10^9 for some input shapes)
|
||||
#
|
||||
|
||||
@autotvm.template
|
||||
def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
|
||||
assert N == 1, "Only consider batch_size = 1 in this template"
|
||||
|
||||
data = tvm.placeholder((N, CI, H, W), name='data')
|
||||
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
|
||||
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, 'float32')
|
||||
s = tvm.create_schedule([conv.op])
|
||||
|
||||
# inline padding
|
||||
pad_data = s[conv].op.input_tensors[0]
|
||||
s[pad_data].compute_inline()
|
||||
data, raw_data = pad_data, data
|
||||
|
||||
output = conv
|
||||
OL = s.cache_write(conv, 'local')
|
||||
|
||||
# create cache stage
|
||||
AA = s.cache_read(data, 'shared', [OL])
|
||||
WW = s.cache_read(kernel, 'shared', [OL])
|
||||
AL = s.cache_read(AA, 'local', [OL])
|
||||
WL = s.cache_read(WW, 'local', [OL])
|
||||
|
||||
# tile and bind spatial axes
|
||||
n, f, y, x = s[output].op.axis
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
|
||||
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
|
||||
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
|
||||
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
|
||||
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
|
||||
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
|
||||
kernel_scope = n # this is the scope to attach global config inside this kernel
|
||||
|
||||
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
|
||||
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
|
||||
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[output].bind(vf, tvm.thread_axis("vthread"))
|
||||
s[output].bind(vy, tvm.thread_axis("vthread"))
|
||||
s[output].bind(vx, tvm.thread_axis("vthread"))
|
||||
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
|
||||
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
|
||||
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
|
||||
s[OL].compute_at(s[output], tx)
|
||||
|
||||
# tile and bind reduction axes
|
||||
n, f, y, x = s[OL].op.axis
|
||||
rc, ry, rx = s[OL].op.reduce_axis
|
||||
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
|
||||
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
|
||||
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
|
||||
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
|
||||
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
|
||||
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
|
||||
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
|
||||
|
||||
s[AA].compute_at(s[OL], rxo)
|
||||
s[WW].compute_at(s[OL], rxo)
|
||||
s[AL].compute_at(s[OL], rxm)
|
||||
s[WL].compute_at(s[OL], rxm)
|
||||
|
||||
# cooperative fetching
|
||||
for load in [AA, WW]:
|
||||
n, f, y, x = s[load].op.axis
|
||||
fused = s[load].fuse(n, f, y, x)
|
||||
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
|
||||
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
|
||||
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
|
||||
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
|
||||
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
|
||||
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
|
||||
# tune unroll
|
||||
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
|
||||
cfg.define_knob("unroll_explicit", [0, 1])
|
||||
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
|
||||
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
|
||||
|
||||
return s, [raw_data, kernel, conv]
|
||||
|
||||
######################################################################
|
||||
# Step 2: Search through the space
|
||||
# ---------------------------------
|
||||
# We pick the last layer on resnet as test case.
|
||||
# Since our space is very large, :code:`XGBoostTuner` is most suitable
|
||||
# for our case. Here we only do 20 trials for demonstration.
|
||||
# In practice, making 1000 trials usually can find some good kernels
|
||||
# for this template
|
||||
|
||||
# logging config (for printing tuning log to screen)
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
|
||||
# the last layer in resnet
|
||||
task = autotvm.task.create(conv2d_no_batching,
|
||||
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
|
||||
target='cuda')
|
||||
print(task.config_space)
|
||||
|
||||
# use local gpu, measure 5 times for every config to reduce variance
|
||||
# run 8 parallel threads for compilation
|
||||
measure_option = autotvm.measure_option(mode='local',
|
||||
number=10,
|
||||
parallel_num=8,
|
||||
timeout=20)
|
||||
|
||||
# begin tuning, log records to file `cache.tsv`
|
||||
tuner = autotvm.tuner.XGBTuner(task)
|
||||
tuner.tune(n_trial=20,
|
||||
measure_option=measure_option,
|
||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
||||
|
||||
# get best config from cache file
|
||||
dispatch_context = autotvm.apply_history_best("cache.tsv")
|
||||
best_config = dispatch_context.query(task.target, task.workload)
|
||||
print("\nBest config:")
|
||||
print(best_config)
|
||||
|
|
@ -0,0 +1,284 @@
|
|||
"""
|
||||
Writing tunable template and Using auto-tuner
|
||||
=============================================
|
||||
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
|
||||
|
||||
This is an introduction tutorial to the auto-tuning module in tvm.
|
||||
|
||||
There are two steps in auto-tuning.
|
||||
The first step is defining a search space.
|
||||
The second step is running a search algorithm to explore through this space.
|
||||
In this tutorial, you can learn how to perform these two steps in tvm.
|
||||
The whole workflow is illustrated by a matrix multiplication example.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
|
||||
# the module is called `autotvm`
|
||||
from tvm import autotvm
|
||||
|
||||
######################################################################
|
||||
# Step 1: Define the search space
|
||||
# ---------------------------------
|
||||
# In this section, we will rewrite a deterministic tvm schedule code to a
|
||||
# tunable schedule template. You can regard the process of search space definition
|
||||
# as the parametrization of our exiting schedule code.
|
||||
#
|
||||
# To begin with, here is how we implement a blocked matrix multiplication in tvm
|
||||
|
||||
# Matmul V0: Constant tiling factor
|
||||
def matmul_v0(N, L, M, dtype):
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
k = s[C].op.reduce_axis[0]
|
||||
|
||||
yo, yi = s[C].split(y, 8)
|
||||
xo, xi = s[C].split(x, 8)
|
||||
|
||||
s[C].reorder(yo, xo, k, yi, xi)
|
||||
|
||||
return s, [A, B, C]
|
||||
|
||||
#####################################################################
|
||||
# Parametrize the schedule
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# In the previous schedule code, we use a constant "8" as tiling factor.
|
||||
# However, it might not be the best one because the best tiling factor depends
|
||||
# on real hardware environment and input shape.
|
||||
#
|
||||
# If you want the schedule code to be portable across a wider range of input shapes
|
||||
# and target hardware, it is better to define a set of candidate values and
|
||||
# pick the best one according to the measurement results on target hardware.
|
||||
#
|
||||
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
|
||||
|
||||
# Matmul V1: List candidate values
|
||||
@autotvm.template # 1. use a decorator
|
||||
def matmul_v1(N, L, M, dtype):
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
k = s[C].op.reduce_axis[0]
|
||||
|
||||
# 2. get the config object
|
||||
cfg = autotvm.get_config()
|
||||
|
||||
# 3. define search space
|
||||
cfg.define_knob("tile_y", [1, 2, 4, 8, 16])
|
||||
cfg.define_knob("tile_x", [1, 2, 4, 8, 16])
|
||||
|
||||
# 4. schedule according to config
|
||||
yo, yi = s[C].split(y, cfg['tile_y'].val)
|
||||
xo, xi = s[C].split(x, cfg['tile_x'].val)
|
||||
|
||||
s[C].reorder(yo, xo, k, yi, xi)
|
||||
|
||||
return s, [A, B, C]
|
||||
|
||||
###############################################################################
|
||||
# Here we make four modifications to the previous schedule code and get
|
||||
# a tunable "template". We can explain the modifications one by one.
|
||||
#
|
||||
# 1. Use a decorator to mark this function as a simple template
|
||||
# 2. Get a config object:
|
||||
# You can regard this :code:`cfg` as an argument of this function but
|
||||
# we obtain it in a different way. With this argument, this function is no longer
|
||||
# a deterministic schedule code. Instead, we can pass different configurations to
|
||||
# this function and get different schedules, so this function is a "template".
|
||||
#
|
||||
# To make the template function more compact, we do two things in a single function.
|
||||
# (1) define a search space and (2) schedule according to an entity in this space.
|
||||
# To achieve this, we make :code:`cfg` be either
|
||||
# a :any:`ConfigSpace` or a :any:`ConfigEntity` object.
|
||||
#
|
||||
# When it is a :any:`ConfigSpace`, it will collect all tunable knobs in this function and
|
||||
# build the search space.
|
||||
# When it is a :any:`ConfigEntity`, it will ignore all space definition API
|
||||
# (namely, :code:`cfg.define_XXXXX(...)`). Instead, it stores deterministic values for
|
||||
# all tunable knobs, and we schedule according to these values.
|
||||
#
|
||||
# During auto-tuning, we will first call this template with a :any:`ConfigSpace`
|
||||
# object to build the search space. Then we call this template with different :any:`ConfigEntity`
|
||||
# in the built space to get different schedules. Finally we will measure the code generated by
|
||||
# different schedules and pick the best one.
|
||||
#
|
||||
# 3. Define two tunable knobs. The first one is :code:`tile_y` with
|
||||
# 5 possible values. The second one is :code:`tile_x` with a same
|
||||
# list of possible values. These two knobs are independent, so they
|
||||
# span a search space with size = 5x5 = 25
|
||||
# 4. Schedule according to the deterministic values in :code:`cfg`
|
||||
#
|
||||
|
||||
#####################################################################
|
||||
# Use better space definition API
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# In the previous template, we manually list all possible values for a knob.
|
||||
# This is the lowest level API to define the space.
|
||||
# However, we also provide another set of API to make the space definition
|
||||
# easier and smarter. It is recommended to use this set of high level API.
|
||||
#
|
||||
# In the flowing example, we use :any:`ConfigSpace.define_split` to define a split
|
||||
# knob. It will enumerate all the possible ways to split an axis and construct
|
||||
# the space.
|
||||
#
|
||||
# We also have :any:`ConfigSpace.define_reorder` for reorder knob and
|
||||
# :any:`ConfigSpace.define_annotate` for annotation like unroll, vectorization,
|
||||
# thread binding.
|
||||
# When the high level API cannot meet your requirement, you can always fall
|
||||
# back to use low level API.
|
||||
|
||||
@autotvm.template
|
||||
def matmul(N, L, M, dtype):
|
||||
A = tvm.placeholder((N, L), name='A', dtype=dtype)
|
||||
B = tvm.placeholder((L, M), name='B', dtype=dtype)
|
||||
|
||||
k = tvm.reduce_axis((0, L), name='k')
|
||||
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
# schedule
|
||||
y, x = s[C].op.axis
|
||||
k = s[C].op.reduce_axis[0]
|
||||
|
||||
##### define space begin #####
|
||||
cfg = autotvm.get_config()
|
||||
cfg.define_split("tile_y", y, num_outputs=2)
|
||||
cfg.define_split("tile_x", x, num_outputs=2)
|
||||
##### define space end #####
|
||||
|
||||
# schedule according to config
|
||||
yo, yi = cfg["tile_y"].apply(s, C, y)
|
||||
xo, xi = cfg["tile_x"].apply(s, C, x)
|
||||
|
||||
s[C].reorder(yo, xo, k, yi, xi)
|
||||
|
||||
return s, [A, B, C]
|
||||
|
||||
######################################################################
|
||||
# .. note:: More Explanation on :code:`cfg.defile_split`
|
||||
#
|
||||
# In this template, :code:`cfg.define_split("tile_y", y, num_outputs=2)` will enumerate
|
||||
# all possible combinations that can split axis y into two axes with factors of the length of y.
|
||||
# For example, if the length of y is 32 and we want to split it into two axes
|
||||
# using factors of 32, then there are 6 possible values for
|
||||
# (length of outer axis, length of inner axis) pair, namely
|
||||
# (32, 1), (16, 2), (8, 4), (4, 8), (2, 16) or (1, 32).
|
||||
# They are just the 6 possible values of `tile_y`.
|
||||
#
|
||||
# During schedule, :code:`cfg["tile_y"]` is a :code:`SplitEntity` object.
|
||||
# We stores the lengths of outer axes and inner axes in :code:`cfg['tile_y'].size`
|
||||
# (a tuple with two elements).
|
||||
# In this template, we apply it by using :code:`yo, yi = cfg['tile_y'].apply(s, C, y)`.
|
||||
# Actually, this is equivalent to
|
||||
# :code:`yo, yi = s[C].split(y, cfg["tile_y"].size[1])`
|
||||
# or :code:`yo, yi = s[C].split(y, nparts=cfg['tile_y"].size[0])`
|
||||
#
|
||||
# The advantage of using cfg.apply API is that it makes multi-level split
|
||||
# (when num_outputs >= 3) easier.
|
||||
|
||||
######################################################################
|
||||
# Step 2: Search through the space
|
||||
# ---------------------------------
|
||||
# In step 1, we build the search space by extending our old schedule code
|
||||
# into a template. The next step is to pick a tuner and explore in this space.
|
||||
#
|
||||
# Auto-tuners in tvm
|
||||
# ^^^^^^^^^^^^^^^^^^
|
||||
# The job for a tuner can be described by following pseudo code
|
||||
#
|
||||
# .. code-block:: c
|
||||
#
|
||||
# ct = 0
|
||||
# while ct < max_number_of_trials:
|
||||
# propose a batch of configs
|
||||
# measure this batch of configs on real hardware and get results
|
||||
# ct += batch_size
|
||||
#
|
||||
# When proposing the next batch of configs, the tuner can take different strategies. We
|
||||
# provide four tuners with different strategies in autotvm.
|
||||
#
|
||||
# * :any:`RandomTuner`: Enumerate the space in a random order
|
||||
# * :any:`GridSearchTuner`: Enumerate the space in a grid search order
|
||||
# * :any:`GATuner`: Using genetic algorithm to search through the space
|
||||
# * :any:`XGBTuner`: Uses a model based method. Train a XGBoost model to predict the speed of lowered IR and pick the next batch according to the prediction.
|
||||
#
|
||||
# You can choose the tuner according to the size of your space, your time budget and other factors.
|
||||
# For example, if your space is very small (less than 1000), a gridsearch tuner or a
|
||||
# random tuner is good enough. If your space is at the level of 10^9 (this is the space
|
||||
# size of a conv2d operator on CUDA GPU), XGBoostTuner can explore more efficiently
|
||||
# and find better configs.
|
||||
|
||||
################################################################
|
||||
# Begin tuning
|
||||
# ^^^^^^^^^^^^
|
||||
# Here we continue our matrix multiplication example.
|
||||
# First we should create a tuning task.
|
||||
# We can also inspect the initialized search space.
|
||||
# In this case, for a 512x512 square matrix multiplication, the space size
|
||||
# is 10x10=100
|
||||
N, L, M = 512, 512, 512
|
||||
task = autotvm.task.create(matmul, args=(N, L, M, 'float32'), target='llvm')
|
||||
print(task.config_space)
|
||||
|
||||
################################################################
|
||||
# Then we need to define how to measure the generated code and pick a tuner.
|
||||
# Since our space is small, a random tuner is just okay.
|
||||
#
|
||||
# We only make 10 trials in this tutorial for demonstration. In practice,
|
||||
# you can do more trials according to your time budget.
|
||||
# We will log the tuning results into a cache file. This file can be
|
||||
# used to get the best config later.
|
||||
|
||||
# logging config (for printing tuning log to screen)
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
|
||||
# use local cpu, measure 5 times for every config to reduce variance
|
||||
measure_option = autotvm.measure_option(mode='local',
|
||||
number=5)
|
||||
|
||||
# begin tuning, log records to file `cache.tsv`
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
tuner.tune(n_trial=10,
|
||||
measure_option=measure_option,
|
||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
||||
|
||||
#########################################################################
|
||||
# Finally we apply history best from the cache file and check its correctness.
|
||||
# We can call the function :code:`matmul` directly under the
|
||||
# :any:`autotvm.apply_history_best` context. When we call this function,
|
||||
# it will query the dispatch context with its argument and get the best config
|
||||
# with the same argument.
|
||||
|
||||
# apply history best from log file
|
||||
with autotvm.apply_history_best('cache.tsv'):
|
||||
with tvm.target.create("llvm"):
|
||||
s, arg_bufs = matmul(N, L, M, 'float32')
|
||||
func = tvm.build(s, arg_bufs)
|
||||
|
||||
# check correctness
|
||||
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
|
||||
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
|
||||
c_np = a_np.dot(b_np)
|
||||
|
||||
c_tvm = tvm.nd.empty(c_np.shape)
|
||||
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
|
||||
|
||||
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
||||
|
Загрузка…
Ссылка в новой задаче