зеркало из https://github.com/microsoft/nni.git
graphutils supports torch17 (#3076)
This commit is contained in:
Родитель
b6233e524b
Коммит
f7b7edac5b
|
@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
|
|||
LIST_UNPACK_KIND = 'prim::ListUnpack'
|
||||
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
|
||||
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
|
||||
CONSTANT_KIND = 'prim::Constant'
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -68,9 +69,11 @@ class TorchGraph:
|
|||
'Please provide model & dummy_input or the traced_model as inputs')
|
||||
|
||||
def _trace(self, model, dummy_input):
|
||||
with torch.onnx.set_training(model, False):
|
||||
self.trace = torch.jit.trace(model, dummy_input)
|
||||
torch._C._jit_pass_inline(self.trace.graph)
|
||||
training = model.training
|
||||
model.eval()
|
||||
self.trace = torch.jit.trace(model, dummy_input)
|
||||
torch._C._jit_pass_inline(self.trace.graph)
|
||||
model.train(training)
|
||||
|
||||
|
||||
class TorchProtoGraph(TorchGraph):
|
||||
|
@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
|
|||
self.global_count += 1
|
||||
op_type = node.kind()
|
||||
node_group = [node]
|
||||
inputs = list()
|
||||
outputs = list()
|
||||
inputs = set()
|
||||
outputs = set()
|
||||
node_queue = queue.Queue()
|
||||
node_queue.put(node)
|
||||
while not node_queue.empty():
|
||||
curr_node = node_queue.get()
|
||||
for _input in curr_node.inputs():
|
||||
if _input.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
input_name = _input.debugName()
|
||||
if input_name in output_to_node and output_to_node[input_name] in nodes:
|
||||
predecessor_node = output_to_node[input_name]
|
||||
if not self._is_key_func(predecessor_node):
|
||||
node_group.append(predecessor_node)
|
||||
node_queue.put(predecessor_node)
|
||||
else:
|
||||
inputs.append(input_name)
|
||||
if input_name in output_to_node:
|
||||
for predecessor_node in output_to_node[input_name]:
|
||||
if predecessor_node in nodes:
|
||||
if not self._is_key_func(predecessor_node):
|
||||
if predecessor_node not in node_group:
|
||||
node_group.append(predecessor_node)
|
||||
node_queue.put(predecessor_node)
|
||||
else:
|
||||
inputs.add(input_name)
|
||||
else:
|
||||
inputs.add(input_name)
|
||||
else:
|
||||
inputs.append(input_name)
|
||||
inputs.add(input_name)
|
||||
for output in node.outputs():
|
||||
outputs.append(output.debugName())
|
||||
if output.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
outputs.add(output.debugName())
|
||||
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
|
||||
node_group, inputs=inputs, outputs=outputs, key_node=node)
|
||||
node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
|
||||
return nodepy
|
||||
|
||||
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
|
||||
|
@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
|
|||
if not op_type:
|
||||
op_type = node.kind()
|
||||
node_group = [node]
|
||||
inputs = list()
|
||||
outputs = list()
|
||||
inputs = set()
|
||||
outputs = set()
|
||||
node_queue = queue.Queue()
|
||||
node_queue.put(node)
|
||||
visited = {node}
|
||||
while not node_queue.empty():
|
||||
curr_node = node_queue.get()
|
||||
for _input in curr_node.inputs():
|
||||
if _input.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
input_name = _input.debugName()
|
||||
if input_name in output_to_node and output_to_node[input_name] in nodes:
|
||||
predecessor_node = output_to_node[input_name]
|
||||
if predecessor_node not in visited:
|
||||
node_group.append(predecessor_node)
|
||||
node_queue.put(predecessor_node)
|
||||
visited.add(predecessor_node)
|
||||
if input_name in output_to_node:
|
||||
for predecessor_node in output_to_node[input_name]:
|
||||
if predecessor_node in nodes:
|
||||
if predecessor_node not in visited:
|
||||
node_group.append(predecessor_node)
|
||||
node_queue.put(predecessor_node)
|
||||
visited.add(predecessor_node)
|
||||
else:
|
||||
inputs.add(input_name)
|
||||
else:
|
||||
inputs.append(input_name)
|
||||
inputs.add(input_name)
|
||||
for _output in curr_node.outputs():
|
||||
if _output.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
output_name = _output.debugName()
|
||||
if output_name in input_to_node and input_to_node[output_name] in nodes:
|
||||
successor_node = input_to_node[output_name]
|
||||
if successor_node not in visited:
|
||||
node_group.append(successor_node)
|
||||
node_queue.put(successor_node)
|
||||
visited.add(successor_node)
|
||||
if output_name in input_to_node:
|
||||
for successor_node in input_to_node[output_name]:
|
||||
if successor_node in nodes:
|
||||
if successor_node not in visited:
|
||||
node_group.append(successor_node)
|
||||
node_queue.put(successor_node)
|
||||
visited.add(successor_node)
|
||||
else:
|
||||
outputs.add(output_name)
|
||||
else:
|
||||
outputs.append(output_name)
|
||||
outputs.add(output_name)
|
||||
|
||||
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
|
||||
node_group, inputs=inputs, outputs=outputs)
|
||||
node_group, inputs=list(inputs), outputs=list(outputs))
|
||||
return nodepy
|
||||
|
||||
def _extract_cat_info(self, node_group, cpp_node):
|
||||
|
@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
|
|||
input_to_node[_input].append(node)
|
||||
for output in node.outputs:
|
||||
assert not output in output_to_node, \
|
||||
"One output cannot be generated by multiple nodes"
|
||||
"One output cannot be generated by multiple nodes %s" % output
|
||||
output_to_node[output] = node
|
||||
return name_to_node, input_to_node, output_to_node
|
||||
|
||||
|
@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
|
|||
omit_useless_nodes = True
|
||||
graph = self.trace.graph
|
||||
_logger.debug(graph)
|
||||
# build output mapping, from output debugName to its node
|
||||
output_to_node = {x.debugName(): n for n in graph.nodes()
|
||||
for x in n.outputs()}
|
||||
# build input mapping, from input debugName to its node
|
||||
input_to_node = {x.debugName(): n for n in graph.nodes()
|
||||
for x in n.inputs()}
|
||||
# build input/output mapping, from input/output debugName to its node
|
||||
input_to_node = defaultdict(list)
|
||||
output_to_node = defaultdict(list)
|
||||
for node in graph.nodes():
|
||||
if node.kind() == CONSTANT_KIND:
|
||||
continue
|
||||
for x in node.outputs():
|
||||
if x.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
output_to_node[x.debugName()].append(node)
|
||||
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
|
||||
for x in node.inputs():
|
||||
if x.node().kind() == CONSTANT_KIND:
|
||||
continue
|
||||
input_to_node[x.debugName()].append(node)
|
||||
|
||||
# build module mapping, from module name to all nodes (as list) under this module scope
|
||||
module_to_nodes = defaultdict(list)
|
||||
# the mapping of function (non-module in forward) to nodes, key is scope name
|
||||
|
@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
|
|||
|
||||
# associate module name with their trace graph nodes
|
||||
for node in graph.nodes():
|
||||
if node.kind() == CONSTANT_KIND:
|
||||
continue
|
||||
module_name = self._get_module_name(node.scopeName())
|
||||
if module_name in self.leaf_modules:
|
||||
module_to_nodes[module_name].append(node)
|
||||
|
|
|
@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
|
|||
# this traced model.
|
||||
if traced is None:
|
||||
assert model is not None and dummy_input is not None
|
||||
with torch.onnx.set_training(model, False):
|
||||
# We need to trace the model in this way, else it will have problems
|
||||
traced = torch.jit.trace(model, dummy_input)
|
||||
training = model.training
|
||||
model.eval()
|
||||
# We need to trace the model in eval mode
|
||||
traced = torch.jit.trace(model, dummy_input)
|
||||
model.train(training)
|
||||
|
||||
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
|
||||
masks = fix_group_mask.fix_mask()
|
||||
|
|
|
@ -34,7 +34,7 @@ jobs:
|
|||
set -e
|
||||
sudo apt-get install -y pandoc
|
||||
python3 -m pip install -U --upgrade pygments
|
||||
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python3 -m pip install -U torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python3 -m pip install -U tensorflow==2.3.1
|
||||
python3 -m pip install -U keras==2.4.2
|
||||
python3 -m pip install -U gym onnx peewee thop
|
||||
|
@ -96,7 +96,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
set -e
|
||||
python3 -m pip install -U torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python3 -m pip install -U tensorflow==1.15.2
|
||||
python3 -m pip install -U keras==2.1.6
|
||||
python3 -m pip install -U gym onnx peewee
|
||||
|
|
|
@ -61,7 +61,6 @@ channel_dependency_ground_truth = {
|
|||
unittest.TestLoader.sortTestMethodsUsing = None
|
||||
|
||||
|
||||
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
|
||||
class AnalysisUtilsTest(TestCase):
|
||||
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
|
||||
def test_channel_dependency(self):
|
||||
|
|
|
@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
|
|||
return cfg_list
|
||||
|
||||
|
||||
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
|
||||
class DependencyawareTest(TestCase):
|
||||
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
|
||||
def test_dependency_aware_pruning(self):
|
||||
|
|
|
@ -177,7 +177,6 @@ def channel_prune(model):
|
|||
pruner.compress()
|
||||
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
|
||||
|
||||
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
|
||||
class SpeedupTestCase(TestCase):
|
||||
def test_speedup_vgg16(self):
|
||||
prune_model_l1(vgg16())
|
||||
|
|
|
@ -264,7 +264,6 @@ class SimpleDataset:
|
|||
def __len__(self):
|
||||
return 1000
|
||||
|
||||
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
|
||||
class PrunerTestCase(TestCase):
|
||||
def test_pruners(self):
|
||||
pruners_test(bias=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче