fix issue in single gpu
This commit is contained in:
Родитель
8174df90ff
Коммит
29c7d2078b
|
@ -8,5 +8,5 @@ from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
|
|||
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
|
||||
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM']
|
||||
#import superbench.benchmarks.model_benchmarks.sync_hooks as sync_hooks
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'sync_hooks']
|
||||
|
|
|
@ -159,10 +159,10 @@ class PytorchBase(ModelBenchmark):
|
|||
self._model = torch.nn.parallel.DistributedDataParallel(
|
||||
self._model, device_ids=[self._local_rank], output_device=self._local_rank, bucket_cap_mb=81920, broadcast_buffers=False
|
||||
)
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
|
||||
self._model.register_comm_hook(None, noop_hook)
|
||||
#from superbench.benchmarks.model_benchmarks.sync_hooks import noop_barrier_hook
|
||||
#self._model.register_comm_hook(None, noop_barrier_hook)
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
|
||||
self._model.register_comm_hook(None, noop_hook)
|
||||
#from superbench.benchmarks.model_benchmarks.sync_hooks import noop_barrier_hook
|
||||
#self._model.register_comm_hook(None, noop_barrier_hook)
|
||||
if self._optimizer_type == Optimizer.SGD:
|
||||
self._optimizer = torch.optim.SGD(
|
||||
self._model.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4, nesterov=True
|
||||
|
|
|
@ -11,7 +11,7 @@ from superbench.benchmarks import BenchmarkRegistry, Precision
|
|||
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
|
||||
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
|
||||
|
||||
from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl
|
||||
|
||||
def _keep_BatchNorm_as_float(module):
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
|
@ -110,7 +110,8 @@ class PytorchCNN(PytorchBase):
|
|||
loss.backward()
|
||||
self._optimizer.step()
|
||||
end = self._timer()
|
||||
torch.distributed.barrier()
|
||||
if self._args.distributed_impl == DistributedImpl.DDP:
|
||||
torch.distributed.barrier()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
# Save the step time of every training/inference step, unit is millisecond.
|
||||
|
|
Загрузка…
Ссылка в новой задаче