[AUTOTVM] Misc bug fix (#1467)
This commit is contained in:
Родитель
9026f3fcd4
Коммит
ad28f5ca3e
|
@ -6,7 +6,6 @@ from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ... import build, nd, target as _target
|
from ... import build, nd, target as _target
|
||||||
from ...contrib.util import tempdir
|
|
||||||
from ...rpc.tracker import Tracker
|
from ...rpc.tracker import Tracker
|
||||||
from ...rpc.server import Server
|
from ...rpc.server import Server
|
||||||
|
|
||||||
|
@ -209,14 +208,12 @@ def create_measure_batch(task, options):
|
||||||
kwargs['rpc_device_key'] = rpc_device_key
|
kwargs['rpc_device_key'] = rpc_device_key
|
||||||
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
|
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
|
||||||
kwargs['rpc_timeout'] = timeout
|
kwargs['rpc_timeout'] = timeout
|
||||||
kwargs['tmp_dir'] = tempdir()
|
|
||||||
elif mode == 'rpc':
|
elif mode == 'rpc':
|
||||||
fmeasure = measure_methods.measure_rpc
|
fmeasure = measure_methods.measure_rpc
|
||||||
kwargs['rpc_device_key'] = rpc_device_key
|
kwargs['rpc_device_key'] = rpc_device_key
|
||||||
kwargs['rpc_priority'] = rpc_priority
|
kwargs['rpc_priority'] = rpc_priority
|
||||||
kwargs['rpc_timeout'] = rpc_timeout
|
kwargs['rpc_timeout'] = rpc_timeout
|
||||||
kwargs['use_ndk'] = use_ndk
|
kwargs['use_ndk'] = use_ndk
|
||||||
kwargs['tmp_dir'] = tempdir()
|
|
||||||
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
|
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
|
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
|
||||||
|
@ -243,7 +240,7 @@ def create_measure_batch(task, options):
|
||||||
tvm_buf = [nd.array(x) for x in ref_input]
|
tvm_buf = [nd.array(x) for x in ref_input]
|
||||||
func(*tvm_buf)
|
func(*tvm_buf)
|
||||||
ref_output = [x.asnumpy() for x in tvm_buf]
|
ref_output = [x.asnumpy() for x in tvm_buf]
|
||||||
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
|
kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
|
||||||
|
|
||||||
def measure_batch(measure_inputs):
|
def measure_batch(measure_inputs):
|
||||||
"""measure the time cost for a batch of configs in real machines"""
|
"""measure the time cost for a batch of configs in real machines"""
|
||||||
|
|
|
@ -12,7 +12,7 @@ from random import getrandbits
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...contrib import ndk, nvcc
|
from ...contrib import ndk, nvcc, util
|
||||||
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
|
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
|
||||||
|
|
||||||
from ..util import get_const_tuple
|
from ..util import get_const_tuple
|
||||||
|
@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
|
||||||
if ref_input:
|
if ref_input:
|
||||||
args = [nd.array(x, ctx) for x in ref_input]
|
args = [nd.array(x, ctx) for x in ref_input]
|
||||||
else:
|
else:
|
||||||
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
|
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
|
||||||
ctx) for x in arg_bufs]
|
ctx=ctx) for x in arg_bufs]
|
||||||
costs = time_f(*args).results
|
costs = time_f(*args).results
|
||||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||||
costs = list(costs)
|
costs = list(costs)
|
||||||
|
@ -173,7 +173,6 @@ def measure_rpc(input_pack,
|
||||||
rpc_tracker_addr=None,
|
rpc_tracker_addr=None,
|
||||||
rpc_priority=1,
|
rpc_priority=1,
|
||||||
rpc_timeout=60,
|
rpc_timeout=60,
|
||||||
tmp_dir=None,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Measure the time cost on a device by rpc
|
"""Measure the time cost on a device by rpc
|
||||||
|
|
||||||
|
@ -198,9 +197,6 @@ def measure_rpc(input_pack,
|
||||||
rpc_timeout: int, optional
|
rpc_timeout: int, optional
|
||||||
timeout of the rpc session
|
timeout of the rpc session
|
||||||
|
|
||||||
tmp_dir: tvm.contrib.util.TempDirectory, optional
|
|
||||||
directory to store temp file
|
|
||||||
|
|
||||||
kwargs: dict, optional
|
kwargs: dict, optional
|
||||||
Additional key word arguments
|
Additional key word arguments
|
||||||
|
|
||||||
|
@ -213,6 +209,7 @@ def measure_rpc(input_pack,
|
||||||
""" Local build function."""
|
""" Local build function."""
|
||||||
func, args = _build_func(inp, build_option, kwargs)
|
func, args = _build_func(inp, build_option, kwargs)
|
||||||
|
|
||||||
|
tmp_dir = util.tempdir()
|
||||||
if not kwargs.get('use_ndk', False):
|
if not kwargs.get('use_ndk', False):
|
||||||
file_name = "tmp_func_%0x.tar" % getrandbits(64)
|
file_name = "tmp_func_%0x.tar" % getrandbits(64)
|
||||||
path = tmp_dir.relpath(file_name)
|
path = tmp_dir.relpath(file_name)
|
||||||
|
|
|
@ -9,11 +9,12 @@ import multiprocessing
|
||||||
import pickle
|
import pickle
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import target, build, lower
|
from .. import build, lower, target as _target
|
||||||
|
|
||||||
from . import task
|
from . import task
|
||||||
from .task import DispatchContext, ConfigEntity
|
from .task import DispatchContext, ConfigEntity
|
||||||
|
@ -26,6 +27,11 @@ try: # convert unicode to str for python2
|
||||||
except NameError:
|
except NameError:
|
||||||
_unicode = ()
|
_unicode = ()
|
||||||
|
|
||||||
|
try:
|
||||||
|
_long = long
|
||||||
|
except NameError:
|
||||||
|
_long = int
|
||||||
|
|
||||||
|
|
||||||
def measure_str_key(inp, include_config=True):
|
def measure_str_key(inp, include_config=True):
|
||||||
""" get unique str key for MeasureInput
|
""" get unique str key for MeasureInput
|
||||||
|
@ -111,7 +117,7 @@ def decode(row, protocol='json'):
|
||||||
if protocol == 'json':
|
if protocol == 'json':
|
||||||
row = json.loads(row)
|
row = json.loads(row)
|
||||||
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
|
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
|
||||||
tgt = target.create(str(tgt))
|
tgt = _target.create(str(tgt))
|
||||||
|
|
||||||
def clean_json_to_python(x):
|
def clean_json_to_python(x):
|
||||||
"""1. convert all list in x to tuple (hashable)
|
"""1. convert all list in x to tuple (hashable)
|
||||||
|
@ -121,6 +127,8 @@ def decode(row, protocol='json'):
|
||||||
return tuple([clean_json_to_python(a) for a in x])
|
return tuple([clean_json_to_python(a) for a in x])
|
||||||
if isinstance(x, _unicode):
|
if isinstance(x, _unicode):
|
||||||
return str(x)
|
return str(x)
|
||||||
|
if isinstance(x, (_long, int)):
|
||||||
|
return int(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
|
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
|
||||||
|
@ -132,7 +140,7 @@ def decode(row, protocol='json'):
|
||||||
return inp, result
|
return inp, result
|
||||||
elif protocol == 'pickle':
|
elif protocol == 'pickle':
|
||||||
items = row.split("\t")
|
items = row.split("\t")
|
||||||
tgt = target.create(items[0])
|
tgt = _target.create(items[0])
|
||||||
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
|
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
|
||||||
config = pickle.loads(base64.b64decode(items[2].encode()))
|
config = pickle.loads(base64.b64decode(items[2].encode()))
|
||||||
result = pickle.loads(base64.b64decode(items[3].encode()))
|
result = pickle.loads(base64.b64decode(items[3].encode()))
|
||||||
|
@ -168,36 +176,70 @@ class ApplyHistoryBest(DispatchContext):
|
||||||
----------
|
----------
|
||||||
records : str or iterator of (MeasureInput, MeasureResult)
|
records : str or iterator of (MeasureInput, MeasureResult)
|
||||||
Collection of tuning records.
|
Collection of tuning records.
|
||||||
if is str, then it should be the filename of a records log file.
|
If is str, then it should be the filename of a records log file.
|
||||||
Each row of this file is an encoded record pair.
|
Each row of this file is an encoded record pair.
|
||||||
otherwise, it is an iterator
|
Otherwise, it is an iterator.
|
||||||
default: ConfigEntity, optional
|
default: ConfigEntity, optional
|
||||||
default config to return when no history records
|
The default config to return when no history records
|
||||||
"""
|
"""
|
||||||
def __init__(self, records, default=None):
|
def __init__(self, records, default=None):
|
||||||
super(ApplyHistoryBest, self).__init__()
|
super(ApplyHistoryBest, self).__init__()
|
||||||
|
|
||||||
|
self.best_by_targetkey = {}
|
||||||
|
self.best_by_model = {}
|
||||||
|
self._default = default
|
||||||
|
|
||||||
|
self.load(records)
|
||||||
|
|
||||||
|
def load(self, records):
|
||||||
|
"""Load records to this dispatch context
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
records : str or iterator of (MeasureInput, MeasureResult)
|
||||||
|
Collection of tuning records.
|
||||||
|
If is str, then it should be the filename of a records log file.
|
||||||
|
Each row of this file is an encoded record pair.
|
||||||
|
Otherwise, it is an iterator.
|
||||||
|
"""
|
||||||
if isinstance(records, str):
|
if isinstance(records, str):
|
||||||
records = load_from_file(records)
|
records = load_from_file(records)
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
best_by_targetkey = self.best_by_targetkey
|
||||||
|
best_by_model = self.best_by_model
|
||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
best_map = {}
|
|
||||||
for inp, res in records:
|
for inp, res in records:
|
||||||
counter += 1
|
counter += 1
|
||||||
if res.error_no != 0:
|
if res.error_no != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# use target keys in tvm target system as key to build best map
|
||||||
for k in inp.target.keys:
|
for k in inp.target.keys:
|
||||||
key = (k, inp.task.workload)
|
key = (k, inp.task.workload)
|
||||||
if key not in best_map:
|
if key not in best_by_targetkey:
|
||||||
best_map[key] = (inp, res)
|
best_by_targetkey[key] = (inp, res)
|
||||||
else:
|
else:
|
||||||
_, other_res = best_map[key]
|
_, other_res = best_by_targetkey[key]
|
||||||
if np.mean(other_res.costs) > np.mean(res.costs):
|
if np.mean(other_res.costs) > np.mean(res.costs):
|
||||||
best_map[key] = (inp, res)
|
best_by_targetkey[key] = (inp, res)
|
||||||
logging.info(
|
|
||||||
"Finish load %d records, %d entries selected", counter, len(best_map))
|
# use model as key to build best map
|
||||||
self._best_map = best_map
|
for opt in inp.target.options:
|
||||||
self._default = default
|
if opt.startswith("-model"):
|
||||||
|
model = opt[7:]
|
||||||
|
key = (model, inp.task.workload)
|
||||||
|
if key not in best_by_model:
|
||||||
|
best_by_model[key] = (inp, res)
|
||||||
|
else:
|
||||||
|
_, other_res = best_by_model[key]
|
||||||
|
if np.mean(other_res.costs) > np.mean(res.costs):
|
||||||
|
best_by_model[key] = (inp, res)
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.info("Finish loading %d records", counter)
|
||||||
|
|
||||||
def query(self, target, workload):
|
def query(self, target, workload):
|
||||||
if target is None:
|
if target is None:
|
||||||
|
@ -205,29 +247,25 @@ class ApplyHistoryBest(DispatchContext):
|
||||||
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
|
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
|
||||||
" above the dispatcher call. So does other target. ")
|
" above the dispatcher call. So does other target. ")
|
||||||
|
|
||||||
|
# first try matching by model
|
||||||
|
for opt in target.options:
|
||||||
|
if opt.startswith("-model"):
|
||||||
|
model = opt[7:]
|
||||||
|
key = (model, workload)
|
||||||
|
if key in self.best_by_model:
|
||||||
|
return self.best_by_model[key][0].config
|
||||||
|
|
||||||
|
# then try matching by target key
|
||||||
for k in target.keys:
|
for k in target.keys:
|
||||||
key = (k, workload)
|
key = (k, workload)
|
||||||
if key in self._best_map:
|
if key in self.best_by_targetkey:
|
||||||
return self._best_map[key][0].config
|
return self.best_by_targetkey[key][0].config
|
||||||
|
|
||||||
if self._default:
|
if self._default:
|
||||||
return self._default
|
return self._default
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Cannot find config for target=%s, workload=%s" % (target, workload))
|
"Cannot find config for target=%s, workload=%s" % (target, workload))
|
||||||
|
|
||||||
def dump_best(self, out_file):
|
|
||||||
"""Dump the best records for each workload to a file
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
out_file: str
|
|
||||||
filename
|
|
||||||
"""
|
|
||||||
fout = open(out_file, 'a')
|
|
||||||
for val in self._best_map.values():
|
|
||||||
inp, res = val
|
|
||||||
fout.write(encode(inp, res) + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
def split_workload(in_file, clean=True):
|
def split_workload(in_file, clean=True):
|
||||||
"""Split a log file into separate files, each of which contains only a single workload
|
"""Split a log file into separate files, each of which contains only a single workload
|
||||||
|
@ -243,7 +281,7 @@ def split_workload(in_file, clean=True):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
lines = list(open(in_file).readlines())
|
lines = list(open(in_file).readlines())
|
||||||
|
|
||||||
logging.info("start convert...")
|
logging.info("start converting...")
|
||||||
pool = multiprocessing.Pool()
|
pool = multiprocessing.Pool()
|
||||||
lines = pool.map(decode, lines)
|
lines = pool.map(decode, lines)
|
||||||
logging.info("map done %.2f", time.time() - tic)
|
logging.info("map done %.2f", time.time() - tic)
|
||||||
|
@ -279,23 +317,69 @@ def split_workload(in_file, clean=True):
|
||||||
for inp, res in v:
|
for inp, res in v:
|
||||||
fout.write(encode(inp, res) + '\n')
|
fout.write(encode(inp, res) + '\n')
|
||||||
|
|
||||||
|
def pick_best(in_file, out_file):
|
||||||
|
"""
|
||||||
|
Pick best entries from a file and store it to another file.
|
||||||
|
This distill the useful log entries from a large log file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_file: str
|
||||||
|
The filename of input
|
||||||
|
out_file:
|
||||||
|
The filename of output
|
||||||
|
"""
|
||||||
|
best_context = ApplyHistoryBest(load_from_file(in_file))
|
||||||
|
best_set = set()
|
||||||
|
|
||||||
|
for v in best_context.best_by_model.values():
|
||||||
|
best_set.add(measure_str_key(v[0]))
|
||||||
|
|
||||||
|
for v in best_context.best_by_targetkey.values():
|
||||||
|
best_set.add(measure_str_key(v[0]))
|
||||||
|
|
||||||
|
logging.info("Extract %d best records from the log file", len(best_set))
|
||||||
|
|
||||||
|
fout = open(out_file, 'w')
|
||||||
|
for inp, res in load_from_file(in_file):
|
||||||
|
if measure_str_key(inp) in best_set:
|
||||||
|
fout.write(encode(inp, res) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def load_op_param(rootpath=os.path.join(os.path.expanduser('~'), ".tvm", "op_params")):
|
||||||
|
"""Load pre-tuned parameters of operators.
|
||||||
|
This function will load all "*.log" file under root path and select best configs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rootpath: str
|
||||||
|
The root path of stored parameters
|
||||||
|
"""
|
||||||
|
best_context = ApplyHistoryBest([])
|
||||||
|
for dirpath, _, filenames in os.walk(rootpath):
|
||||||
|
for filename in filenames:
|
||||||
|
if os.path.splitext(filename)[1] == '.log':
|
||||||
|
best_context.load(os.path.join(dirpath, filename))
|
||||||
|
|
||||||
|
assert not DispatchContext.current, "Cannot load pre-tuned parameters inside a dispatch context"
|
||||||
|
DispatchContext.current = best_context
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
This record executable module has three modes.
|
This record executable module has three modes.
|
||||||
|
|
||||||
* Print log file in readable format
|
* Print log file in readable format
|
||||||
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
|
e.g. python -m autotvm.record --mode read --i collect_conv.log --begin 0 --end 5 --ir --code
|
||||||
|
|
||||||
* Extract history best from a large log file
|
* Extract history best from a large log file
|
||||||
e.g. python -m autotvm.record --mode best --i collect.tsv
|
e.g. python -m autotvm.record --mode pick --i collect.log
|
||||||
|
|
||||||
* Split a log file into separate files, each of which contains only a single wkl
|
* Split a log file into separate files, each of which contains only a single wkl
|
||||||
e.g. python -m autotvm.record --mode split --i collect.tsv
|
e.g. python -m autotvm.record --mode split --i collect.log
|
||||||
"""
|
"""
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
|
parser.add_argument("--mode", choices=['read', 'pick', 'split'], default='read')
|
||||||
parser.add_argument("--i", type=str, help="input file")
|
parser.add_argument("--i", type=str, help="input file")
|
||||||
parser.add_argument("--o", type=str, default=None, help='output file')
|
parser.add_argument("--o", type=str, default=None, help='output file')
|
||||||
parser.add_argument("--begin", type=int, default=0)
|
parser.add_argument("--begin", type=int, default=0)
|
||||||
|
@ -306,10 +390,9 @@ if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
if args.mode == 'best':
|
if args.mode == 'pick':
|
||||||
args.o = args.o or args.i + ".best"
|
args.o = args.o or args.i + ".best.log"
|
||||||
hist_best = ApplyHistoryBest(load_from_file(args.i))
|
pick_best(args.i, args.o)
|
||||||
hist_best.dump_best(args.o)
|
|
||||||
elif args.mode == 'read':
|
elif args.mode == 'read':
|
||||||
for i, (inp, result) in enumerate(load_from_file(args.i)):
|
for i, (inp, result) in enumerate(load_from_file(args.i)):
|
||||||
if args.begin <= i < args.end:
|
if args.begin <= i < args.end:
|
||||||
|
|
|
@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
|
||||||
of typical tasks of interest.
|
of typical tasks of interest.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .task import Task, create, register, template, get_config
|
from .task import Task, create, register, template, get_config, args_to_workload
|
||||||
from .space import ConfigSpace, ConfigEntity
|
from .space import ConfigSpace, ConfigEntity
|
||||||
from .code_hash import attach_code_hash, attach_code_hash_to_arg
|
from .code_hash import attach_code_hash, attach_code_hash_to_arg
|
||||||
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
|
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
|
||||||
|
|
|
@ -68,6 +68,33 @@ class Task(object):
|
||||||
self.flop = config.flop
|
self.flop = config.flop
|
||||||
return sch, arg_bufs
|
return sch, arg_bufs
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# custom pickle implementation is required for
|
||||||
|
# some unpickable local task functions.
|
||||||
|
# So we only pickle the name of the function
|
||||||
|
# and restore the function by name when unpickling it.
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"args": self.args,
|
||||||
|
"kwargs": self.kwargs,
|
||||||
|
"config_space": self.config_space,
|
||||||
|
"workload": self.workload,
|
||||||
|
"flop": self.flop,
|
||||||
|
"target": self.target,
|
||||||
|
"target_host": self.target_host
|
||||||
|
}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.name = state["name"]
|
||||||
|
self.args = state["args"]
|
||||||
|
self.kwargs = state["kwargs"]
|
||||||
|
self.config_space = state["config_space"]
|
||||||
|
self.func = TASK_TABLE.get(state["name"], _raise_error)
|
||||||
|
self.workload = state["workload"]
|
||||||
|
self.flop = state["flop"]
|
||||||
|
self.target = state["target"]
|
||||||
|
self.target_host = state["target_host"]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
|
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
|
||||||
self.name, self.args, self.kwargs, self.workload
|
self.name, self.args, self.kwargs, self.workload
|
||||||
|
|
|
@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
|
||||||
self.train_ct += 1
|
self.train_ct += 1
|
||||||
|
|
||||||
def load_history(self, data_set):
|
def load_history(self, data_set):
|
||||||
base_model = self.cost_model.clone_new()
|
# filter data, only pick the data with a same task
|
||||||
base_model.fit_log(data_set, self.plan_size)
|
data = []
|
||||||
|
for inp, res in data_set:
|
||||||
|
if inp.task.name == self.task.name and \
|
||||||
|
inp.config.template_key == self.task.config_space.template_key:
|
||||||
|
data.append((inp, res))
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# fit base model
|
||||||
|
base_model = self.cost_model.clone_new()
|
||||||
|
base_model.fit_log(data, self.plan_size)
|
||||||
|
|
||||||
|
# use base model to select initial points
|
||||||
if not self.trials:
|
if not self.trials:
|
||||||
# no plan yet, use base model to select initial trials
|
# no plan yet, use base model to select initial trials
|
||||||
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
|
maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
|
||||||
self.trials = maximums
|
self.trials = maximums
|
||||||
self.trial_pt = 0
|
self.trial_pt = 0
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
||||||
Print log every `verbose` iterations
|
Print log every `verbose` iterations
|
||||||
"""
|
"""
|
||||||
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
|
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
|
||||||
early_stop=30, verbose=50):
|
early_stop=50, verbose=50):
|
||||||
super(SimulatedAnnealingOptimizer, self).__init__()
|
super(SimulatedAnnealingOptimizer, self).__init__()
|
||||||
|
|
||||||
self.task = task
|
self.task = task
|
||||||
|
@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
||||||
self.n_iter = n_iter
|
self.n_iter = n_iter
|
||||||
self.temp = temp
|
self.temp = temp
|
||||||
self.persistent = persistent
|
self.persistent = persistent
|
||||||
self.parallel_size = parallel_size
|
self.parallel_size = min(parallel_size, len(self.task.config_space))
|
||||||
self.early_stop = early_stop
|
self.early_stop = early_stop or 1e9
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.points = None
|
self.points = None
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class Tuner(object):
|
||||||
self.best_config = None
|
self.best_config = None
|
||||||
self.best_flops = 0
|
self.best_flops = 0
|
||||||
self.best_measure_pair = None
|
self.best_measure_pair = None
|
||||||
|
self.best_iter = 0
|
||||||
|
|
||||||
def has_next(self):
|
def has_next(self):
|
||||||
"""Whether has next untried config in the space
|
"""Whether has next untried config in the space
|
||||||
|
@ -63,7 +64,7 @@ class Tuner(object):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
|
def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
|
||||||
"""Begin tuning
|
"""Begin tuning
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -73,6 +74,8 @@ class Tuner(object):
|
||||||
measure_option: dict
|
measure_option: dict
|
||||||
The options for how to measure generated code.
|
The options for how to measure generated code.
|
||||||
You should use the return value ot autotvm.measure_option for this argument.
|
You should use the return value ot autotvm.measure_option for this argument.
|
||||||
|
early_stop: int
|
||||||
|
Early stop the tuning when not finding better configs in this number of trials
|
||||||
verbose: int
|
verbose: int
|
||||||
0: silent mode, no output
|
0: silent mode, no output
|
||||||
1: print every measurement result
|
1: print every measurement result
|
||||||
|
@ -84,6 +87,7 @@ class Tuner(object):
|
||||||
"""
|
"""
|
||||||
measure_batch = create_measure_batch(self.task, measure_option)
|
measure_batch = create_measure_batch(self.task, measure_option)
|
||||||
parallel_num = getattr(measure_batch, 'parallel_num', 1)
|
parallel_num = getattr(measure_batch, 'parallel_num', 1)
|
||||||
|
early_stop = early_stop or 1e9
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while i < n_trial:
|
while i < n_trial:
|
||||||
|
@ -107,6 +111,7 @@ class Tuner(object):
|
||||||
self.best_flops = flops
|
self.best_flops = flops
|
||||||
self.best_config = config
|
self.best_config = config
|
||||||
self.best_measure_pair = (inp, res)
|
self.best_measure_pair = (inp, res)
|
||||||
|
self.best_iter = i + k
|
||||||
|
|
||||||
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
||||||
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
||||||
|
@ -119,6 +124,10 @@ class Tuner(object):
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback(self, inputs, results)
|
callback(self, inputs, results)
|
||||||
|
|
||||||
|
if i > self.best_iter + early_stop:
|
||||||
|
logging.info("Early stopped. Best iter: %d.", self.best_iter)
|
||||||
|
break
|
||||||
|
|
||||||
del measure_batch
|
del measure_batch
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|
|
@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
|
||||||
self.feature_extra_ct = 0
|
self.feature_extra_ct = 0
|
||||||
self.pool = None
|
self.pool = None
|
||||||
self.base_model = None
|
self.base_model = None
|
||||||
|
self.upper_model = None
|
||||||
|
|
||||||
|
self._sample_size = 0
|
||||||
|
|
||||||
self._reset_pool()
|
self._reset_pool()
|
||||||
|
|
||||||
|
@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
|
||||||
_extract_task = self.task
|
_extract_task = self.task
|
||||||
self.pool = multiprocessing.Pool(self.num_threads)
|
self.pool = multiprocessing.Pool(self.num_threads)
|
||||||
|
|
||||||
|
def _base_model_discount(self):
|
||||||
|
return 1.0 / (2 ** (self._sample_size / 50.0))
|
||||||
|
|
||||||
def fit(self, xs, ys, plan_size):
|
def fit(self, xs, ys, plan_size):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
self._reset_pool()
|
self._reset_pool()
|
||||||
|
|
||||||
x_train = self._get_feature(xs)
|
x_train = self._get_feature(xs)
|
||||||
y_train = np.array(ys)
|
y_train = np.array(ys)
|
||||||
y_train /= np.max(y_train)
|
y_train = y_train / np.max(y_train)
|
||||||
|
|
||||||
valid_index = y_train > 1e-6
|
valid_index = y_train > 1e-6
|
||||||
index = np.random.permutation(len(x_train))
|
index = np.random.permutation(len(x_train))
|
||||||
dtrain = xgb.DMatrix(x_train[index], y_train[index])
|
dtrain = xgb.DMatrix(x_train[index], y_train[index])
|
||||||
|
self._sample_size = len(x_train)
|
||||||
|
|
||||||
if self.base_model:
|
if self.base_model:
|
||||||
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
dtrain.set_base_margin(self._base_model_discount() *
|
||||||
|
self.base_model.predict(xs, output_margin=True))
|
||||||
|
|
||||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||||
num_boost_round=8000,
|
num_boost_round=8000,
|
||||||
|
@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
|
||||||
self._reset_pool()
|
self._reset_pool()
|
||||||
|
|
||||||
args = list(records)
|
args = list(records)
|
||||||
|
logging.info("Load %d entries from history log file", len(args))
|
||||||
if self.fea_type == 'itervar':
|
if self.fea_type == 'itervar':
|
||||||
feature_extract_func = _extract_itervar_feature_log
|
feature_extract_func = _extract_itervar_feature_log
|
||||||
elif self.fea_type == 'knob':
|
elif self.fea_type == 'knob':
|
||||||
|
@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
|
||||||
|
|
||||||
plan_size *= 2
|
plan_size *= 2
|
||||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||||
num_boost_round=200,
|
num_boost_round=400,
|
||||||
callbacks=[custom_callback(
|
callbacks=[custom_callback(
|
||||||
stopping_rounds=100,
|
stopping_rounds=100,
|
||||||
metric='tr-a-recall@%d' % plan_size,
|
metric='tr-a-recall@%d' % plan_size,
|
||||||
|
@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
|
||||||
dtest = xgb.DMatrix(feas)
|
dtest = xgb.DMatrix(feas)
|
||||||
|
|
||||||
if self.base_model:
|
if self.base_model:
|
||||||
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
dtest.set_base_margin(self._base_model_discount() *
|
||||||
|
self.base_model.predict(xs, output_margin=True))
|
||||||
|
|
||||||
return self.bst.predict(dtest, output_margin=output_margin)
|
return self.bst.predict(dtest, output_margin=output_margin)
|
||||||
|
|
||||||
def load_basemodel(self, base_model):
|
def load_basemodel(self, base_model):
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
|
if isinstance(base_model, XGBoostCostModel):
|
||||||
|
# share feature cache
|
||||||
|
base_model.feature_cache = self.feature_cache
|
||||||
|
|
||||||
|
# close thread pool
|
||||||
|
if base_model.pool:
|
||||||
|
base_model.pool.terminate()
|
||||||
|
base_model.pool.join()
|
||||||
|
del base_model.pool
|
||||||
|
self.base_model.upper_model = self
|
||||||
|
|
||||||
def clone_new(self):
|
def clone_new(self):
|
||||||
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
|
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
|
||||||
|
@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
|
||||||
need_extract = [x for x in indexes if x not in fea_cache]
|
need_extract = [x for x in indexes if x not in fea_cache]
|
||||||
|
|
||||||
if need_extract:
|
if need_extract:
|
||||||
feas = self.pool.map(self.feature_extract_func, need_extract)
|
pool = self.pool if self.upper_model is None else self.upper_model.pool
|
||||||
|
feas = pool.map(self.feature_extract_func, need_extract)
|
||||||
for i, fea in zip(need_extract, feas):
|
for i, fea in zip(need_extract, feas):
|
||||||
fea_cache[i] = fea
|
fea_cache[i] = fea
|
||||||
|
|
||||||
|
|
|
@ -346,6 +346,7 @@ def generic_func(fdefault):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
fdecorate = decorate(fdefault, dispatch_func)
|
fdecorate = decorate(fdefault, dispatch_func)
|
||||||
fdecorate.register = register
|
fdecorate.register = register
|
||||||
|
fdecorate.fdefault = fdefault
|
||||||
return fdecorate
|
return fdecorate
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tvm
|
||||||
|
from tvm import autotvm
|
||||||
|
from tvm.autotvm import MeasureInput, MeasureResult
|
||||||
|
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
|
||||||
|
|
||||||
|
from test_autotvm_common import get_sample_task, get_sample_records
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit():
|
||||||
|
task, target = get_sample_task()
|
||||||
|
records = get_sample_records(n=100)
|
||||||
|
|
||||||
|
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||||
|
base_model.fit_log(records, plan_size=32)
|
||||||
|
|
||||||
|
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||||
|
upper_model.load_basemodel(base_model)
|
||||||
|
|
||||||
|
xs = np.arange(100)
|
||||||
|
ys = np.arange(100)
|
||||||
|
|
||||||
|
upper_model.fit(xs, ys, plan_size=32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuner():
|
||||||
|
task, target = get_sample_task()
|
||||||
|
records = get_sample_records(n=100)
|
||||||
|
|
||||||
|
tuner = autotvm.tuner.XGBTuner(task)
|
||||||
|
tuner.load_history(records)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fit()
|
||||||
|
test_tuner()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
|
Tuning High Performance Convolution on NVIDIA GPUs
|
||||||
=========================================================================
|
=========================================================================
|
||||||
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
|
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
|
||||||
|
|
||||||
|
@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import tvm
|
import tvm
|
||||||
import topi
|
import topi
|
||||||
|
from topi.testing import conv2d_nchw_python
|
||||||
|
|
||||||
from tvm import autotvm
|
from tvm import autotvm
|
||||||
|
|
||||||
|
@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
|
||||||
# logging config (for printing tuning log to screen)
|
# logging config (for printing tuning log to screen)
|
||||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||||
|
|
||||||
# the last layer in resnet
|
# the last layer in resnet
|
||||||
|
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
|
||||||
task = autotvm.task.create(conv2d_no_batching,
|
task = autotvm.task.create(conv2d_no_batching,
|
||||||
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
|
args=(N, H, W, CO, CI, KH, KW, strides, padding),
|
||||||
target='cuda')
|
target='cuda')
|
||||||
print(task.config_space)
|
print(task.config_space)
|
||||||
|
|
||||||
|
@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
|
||||||
parallel_num=8,
|
parallel_num=8,
|
||||||
timeout=20)
|
timeout=20)
|
||||||
|
|
||||||
# begin tuning, log records to file `cache.tsv`
|
# begin tuning, log records to file `conv2d.tsv`
|
||||||
tuner = autotvm.tuner.XGBTuner(task)
|
tuner = autotvm.tuner.XGBTuner(task)
|
||||||
tuner.tune(n_trial=20,
|
tuner.tune(n_trial=20,
|
||||||
measure_option=measure_option,
|
measure_option=measure_option,
|
||||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
|
||||||
|
|
||||||
# get best config from cache file
|
#########################################################################
|
||||||
dispatch_context = autotvm.apply_history_best("cache.tsv")
|
# Finally we can inspect the best config from log file, check correctness,
|
||||||
|
# and measure running time.
|
||||||
|
|
||||||
|
# inspect the best config
|
||||||
|
dispatch_context = autotvm.apply_history_best("conv2d.log")
|
||||||
best_config = dispatch_context.query(task.target, task.workload)
|
best_config = dispatch_context.query(task.target, task.workload)
|
||||||
print("\nBest config:")
|
print("\nBest config:")
|
||||||
print(best_config)
|
print(best_config)
|
||||||
|
|
||||||
|
# apply history best from log file
|
||||||
|
with autotvm.apply_history_best('conv2d.log'):
|
||||||
|
with tvm.target.create("cuda"):
|
||||||
|
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
|
||||||
|
func = tvm.build(s, arg_bufs)
|
||||||
|
|
||||||
|
# check correctness
|
||||||
|
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
|
||||||
|
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
|
||||||
|
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
|
||||||
|
|
||||||
|
ctx = tvm.gpu()
|
||||||
|
a_tvm = tvm.nd.array(a_np, ctx=ctx)
|
||||||
|
w_tvm = tvm.nd.array(w_np, ctx=ctx)
|
||||||
|
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
|
||||||
|
func(a_tvm, w_tvm, c_tvm)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
||||||
|
|
||||||
|
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
|
||||||
|
# and the overhead of kernel launch. You can also use nvprof to validate the result.
|
||||||
|
|
||||||
|
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
|
||||||
|
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
|
||||||
|
|
|
@ -243,7 +243,7 @@ print(task.config_space)
|
||||||
#
|
#
|
||||||
# We only make 10 trials in this tutorial for demonstration. In practice,
|
# We only make 10 trials in this tutorial for demonstration. In practice,
|
||||||
# you can do more trials according to your time budget.
|
# you can do more trials according to your time budget.
|
||||||
# We will log the tuning results into a cache file. This file can be
|
# We will log the tuning results into a log file. This file can be
|
||||||
# used to get the best config later.
|
# used to get the best config later.
|
||||||
|
|
||||||
# logging config (for printing tuning log to screen)
|
# logging config (for printing tuning log to screen)
|
||||||
|
@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||||
measure_option = autotvm.measure_option(mode='local',
|
measure_option = autotvm.measure_option(mode='local',
|
||||||
number=5)
|
number=5)
|
||||||
|
|
||||||
# begin tuning, log records to file `cache.tsv`
|
# begin tuning, log records to file `matmul.log`
|
||||||
tuner = autotvm.tuner.RandomTuner(task)
|
tuner = autotvm.tuner.RandomTuner(task)
|
||||||
tuner.tune(n_trial=10,
|
tuner.tune(n_trial=10,
|
||||||
measure_option=measure_option,
|
measure_option=measure_option,
|
||||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
callbacks=[autotvm.callback.log_to_file('matmul.log')])
|
||||||
|
|
||||||
#########################################################################
|
#########################################################################
|
||||||
# Finally we apply history best from the cache file and check its correctness.
|
# Finally we apply history best from the cache file and check its correctness.
|
||||||
|
@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
|
||||||
# with the same argument.
|
# with the same argument.
|
||||||
|
|
||||||
# apply history best from log file
|
# apply history best from log file
|
||||||
with autotvm.apply_history_best('cache.tsv'):
|
with autotvm.apply_history_best('matmul.log'):
|
||||||
with tvm.target.create("llvm"):
|
with tvm.target.create("llvm"):
|
||||||
s, arg_bufs = matmul(N, L, M, 'float32')
|
s, arg_bufs = matmul(N, L, M, 'float32')
|
||||||
func = tvm.build(s, arg_bufs)
|
func = tvm.build(s, arg_bufs)
|
||||||
|
@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
|
||||||
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
|
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
|
||||||
|
|
||||||
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче