fix leaky model for model_stats
This commit is contained in:
Родитель
cbf1594f6b
Коммит
9299eb8d3a
|
@ -0,0 +1,164 @@
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .compute_madd import compute_madd
|
||||
from .compute_flops import compute_flops
|
||||
from .compute_memory import compute_memory
|
||||
from .stat_tree import StatTree, StatNode
|
||||
from .reporter import report_format
|
||||
|
||||
class ModuleStats:
|
||||
def __init__(self, name) -> None:
|
||||
self.name = name
|
||||
self.start_time = 0.0
|
||||
self.end_time = 0.0
|
||||
self.inference_memory = 0
|
||||
self.input_shape:Sequence[int] = []
|
||||
self.output_shape:Sequence[int] = []
|
||||
self.MAdd = 0
|
||||
self.duration = 0.0
|
||||
self.Flops = 0
|
||||
self.Memory = 0
|
||||
self.parameter_quantity = 0
|
||||
self.done=False
|
||||
|
||||
def print_report(self, collected_nodes):
|
||||
report = report_format(self.collected_nodes)
|
||||
print(report)
|
||||
|
||||
def analyze(model:nn.Module, input_size, query_granularity:int):
|
||||
assert isinstance(model, nn.Module)
|
||||
assert isinstance(input_size, (list, tuple))
|
||||
|
||||
pre_hooks, post_hooks = [], []
|
||||
stats:OrderedDict[str, ModuleStats] = OrderedDict()
|
||||
|
||||
try:
|
||||
_for_leaf(model, _register_hooks, pre_hooks, post_hooks, stats)
|
||||
|
||||
x = torch.rand(*input_size) # add module duration time
|
||||
x = x.to(next(model.parameters()).device)
|
||||
model.eval()
|
||||
model(x)
|
||||
|
||||
stat_tree = _convert_leaf_modules_to_stat_tree(stats)
|
||||
|
||||
return stat_tree.get_collected_stat_nodes(query_granularity)
|
||||
|
||||
finally:
|
||||
for stat in stats.values():
|
||||
stat.done = True
|
||||
for hook in itertools.chain(pre_hooks, post_hooks):
|
||||
hook.remove()
|
||||
|
||||
def _for_leaf(model, fn, *args):
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0:
|
||||
fn(name, module, *args)
|
||||
|
||||
def _register_hooks(name:str, module:nn.Module, pre_hooks, post_hooks, stats):
|
||||
assert isinstance(module, nn.Module) and len(list(module.children()))==0
|
||||
|
||||
if name in stats:
|
||||
return
|
||||
|
||||
module_stats = ModuleStats(name)
|
||||
stats[name] = module_stats
|
||||
|
||||
post_hook = module.register_forward_hook(
|
||||
functools.partial(_forward_post_hook, module_stats))
|
||||
post_hooks.append(post_hook)
|
||||
|
||||
pre_hook = module.register_forward_pre_hook(
|
||||
functools.partial(_forward_pre_hook, module_stats))
|
||||
pre_hooks.append(pre_hook)
|
||||
|
||||
def _flatten(x):
|
||||
"""Flattens the tree of tensors to flattened sequence of tensors"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return [x]
|
||||
if isinstance(x, Sequence):
|
||||
res = []
|
||||
for xi in x:
|
||||
res += _flatten(xi)
|
||||
return res
|
||||
return []
|
||||
|
||||
def _forward_pre_hook(module_stats:ModuleStats, module:nn.Module, input):
|
||||
assert not module_stats.done
|
||||
module_stats.start_time = time.time()
|
||||
|
||||
def _forward_post_hook(module_stats:ModuleStats, module:nn.Module, input, output):
|
||||
assert not module_stats.done
|
||||
|
||||
module_stats.end_time = time.time()
|
||||
module_stats.duration = module_stats.end_time-module_stats.start_time
|
||||
|
||||
inputs, outputs = _flatten(input), _flatten(output)
|
||||
module_stats.input_shape = inputs[0].size()
|
||||
module_stats.output_shape = outputs[0].size()
|
||||
|
||||
parameter_quantity = 0
|
||||
# iterate through parameters and count num params
|
||||
for name, p in module.named_parameters():
|
||||
parameter_quantity += (0 if p is None else torch.numel(p.data))
|
||||
module_stats.parameter_quantity = parameter_quantity
|
||||
|
||||
inference_memory = 1
|
||||
for oi in outputs:
|
||||
for s in oi.size():
|
||||
inference_memory *= s
|
||||
# memory += parameters_number # exclude parameter memory
|
||||
inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
|
||||
module_stats.inference_memory = inference_memory
|
||||
module_stats.MAdd = compute_madd(module, inputs, outputs)
|
||||
module_stats.Flops = compute_flops(module, inputs, outputs)
|
||||
module_stats.Memory = compute_memory(module, inputs, outputs)
|
||||
|
||||
return output
|
||||
|
||||
def get_parent_node(root_node, stat_node_name):
|
||||
assert isinstance(root_node, StatNode)
|
||||
|
||||
node = root_node
|
||||
names = stat_node_name.split('.')
|
||||
for i in range(len(names) - 1):
|
||||
node_name = '.'.join(names[0:i+1])
|
||||
child_index = node.find_child_index(node_name)
|
||||
assert child_index != -1
|
||||
node = node.children[child_index]
|
||||
return node
|
||||
|
||||
|
||||
def _convert_leaf_modules_to_stat_tree(leaf_modules):
|
||||
assert isinstance(leaf_modules, OrderedDict)
|
||||
|
||||
create_index = 1
|
||||
root_node = StatNode(name='root', parent=None)
|
||||
for name, module_stats in leaf_modules.items():
|
||||
names = name.split('.')
|
||||
for i in range(len(names)):
|
||||
create_index += 1
|
||||
stat_node_name = '.'.join(names[0:i+1])
|
||||
parent_node = get_parent_node(root_node, stat_node_name)
|
||||
node = StatNode(name=stat_node_name, parent=parent_node)
|
||||
parent_node.add_child(node)
|
||||
if i == len(names) - 1: # leaf module itself
|
||||
input_shape = module_stats.input_shape
|
||||
output_shape = module_stats.output_shape
|
||||
node.input_shape = input_shape
|
||||
node.output_shape = output_shape
|
||||
node.parameter_quantity = module_stats.parameter_quantity
|
||||
node.inference_memory = module_stats.inference_memory
|
||||
node.MAdd = module_stats.MAdd
|
||||
node.Flops = module_stats.Flops
|
||||
node.duration = module_stats.duration
|
||||
node.Memory = module_stats.Memory
|
||||
return StatTree(root_node)
|
|
@ -32,7 +32,7 @@ def compute_ReLU_memory(module, inp, out):
|
|||
mread = inp.numel()
|
||||
mwrite = out.numel()
|
||||
|
||||
return mread, mwrite
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
||||
|
||||
def compute_PReLU_memory(module, inp, out):
|
||||
|
@ -42,7 +42,7 @@ def compute_PReLU_memory(module, inp, out):
|
|||
mread = batch_size * (inp[0].numel() + num_params(module))
|
||||
mwrite = out.numel()
|
||||
|
||||
return mread, mwrite
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
||||
|
||||
def compute_Conv2d_memory(module, inp, out):
|
||||
|
@ -55,7 +55,8 @@ def compute_Conv2d_memory(module, inp, out):
|
|||
# This includes weights with bias if the module contains it.
|
||||
mread = batch_size * (inp[0].numel() + num_params(module))
|
||||
mwrite = out.numel()
|
||||
return mread, mwrite
|
||||
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
||||
|
||||
def compute_BatchNorm2d_memory(module, inp, out):
|
||||
|
@ -66,7 +67,7 @@ def compute_BatchNorm2d_memory(module, inp, out):
|
|||
mread = batch_size * (inp[0].numel() + 2 * in_c)
|
||||
mwrite = out.numel()
|
||||
|
||||
return mread, mwrite
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
||||
|
||||
def compute_Linear_memory(module, inp, out):
|
||||
|
@ -79,8 +80,7 @@ def compute_Linear_memory(module, inp, out):
|
|||
mread = batch_size * (inp[0].numel() + num_params(module))
|
||||
mwrite = out.numel()
|
||||
|
||||
return mread, mwrite
|
||||
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
||||
def compute_Pool2d_memory(module, inp, out):
|
||||
assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
|
||||
|
@ -89,4 +89,4 @@ def compute_Pool2d_memory(module, inp, out):
|
|||
mread = inp.numel()
|
||||
mwrite = out.numel()
|
||||
|
||||
return mread, mwrite
|
||||
return mread*inp.element_size(), mwrite*out.element_size()
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .compute_madd import compute_madd
|
||||
from .compute_flops import compute_flops
|
||||
from .compute_memory import compute_memory
|
||||
|
||||
|
||||
class ModelHook(object):
|
||||
def __init__(self, model, input_size):
|
||||
assert isinstance(model, nn.Module)
|
||||
assert isinstance(input_size, (list, tuple))
|
||||
|
||||
self._model = model
|
||||
self._input_size = input_size
|
||||
self._origin_call = dict() # sub module call hook
|
||||
|
||||
self._hook_model()
|
||||
x = torch.rand(*self._input_size) # add module duration time
|
||||
x = x.to(next(model.parameters()).device)
|
||||
self._model.eval()
|
||||
self._model(x)
|
||||
|
||||
@staticmethod
|
||||
def _register_buffer(module):
|
||||
assert isinstance(module, nn.Module)
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
return
|
||||
|
||||
if hasattr(module, 'parameter_quantity'):
|
||||
return
|
||||
|
||||
# register variables for each module to hold values we will compute
|
||||
module.register_buffer('parameter_quantity', torch.zeros(1).int())
|
||||
module.register_buffer('inference_memory', torch.zeros(1).long())
|
||||
module.register_buffer('input_shape', torch.zeros(3).int())
|
||||
module.register_buffer('output_shape', torch.zeros(3).int())
|
||||
module.register_buffer('MAdd', torch.zeros(1).long())
|
||||
module.register_buffer('duration', torch.zeros(1).float())
|
||||
module.register_buffer('Flops', torch.zeros(1).long())
|
||||
module.register_buffer('Memory', torch.zeros(2).long())
|
||||
|
||||
def _to_seq(self, x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return [x]
|
||||
if isinstance(x, Sequence):
|
||||
res = []
|
||||
for xi in x:
|
||||
res += self._to_seq(xi)
|
||||
return res
|
||||
return []
|
||||
|
||||
def _sub_module_call_hook(self):
|
||||
def wrap_call(module, *input, **kwargs):
|
||||
assert module.__class__ in self._origin_call
|
||||
|
||||
start = time.time()
|
||||
output = self._origin_call[module.__class__](module, *input, **kwargs)
|
||||
end = time.time()
|
||||
module.duration = torch.from_numpy(
|
||||
np.array([end - start], dtype=np.float32))
|
||||
|
||||
inputs, outputs = self._to_seq(input), self._to_seq(output)
|
||||
module.input_shape = torch.from_numpy(
|
||||
np.array(inputs[0].size(), dtype=np.int32))
|
||||
module.output_shape = torch.from_numpy(
|
||||
np.array(outputs[0].size(), dtype=np.int32))
|
||||
|
||||
parameter_quantity = 0
|
||||
# iterate through parameters and count num params
|
||||
for name, p in module._parameters.items():
|
||||
parameter_quantity += (0 if p is None else torch.numel(p.data))
|
||||
module.parameter_quantity = torch.from_numpy(
|
||||
np.array([parameter_quantity], dtype=np.long))
|
||||
|
||||
inference_memory = 1
|
||||
for oi in outputs:
|
||||
for s in oi.size():
|
||||
inference_memory *= s
|
||||
# memory += parameters_number # exclude parameter memory
|
||||
inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
|
||||
module.inference_memory = torch.from_numpy(
|
||||
np.array([inference_memory], dtype=np.float32))
|
||||
|
||||
madd = compute_madd(module, inputs, outputs)
|
||||
flops = compute_flops(module, inputs, outputs)
|
||||
Memory = compute_memory(module, inputs, outputs)
|
||||
|
||||
module.MAdd = torch.from_numpy(
|
||||
np.array([madd], dtype=np.int64))
|
||||
module.Flops = torch.from_numpy(
|
||||
np.array([flops], dtype=np.int64))
|
||||
Memory = np.array(Memory, dtype=np.int32) * \
|
||||
sum(oi.cpu().detach().numpy().itemsize for oi in outputs)
|
||||
module.Memory = torch.from_numpy(Memory)
|
||||
|
||||
return output
|
||||
|
||||
for module in self._model.modules():
|
||||
if len(list(module.children())) == 0 and module.__class__ not in self._origin_call:
|
||||
self._origin_call[module.__class__] = module.__class__.__call__
|
||||
module.__class__.__call__ = wrap_call
|
||||
|
||||
def _hook_model(self):
|
||||
self._model.apply(self._register_buffer)
|
||||
self._sub_module_call_hook()
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_leaf_modules(model):
|
||||
leaf_modules = []
|
||||
for name, m in model.named_modules():
|
||||
if len(list(m.children())) == 0:
|
||||
leaf_modules.append((name, m))
|
||||
return leaf_modules
|
||||
|
||||
def retrieve_leaf_modules(self):
|
||||
return OrderedDict(self._retrieve_leaf_modules(self._model))
|
|
@ -1,72 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .model_hook import ModelHook
|
||||
from collections import OrderedDict
|
||||
from .stat_tree import StatTree, StatNode
|
||||
from .reporter import report_format
|
||||
|
||||
|
||||
def get_parent_node(root_node, stat_node_name):
|
||||
assert isinstance(root_node, StatNode)
|
||||
|
||||
node = root_node
|
||||
names = stat_node_name.split('.')
|
||||
for i in range(len(names) - 1):
|
||||
node_name = '.'.join(names[0:i+1])
|
||||
child_index = node.find_child_index(node_name)
|
||||
assert child_index != -1
|
||||
node = node.children[child_index]
|
||||
return node
|
||||
|
||||
|
||||
def convert_leaf_modules_to_stat_tree(leaf_modules):
|
||||
assert isinstance(leaf_modules, OrderedDict)
|
||||
|
||||
create_index = 1
|
||||
root_node = StatNode(name='root', parent=None)
|
||||
for leaf_module_name, leaf_module in leaf_modules.items():
|
||||
names = leaf_module_name.split('.')
|
||||
for i in range(len(names)):
|
||||
create_index += 1
|
||||
stat_node_name = '.'.join(names[0:i+1])
|
||||
parent_node = get_parent_node(root_node, stat_node_name)
|
||||
node = StatNode(name=stat_node_name, parent=parent_node)
|
||||
parent_node.add_child(node)
|
||||
if i == len(names) - 1: # leaf module itself
|
||||
input_shape = leaf_module.input_shape.numpy().tolist()
|
||||
output_shape = leaf_module.output_shape.numpy().tolist()
|
||||
node.input_shape = input_shape
|
||||
node.output_shape = output_shape
|
||||
node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0]
|
||||
node.inference_memory = leaf_module.inference_memory.numpy()[0]
|
||||
node.MAdd = leaf_module.MAdd.numpy()[0]
|
||||
node.Flops = leaf_module.Flops.numpy()[0]
|
||||
node.duration = leaf_module.duration.numpy()[0]
|
||||
node.Memory = leaf_module.Memory.numpy().tolist()
|
||||
return StatTree(root_node)
|
||||
|
||||
|
||||
class ModelStat(object):
|
||||
def __init__(self, model, input_size, query_granularity=1):
|
||||
assert isinstance(model, nn.Module)
|
||||
assert isinstance(input_size, (tuple, list))
|
||||
self._model = model
|
||||
self._input_size = input_size
|
||||
self._query_granularity = query_granularity
|
||||
|
||||
def _analyze_model(self):
|
||||
model_hook = ModelHook(self._model, self._input_size)
|
||||
leaf_modules = model_hook.retrieve_leaf_modules()
|
||||
stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
|
||||
collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity)
|
||||
return collected_nodes
|
||||
|
||||
def show_report(self):
|
||||
collected_nodes = self._analyze_model()
|
||||
report = report_format(collected_nodes)
|
||||
print(report)
|
||||
|
||||
|
||||
def stat(model, input_size, query_granularity=1):
|
||||
ms = ModelStat(model, input_size, query_granularity)
|
||||
ms.show_report()
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .torchstat import statistics
|
||||
from .torchstat import analyzer
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
|
@ -18,11 +18,10 @@ class LayerStats:
|
|||
self.duration = node.duration
|
||||
|
||||
class ModelStats(LayerStats):
|
||||
def __init__(self, model, input_shape, clone_model=True) -> None:
|
||||
def __init__(self, model, input_shape, clone_model=False) -> None:
|
||||
if clone_model:
|
||||
model = copy.deepcopy(model)
|
||||
ms = statistics.ModelStat(model, input_shape, 1)
|
||||
collected_nodes = ms._analyze_model()
|
||||
collected_nodes = analyzer.analyze(model, input_shape, 1)
|
||||
self.layer_stats = []
|
||||
for node in collected_nodes:
|
||||
self.layer_stats.append(LayerStats(node))
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import copy
|
||||
import tensorwatch as tw
|
||||
import torchvision.models
|
||||
import torch
|
||||
import time
|
||||
|
||||
model = getattr(torchvision.models, 'densenet201')()
|
||||
|
||||
def model_timing(model):
|
||||
st = time.time()
|
||||
for _ in range(20):
|
||||
batch = torch.rand([64, 3, 224, 224])
|
||||
y = model(batch)
|
||||
return time.time()-st
|
||||
|
||||
print(model_timing(model))
|
||||
model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False)
|
||||
print(f'flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}')
|
||||
print(model_timing(model))
|
Загрузка…
Ссылка в новой задаче