This commit is contained in:
yukirora 2022-11-10 07:11:21 +00:00
Родитель 8174df90ff
Коммит 29c7d2078b
3 изменённых файлов: 9 добавлений и 8 удалений

Просмотреть файл

@ -8,5 +8,5 @@ from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM']
#import superbench.benchmarks.model_benchmarks.sync_hooks as sync_hooks
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'sync_hooks']

Просмотреть файл

@ -159,10 +159,10 @@ class PytorchBase(ModelBenchmark):
self._model = torch.nn.parallel.DistributedDataParallel(
self._model, device_ids=[self._local_rank], output_device=self._local_rank, bucket_cap_mb=81920, broadcast_buffers=False
)
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
self._model.register_comm_hook(None, noop_hook)
#from superbench.benchmarks.model_benchmarks.sync_hooks import noop_barrier_hook
#self._model.register_comm_hook(None, noop_barrier_hook)
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
self._model.register_comm_hook(None, noop_hook)
#from superbench.benchmarks.model_benchmarks.sync_hooks import noop_barrier_hook
#self._model.register_comm_hook(None, noop_barrier_hook)
if self._optimizer_type == Optimizer.SGD:
self._optimizer = torch.optim.SGD(
self._model.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4, nesterov=True

Просмотреть файл

@ -11,7 +11,7 @@ from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl
def _keep_BatchNorm_as_float(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
@ -110,7 +110,8 @@ class PytorchCNN(PytorchBase):
loss.backward()
self._optimizer.step()
end = self._timer()
torch.distributed.barrier()
if self._args.distributed_impl == DistributedImpl.DDP:
torch.distributed.barrier()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.