Add distributed PyTorch training example on CIFAR-10 (#381)

* Add distributed PyTorch training example on CIFAR-10

* run readme

* move argparse section; add a few comments

* fix mistake in contributing

* fix args

Co-authored-by: Cody <54814569+lostmygithubaccount@users.noreply.github.com>
This commit is contained in:
mx-iao 2021-02-17 15:51:15 -08:00 коммит произвёл GitHub
Родитель 109d9fc7fb
Коммит 32c287a518
5 изменённых файлов: 395 добавлений и 1 удалений

32
.github/workflows/train-pytorch-cifar-distributed-job.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,32 @@
name: train-pytorch-cifar-distributed-job
on:
schedule:
- cron: "0 0/2 * * *"
pull_request:
branches:
- main
paths:
- workflows/train/pytorch/cifar-distributed/**
- .github/workflows/train-pytorch-cifar-distributed-job.yml
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: check out repo
uses: actions/checkout@v2
- name: setup python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: pip install
run: pip install -r requirements.txt
- name: azure login
uses: azure/login@v1
with:
creds: ${{secrets.AZ_AE_CREDS}}
- name: install azmlcli
run: az extension add -n azure-cli-ml -y
- name: attach to workspace
run: az ml folder attach -w default -g azureml-examples
- name: run workflow
run: python workflows/train/pytorch/cifar-distributed/job.py

Просмотреть файл

@ -45,7 +45,7 @@ Pull requests (PRs) to this repo require review and approval by the Azure Machin
### Miscellaneous
- to modify `README.md`, you need to modify `readme.py` and accompanying markdown files other files (`prefix.md` and `suffix.md`)
- to modify `README.md`, you need to modify `readme.py` and accompanying files (`prefix.md` and `suffix.md`)
- develop on a branch, not a fork, for workflows to run properly
- use an existing environment where possible
- use an existing dataset where possible

Просмотреть файл

@ -81,6 +81,7 @@ path|status|description
[fastai/mnist/job.py](workflows/train/fastai/mnist/job.py)|[![train-fastai-mnist-job](https://github.com/Azure/azureml-examples/workflows/train-fastai-mnist-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-fastai-mnist-job)|train fastai resnet18 model on mnist data
[fastai/pets/job.py](workflows/train/fastai/pets/job.py)|[![train-fastai-pets-job](https://github.com/Azure/azureml-examples/workflows/train-fastai-pets-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-fastai-pets-job)|train fastai resnet34 model on pets data
[lightgbm/iris/job.py](workflows/train/lightgbm/iris/job.py)|[![train-lightgbm-iris-job](https://github.com/Azure/azureml-examples/workflows/train-lightgbm-iris-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-lightgbm-iris-job)|train a lightgbm model on iris data
[pytorch/cifar-distributed/job.py](workflows/train/pytorch/cifar-distributed/job.py)|[![train-pytorch-cifar-distributed-job](https://github.com/Azure/azureml-examples/workflows/train-pytorch-cifar-distributed-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-pytorch-cifar-distributed-job)|train CNN model on CIFAR-10 dataset with distributed PyTorch
[pytorch/mnist-mlproject/job.py](workflows/train/pytorch/mnist-mlproject/job.py)|[![train-pytorch-mnist-mlproject-job](https://github.com/Azure/azureml-examples/workflows/train-pytorch-mnist-mlproject-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-pytorch-mnist-mlproject-job)|train a pytorch CNN model on mnist data via mlflow mlproject
[pytorch/mnist/job.py](workflows/train/pytorch/mnist/job.py)|[![train-pytorch-mnist-job](https://github.com/Azure/azureml-examples/workflows/train-pytorch-mnist-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-pytorch-mnist-job)|train a pytorch CNN model on mnist data
[scikit-learn/diabetes-mlproject/job.py](workflows/train/scikit-learn/diabetes-mlproject/job.py)|[![train-scikit-learn-diabetes-mlproject-job](https://github.com/Azure/azureml-examples/workflows/train-scikit-learn-diabetes-mlproject-job/badge.svg)](https://github.com/Azure/azureml-examples/actions?query=workflow%3Atrain-scikit-learn-diabetes-mlproject-job)|train sklearn ridge model on diabetes data via mlflow mlproject

Просмотреть файл

@ -0,0 +1,123 @@
# description: train CNN model on CIFAR-10 dataset with distributed PyTorch
# imports
import os
import urllib
import tarfile
from pathlib import Path
from azureml.core import Workspace
from azureml.core import ScriptRunConfig, Experiment, Environment, Dataset
from azureml.core.runconfig import PyTorchConfiguration
# get workspace
ws = Workspace.from_config()
# get root of git repo
prefix = Path(__file__).parent
# training script
source_dir = str(prefix.joinpath("src"))
script_name = "train.py"
# azure ml settings
environment_name = "AzureML-PyTorch-1.6-GPU" # using curated environment
experiment_name = "pytorch-cifar10-distributed-example"
compute_name = "gpu-K80-2"
# get environment
env = Environment.get(ws, name=environment_name)
# download and extract cifar-10 data
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
data_root = "cifar-10"
filepath = os.path.join(data_root, filename)
if not os.path.isdir(data_root):
os.makedirs(data_root, exist_ok=True)
urllib.request.urlretrieve(url, filepath)
with tarfile.open(filepath, "r:gz") as tar:
tar.extractall(path=data_root)
os.remove(filepath) # delete tar.gz file after extraction
# create azureml dataset
datastore = ws.get_default_datastore()
dataset = Dataset.File.upload_directory(
src_dir=data_root, target=(datastore, data_root)
)
# The training script in this example utilizes native PyTorch distributed training with DistributeDataParallel.
#
# To launch a distributed PyTorch job on Azure ML, you have two options:
# 1) Per-process launch - specify the total # of worker processes (typically one per GPU) you want to run, and
# Azure ML will handle launching each process.
# 2) Per-node launch with torch.distributed.launch - provide the torch.distributed.launch command you want to
# run on each node.
#
# Both options are demonstrated below.
###############################
# Option 1 - per-process launch
###############################
# To use the per-process launch option in which Azure ML will handle launching each of the processes to run
# your training script, create a `PyTorchConfiguration` and specify `node_count` and `process_count`.
# The `process_count` is the total number of processes you want to run for the job; this should typically
# equal the # of GPUs available on each node multiplied by the # of nodes.
#
# Azure ML will set the MASTER_ADDR, MASTER_PORT, NODE_RANK, WORLD_SIZE environment variables on each node, in addition
# to the process-level RANK and LOCAL_RANK environment variables, that are needed for distributed PyTorch training.
# create distributed config
distr_config = PyTorchConfiguration(process_count=4, node_count=2)
# create args
args = ["--data-dir", dataset.as_download(), "--epochs", 25]
# create job config
src = ScriptRunConfig(
source_directory=source_dir,
script=script_name,
arguments=args,
compute_target=compute_name,
environment=env,
distributed_job_config=distr_config,
)
###############################
# Option 2 - per-node launch
###############################
# If you would instead like to use the PyTorch-provided launch utility `torch.distributed.launch` to
# handle launching the worker processes on each node, you can do so as well. Create a
# `PyTorchConfiguration` and specify the `node_count`. You do not need to specify the `process_count`;
# by default Azure ML will launch one process per node to run the `command` you provided.
#
# Provide the launch command to the `command` parameter of ScriptRunConfig. For PyTorch jobs Azure ML
# will set the MASTER_ADDR, MASTER_PORT, and NODE_RANK environment variables on each node, so you can
# simply just reference those environment variables in your command.
#
# Uncomment the code below to configure a job with this method.
"""
# create distributed config
distr_config = PyTorchConfiguration(node_count=2)
# define command
launch_cmd = ["python -m torch.distributed.launch --nproc_per_node 2 --nnodes 2 " \
"--node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT --use_env " \
"train.py --data-dir", dataset.as_download(), "--epochs 25"]
# create job config
src = ScriptRunConfig(
source_directory=source_dir,
command=launch_cmd,
compute_target=compute_name,
environment=env,
distributed_job_config=distr_config,
)
"""
# submit job
run = Experiment(ws, experiment_name).submit(src)
run.wait_for_completion(show_output=True)

Просмотреть файл

@ -0,0 +1,238 @@
# Copyright (c) 2017 Facebook, Inc. All rights reserved.
# BSD 3-Clause License
#
# Script adapted from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
# ==============================================================================
# imports
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os, argparse
# define network architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3)
self.conv3 = nn.Conv2d(64, 128, 3)
self.fc1 = nn.Linear(128 * 6 * 6, 120)
self.dropout = nn.Dropout(p=0.2)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 6 * 6)
x = self.dropout(F.relu(self.fc1(x)))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# define functions
def train(train_loader, model, criterion, optimizer, epoch, device, print_freq, rank):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % print_freq == 0: # print every print_freq mini-batches
print(
"Rank %d: [%d, %5d] loss: %.3f"
% (rank, epoch + 1, i + 1, running_loss / print_freq)
)
running_loss = 0.0
def evaluate(test_loader, model, device):
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
model.eval()
correct = 0
total = 0
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))
with torch.no_grad():
for data in test_loader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
c = (predicted == labels).squeeze()
for i in range(10):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
# print total test set accuracy
print(
"Accuracy of the network on the 10000 test images: %d %%"
% (100 * correct / total)
)
# print test accuracy for each of the classes
for i in range(10):
print(
"Accuracy of %5s : %2d %%"
% (classes[i], 100 * class_correct[i] / class_total[i])
)
def main(args):
# get PyTorch environment variables
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
distributed = world_size > 1
# set device
if distributed:
device = torch.device("cuda", local_rank)
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# initialize distributed process group using default env:// method
if distributed:
torch.distributed.init_process_group(backend="nccl")
# define train and test dataset DataLoaders
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_set = torchvision.datasets.CIFAR10(
root=args.data_dir, train=True, download=False, transform=transform
)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
sampler=train_sampler,
)
test_set = torchvision.datasets.CIFAR10(
root=args.data_dir, train=False, download=False, transform=transform
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers
)
model = Net().to(device)
# wrap model with DDP
if distributed:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)
# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(), lr=args.learning_rate, momentum=args.momentum
)
# train the model
for epoch in range(args.epochs):
print("Rank %d: Starting epoch %d" % (rank, epoch))
if distributed:
train_sampler.set_epoch(epoch)
model.train()
train(
train_loader,
model,
criterion,
optimizer,
epoch,
device,
args.print_freq,
rank,
)
print("Rank %d: Finished Training" % (rank))
if not distributed or rank == 0:
os.makedirs(args.output_dir, exist_ok=True)
model_path = os.path.join(args.output_dir, "cifar_net.pt")
torch.save(model.state_dict(), model_path)
# evaluate on full test dataset
evaluate(test_loader, model, device)
# run script
if __name__ == "__main__":
# setup argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir", type=str, help="directory containing CIFAR-10 dataset"
)
parser.add_argument("--epochs", default=10, type=int, help="number of epochs")
parser.add_argument(
"--batch-size",
default=16,
type=int,
help="mini batch size for each gpu/process",
)
parser.add_argument(
"--workers",
default=2,
type=int,
help="number of data loading workers for each gpu/process",
)
parser.add_argument(
"--learning-rate", default=0.001, type=float, help="learning rate"
)
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
parser.add_argument(
"--output-dir", default="outputs", type=str, help="directory to save model to"
)
parser.add_argument(
"--print-freq",
default=200,
type=int,
help="frequency of printing training statistics",
)
args = parser.parse_args()
# call main function
main(args)