diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 436b06c5..3080baf1 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -264,7 +264,8 @@ def build(graph, target=None, shape=None, dtype="float32", if _all_var_init: init_var = initialize_variables(shape, dtype) # Apply optimization - graph = optimize(graph, shape, dtype, layout) + with target: + graph = optimize(graph, shape, dtype, layout) # Precompute prune if params and cfg.pass_enabled("PrecomputePrune"): graph, params = precompute_prune(graph, params)