This commit is contained in:
chicm-ms 2020-11-23 10:21:44 +08:00 коммит произвёл GitHub
Родитель b6233e524b
Коммит f7b7edac5b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 79 добавлений и 48 удалений

Просмотреть файл

@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
CONSTANT_KIND = 'prim::Constant'
_logger = logging.getLogger(__name__)
@ -68,9 +69,11 @@ class TorchGraph:
'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False):
self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
training = model.training
model.eval()
self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)
class TorchProtoGraph(TorchGraph):
@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
inputs = set()
outputs = set()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if not self._is_key_func(predecessor_node):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
if input_name in output_to_node:
for predecessor_node in output_to_node[input_name]:
if predecessor_node in nodes:
if not self._is_key_func(predecessor_node):
if predecessor_node not in node_group:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.add(input_name)
else:
inputs.add(input_name)
else:
inputs.append(input_name)
inputs.add(input_name)
for output in node.outputs():
outputs.append(output.debugName())
if output.node().kind() == CONSTANT_KIND:
continue
outputs.add(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs, key_node=node)
node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
inputs = set()
outputs = set()
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node not in visited:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
visited.add(predecessor_node)
if input_name in output_to_node:
for predecessor_node in output_to_node[input_name]:
if predecessor_node in nodes:
if predecessor_node not in visited:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.add(input_name)
else:
inputs.append(input_name)
inputs.add(input_name)
for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND:
continue
output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes:
successor_node = input_to_node[output_name]
if successor_node not in visited:
node_group.append(successor_node)
node_queue.put(successor_node)
visited.add(successor_node)
if output_name in input_to_node:
for successor_node in input_to_node[output_name]:
if successor_node in nodes:
if successor_node not in visited:
node_group.append(successor_node)
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.add(output_name)
else:
outputs.append(output_name)
outputs.add(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
node_group, inputs=list(inputs), outputs=list(outputs))
return nodepy
def _extract_cat_info(self, node_group, cpp_node):
@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes"
"One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node
@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
omit_useless_nodes = True
graph = self.trace.graph
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.outputs()}
# build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.inputs()}
# build input/output mapping, from input/output debugName to its node
input_to_node = defaultdict(list)
output_to_node = defaultdict(list)
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
for x in node.outputs():
if x.node().kind() == CONSTANT_KIND:
continue
output_to_node[x.debugName()].append(node)
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
for x in node.inputs():
if x.node().kind() == CONSTANT_KIND:
continue
input_to_node[x.debugName()].append(node)
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
# associate module name with their trace graph nodes
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)

Просмотреть файл

@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# this traced model.
if traced is None:
assert model is not None and dummy_input is not None
with torch.onnx.set_training(model, False):
# We need to trace the model in this way, else it will have problems
traced = torch.jit.trace(model, dummy_input)
training = model.training
model.eval()
# We need to trace the model in eval mode
traced = torch.jit.trace(model, dummy_input)
model.train(training)
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask()

Просмотреть файл

@ -34,7 +34,7 @@ jobs:
set -e
sudo apt-get install -y pandoc
python3 -m pip install -U --upgrade pygments
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U keras==2.4.2
python3 -m pip install -U gym onnx peewee thop
@ -96,7 +96,7 @@ jobs:
- script: |
set -e
python3 -m pip install -U torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U keras==2.1.6
python3 -m pip install -U gym onnx peewee

Просмотреть файл

@ -61,7 +61,6 @@ channel_dependency_ground_truth = {
unittest.TestLoader.sortTestMethodsUsing = None
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class AnalysisUtilsTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_channel_dependency(self):

Просмотреть файл

@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
return cfg_list
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class DependencyawareTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_dependency_aware_pruning(self):

Просмотреть файл

@ -177,7 +177,6 @@ def channel_prune(model):
pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())

Просмотреть файл

@ -264,7 +264,6 @@ class SimpleDataset:
def __len__(self):
return 1000
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class PrunerTestCase(TestCase):
def test_pruners(self):
pruners_test(bias=True)