GraphTuner supports relay.module as input (#3434)
This commit is contained in:
Родитель
a074dafc5b
Коммит
6c43019b4c
|
@ -141,6 +141,9 @@ class BaseGraphTuner(object):
|
||||||
self._logger.propagate = False
|
self._logger.propagate = False
|
||||||
|
|
||||||
# Generate workload and schedule dictionaries.
|
# Generate workload and schedule dictionaries.
|
||||||
|
if isinstance(graph, relay.Module):
|
||||||
|
graph = graph[graph.entry_func]
|
||||||
|
|
||||||
if isinstance(graph, relay.expr.Function):
|
if isinstance(graph, relay.expr.Function):
|
||||||
node_dict = {}
|
node_dict = {}
|
||||||
graph = bind_inputs(graph, input_shapes, dtype)
|
graph = bind_inputs(graph, input_shapes, dtype)
|
||||||
|
|
|
@ -159,6 +159,8 @@ def test_DPTuner_run():
|
||||||
target_ops = [relay.nn.conv2d]
|
target_ops = [relay.nn.conv2d]
|
||||||
|
|
||||||
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
|
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
|
||||||
|
mod = relay.module.Module()
|
||||||
|
mod[mod.entry_func] = g
|
||||||
costs = [0.02, 0.02, 0.045]
|
costs = [0.02, 0.02, 0.045]
|
||||||
config_list = []
|
config_list = []
|
||||||
cfg_dict = {"i": -1,
|
cfg_dict = {"i": -1,
|
||||||
|
@ -190,7 +192,7 @@ def test_DPTuner_run():
|
||||||
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
|
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
|
||||||
records.append((ms_input, ms_output))
|
records.append((ms_input, ms_output))
|
||||||
|
|
||||||
executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
|
executor = DPTuner(mod, {"data": dshape}, records, target_ops, target, log_file=log_file)
|
||||||
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
|
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
|
||||||
executor.run()
|
executor.run()
|
||||||
out = [record[0].config for record in executor.get_optimal_records()]
|
out = [record[0].config for record in executor.get_optimal_records()]
|
||||||
|
|
Загрузка…
Ссылка в новой задаче