fix leaky model for model_stats

This commit is contained in:
Shital Shah 2020-03-03 17:19:25 -08:00
Родитель cbf1594f6b
Коммит 9299eb8d3a
6 изменённых файлов: 193 добавлений и 205 удалений

Просмотреть файл

@ -0,0 +1,164 @@
import time
from collections import OrderedDict
from typing import Dict, Sequence
import functools
import itertools
import numpy as np
import torch
import torch.nn as nn
from .compute_madd import compute_madd
from .compute_flops import compute_flops
from .compute_memory import compute_memory
from .stat_tree import StatTree, StatNode
from .reporter import report_format
class ModuleStats:
def __init__(self, name) -> None:
self.name = name
self.start_time = 0.0
self.end_time = 0.0
self.inference_memory = 0
self.input_shape:Sequence[int] = []
self.output_shape:Sequence[int] = []
self.MAdd = 0
self.duration = 0.0
self.Flops = 0
self.Memory = 0
self.parameter_quantity = 0
self.done=False
def print_report(self, collected_nodes):
report = report_format(self.collected_nodes)
print(report)
def analyze(model:nn.Module, input_size, query_granularity:int):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (list, tuple))
pre_hooks, post_hooks = [], []
stats:OrderedDict[str, ModuleStats] = OrderedDict()
try:
_for_leaf(model, _register_hooks, pre_hooks, post_hooks, stats)
x = torch.rand(*input_size) # add module duration time
x = x.to(next(model.parameters()).device)
model.eval()
model(x)
stat_tree = _convert_leaf_modules_to_stat_tree(stats)
return stat_tree.get_collected_stat_nodes(query_granularity)
finally:
for stat in stats.values():
stat.done = True
for hook in itertools.chain(pre_hooks, post_hooks):
hook.remove()
def _for_leaf(model, fn, *args):
for name, module in model.named_modules():
if len(list(module.children())) == 0:
fn(name, module, *args)
def _register_hooks(name:str, module:nn.Module, pre_hooks, post_hooks, stats):
assert isinstance(module, nn.Module) and len(list(module.children()))==0
if name in stats:
return
module_stats = ModuleStats(name)
stats[name] = module_stats
post_hook = module.register_forward_hook(
functools.partial(_forward_post_hook, module_stats))
post_hooks.append(post_hook)
pre_hook = module.register_forward_pre_hook(
functools.partial(_forward_pre_hook, module_stats))
pre_hooks.append(pre_hook)
def _flatten(x):
"""Flattens the tree of tensors to flattened sequence of tensors"""
if isinstance(x, torch.Tensor):
return [x]
if isinstance(x, Sequence):
res = []
for xi in x:
res += _flatten(xi)
return res
return []
def _forward_pre_hook(module_stats:ModuleStats, module:nn.Module, input):
assert not module_stats.done
module_stats.start_time = time.time()
def _forward_post_hook(module_stats:ModuleStats, module:nn.Module, input, output):
assert not module_stats.done
module_stats.end_time = time.time()
module_stats.duration = module_stats.end_time-module_stats.start_time
inputs, outputs = _flatten(input), _flatten(output)
module_stats.input_shape = inputs[0].size()
module_stats.output_shape = outputs[0].size()
parameter_quantity = 0
# iterate through parameters and count num params
for name, p in module.named_parameters():
parameter_quantity += (0 if p is None else torch.numel(p.data))
module_stats.parameter_quantity = parameter_quantity
inference_memory = 1
for oi in outputs:
for s in oi.size():
inference_memory *= s
# memory += parameters_number # exclude parameter memory
inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
module_stats.inference_memory = inference_memory
module_stats.MAdd = compute_madd(module, inputs, outputs)
module_stats.Flops = compute_flops(module, inputs, outputs)
module_stats.Memory = compute_memory(module, inputs, outputs)
return output
def get_parent_node(root_node, stat_node_name):
assert isinstance(root_node, StatNode)
node = root_node
names = stat_node_name.split('.')
for i in range(len(names) - 1):
node_name = '.'.join(names[0:i+1])
child_index = node.find_child_index(node_name)
assert child_index != -1
node = node.children[child_index]
return node
def _convert_leaf_modules_to_stat_tree(leaf_modules):
assert isinstance(leaf_modules, OrderedDict)
create_index = 1
root_node = StatNode(name='root', parent=None)
for name, module_stats in leaf_modules.items():
names = name.split('.')
for i in range(len(names)):
create_index += 1
stat_node_name = '.'.join(names[0:i+1])
parent_node = get_parent_node(root_node, stat_node_name)
node = StatNode(name=stat_node_name, parent=parent_node)
parent_node.add_child(node)
if i == len(names) - 1: # leaf module itself
input_shape = module_stats.input_shape
output_shape = module_stats.output_shape
node.input_shape = input_shape
node.output_shape = output_shape
node.parameter_quantity = module_stats.parameter_quantity
node.inference_memory = module_stats.inference_memory
node.MAdd = module_stats.MAdd
node.Flops = module_stats.Flops
node.duration = module_stats.duration
node.Memory = module_stats.Memory
return StatTree(root_node)

Просмотреть файл

@ -32,7 +32,7 @@ def compute_ReLU_memory(module, inp, out):
mread = inp.numel()
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()
def compute_PReLU_memory(module, inp, out):
@ -42,7 +42,7 @@ def compute_PReLU_memory(module, inp, out):
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()
def compute_Conv2d_memory(module, inp, out):
@ -55,7 +55,8 @@ def compute_Conv2d_memory(module, inp, out):
# This includes weights with bias if the module contains it.
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()
def compute_BatchNorm2d_memory(module, inp, out):
@ -66,7 +67,7 @@ def compute_BatchNorm2d_memory(module, inp, out):
mread = batch_size * (inp[0].numel() + 2 * in_c)
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()
def compute_Linear_memory(module, inp, out):
@ -79,8 +80,7 @@ def compute_Linear_memory(module, inp, out):
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()
def compute_Pool2d_memory(module, inp, out):
assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
@ -89,4 +89,4 @@ def compute_Pool2d_memory(module, inp, out):
mread = inp.numel()
mwrite = out.numel()
return mread, mwrite
return mread*inp.element_size(), mwrite*out.element_size()

Просмотреть файл

@ -1,122 +0,0 @@
import time
from collections import OrderedDict
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from .compute_madd import compute_madd
from .compute_flops import compute_flops
from .compute_memory import compute_memory
class ModelHook(object):
def __init__(self, model, input_size):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (list, tuple))
self._model = model
self._input_size = input_size
self._origin_call = dict() # sub module call hook
self._hook_model()
x = torch.rand(*self._input_size) # add module duration time
x = x.to(next(model.parameters()).device)
self._model.eval()
self._model(x)
@staticmethod
def _register_buffer(module):
assert isinstance(module, nn.Module)
if len(list(module.children())) > 0:
return
if hasattr(module, 'parameter_quantity'):
return
# register variables for each module to hold values we will compute
module.register_buffer('parameter_quantity', torch.zeros(1).int())
module.register_buffer('inference_memory', torch.zeros(1).long())
module.register_buffer('input_shape', torch.zeros(3).int())
module.register_buffer('output_shape', torch.zeros(3).int())
module.register_buffer('MAdd', torch.zeros(1).long())
module.register_buffer('duration', torch.zeros(1).float())
module.register_buffer('Flops', torch.zeros(1).long())
module.register_buffer('Memory', torch.zeros(2).long())
def _to_seq(self, x):
if isinstance(x, torch.Tensor):
return [x]
if isinstance(x, Sequence):
res = []
for xi in x:
res += self._to_seq(xi)
return res
return []
def _sub_module_call_hook(self):
def wrap_call(module, *input, **kwargs):
assert module.__class__ in self._origin_call
start = time.time()
output = self._origin_call[module.__class__](module, *input, **kwargs)
end = time.time()
module.duration = torch.from_numpy(
np.array([end - start], dtype=np.float32))
inputs, outputs = self._to_seq(input), self._to_seq(output)
module.input_shape = torch.from_numpy(
np.array(inputs[0].size(), dtype=np.int32))
module.output_shape = torch.from_numpy(
np.array(outputs[0].size(), dtype=np.int32))
parameter_quantity = 0
# iterate through parameters and count num params
for name, p in module._parameters.items():
parameter_quantity += (0 if p is None else torch.numel(p.data))
module.parameter_quantity = torch.from_numpy(
np.array([parameter_quantity], dtype=np.long))
inference_memory = 1
for oi in outputs:
for s in oi.size():
inference_memory *= s
# memory += parameters_number # exclude parameter memory
inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
module.inference_memory = torch.from_numpy(
np.array([inference_memory], dtype=np.float32))
madd = compute_madd(module, inputs, outputs)
flops = compute_flops(module, inputs, outputs)
Memory = compute_memory(module, inputs, outputs)
module.MAdd = torch.from_numpy(
np.array([madd], dtype=np.int64))
module.Flops = torch.from_numpy(
np.array([flops], dtype=np.int64))
Memory = np.array(Memory, dtype=np.int32) * \
sum(oi.cpu().detach().numpy().itemsize for oi in outputs)
module.Memory = torch.from_numpy(Memory)
return output
for module in self._model.modules():
if len(list(module.children())) == 0 and module.__class__ not in self._origin_call:
self._origin_call[module.__class__] = module.__class__.__call__
module.__class__.__call__ = wrap_call
def _hook_model(self):
self._model.apply(self._register_buffer)
self._sub_module_call_hook()
@staticmethod
def _retrieve_leaf_modules(model):
leaf_modules = []
for name, m in model.named_modules():
if len(list(m.children())) == 0:
leaf_modules.append((name, m))
return leaf_modules
def retrieve_leaf_modules(self):
return OrderedDict(self._retrieve_leaf_modules(self._model))

