[AUTOTVM] Misc bug fix (#1467)
This commit is contained in:
Родитель
9026f3fcd4
Коммит
ad28f5ca3e
|
@ -6,7 +6,6 @@ from collections import namedtuple
|
|||
import numpy as np
|
||||
|
||||
from ... import build, nd, target as _target
|
||||
from ...contrib.util import tempdir
|
||||
from ...rpc.tracker import Tracker
|
||||
from ...rpc.server import Server
|
||||
|
||||
|
@ -209,14 +208,12 @@ def create_measure_batch(task, options):
|
|||
kwargs['rpc_device_key'] = rpc_device_key
|
||||
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
|
||||
kwargs['rpc_timeout'] = timeout
|
||||
kwargs['tmp_dir'] = tempdir()
|
||||
elif mode == 'rpc':
|
||||
fmeasure = measure_methods.measure_rpc
|
||||
kwargs['rpc_device_key'] = rpc_device_key
|
||||
kwargs['rpc_priority'] = rpc_priority
|
||||
kwargs['rpc_timeout'] = rpc_timeout
|
||||
kwargs['use_ndk'] = use_ndk
|
||||
kwargs['tmp_dir'] = tempdir()
|
||||
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
|
||||
elif mode == "custom":
|
||||
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
|
||||
|
@ -243,7 +240,7 @@ def create_measure_batch(task, options):
|
|||
tvm_buf = [nd.array(x) for x in ref_input]
|
||||
func(*tvm_buf)
|
||||
ref_output = [x.asnumpy() for x in tvm_buf]
|
||||
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
|
||||
kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
|
||||
|
||||
def measure_batch(measure_inputs):
|
||||
"""measure the time cost for a batch of configs in real machines"""
|
||||
|
|
|
@ -12,7 +12,7 @@ from random import getrandbits
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ...contrib import ndk, nvcc
|
||||
from ...contrib import ndk, nvcc, util
|
||||
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
|
||||
|
||||
from ..util import get_const_tuple
|
||||
|
@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
|
|||
if ref_input:
|
||||
args = [nd.array(x, ctx) for x in ref_input]
|
||||
else:
|
||||
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
|
||||
ctx) for x in arg_bufs]
|
||||
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
|
||||
ctx=ctx) for x in arg_bufs]
|
||||
costs = time_f(*args).results
|
||||
if len(costs) > 2: # remove largest and smallest value to reduce variance
|
||||
costs = list(costs)
|
||||
|
@ -173,7 +173,6 @@ def measure_rpc(input_pack,
|
|||
rpc_tracker_addr=None,
|
||||
rpc_priority=1,
|
||||
rpc_timeout=60,
|
||||
tmp_dir=None,
|
||||
**kwargs):
|
||||
"""Measure the time cost on a device by rpc
|
||||
|
||||
|
@ -198,9 +197,6 @@ def measure_rpc(input_pack,
|
|||
rpc_timeout: int, optional
|
||||
timeout of the rpc session
|
||||
|
||||
tmp_dir: tvm.contrib.util.TempDirectory, optional
|
||||
directory to store temp file
|
||||
|
||||
kwargs: dict, optional
|
||||
Additional key word arguments
|
||||
|
||||
|
@ -213,6 +209,7 @@ def measure_rpc(input_pack,
|
|||
""" Local build function."""
|
||||
func, args = _build_func(inp, build_option, kwargs)
|
||||
|
||||
tmp_dir = util.tempdir()
|
||||
if not kwargs.get('use_ndk', False):
|
||||
file_name = "tmp_func_%0x.tar" % getrandbits(64)
|
||||
path = tmp_dir.relpath(file_name)
|
||||
|
|
|
@ -9,11 +9,12 @@ import multiprocessing
|
|||
import pickle
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import target, build, lower
|
||||
from .. import build, lower, target as _target
|
||||
|
||||
from . import task
|
||||
from .task import DispatchContext, ConfigEntity
|
||||
|
@ -26,6 +27,11 @@ try: # convert unicode to str for python2
|
|||
except NameError:
|
||||
_unicode = ()
|
||||
|
||||
try:
|
||||
_long = long
|
||||
except NameError:
|
||||
_long = int
|
||||
|
||||
|
||||
def measure_str_key(inp, include_config=True):
|
||||
""" get unique str key for MeasureInput
|
||||
|
@ -111,7 +117,7 @@ def decode(row, protocol='json'):
|
|||
if protocol == 'json':
|
||||
row = json.loads(row)
|
||||
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
|
||||
tgt = target.create(str(tgt))
|
||||
tgt = _target.create(str(tgt))
|
||||
|
||||
def clean_json_to_python(x):
|
||||
"""1. convert all list in x to tuple (hashable)
|
||||
|
@ -121,6 +127,8 @@ def decode(row, protocol='json'):
|
|||
return tuple([clean_json_to_python(a) for a in x])
|
||||
if isinstance(x, _unicode):
|
||||
return str(x)
|
||||
if isinstance(x, (_long, int)):
|
||||
return int(x)
|
||||
return x
|
||||
|
||||
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
|
||||
|
@ -132,7 +140,7 @@ def decode(row, protocol='json'):
|
|||
return inp, result
|
||||
elif protocol == 'pickle':
|
||||
items = row.split("\t")
|
||||
tgt = target.create(items[0])
|
||||
tgt = _target.create(items[0])
|
||||
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
|
||||
config = pickle.loads(base64.b64decode(items[2].encode()))
|
||||
result = pickle.loads(base64.b64decode(items[3].encode()))
|
||||
|
@ -168,36 +176,70 @@ class ApplyHistoryBest(DispatchContext):
|
|||
----------
|
||||
records : str or iterator of (MeasureInput, MeasureResult)
|
||||
Collection of tuning records.
|
||||
if is str, then it should be the filename of a records log file.
|
||||
If is str, then it should be the filename of a records log file.
|
||||
Each row of this file is an encoded record pair.
|
||||
otherwise, it is an iterator
|
||||
Otherwise, it is an iterator.
|
||||
default: ConfigEntity, optional
|
||||
default config to return when no history records
|
||||
The default config to return when no history records
|
||||
"""
|
||||
def __init__(self, records, default=None):
|
||||
super(ApplyHistoryBest, self).__init__()
|
||||
|
||||
self.best_by_targetkey = {}
|
||||
self.best_by_model = {}
|
||||
self._default = default
|
||||
|
||||
self.load(records)
|
||||
|
||||
def load(self, records):
|
||||
"""Load records to this dispatch context
|
||||
|
||||
Parameters
|
||||
----------
|
||||
records : str or iterator of (MeasureInput, MeasureResult)
|
||||
Collection of tuning records.
|
||||
If is str, then it should be the filename of a records log file.
|
||||
Each row of this file is an encoded record pair.
|
||||
Otherwise, it is an iterator.
|
||||
"""
|
||||
if isinstance(records, str):
|
||||
records = load_from_file(records)
|
||||
if not records:
|
||||
return
|
||||
|
||||
best_by_targetkey = self.best_by_targetkey
|
||||
best_by_model = self.best_by_model
|
||||
|
||||
counter = 0
|
||||
best_map = {}
|
||||
for inp, res in records:
|
||||
counter += 1
|
||||
if res.error_no != 0:
|
||||
continue
|
||||
|
||||
# use target keys in tvm target system as key to build best map
|
||||
for k in inp.target.keys:
|
||||
key = (k, inp.task.workload)
|
||||
if key not in best_map:
|
||||
best_map[key] = (inp, res)
|
||||
if key not in best_by_targetkey:
|
||||
best_by_targetkey[key] = (inp, res)
|
||||
else:
|
||||
_, other_res = best_map[key]
|
||||
_, other_res = best_by_targetkey[key]
|
||||
if np.mean(other_res.costs) > np.mean(res.costs):
|
||||
best_map[key] = (inp, res)
|
||||
logging.info(
|
||||
"Finish load %d records, %d entries selected", counter, len(best_map))
|
||||
self._best_map = best_map
|
||||
self._default = default
|
||||
best_by_targetkey[key] = (inp, res)
|
||||
|
||||
# use model as key to build best map
|
||||
for opt in inp.target.options:
|
||||
if opt.startswith("-model"):
|
||||
model = opt[7:]
|
||||
key = (model, inp.task.workload)
|
||||
if key not in best_by_model:
|
||||
best_by_model[key] = (inp, res)
|
||||
else:
|
||||
_, other_res = best_by_model[key]
|
||||
if np.mean(other_res.costs) > np.mean(res.costs):
|
||||
best_by_model[key] = (inp, res)
|
||||
break
|
||||
|
||||
logging.info("Finish loading %d records", counter)
|
||||
|
||||
def query(self, target, workload):
|
||||
if target is None:
|
||||
|
@ -205,29 +247,25 @@ class ApplyHistoryBest(DispatchContext):
|
|||
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
|
||||
" above the dispatcher call. So does other target. ")
|
||||
|
||||
# first try matching by model
|
||||
for opt in target.options:
|
||||
if opt.startswith("-model"):
|
||||
model = opt[7:]
|
||||
key = (model, workload)
|
||||
if key in self.best_by_model:
|
||||
return self.best_by_model[key][0].config
|
||||
|
||||
# then try matching by target key
|
||||
for k in target.keys:
|
||||
key = (k, workload)
|
||||
if key in self._best_map:
|
||||
return self._best_map[key][0].config
|
||||
if key in self.best_by_targetkey:
|
||||
return self.best_by_targetkey[key][0].config
|
||||
|
||||
if self._default:
|
||||
return self._default
|
||||
raise RuntimeError(
|
||||
"Cannot find config for target=%s, workload=%s" % (target, workload))
|
||||
|
||||
def dump_best(self, out_file):
|
||||
"""Dump the best records for each workload to a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_file: str
|
||||
filename
|
||||
"""
|
||||
fout = open(out_file, 'a')
|
||||
for val in self._best_map.values():
|
||||
inp, res = val
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
|
||||
|
||||
def split_workload(in_file, clean=True):
|
||||
"""Split a log file into separate files, each of which contains only a single workload
|
||||
|
@ -243,7 +281,7 @@ def split_workload(in_file, clean=True):
|
|||
tic = time.time()
|
||||
lines = list(open(in_file).readlines())
|
||||
|
||||
logging.info("start convert...")
|
||||
logging.info("start converting...")
|
||||
pool = multiprocessing.Pool()
|
||||
lines = pool.map(decode, lines)
|
||||
logging.info("map done %.2f", time.time() - tic)
|
||||
|
@ -279,23 +317,69 @@ def split_workload(in_file, clean=True):
|
|||
for inp, res in v:
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
|
||||
def pick_best(in_file, out_file):
|
||||
"""
|
||||
Pick best entries from a file and store it to another file.
|
||||
This distill the useful log entries from a large log file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_file: str
|
||||
The filename of input
|
||||
out_file:
|
||||
The filename of output
|
||||
"""
|
||||
best_context = ApplyHistoryBest(load_from_file(in_file))
|
||||
best_set = set()
|
||||
|
||||
for v in best_context.best_by_model.values():
|
||||
best_set.add(measure_str_key(v[0]))
|
||||
|
||||
for v in best_context.best_by_targetkey.values():
|
||||
best_set.add(measure_str_key(v[0]))
|
||||
|
||||
logging.info("Extract %d best records from the log file", len(best_set))
|
||||
|
||||
fout = open(out_file, 'w')
|
||||
for inp, res in load_from_file(in_file):
|
||||
if measure_str_key(inp) in best_set:
|
||||
fout.write(encode(inp, res) + "\n")
|
||||
|
||||
|
||||
def load_op_param(rootpath=os.path.join(os.path.expanduser('~'), ".tvm", "op_params")):
|
||||
"""Load pre-tuned parameters of operators.
|
||||
This function will load all "*.log" file under root path and select best configs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rootpath: str
|
||||
The root path of stored parameters
|
||||
"""
|
||||
best_context = ApplyHistoryBest([])
|
||||
for dirpath, _, filenames in os.walk(rootpath):
|
||||
for filename in filenames:
|
||||
if os.path.splitext(filename)[1] == '.log':
|
||||
best_context.load(os.path.join(dirpath, filename))
|
||||
|
||||
assert not DispatchContext.current, "Cannot load pre-tuned parameters inside a dispatch context"
|
||||
DispatchContext.current = best_context
|
||||
|
||||
"""
|
||||
Usage:
|
||||
This record executable module has three modes.
|
||||
|
||||
* Print log file in readable format
|
||||
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
|
||||
e.g. python -m autotvm.record --mode read --i collect_conv.log --begin 0 --end 5 --ir --code
|
||||
|
||||
* Extract history best from a large log file
|
||||
e.g. python -m autotvm.record --mode best --i collect.tsv
|
||||
e.g. python -m autotvm.record --mode pick --i collect.log
|
||||
|
||||
* Split a log file into separate files, each of which contains only a single wkl
|
||||
e.g. python -m autotvm.record --mode split --i collect.tsv
|
||||
e.g. python -m autotvm.record --mode split --i collect.log
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
|
||||
parser.add_argument("--mode", choices=['read', 'pick', 'split'], default='read')
|
||||
parser.add_argument("--i", type=str, help="input file")
|
||||
parser.add_argument("--o", type=str, default=None, help='output file')
|
||||
parser.add_argument("--begin", type=int, default=0)
|
||||
|
@ -306,10 +390,9 @@ if __name__ == '__main__':
|
|||
args = parser.parse_args()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if args.mode == 'best':
|
||||
args.o = args.o or args.i + ".best"
|
||||
hist_best = ApplyHistoryBest(load_from_file(args.i))
|
||||
hist_best.dump_best(args.o)
|
||||
if args.mode == 'pick':
|
||||
args.o = args.o or args.i + ".best.log"
|
||||
pick_best(args.i, args.o)
|
||||
elif args.mode == 'read':
|
||||
for i, (inp, result) in enumerate(load_from_file(args.i)):
|
||||
if args.begin <= i < args.end:
|
||||
|
|
|
@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
|
|||
of typical tasks of interest.
|
||||
"""
|
||||
|
||||
from .task import Task, create, register, template, get_config
|
||||
from .task import Task, create, register, template, get_config, args_to_workload
|
||||
from .space import ConfigSpace, ConfigEntity
|
||||
from .code_hash import attach_code_hash, attach_code_hash_to_arg
|
||||
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
|
||||
|
|
|
@ -68,6 +68,33 @@ class Task(object):
|
|||
self.flop = config.flop
|
||||
return sch, arg_bufs
|
||||
|
||||
def __getstate__(self):
|
||||
# custom pickle implementation is required for
|
||||
# some unpickable local task functions.
|
||||
# So we only pickle the name of the function
|
||||
# and restore the function by name when unpickling it.
|
||||
return {
|
||||
"name": self.name,
|
||||
"args": self.args,
|
||||
"kwargs": self.kwargs,
|
||||
"config_space": self.config_space,
|
||||
"workload": self.workload,
|
||||
"flop": self.flop,
|
||||
"target": self.target,
|
||||
"target_host": self.target_host
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.name = state["name"]
|
||||
self.args = state["args"]
|
||||
self.kwargs = state["kwargs"]
|
||||
self.config_space = state["config_space"]
|
||||
self.func = TASK_TABLE.get(state["name"], _raise_error)
|
||||
self.workload = state["workload"]
|
||||
self.flop = state["flop"]
|
||||
self.target = state["target"]
|
||||
self.target_host = state["target_host"]
|
||||
|
||||
def __repr__(self):
|
||||
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
|
||||
self.name, self.args, self.kwargs, self.workload
|
||||
|
|
|
@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
|
|||
self.train_ct += 1
|
||||
|
||||
def load_history(self, data_set):
|
||||
base_model = self.cost_model.clone_new()
|
||||
base_model.fit_log(data_set, self.plan_size)
|
||||
# filter data, only pick the data with a same task
|
||||
data = []
|
||||
for inp, res in data_set:
|
||||
if inp.task.name == self.task.name and \
|
||||
inp.config.template_key == self.task.config_space.template_key:
|
||||
data.append((inp, res))
|
||||
if not data:
|
||||
return
|
||||
|
||||
# fit base model
|
||||
base_model = self.cost_model.clone_new()
|
||||
base_model.fit_log(data, self.plan_size)
|
||||
|
||||
# use base model to select initial points
|
||||
if not self.trials:
|
||||
# no plan yet, use base model to select initial trials
|
||||
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
|
||||
maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
|
||||
self.trials = maximums
|
||||
self.trial_pt = 0
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
|||
Print log every `verbose` iterations
|
||||
"""
|
||||
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
|
||||
early_stop=30, verbose=50):
|
||||
early_stop=50, verbose=50):
|
||||
super(SimulatedAnnealingOptimizer, self).__init__()
|
||||
|
||||
self.task = task
|
||||
|
@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
|||
self.n_iter = n_iter
|
||||
self.temp = temp
|
||||
self.persistent = persistent
|
||||
self.parallel_size = parallel_size
|
||||
self.early_stop = early_stop
|
||||
self.parallel_size = min(parallel_size, len(self.task.config_space))
|
||||
self.early_stop = early_stop or 1e9
|
||||
self.verbose = verbose
|
||||
self.points = None
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class Tuner(object):
|
|||
self.best_config = None
|
||||
self.best_flops = 0
|
||||
self.best_measure_pair = None
|
||||
self.best_iter = 0
|
||||
|
||||
def has_next(self):
|
||||
"""Whether has next untried config in the space
|
||||
|
@ -63,7 +64,7 @@ class Tuner(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
|
||||
def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
|
||||
"""Begin tuning
|
||||
|
||||
Parameters
|
||||
|
@ -73,6 +74,8 @@ class Tuner(object):
|
|||
measure_option: dict
|
||||
The options for how to measure generated code.
|
||||
You should use the return value ot autotvm.measure_option for this argument.
|
||||
early_stop: int
|
||||
Early stop the tuning when not finding better configs in this number of trials
|
||||
verbose: int
|
||||
0: silent mode, no output
|
||||
1: print every measurement result
|
||||
|
@ -84,6 +87,7 @@ class Tuner(object):
|
|||
"""
|
||||
measure_batch = create_measure_batch(self.task, measure_option)
|
||||
parallel_num = getattr(measure_batch, 'parallel_num', 1)
|
||||
early_stop = early_stop or 1e9
|
||||
|
||||
i = 0
|
||||
while i < n_trial:
|
||||
|
@ -107,6 +111,7 @@ class Tuner(object):
|
|||
self.best_flops = flops
|
||||
self.best_config = config
|
||||
self.best_measure_pair = (inp, res)
|
||||
self.best_iter = i + k
|
||||
|
||||
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
||||
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
||||
|
@ -119,6 +124,10 @@ class Tuner(object):
|
|||
for callback in callbacks:
|
||||
callback(self, inputs, results)
|
||||
|
||||
if i > self.best_iter + early_stop:
|
||||
logging.info("Early stopped. Best iter: %d.", self.best_iter)
|
||||
break
|
||||
|
||||
del measure_batch
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
|
|||
self.feature_extra_ct = 0
|
||||
self.pool = None
|
||||
self.base_model = None
|
||||
self.upper_model = None
|
||||
|
||||
self._sample_size = 0
|
||||
|
||||
self._reset_pool()
|
||||
|
||||
|
@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
|
|||
_extract_task = self.task
|
||||
self.pool = multiprocessing.Pool(self.num_threads)
|
||||
|
||||
def _base_model_discount(self):
|
||||
return 1.0 / (2 ** (self._sample_size / 50.0))
|
||||
|
||||
def fit(self, xs, ys, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
|
||||
x_train = self._get_feature(xs)
|
||||
y_train = np.array(ys)
|
||||
y_train /= np.max(y_train)
|
||||
y_train = y_train / np.max(y_train)
|
||||
|
||||
valid_index = y_train > 1e-6
|
||||
index = np.random.permutation(len(x_train))
|
||||
dtrain = xgb.DMatrix(x_train[index], y_train[index])
|
||||
self._sample_size = len(x_train)
|
||||
|
||||
if self.base_model:
|
||||
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
||||
dtrain.set_base_margin(self._base_model_discount() *
|
||||
self.base_model.predict(xs, output_margin=True))
|
||||
|
||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||
num_boost_round=8000,
|
||||
|
@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
|
|||
self._reset_pool()
|
||||
|
||||
args = list(records)
|
||||
logging.info("Load %d entries from history log file", len(args))
|
||||
if self.fea_type == 'itervar':
|
||||
feature_extract_func = _extract_itervar_feature_log
|
||||
elif self.fea_type == 'knob':
|
||||
|
@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
|
|||
|
||||
plan_size *= 2
|
||||
self.bst = xgb.train(self.xgb_params, dtrain,
|
||||
num_boost_round=200,
|
||||
num_boost_round=400,
|
||||
callbacks=[custom_callback(
|
||||
stopping_rounds=100,
|
||||
metric='tr-a-recall@%d' % plan_size,
|
||||
|
@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
|
|||
dtest = xgb.DMatrix(feas)
|
||||
|
||||
if self.base_model:
|
||||
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
|
||||
dtest.set_base_margin(self._base_model_discount() *
|
||||
self.base_model.predict(xs, output_margin=True))
|
||||
|
||||
return self.bst.predict(dtest, output_margin=output_margin)
|
||||
|
||||
def load_basemodel(self, base_model):
|
||||
self.base_model = base_model
|
||||
if isinstance(base_model, XGBoostCostModel):
|
||||
# share feature cache
|
||||
base_model.feature_cache = self.feature_cache
|
||||
|
||||
# close thread pool
|
||||
if base_model.pool:
|
||||
base_model.pool.terminate()
|
||||
base_model.pool.join()
|
||||
del base_model.pool
|
||||
self.base_model.upper_model = self
|
||||
|
||||
def clone_new(self):
|
||||
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
|
||||
|
@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
|
|||
need_extract = [x for x in indexes if x not in fea_cache]
|
||||
|
||||
if need_extract:
|
||||
feas = self.pool.map(self.feature_extract_func, need_extract)
|
||||
pool = self.pool if self.upper_model is None else self.upper_model.pool
|
||||
feas = pool.map(self.feature_extract_func, need_extract)
|
||||
for i, fea in zip(need_extract, feas):
|
||||
fea_cache[i] = fea
|
||||
|
||||
|
|
|
@ -346,6 +346,7 @@ def generic_func(fdefault):
|
|||
return func(*args, **kwargs)
|
||||
fdecorate = decorate(fdefault, dispatch_func)
|
||||
fdecorate.register = register
|
||||
fdecorate.fdefault = fdefault
|
||||
return fdecorate
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
from tvm.autotvm import MeasureInput, MeasureResult
|
||||
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
|
||||
|
||||
from test_autotvm_common import get_sample_task, get_sample_records
|
||||
|
||||
|
||||
def test_fit():
|
||||
task, target = get_sample_task()
|
||||
records = get_sample_records(n=100)
|
||||
|
||||
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||
base_model.fit_log(records, plan_size=32)
|
||||
|
||||
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
|
||||
upper_model.load_basemodel(base_model)
|
||||
|
||||
xs = np.arange(100)
|
||||
ys = np.arange(100)
|
||||
|
||||
upper_model.fit(xs, ys, plan_size=32)
|
||||
|
||||
|
||||
def test_tuner():
|
||||
task, target = get_sample_task()
|
||||
records = get_sample_records(n=100)
|
||||
|
||||
tuner = autotvm.tuner.XGBTuner(task)
|
||||
tuner.load_history(records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fit()
|
||||
test_tuner()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
|
||||
Tuning High Performance Convolution on NVIDIA GPUs
|
||||
=========================================================================
|
||||
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
|
||||
|
||||
|
@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
|
|||
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import tvm
|
||||
import topi
|
||||
from topi.testing import conv2d_nchw_python
|
||||
|
||||
from tvm import autotvm
|
||||
|
||||
|
@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
|
|||
# logging config (for printing tuning log to screen)
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
|
||||
# the last layer in resnet
|
||||
# the last layer in resnet
|
||||
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
|
||||
task = autotvm.task.create(conv2d_no_batching,
|
||||
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
|
||||
args=(N, H, W, CO, CI, KH, KW, strides, padding),
|
||||
target='cuda')
|
||||
print(task.config_space)
|
||||
|
||||
|
@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
|
|||
parallel_num=8,
|
||||
timeout=20)
|
||||
|
||||
# begin tuning, log records to file `cache.tsv`
|
||||
# begin tuning, log records to file `conv2d.tsv`
|
||||
tuner = autotvm.tuner.XGBTuner(task)
|
||||
tuner.tune(n_trial=20,
|
||||
measure_option=measure_option,
|
||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
||||
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
|
||||
|
||||
# get best config from cache file
|
||||
dispatch_context = autotvm.apply_history_best("cache.tsv")
|
||||
#########################################################################
|
||||
# Finally we can inspect the best config from log file, check correctness,
|
||||
# and measure running time.
|
||||
|
||||
# inspect the best config
|
||||
dispatch_context = autotvm.apply_history_best("conv2d.log")
|
||||
best_config = dispatch_context.query(task.target, task.workload)
|
||||
print("\nBest config:")
|
||||
print(best_config)
|
||||
|
||||
# apply history best from log file
|
||||
with autotvm.apply_history_best('conv2d.log'):
|
||||
with tvm.target.create("cuda"):
|
||||
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
|
||||
func = tvm.build(s, arg_bufs)
|
||||
|
||||
# check correctness
|
||||
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
|
||||
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
|
||||
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
|
||||
|
||||
ctx = tvm.gpu()
|
||||
a_tvm = tvm.nd.array(a_np, ctx=ctx)
|
||||
w_tvm = tvm.nd.array(w_np, ctx=ctx)
|
||||
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
|
||||
func(a_tvm, w_tvm, c_tvm)
|
||||
|
||||
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
||||
|
||||
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
|
||||
# and the overhead of kernel launch. You can also use nvprof to validate the result.
|
||||
|
||||
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
|
||||
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
|
||||
|
|
|
@ -243,7 +243,7 @@ print(task.config_space)
|
|||
#
|
||||
# We only make 10 trials in this tutorial for demonstration. In practice,
|
||||
# you can do more trials according to your time budget.
|
||||
# We will log the tuning results into a cache file. This file can be
|
||||
# We will log the tuning results into a log file. This file can be
|
||||
# used to get the best config later.
|
||||
|
||||
# logging config (for printing tuning log to screen)
|
||||
|
@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
|||
measure_option = autotvm.measure_option(mode='local',
|
||||
number=5)
|
||||
|
||||
# begin tuning, log records to file `cache.tsv`
|
||||
# begin tuning, log records to file `matmul.log`
|
||||
tuner = autotvm.tuner.RandomTuner(task)
|
||||
tuner.tune(n_trial=10,
|
||||
measure_option=measure_option,
|
||||
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
|
||||
callbacks=[autotvm.callback.log_to_file('matmul.log')])
|
||||
|
||||
#########################################################################
|
||||
# Finally we apply history best from the cache file and check its correctness.
|
||||
|
@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
|
|||
# with the same argument.
|
||||
|
||||
# apply history best from log file
|
||||
with autotvm.apply_history_best('cache.tsv'):
|
||||
with autotvm.apply_history_best('matmul.log'):
|
||||
with tvm.target.create("llvm"):
|
||||
s, arg_bufs = matmul(N, L, M, 'float32')
|
||||
func = tvm.build(s, arg_bufs)
|
||||
|
@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
|
|||
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
|
||||
|
||||
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче