GraphTuner supports relay.module as input (#3434)

This commit is contained in:
Yao Wang 2019-06-27 10:00:08 -07:00 коммит произвёл Tianqi Chen
Родитель a074dafc5b
Коммит 6c43019b4c
2 изменённых файлов: 6 добавлений и 1 удалений

Просмотреть файл

@ -141,6 +141,9 @@ class BaseGraphTuner(object):
self._logger.propagate = False self._logger.propagate = False
# Generate workload and schedule dictionaries. # Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
graph = graph[graph.entry_func]
if isinstance(graph, relay.expr.Function): if isinstance(graph, relay.expr.Function):
node_dict = {} node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype) graph = bind_inputs(graph, input_shapes, dtype)

Просмотреть файл

@ -159,6 +159,8 @@ def test_DPTuner_run():
target_ops = [relay.nn.conv2d] target_ops = [relay.nn.conv2d]
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout) g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
mod = relay.module.Module()
mod[mod.entry_func] = g
costs = [0.02, 0.02, 0.045] costs = [0.02, 0.02, 0.045]
config_list = [] config_list = []
cfg_dict = {"i": -1, cfg_dict = {"i": -1,
@ -190,7 +192,7 @@ def test_DPTuner_run():
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output)) records.append((ms_input, ms_output))
executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file) executor = DPTuner(mod, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run() executor.run()
out = [record[0].config for record in executor.get_optimal_records()] out = [record[0].config for record in executor.get_optimal_records()]