update ddp example (#113)
This commit is contained in:
Родитель
8f47462be4
Коммит
c1e0602f41
82
README.md
82
README.md
|
@ -123,48 +123,64 @@ Module metric usage remains the same when using multiple GPUs or multiple nodes.
|
|||
|
||||
``` python
|
||||
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torchmetrics
|
||||
|
||||
# create default process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
def metric_ddp(rank, world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
# initialize model
|
||||
metric = torchmetrics.Accuracy()
|
||||
# create default process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
# define a model and append your metric to it
|
||||
# this allows metric states to be placed on correct accelerators when
|
||||
# .to(device) is called on the model
|
||||
model = nn.Linear(10, 10)
|
||||
model.metric = metric
|
||||
model = model.to(rank)
|
||||
# initialize model
|
||||
metric = torchmetrics.Accuracy()
|
||||
|
||||
# initialize DDP
|
||||
model = DDP(model, device_ids=[rank])
|
||||
# define a model and append your metric to it
|
||||
# this allows metric states to be placed on correct accelerators when
|
||||
# .to(device) is called on the model
|
||||
model = nn.Linear(10, 10)
|
||||
model.metric = metric
|
||||
model = model.to(rank)
|
||||
|
||||
n_epochs = 5
|
||||
# this shows iteration over multiple training epochs
|
||||
for n in range(n_epochs):
|
||||
# initialize DDP
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
# this will be replaced by a DataLoader with a DistributedSampler
|
||||
n_batches = 10
|
||||
for i in range(n_batches):
|
||||
# simulate a classification problem
|
||||
preds = torch.randn(10, 5).softmax(dim=-1)
|
||||
target = torch.randint(5, (10,))
|
||||
n_epochs = 5
|
||||
# this shows iteration over multiple training epochs
|
||||
for n in range(n_epochs):
|
||||
|
||||
# metric on current batch
|
||||
acc = metric(preds, target)
|
||||
if rank == 0: # print only for rank 0
|
||||
print(f"Accuracy on batch {i}: {acc}")
|
||||
# this will be replaced by a DataLoader with a DistributedSampler
|
||||
n_batches = 10
|
||||
for i in range(n_batches):
|
||||
# simulate a classification problem
|
||||
preds = torch.randn(10, 5).softmax(dim=-1)
|
||||
target = torch.randint(5, (10,))
|
||||
|
||||
# metric on all batches and all accelerators using custom accumulation
|
||||
# accuracy is same across both accelerators
|
||||
acc = metric.compute()
|
||||
print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
|
||||
# metric on current batch
|
||||
acc = metric(preds, target)
|
||||
if rank == 0: # print only for rank 0
|
||||
print(f"Accuracy on batch {i}: {acc}")
|
||||
|
||||
# metric on all batches and all accelerators using custom accumulation
|
||||
# accuracy is same across both accelerators
|
||||
acc = metric.compute()
|
||||
print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
|
||||
|
||||
# Reseting internal state such that metric ready for new data
|
||||
metric.reset()
|
||||
|
||||
# cleanup
|
||||
dist.destroy_process_group()
|
||||
|
||||
world_size = 2 # number of gpus to parallize over
|
||||
mp.spawn(metric_dpp, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
# Reseting internal state such that metric ready for new data
|
||||
metric.reset()
|
||||
```
|
||||
</details>
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче