GraphTuner supports relay.module as input (#3434)
This commit is contained in:
Родитель
a074dafc5b
Коммит
6c43019b4c
|
@ -141,6 +141,9 @@ class BaseGraphTuner(object):
|
|||
self._logger.propagate = False
|
||||
|
||||
# Generate workload and schedule dictionaries.
|
||||
if isinstance(graph, relay.Module):
|
||||
graph = graph[graph.entry_func]
|
||||
|
||||
if isinstance(graph, relay.expr.Function):
|
||||
node_dict = {}
|
||||
graph = bind_inputs(graph, input_shapes, dtype)
|
||||
|
|
|
@ -159,6 +159,8 @@ def test_DPTuner_run():
|
|||
target_ops = [relay.nn.conv2d]
|
||||
|
||||
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
|
||||
mod = relay.module.Module()
|
||||
mod[mod.entry_func] = g
|
||||
costs = [0.02, 0.02, 0.045]
|
||||
config_list = []
|
||||
cfg_dict = {"i": -1,
|
||||
|
@ -190,7 +192,7 @@ def test_DPTuner_run():
|
|||
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
|
||||
records.append((ms_input, ms_output))
|
||||
|
||||
executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
|
||||
executor = DPTuner(mod, {"data": dshape}, records, target_ops, target, log_file=log_file)
|
||||
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
|
||||
executor.run()
|
||||
out = [record[0].config for record in executor.get_optimal_records()]
|
||||
|
|
Загрузка…
Ссылка в новой задаче