SSL4EO: add script to measure flops of models (#1516)

* SSL4EO: add script to measure flops of models

* Calculate memory requirements too
This commit is contained in:
Adam J. Stewart 2023-08-31 11:50:16 -05:00 коммит произвёл GitHub
Родитель 055daa8978
Коммит effa992bd8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 31 добавлений и 0 удалений

31
experiments/ssl4eo/flops.py Executable file
Просмотреть файл

@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import timm
from deepspeed.accelerator import get_accelerator
from deepspeed.profiling.flops_profiler import get_model_profile
models = ["resnet18", "resnet50", "vit_small_patch16_224"]
num_classes = 14
in_channels = 11
batch_size = 64
patch_size = 224
input_shape = (batch_size, in_channels, patch_size, patch_size)
for model in models:
print(f"Model: {model}")
m = timm.create_model(model, num_classes=num_classes, in_chans=in_channels)
# Calculate memory requirements of model
mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()])
mem_bufs = sum([b.nelement() * b.element_size() for b in m.buffers()])
mem = (mem_params + mem_bufs) / 2**20
print(f"Memory: {mem:.2f} MB")
with get_accelerator().device(0):
get_model_profile(
model=m, input_shape=input_shape, detailed=False, module_depth=0
)