Просмотреть файл

@ -1,72 +0,0 @@
import torch
import torch.nn as nn
from .model_hook import ModelHook
from collections import OrderedDict
from .stat_tree import StatTree, StatNode
from .reporter import report_format
def get_parent_node(root_node, stat_node_name):
assert isinstance(root_node, StatNode)
node = root_node
names = stat_node_name.split('.')
for i in range(len(names) - 1):
node_name = '.'.join(names[0:i+1])
child_index = node.find_child_index(node_name)
assert child_index != -1
node = node.children[child_index]
return node
def convert_leaf_modules_to_stat_tree(leaf_modules):
assert isinstance(leaf_modules, OrderedDict)
create_index = 1
root_node = StatNode(name='root', parent=None)
for leaf_module_name, leaf_module in leaf_modules.items():
names = leaf_module_name.split('.')
for i in range(len(names)):
create_index += 1
stat_node_name = '.'.join(names[0:i+1])
parent_node = get_parent_node(root_node, stat_node_name)
node = StatNode(name=stat_node_name, parent=parent_node)
parent_node.add_child(node)
if i == len(names) - 1: # leaf module itself
input_shape = leaf_module.input_shape.numpy().tolist()
output_shape = leaf_module.output_shape.numpy().tolist()
node.input_shape = input_shape
node.output_shape = output_shape
node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0]
node.inference_memory = leaf_module.inference_memory.numpy()[0]
node.MAdd = leaf_module.MAdd.numpy()[0]
node.Flops = leaf_module.Flops.numpy()[0]
node.duration = leaf_module.duration.numpy()[0]
node.Memory = leaf_module.Memory.numpy().tolist()
return StatTree(root_node)
class ModelStat(object):
def __init__(self, model, input_size, query_granularity=1):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (tuple, list))
self._model = model
self._input_size = input_size
self._query_granularity = query_granularity
def _analyze_model(self):
model_hook = ModelHook(self._model, self._input_size)
leaf_modules = model_hook.retrieve_leaf_modules()
stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity)
return collected_nodes
def show_report(self):
collected_nodes = self._analyze_model()
report = report_format(collected_nodes)
print(report)
def stat(model, input_size, query_granularity=1):
ms = ModelStat(model, input_size, query_granularity)
ms.show_report()

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .torchstat import statistics
from .torchstat import analyzer
import pandas as pd
import copy
@ -18,11 +18,10 @@ class LayerStats:
self.duration = node.duration
class ModelStats(LayerStats):
def __init__(self, model, input_shape, clone_model=True) -> None:
def __init__(self, model, input_shape, clone_model=False) -> None:
if clone_model:
model = copy.deepcopy(model)
ms = statistics.ModelStat(model, input_shape, 1)
collected_nodes = ms._analyze_model()
collected_nodes = analyzer.analyze(model, input_shape, 1)
self.layer_stats = []
for node in collected_nodes:
self.layer_stats.append(LayerStats(node))

Просмотреть файл

@ -0,0 +1,19 @@
import copy
import tensorwatch as tw
import torchvision.models
import torch
import time
model = getattr(torchvision.models, 'densenet201')()
def model_timing(model):
st = time.time()
for _ in range(20):
batch = torch.rand([64, 3, 224, 224])
y = model(batch)
return time.time()-st
print(model_timing(model))
model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False)
print(f'flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}')
print(model_timing(model))