[AUTOTVM] Improve tutorial and logging (#1544)
This commit is contained in:
Родитель
33606741bd
Коммит
136061dcdc
|
@ -1,7 +1,7 @@
|
|||
"""Distributed executor infrastructure to scale up the tuning"""
|
||||
|
||||
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
|
||||
from .measure_methods import request_remote, create_measure_batch, use_rpc
|
||||
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc
|
||||
|
||||
from .local_executor import LocalExecutor
|
||||
from .executor import Future, Executor
|
||||
|
|
|
@ -9,6 +9,7 @@ import logging
|
|||
import os
|
||||
import time
|
||||
from random import getrandbits
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -23,6 +24,7 @@ from ..task.space import InstantiationError
|
|||
from .measure import MeasureResult, MeasureErrorNo
|
||||
from .local_executor import LocalExecutor
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class HashMismatchError(ValueError):
|
||||
"""Raised when the code hash of a submitted config doesn't match that on the
|
||||
|
@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
|
|||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
and "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
priority of this request, larger is more prior
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
timeout of this session (units: seconds)
|
||||
The timeout of this session (units: seconds)
|
||||
|
||||
Returns
|
||||
------
|
||||
|
@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
|
|||
session_timeout=timeout)
|
||||
return remote
|
||||
|
||||
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
|
||||
"""
|
||||
Check the availability of a remote device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target: Target
|
||||
The wanted compilation target
|
||||
device_key: string
|
||||
device key of registered device in tracker
|
||||
tracker_addr: Tuple(string, int), optional
|
||||
The address of rpc tracker in (host, port) format.
|
||||
If is none, will use environment variable "TVM_TRACKER_HOST"
|
||||
and "TVM_TRACKER_PORT"
|
||||
priority: int, optional
|
||||
The priority of this request, larger is more prior
|
||||
timeout: float, optional
|
||||
The timeout of this check (units: seconds).
|
||||
If time is out, a RuntimerError will be raised.
|
||||
"""
|
||||
def _check():
|
||||
remote = request_remote(device_key, tracker_addr, priority)
|
||||
remote.context(str(target))
|
||||
t = threading.Thread(target=_check,)
|
||||
t.start()
|
||||
t.join(timeout)
|
||||
return not t.is_alive()
|
||||
|
||||
def create_measure_batch(task, option):
|
||||
"""Get a standard measure_batch function.
|
||||
|
@ -115,6 +144,17 @@ def create_measure_batch(task, option):
|
|||
build_func = default_build_func
|
||||
build_kwargs['use_ndk'] = True
|
||||
|
||||
# check the availability of remote devices
|
||||
if hasattr(measure_func, 'rpc_info'):
|
||||
rpc_info = measure_func.rpc_info
|
||||
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
|
||||
logger.info("Get devices for measurement successfully!")
|
||||
else:
|
||||
raise RuntimeError("Cannot get remote devices from the tracker. "
|
||||
"Please check the status of tracker by "
|
||||
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
|
||||
"and make sure you have free devices on the queue status.")
|
||||
|
||||
# add device info of cuda and opencl target
|
||||
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
|
||||
and hasattr(measure_func, 'rpc_info'):
|
||||
|
@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
|||
continue
|
||||
except InstantiationError as e:
|
||||
tstamp = time.time()
|
||||
res_pack.append(MeasureResult((e,),
|
||||
res_pack.append(MeasureResult((InstantiationError(str(e)),),
|
||||
MeasureErrorNo.INSTANTIATION_ERROR,
|
||||
tstamp - tic, tstamp))
|
||||
continue
|
||||
|
@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
|
|||
if ref_output:
|
||||
for expected, real in zip(ref_output, args):
|
||||
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
|
||||
logging.warning("Wrong Answer!")
|
||||
logger.warning("Wrong Answer!")
|
||||
errno = MeasureErrorNo.WRONG_ANSWER
|
||||
except TVMError as exc:
|
||||
msg = str(exc)
|
||||
|
|
|
@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
|
|||
from .measure import MeasureInput, MeasureResult
|
||||
|
||||
AUTOTVM_LOG_VERSION = 0.1
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
try: # convert unicode to str for python2
|
||||
_unicode = unicode
|
||||
|
@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
|
|||
tic = time.time()
|
||||
lines = list(open(in_file).readlines())
|
||||
|
||||
logging.info("start converting...")
|
||||
logger.info("start converting...")
|
||||
pool = multiprocessing.Pool()
|
||||
lines = pool.map(decode, lines)
|
||||
logging.info("map done %.2f", time.time() - tic)
|
||||
logger.info("map done %.2f", time.time() - tic)
|
||||
|
||||
wkl_dict = OrderedDict()
|
||||
for inp, res in lines:
|
||||
|
@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
|
|||
cleaned.append([inp, res])
|
||||
|
||||
# write to file
|
||||
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
|
||||
logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
|
||||
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
|
||||
for inp, res in cleaned:
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
else:
|
||||
for i, (k, v) in enumerate(wkl_dict.items()):
|
||||
logging.info("Key: %s\tNum: %d", k, len(v))
|
||||
logger.info("Key: %s\tNum: %d", k, len(v))
|
||||
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
|
||||
for inp, res in v:
|
||||
fout.write(encode(inp, res) + '\n')
|
||||
|
@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
|
|||
for v in best_context.best_by_targetkey.values():
|
||||
best_set.add(measure_str_key(v[0]))
|
||||
|
||||
logging.info("Extract %d best records from the %s", len(best_set), in_file)
|
||||
logger.info("Extract %d best records from the %s", len(best_set), in_file)
|
||||
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
|
||||
|
||||
for inp, res in load_from_file(in_file):
|
||||
|
@ -270,7 +271,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--code", action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.basicConfig(level=logger.INFO)
|
||||
|
||||
if args.mode == 'pick':
|
||||
args.o = args.o or args.i + ".best.log"
|
||||
|
|
|
@ -10,6 +10,8 @@ of the DispatchContext base class.
|
|||
- During search, we can use it to pass the current proposal from tuner.
|
||||
- During evaluation, we can use it to set pick the best policy.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
import logging
|
||||
|
@ -19,6 +21,8 @@ import numpy as np
|
|||
|
||||
from tvm import target as _target
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class DispatchContext(object):
|
||||
"""
|
||||
Base class of dispatch context.
|
||||
|
@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
|
|||
best_by_model[key] = (inp, res)
|
||||
break
|
||||
|
||||
logging.debug("Finish loading %d records", counter)
|
||||
logger.debug("Finish loading %d records", counter)
|
||||
|
||||
def query(self, target, workload):
|
||||
if target is None:
|
||||
|
|
|
@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
|
|||
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
|
||||
TVM will download these parameters for you when you create the target for the first time.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
@ -16,6 +17,7 @@ from ..contrib.download import download
|
|||
|
||||
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
def _alias(name):
|
||||
"""convert alias for some packages"""
|
||||
|
@ -79,7 +81,7 @@ def download_package(backend):
|
|||
os.mkdir(path)
|
||||
|
||||
backend = _alias(backend)
|
||||
logging.info("Download pre-tuned parameters for %s", backend)
|
||||
logger.info("Download pre-tuned parameters for %s", backend)
|
||||
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
|
||||
os.path.join(rootpath, backend + ".log"), True, verbose=0)
|
||||
|
||||
|
@ -110,7 +112,7 @@ def list_packages():
|
|||
"""
|
||||
path = tempdir()
|
||||
filename = path.relpath("info.json")
|
||||
logging.info("Download meta info for pre-tuned parameters")
|
||||
logger.info("Download meta info for pre-tuned parameters")
|
||||
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
|
||||
filename, True, verbose=0)
|
||||
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
"""Namespace of callback utilities of AutoTVM"""
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import record
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
def log_to_file(file_out, protocol='json'):
|
||||
"""Log the tuning records into file.
|
||||
|
@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
|
|||
prefix: str
|
||||
The prefix of output message
|
||||
"""
|
||||
class _Context:
|
||||
class _Context(object):
|
||||
"""Context to store local variables"""
|
||||
def __init__(self):
|
||||
self.best_flops = 0
|
||||
|
@ -112,13 +114,14 @@ def progress_bar(total, prefix=''):
|
|||
if res.error_no == 0:
|
||||
flops = inp.task.flop / np.mean(res.costs)
|
||||
|
||||
ctx.cur_flops = flops
|
||||
ctx.best_flops = tuner.best_flops
|
||||
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
|
||||
ctx.cur_flops = flops
|
||||
ctx.best_flops = tuner.best_flops
|
||||
|
||||
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
|
||||
'| %.2f s' %
|
||||
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
|
||||
time.time() - tic))
|
||||
sys.stdout.flush()
|
||||
sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
|
||||
'| %.2f s\r' %
|
||||
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
|
||||
time.time() - tic))
|
||||
sys.stdout.flush()
|
||||
|
||||
return _callback
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=consider-using-enumerate
|
||||
# pylint: disable=consider-using-enumerate, invalid-name
|
||||
"""
|
||||
Cost model optimizer based on simulated annealing
|
||||
"""
|
||||
|
@ -12,6 +12,8 @@ import numpy as np
|
|||
from ..util import sample_ints
|
||||
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class SimulatedAnnealingOptimizer(ModelOptimizer):
|
||||
"""parallel simulated annealing optimization algorithm
|
||||
|
||||
|
@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
|
|||
|
||||
if log_interval and k % log_interval == 0:
|
||||
t_str = "%.2f" % t
|
||||
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
|
||||
"elapsed: %.2f",
|
||||
k, k_last_modify, heap_items[0][0],
|
||||
np.max([v for v, _ in heap_items]), t_str,
|
||||
time.time() - tic)
|
||||
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
|
||||
"elapsed: %.2f",
|
||||
k, k_last_modify, heap_items[0][0],
|
||||
np.max([v for v, _ in heap_items]), t_str,
|
||||
time.time() - tic)
|
||||
|
||||
heap_items.sort(key=lambda item: -item[0])
|
||||
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
|
||||
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
|
||||
logging.debug("SA Maximums: %s", heap_items)
|
||||
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
|
||||
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
|
||||
logger.debug("SA Maximums: %s", heap_items)
|
||||
|
||||
if self.persistent:
|
||||
self.points = points
|
||||
|
|
|
@ -4,11 +4,12 @@ import logging
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ..measure import MeasureInput
|
||||
from ..measure import create_measure_batch
|
||||
from ..measure import MeasureInput, create_measure_batch
|
||||
|
||||
from ..env import GLOBAL_SCOPE
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class Tuner(object):
|
||||
"""Base class for tuners
|
||||
|
||||
|
@ -86,9 +87,10 @@ class Tuner(object):
|
|||
measure_batch = create_measure_batch(self.task, measure_option)
|
||||
parallel_num = getattr(measure_batch, 'parallel_num', 1)
|
||||
early_stopping = early_stopping or 1e9
|
||||
old_level = logger.level
|
||||
|
||||
GLOBAL_SCOPE.in_tuning = True
|
||||
i = 0
|
||||
i = error_ct = 0
|
||||
while i < n_trial:
|
||||
if not self.has_next():
|
||||
break
|
||||
|
@ -103,17 +105,20 @@ class Tuner(object):
|
|||
config = inp.config
|
||||
if res.error_no == 0:
|
||||
flops = inp.task.flop / np.mean(res.costs)
|
||||
error_ct = 0
|
||||
else:
|
||||
flops = 0
|
||||
error_ct += 1
|
||||
|
||||
if flops > self.best_flops:
|
||||
self.best_flops = flops
|
||||
self.best_config = config
|
||||
self.best_measure_pair = (inp, res)
|
||||
self.best_iter = i + k
|
||||
|
||||
logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
||||
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
||||
res, config)
|
||||
logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
|
||||
i + k + 1, flops / 1e9, self.best_flops / 1e9,
|
||||
res, config)
|
||||
|
||||
i += len(results)
|
||||
|
||||
|
@ -123,11 +128,16 @@ class Tuner(object):
|
|||
callback(self, inputs, results)
|
||||
|
||||
if i > self.best_iter + early_stopping:
|
||||
logging.debug("Early stopped. Best iter: %d.", self.best_iter)
|
||||
logger.debug("Early stopped. Best iter: %d.", self.best_iter)
|
||||
break
|
||||
|
||||
GLOBAL_SCOPE.in_tuning = False
|
||||
if error_ct > 50:
|
||||
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
else:
|
||||
logger.setLevel(old_level)
|
||||
|
||||
GLOBAL_SCOPE.in_tuning = False
|
||||
del measure_batch
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -16,6 +16,8 @@ from ..util import get_rank
|
|||
from .metric import max_curve, recall_curve, cover_curve
|
||||
from .model_based_tuner import CostModel, FeatureCache
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class XGBoostCostModel(CostModel):
|
||||
"""XGBoost as cost model
|
||||
|
||||
|
@ -163,17 +165,17 @@ class XGBoostCostModel(CostModel):
|
|||
],
|
||||
verbose_eval=self.log_interval)])
|
||||
|
||||
logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
|
||||
time.time() - tic, len(xs),
|
||||
len(xs) - np.sum(valid_index),
|
||||
self.feature_cache.size(self.fea_type))
|
||||
logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
|
||||
time.time() - tic, len(xs),
|
||||
len(xs) - np.sum(valid_index),
|
||||
self.feature_cache.size(self.fea_type))
|
||||
|
||||
def fit_log(self, records, plan_size):
|
||||
tic = time.time()
|
||||
self._reset_pool()
|
||||
|
||||
args = list(records)
|
||||
logging.debug("XGB load %d entries from history log file", len(args))
|
||||
logger.debug("XGB load %d entries from history log file", len(args))
|
||||
|
||||
if self.fea_type == 'itervar':
|
||||
feature_extract_func = _extract_itervar_feature_log
|
||||
|
@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
|
|||
],
|
||||
verbose_eval=self.log_interval)])
|
||||
|
||||
logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
|
||||
logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
|
||||
|
||||
def predict(self, xs, output_margin=False):
|
||||
feas = self._get_feature(xs)
|
||||
|
@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
|
|||
infos.append("%s: %.6f" % (item[0], item[1]))
|
||||
|
||||
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
|
||||
logging.debug("\t".join(infos))
|
||||
logger.debug("\t".join(infos))
|
||||
if log_file:
|
||||
with open(log_file, "a") as fout:
|
||||
fout.write("\t".join(infos) + '\n')
|
||||
|
@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
|
|||
elif env.iteration - best_iteration >= stopping_rounds:
|
||||
best_msg = state['best_msg']
|
||||
if verbose_eval and env.rank == 0:
|
||||
logging.debug("XGB stopped. Best iteration: %s ", best_msg)
|
||||
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
|
||||
raise EarlyStopException(best_iteration)
|
||||
|
||||
return callback
|
||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
|
||||
from .. import expr, ir_pass
|
||||
|
||||
logger = logging.getLogger('autotvm')
|
||||
|
||||
class EmptyContext(object):
|
||||
"""An empty context"""
|
||||
|
@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
|
|||
tic = time.time()
|
||||
local_pool = pool or multiprocessing.Pool()
|
||||
if verbose:
|
||||
logging.info("mapping begin")
|
||||
logger.info("mapping begin")
|
||||
for i in range(0, len(args), batch_size):
|
||||
if verbose:
|
||||
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
|
||||
time.time() - tic)
|
||||
logger.info("mapping %d/%d elapsed %.2f", i, len(args),
|
||||
time.time() - tic)
|
||||
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
|
||||
ret = tmp if ret is None else np.concatenate((ret, tmp))
|
||||
if verbose:
|
||||
logging.info("mapping done")
|
||||
logger.info("mapping done")
|
||||
if not pool:
|
||||
local_pool.close()
|
||||
return ret
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Base definitions for RPC."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import socket
|
||||
|
@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
|
|||
# cannot found matched key in server
|
||||
RPC_CODE_MISMATCH = RPC_MAGIC + 2
|
||||
|
||||
logger = logging.getLogger('RPCServer')
|
||||
|
||||
class TrackerCode(object):
|
||||
"""Enumeration code for the RPC tracker"""
|
||||
|
@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
|
|||
return prefix + str(random.random())
|
||||
|
||||
|
||||
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
|
||||
def connect_with_retry(addr, timeout=60, retry_period=5):
|
||||
"""Connect to a TPC address with retry
|
||||
|
||||
This function is only reliable to short period of server restart.
|
||||
|
@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
|
|||
|
||||
retry_period : float
|
||||
Number of seconds before we retry again.
|
||||
|
||||
silent: bool
|
||||
whether run in silent mode
|
||||
"""
|
||||
tstart = time.time()
|
||||
while True:
|
||||
|
@ -152,9 +152,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
|
|||
if period > timeout:
|
||||
raise RuntimeError(
|
||||
"Failed to connect to server %s" % str(addr))
|
||||
if not silent:
|
||||
logging.info("Cannot connect to tracker%s, retry in %g secs...",
|
||||
str(addr), retry_period)
|
||||
logger.warning("Cannot connect to tracker %s, retry in %g secs...",
|
||||
str(addr), retry_period)
|
||||
time.sleep(retry_period)
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@ try:
|
|||
from tornado import ioloop
|
||||
from . import tornado_util
|
||||
except ImportError as error_msg:
|
||||
raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
|
||||
raise ImportError(
|
||||
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
|
||||
|
||||
from . import base
|
||||
from .base import TrackerCode
|
||||
|
@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
|
|||
def _connect(key):
|
||||
conn = yield websocket.websocket_connect(url)
|
||||
on_message = create_on_message(conn)
|
||||
temp = _server_env(None, None)
|
||||
temp = _server_env(None)
|
||||
# Start connecton
|
||||
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
|
||||
key = "server:" + key
|
||||
|
|
|
@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
|
|||
- The key is in format
|
||||
- {server|client}:device-type[:random-key] [-timeout=timeout]
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
|
@ -30,11 +32,11 @@ from ..contrib import util
|
|||
from . import base
|
||||
from . base import TrackerCode
|
||||
|
||||
def _server_env(load_library, logger):
|
||||
logger = logging.getLogger('RPCServer')
|
||||
|
||||
def _server_env(load_library):
|
||||
"""Server environment function return temp dir"""
|
||||
temp = util.tempdir()
|
||||
if logger is None:
|
||||
logger = logging.getLogger()
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
@register_func("tvm.rpc.server.workpath")
|
||||
|
@ -59,13 +61,10 @@ def _server_env(load_library, logger):
|
|||
return temp
|
||||
|
||||
|
||||
def _serve_loop(sock, addr, load_library, silent):
|
||||
def _serve_loop(sock, addr, load_library):
|
||||
"""Server loop"""
|
||||
logger = logging.getLogger("RPCServer")
|
||||
if silent:
|
||||
logger.disabled = True
|
||||
sockfd = sock.fileno()
|
||||
temp = _server_env(load_library, logger)
|
||||
temp = _server_env(load_library)
|
||||
base._ServerLoop(sockfd)
|
||||
temp.remove()
|
||||
logger.info("Finish serving %s", addr)
|
||||
|
@ -79,12 +78,8 @@ def _parse_server_opt(opts):
|
|||
ret["timeout"] = float(kv[9:])
|
||||
return ret
|
||||
|
||||
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
|
||||
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
|
||||
"""Listening loop of the server master."""
|
||||
logger = logging.getLogger("RPCServer")
|
||||
if silent:
|
||||
logger.disabled = True
|
||||
|
||||
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
|
||||
"""Accept connection from the other places.
|
||||
|
||||
|
@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
|
|||
if arr[0] != expect_header:
|
||||
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
|
||||
conn.close()
|
||||
logger.info("mismatch key from %s", addr)
|
||||
logger.warning("mismatch key from %s", addr)
|
||||
continue
|
||||
else:
|
||||
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
|
||||
|
@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
|
|||
try:
|
||||
# step 1: setup tracker and report to tracker
|
||||
if tracker_addr and tracker_conn is None:
|
||||
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
|
||||
tracker_conn = base.connect_with_retry(tracker_addr)
|
||||
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
|
||||
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
|
||||
if magic != base.RPC_TRACKER_MAGIC:
|
||||
|
@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
|
|||
tracker_conn = None
|
||||
continue
|
||||
except RuntimeError as exc:
|
||||
if silent:
|
||||
return
|
||||
else:
|
||||
raise exc
|
||||
raise exc
|
||||
|
||||
# step 3: serving
|
||||
logger.info("connection from %s", addr)
|
||||
server_proc = multiprocessing.Process(target=_serve_loop,
|
||||
args=(conn, addr, load_library, silent))
|
||||
args=(conn, addr, load_library))
|
||||
server_proc.deamon = True
|
||||
server_proc.start()
|
||||
# close from our side.
|
||||
|
@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
|
|||
server_proc.terminate()
|
||||
|
||||
|
||||
def _connect_proxy_loop(addr, key, load_library, silent):
|
||||
logger = logging.getLogger("RPCProxy")
|
||||
if silent:
|
||||
logger.disabled = True
|
||||
def _connect_proxy_loop(addr, key, load_library):
|
||||
key = "server:" + key
|
||||
retry_count = 0
|
||||
max_retry = 5
|
||||
|
@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
|
|||
if magic == base.RPC_CODE_DUPLICATE:
|
||||
raise RuntimeError("key: %s has already been used in proxy" % key)
|
||||
elif magic == base.RPC_CODE_MISMATCH:
|
||||
logger.info("RPCProxy do not have matching client key %s", key)
|
||||
logger.warning("RPCProxy do not have matching client key %s", key)
|
||||
elif magic != base.RPC_CODE_SUCCESS:
|
||||
raise RuntimeError("%s is not RPC Proxy" % str(addr))
|
||||
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
|
||||
|
@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
|
|||
opts = _parse_server_opt(remote_key.split()[1:])
|
||||
logger.info("connected to %s", str(addr))
|
||||
process = multiprocessing.Process(
|
||||
target=_serve_loop, args=(sock, addr, load_library, silent))
|
||||
target=_serve_loop, args=(sock, addr, load_library))
|
||||
process.deamon = True
|
||||
process.start()
|
||||
sock.close()
|
||||
|
@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
|
|||
retry_count = 0
|
||||
except (socket.error, IOError) as err:
|
||||
retry_count += 1
|
||||
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
|
||||
logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
|
||||
if retry_count > max_retry:
|
||||
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
|
||||
time.sleep(retry_period)
|
||||
|
@ -323,9 +312,8 @@ class Server(object):
|
|||
self.custom_addr = custom_addr
|
||||
self.use_popen = use_popen
|
||||
|
||||
self.logger = logging.getLogger("RPCServer")
|
||||
if silent:
|
||||
self.logger.disabled = True
|
||||
logger.setLevel(logging.WARN)
|
||||
|
||||
if use_popen:
|
||||
cmd = [sys.executable,
|
||||
|
@ -360,18 +348,18 @@ class Server(object):
|
|||
raise sock_err
|
||||
if not self.port:
|
||||
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
|
||||
self.logger.info("bind to %s:%d", host, self.port)
|
||||
logger.info("bind to %s:%d", host, self.port)
|
||||
sock.listen(1)
|
||||
self.sock = sock
|
||||
self.proc = multiprocessing.Process(
|
||||
target=_listen_loop, args=(
|
||||
self.sock, self.port, key, tracker_addr, load_library,
|
||||
self.custom_addr, silent))
|
||||
self.custom_addr))
|
||||
self.proc.deamon = True
|
||||
self.proc.start()
|
||||
else:
|
||||
self.proc = multiprocessing.Process(
|
||||
target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
|
||||
target=_connect_proxy_loop, args=((host, port), key, load_library))
|
||||
self.proc.deamon = True
|
||||
self.proc.start()
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ List of available APIs:
|
|||
- input: [TrackerCode.REQUEST, [key, user, priority]]
|
||||
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import heapq
|
||||
import time
|
||||
import logging
|
||||
|
@ -37,12 +39,13 @@ try:
|
|||
from . import tornado_util
|
||||
except ImportError as error_msg:
|
||||
raise ImportError(
|
||||
"RPCTracker module requires tornado package %s" % error_msg)
|
||||
"RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
|
||||
|
||||
from .._ffi.base import py_str
|
||||
from . import base
|
||||
from .base import RPC_TRACKER_MAGIC, TrackerCode
|
||||
|
||||
logger = logging.getLogger("RPCTracker")
|
||||
|
||||
class Scheduler(object):
|
||||
"""Abstratc interface of scheduler."""
|
||||
|
@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
|
|||
def _init_conn(self, message):
|
||||
"""Initialie the connection"""
|
||||
if len(message) != 4:
|
||||
logging.info("Invalid connection from %s", self.name())
|
||||
logger.warning("Invalid connection from %s", self.name())
|
||||
self.close()
|
||||
magic = struct.unpack('<i', message)[0]
|
||||
if magic != RPC_TRACKER_MAGIC:
|
||||
logging.info("Invalid magic from %s", self.name())
|
||||
logger.warning("Invalid magic from %s", self.name())
|
||||
self.close()
|
||||
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
|
||||
self._init_req_nbytes = 0
|
||||
|
@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
|
|||
status = self._tracker.summary()
|
||||
self.ret_value([TrackerCode.SUCCESS, status])
|
||||
else:
|
||||
logging.info("Unknown tracker code %d", code)
|
||||
logger.warning("Unknown tracker code %d", code)
|
||||
self.close()
|
||||
|
||||
def on_close(self):
|
||||
self._tracker._connections.remove(self)
|
||||
|
||||
def on_error(self, err):
|
||||
logging.info("%s: Error in RPC Tracker: %s", self.name(), err)
|
||||
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
|
||||
self.close()
|
||||
|
||||
|
||||
|
@ -335,9 +338,8 @@ class Tracker(object):
|
|||
port=9190,
|
||||
port_end=9199,
|
||||
silent=False):
|
||||
self.logger = logging.getLogger("RPCTracker")
|
||||
if silent:
|
||||
self.logger.disabled = True
|
||||
logger.setLevel(logging.WARN)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.port = None
|
||||
|
@ -354,7 +356,7 @@ class Tracker(object):
|
|||
raise sock_err
|
||||
if not self.port:
|
||||
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
|
||||
self.logger.info("bind to %s:%d", host, self.port)
|
||||
logger.info("bind to %s:%d", host, self.port)
|
||||
sock.listen(1)
|
||||
self.proc = multiprocessing.Process(
|
||||
target=_tracker_server, args=(sock, self.stop_key))
|
||||
|
@ -380,7 +382,7 @@ class Tracker(object):
|
|||
self._stop_tracker()
|
||||
self.proc.join(1)
|
||||
if self.proc.is_alive():
|
||||
self.logger.info("Terminating Tracker Server...")
|
||||
logger.info("Terminating Tracker Server...")
|
||||
self.proc.terminate()
|
||||
self.proc = None
|
||||
|
||||
|
|
|
@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
|
|||
# for this template
|
||||
|
||||
# logging config (for printing tuning log to screen)
|
||||
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
|
||||
logging.getLogger('autotvm').setLevel(logging.DEBUG)
|
||||
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
|
||||
|
||||
# the last layer in resnet
|
||||
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
|
||||
|
|
|
@ -163,8 +163,10 @@ def get_network(name, batch_size):
|
|||
# Set Tuning Options
|
||||
# ------------------
|
||||
# Before tuning, we should do some configurations. Here I use an RK3399 board
|
||||
# in our environment as example. In your setting, you should modify the target
|
||||
# and device_key accordingly.
|
||||
# as example. In your setting, you should modify the target and device_key accordingly.
|
||||
# set :code:`use_android` to True if you use android phone.
|
||||
|
||||
#### DEVICE CONFIG ####
|
||||
|
||||
# Replace "aarch64-linux-gnu" with the correct target of your board.
|
||||
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
|
||||
|
@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
|
|||
# Also replace this with the device key in your tracker
|
||||
device_key = 'rk3399'
|
||||
|
||||
# tuning option
|
||||
# Set this to True if you use android phone
|
||||
use_android = False
|
||||
|
||||
#### TUNING OPTION ####
|
||||
network = 'resnet-18'
|
||||
log_file = "%s.%s.log" % (device_key, network)
|
||||
dtype = 'float32'
|
||||
|
@ -181,17 +186,17 @@ dtype = 'float32'
|
|||
tuning_option = {
|
||||
'log_filename': log_file,
|
||||
|
||||
'tuner':'xgb',
|
||||
'tuner': 'xgb',
|
||||
'n_trial': 1000,
|
||||
'early_stopping': 200,
|
||||
'early_stopping': 250,
|
||||
|
||||
'measure_option': autotvm.measure_option(
|
||||
autotvm.use_rpc(device_key, host='localhost', port=9190),
|
||||
number=4,
|
||||
parallel_num=1,
|
||||
timeout=10),
|
||||
|
||||
'use_transfer_learning': True,
|
||||
timeout=10,
|
||||
build_func='ndk' if use_android else 'default',
|
||||
),
|
||||
}
|
||||
|
||||
####################################################################
|
||||
|
@ -208,9 +213,6 @@ tuning_option = {
|
|||
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
|
||||
# consider setting timeout larger.
|
||||
#
|
||||
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
|
||||
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
|
||||
#
|
||||
|
||||
###################################################################
|
||||
# Begin Tuning
|
||||
|
@ -280,12 +282,14 @@ def tune_tasks(tasks,
|
|||
|
||||
def tune_and_evaluate():
|
||||
# extract workloads from nnvm graph
|
||||
print("Extract tasks...")
|
||||
net, params, shape, out_shape = get_network(network, batch_size=1)
|
||||
tasks = autotvm.task.extract_from_graph(net, shape=shape, dtype=dtype,
|
||||
symbols=(nnvm.sym.conv2d,),
|
||||
target=target)
|
||||
|
||||
# run tuning tasks
|
||||
print("Tuning...")
|
||||
tune_tasks(tasks, **tuning_option)
|
||||
|
||||
# compile kernels with history best records
|
||||
|
@ -325,10 +329,11 @@ def tune_and_evaluate():
|
|||
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
|
||||
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
|
||||
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
|
||||
(np.mean(prof_res), np.std(prof_res)))
|
||||
(np.mean(prof_res), np.std(prof_res)))
|
||||
|
||||
# We do not run the tuning in our webpage server since it takes too long.
|
||||
# Uncomment the following line to run by yourself.
|
||||
|
||||
# tune_and_evaluate()
|
||||
|
||||
######################################################################
|
||||
|
@ -341,6 +346,8 @@ def tune_and_evaluate():
|
|||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# Extract tasks...
|
||||
# Tuning...
|
||||
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
|
||||
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
|
||||
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
|
||||
|
@ -362,3 +369,23 @@ def tune_and_evaluate():
|
|||
# Evaluate inference time cost...
|
||||
# Mean inference time (std dev): 156.51 ms (0.89 ms)
|
||||
#
|
||||
|
||||
|
||||
######################################################################
|
||||
#
|
||||
# .. note:: **Meet some problems?**
|
||||
#
|
||||
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
|
||||
# then there must be something wrong.
|
||||
#
|
||||
# First, make sure you set the correct configuration of your device.
|
||||
# Then, you can print debug information by adding these lines in the beginning
|
||||
# of the script. It will print every measurement result, where you can find useful
|
||||
# error messages.
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# import logging
|
||||
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
|
||||
#
|
||||
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
|
||||
|
|
|
@ -267,8 +267,9 @@ print(task.config_space)
|
|||
# We will log the tuning results into a log file. This file can be
|
||||
# used to get the best config later.
|
||||
|
||||
# logging config (for printing tuning log to screen)
|
||||
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
|
||||
# logging config (for printing tuning log to the screen)
|
||||
logging.getLogger('autotvm').setLevel(logging.DEBUG)
|
||||
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
|
||||
|
||||
# use local cpu, measure 5 times for every config to reduce variance
|
||||
measure_option = autotvm.measure_option('local',
|
||||
|
|
Загрузка…
Ссылка в новой задаче