зеркало из https://github.com/microsoft/DeepSpeed.git
138 строки
3.3 KiB
Python
Executable File
138 строки
3.3 KiB
Python
Executable File
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
#!/usr/bin/env python
|
|
# run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l)
|
|
#
|
|
# usage:
|
|
# ./flatten_bench.py -t
|
|
# ./flatten_bench.py -c
|
|
# kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
|
|
|
|
import argparse
|
|
|
|
import gc
|
|
|
|
import torch
|
|
from torch._utils import _flatten_dense_tensors
|
|
from deepspeed.accelerator import get_accelerator
|
|
from deepspeed.ops.op_builder import UtilsBuilder
|
|
|
|
from apex_C import flatten as flatten_apex
|
|
|
|
util_ops = UtilsBuilder().load()
|
|
flatten = util_ops.flatten
|
|
unflatten = util_ops.unflatten
|
|
|
|
torch.manual_seed(0)
|
|
# emulate a small typical model weights
|
|
x = [
|
|
torch.rand((512, 512)).to(get_accelerator().device_name()),
|
|
torch.rand((512, 1024)).to(get_accelerator().device_name()),
|
|
torch.rand((512, 30000)).to(get_accelerator().device_name())
|
|
]
|
|
t = x * 30
|
|
|
|
# warm up and check that the same output is produced
|
|
flat_py = _flatten_dense_tensors(t)
|
|
flat_cpp = flatten(t)
|
|
flat_apex = flatten_apex(t)
|
|
#numel = flat_cpp.numel()
|
|
assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor"
|
|
assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor"
|
|
|
|
TIMES = 1000
|
|
|
|
|
|
# the programs being tested
|
|
def py():
|
|
for i in range(TIMES):
|
|
flat = _flatten_dense_tensors(t)
|
|
|
|
|
|
def cpp():
|
|
for i in range(TIMES):
|
|
flat = flatten(t)
|
|
|
|
|
|
def apex():
|
|
for i in range(TIMES):
|
|
flat = flatten_apex(t)
|
|
|
|
|
|
#### cProfile ####
|
|
|
|
import cProfile
|
|
|
|
|
|
def cprofileme():
|
|
print("--------------- cProfile -----------------")
|
|
print("py")
|
|
cProfile.run("py()", sort=-1)
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print("cpp")
|
|
cProfile.run("cpp()", sort=-1)
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print("apex")
|
|
cProfile.run("apex()", sort=-1)
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
|
|
#### timeit ####
|
|
|
|
import timeit
|
|
|
|
|
|
def timeme():
|
|
print("--------------- timeit -----------------")
|
|
print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}')
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}')
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}')
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
|
|
#### line_profiler ####
|
|
# this one requires a special way to be called
|
|
# pip install line_profiler
|
|
# kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
|
|
|
|
|
|
def line_profileme():
|
|
print("--------------- line_profiler -----------------")
|
|
print("py")
|
|
profile(py)() # noqa: F821 # type: ignore
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print("cpp")
|
|
profile(cpp)() # noqa: F821 # type: ignore
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
print("apex")
|
|
profile(apex)() # noqa: F821 # type: ignore
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-l", action='store_true')
|
|
parser.add_argument("-c", action='store_true')
|
|
parser.add_argument("-t", action='store_true')
|
|
args = parser.parse_args()
|
|
if args.l:
|
|
line_profileme()
|
|
elif args.c:
|
|
cprofileme()
|
|
elif args.t:
|
|
timeme()
|