Introducing PyTorch+DirectML Samples (#161)

* Main readme updates.

* Fix typo

* Add squeezenet and resnet50 docs

* Add predict

* Minor updates

* install scripts

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
Sheil Kumar 2021-10-21 11:58:05 -07:00 коммит произвёл GitHub
Родитель 6681088a94
Коммит aa904f774e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 2170 добавлений и 0 удалений

8
PyTorch/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,8 @@
/squeezenet/data
.vscode
__pycache__
traces/
!/op_report/data/traces
*.xlsx
data/cifar-10-python
checkpoints

42
PyTorch/README.md Normal file
Просмотреть файл

@ -0,0 +1,42 @@
# PyTorch with DirectML Samples <!-- omit in toc -->
For detailed instructions on getting started with PyTorch with DirectML, see [GPU accelerated ML training](https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows).
- [Setup](#setup)
- [Samples](#samples)
- [External Links](#external-links)
## Setup
Follow the steps below to get set up with PyTorch on DirectML.
1. Download and install [Python 3.8](https://www.python.org/downloads/release/python-380/).
2. Clone this repo.
3. Install prerequisites
```
pip install -r pytorch\requirements.txt
pip uninstall torch
```
> Note: The torchvision package automatically installs the torch==1.8.0 dependency, but this is not needed and will cause collisions with the pytorch-directml package. We must uninstall the torch package after installing requirements.
4. _(optional)_ Run `pip list`. The following packages should be installed:
```
pytorch-directml 1.8.0a0.dev211019
torchvision 0.9.0
```
## Samples
The following sample models are included in this repo to help you get started. The sample includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.
* [squeezenet - a small image classification model](./squeezenet)
* [resnet50 - an image classification model](./resnet50)
* *more coming soon*
## External Links
* [pytorch-directml PyPI project](https://pypi.org/project/pytorch-directml/)
* [PyTorch homepage](https://pytorch.org/)

Просмотреть файл

@ -0,0 +1,69 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, transforms
import torchvision.models as models
import collections
import matplotlib.pyplot as plt
import argparse
import time
import os
import pathlib
def get_pytorch_root(path):
return pathlib.Path(__file__).parent.parent.resolve()
def get_pytorch_data():
return str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'data'))
def get_data_path(path):
if (os.path.isabs(path)):
return path
else:
return str(os.path.join(get_pytorch_data(), path))
def print_dataloader(dataloader, mode):
for X, y in dataloader:
print("\t{} data X [N, C, H, W]: \n\t\tshape={}, \n\t\tdtype={}".format(mode, X.shape, X.dtype))
print("\t{} data Y: \n\t\tshape={}, \n\t\tdtype={}".format(mode, y.shape, y.dtype))
break
def create_training_data_transform():
return transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def create_training_dataloader(path, batch_size):
path = get_data_path(path)
print('Loading the training dataset from: {}'.format(path))
train_transform = create_training_data_transform()
training_set = datasets.CIFAR10(root=path, train=True, download=False, transform=train_transform)
data_loader = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True, num_workers=0)
print_dataloader(data_loader, 'Train')
return data_loader
def create_testing_data_transform():
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def create_testing_dataloader(path, batch_size):
path = get_data_path(path)
print('Loading the testing dataset from: {}'.format(path))
test_transform = create_testing_data_transform()
test_set = datasets.CIFAR10(root=path, train=False, download=False, transform=test_transform)
data_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print_dataloader(data_loader, 'Test')
return data_loader

Просмотреть файл

@ -0,0 +1,158 @@
import torch
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, transforms
import torchvision.models as models
import collections
import matplotlib.pyplot as plt
import argparse
import time
import os
import pathlib
import dataloader_classification
import torch.autograd.profiler as profiler
from PIL import Image
from os.path import exists
def get_checkpoint_folder(model_str, device):
checkpoint_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(),
'checkpoints', model_str, str(device)))
os.makedirs(checkpoint_folder, exist_ok=True)
return str(os.path.join(checkpoint_folder, 'checkpoint.pth'))
def eval(dataloader, model_str, model, device, loss, highest_accuracy, save_model, trace):
size = len(dataloader.dataset)
num_batches = len(dataloader)
# Switch model to evaluation mode
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
# Evaluate the model on the test input
if (trace):
with profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof:
with profiler.record_function("model_inference"):
pred = model(X)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
break
else:
pred = model(X)
test_loss += loss(pred, y).to("cpu")
correct += (pred.to("cpu").argmax(1) == y.to("cpu")).type(torch.float).sum()
if not trace:
test_loss /= num_batches
correct /= size
if (correct.item() > highest_accuracy):
highest_accuracy = correct.item()
print("current highest_accuracy: ", highest_accuracy)
# save model
if save_model:
state_dict = collections.OrderedDict()
for key in model.state_dict().keys():
state_dict[key] = model.state_dict()[key].to("cpu")
checkpoint = get_checkpoint_folder(model_str, device)
torch.save(state_dict, checkpoint)
print(f"Test Error: \n Accuracy: {(100*correct.item()):>0.1f}%, Avg loss: {test_loss.item():>8f} \n")
return highest_accuracy
def get_model(model_str, device):
if (model_str == 'squeezenet1_1'):
model = models.squeezenet1_1(num_classes=10).to(device)
elif (model_str == 'resnet50'):
model = models.resnet50(num_classes=10).to(device)
else:
raise Exception(f"Model {model_str} is not supported yet!")
checkpoint = get_checkpoint_folder(model_str, device)
if (exists(checkpoint)):
model.load_state_dict(torch.load(checkpoint))
return model
def preprocess(filename, device):
input_image = Image.open(filename)
preprocess_transform = dataloader_classification.create_testing_data_transform()
input_tensor = preprocess_transform(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
input_batch = input_batch.to(device)
return input_batch
def predict(filename, model_str, device):
# Get the model
model = get_model(model_str, device)
model.eval()
# Preprocess input
input = preprocess(filename, device)
# Evaluate
with torch.no_grad():
pred = model(input).to('cpu')
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(pred[0], dim=0)
data_folder = dataloader_classification.get_pytorch_data()
classes_file = str(os.path.join(data_folder, 'imagenet_classes.txt'))
with open(classes_file, "r") as f:
categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
print(categories[top5_catid[i]], top5_prob[i].item())
def main(path, batch_size, device, model_str, trace):
# Load the dataset
testing_dataloader = dataloader_classification.create_testing_dataloader(path, batch_size)
# Create the device
device = torch.device(device)
# Load the model on the device
start = time.time()
if (model_str == 'squeezenet1_1'):
model = models.squeezenet1_1(num_classes=10).to(device)
elif (model_str == 'resnet50'):
model = models.resnet50(num_classes=10).to(device)
else:
raise Exception(f"Model {model_str} is not supported yet!")
print('Finished moving {} to device: {} in {}s.'.format(model_str, device, time.time() - start))
# Test
highest_accuracy = eval(testing_dataloader,
model_str,
model,
device,
nn.CrossEntropyLoss().to(device),
0,
False,
trace)
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--model', type=str, default='squeezenet1_1', help='The model to use.')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
main(args.path, args.batch_size, args.device, args.model, args.trace)

Просмотреть файл

@ -0,0 +1,132 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, transforms
import torchvision.models as models
import collections
import matplotlib.pyplot as plt
import argparse
import time
import os
import pathlib
import test_classification
import dataloader_classification
import torch.autograd.profiler as profiler
def train(dataloader, model, device, loss, learning_rate, momentum, weight_decay, trace):
size = len(dataloader.dataset)
# Define optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay)
optimize_after_batches = 1
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
if (trace):
with profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof:
with profiler.record_function("model_inference"):
# Compute loss and perform backpropagation
batch_loss = loss(model(X), y)
batch_loss.backward()
if batch % optimize_after_batches == 0:
optimizer.step()
optimizer.zero_grad()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
break;
else:
# Compute loss and perform backpropagation
batch_loss = loss(model(X), y)
batch_loss.backward()
if batch % optimize_after_batches == 0:
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
batch_loss_cpu, current = batch_loss.to('cpu'), batch * len(X)
print(f"loss: {batch_loss_cpu.item():>7f} [{current:>5d}/{size:>5d}]")
def main(path, batch_size, epochs, learning_rate,
momentum, weight_decay, device, model_str, save_model, trace):
batch_size = 1 if trace else batch_size
epochs = 1 if trace else epochs
# Load the dataset
training_dataloader = dataloader_classification.create_training_dataloader(path, batch_size)
testing_dataloader = dataloader_classification.create_testing_dataloader(path, batch_size)
# Create the device
device = torch.device(device)
# Load the model on the device
start = time.time()
if (model_str == 'squeezenet1_1'):
model = models.squeezenet1_1(num_classes=10).to(device)
elif (model_str == 'resnet50'):
model = models.resnet50(num_classes=10).to(device)
else:
raise Exception(f"Model {model_str} is not supported yet!")
print('Finished moving {} to device: {} in {}s.'.format(model_str, device, time.time() - start))
cross_entropy_loss = nn.CrossEntropyLoss().to(device)
highest_accuracy = 0
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
# Train
train(training_dataloader,
model,
device,
cross_entropy_loss,
learning_rate,
momentum,
weight_decay,
trace)
if not trace:
# Test
highest_accuracy = test_classification.eval(testing_dataloader,
model_str,
model,
device,
cross_entropy_loss,
highest_accuracy,
save_model,
False)
print("Done! with highest_accuracy: ", highest_accuracy)
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--epochs', type=int, default=50, metavar='N', help='The number of epochs to train for.')
parser.add_argument('--learning_rate', type=float, default=0.001, metavar='LR', help='The learning rate.')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='The percentage of past parameters to store.')
parser.add_argument('--weight_decay', default=0.0001, type=float, help='The parameter to decay weights.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--model', type=str, default='squeezenet1_1', help='The model to use.')
parser.add_argument('--save_model', action='store_true', help='save model state_dict to file')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
main(args.path, args.batch_size, args.epochs, args.learning_rate,
args.momentum, args.weight_decay, args.device, args.model, args.save_model, args.trace)

24
PyTorch/data/cifar.py Normal file
Просмотреть файл

@ -0,0 +1,24 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import pathlib
import os
import subprocess
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, transforms
def get_training_path(args):
if (os.path.isabs(args.path)):
return args.path
else:
return str(os.path.join(pathlib.Path(__file__).parent.resolve(), args.path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("-path", help="Path to cifar dataset.", default="cifar-10-python")
args = parser.parse_args()
path = get_training_path(args)
datasets.CIFAR10(root=path, download=True)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

2
PyTorch/requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,2 @@
torchvision==0.9.0
pytorch-directml

263
PyTorch/resnet50/README.md Normal file
Просмотреть файл

@ -0,0 +1,263 @@
# Resnet50 Model <!-- omit in toc -->
Sample scripts for training the Resnet50 model using PyTorch on DirectML.
These scripts were forked from https://github.com/pytorch/benchmark. The original code is Copyright (c) 2019, pytorch, and is used here under the terms of the BSD 3-Clause License. See [LICENSE](https://github.com/pytorch/benchmark/blob/main/LICENSE) for more information.
The original paper can be found at: https://arxiv.org/abs/1602.07360
- [Setup](#setup)
- [Prepare Data](#prepare-data)
- [Training](#training)
- [Testing](#testing)
- [Predict](#predict)
- [Tracing](#tracing)
- [External Links](#links)
## Setup
Install the following prerequisites:
```
pip install -r pytorch\resnet50\requirements.txt
```
## Prepare Data
After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data:
```
python pytorch\data\cifar.py
```
Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following:
```
>python pytorch\data\cifar.py
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz
170499072it [00:32, 5250164.09it/s]
Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python
```
## Training
A helper script exists to train Resnet50 with default data, batch size, and so on:
```
python pytorch\resnet50\train.py
```
The first few lines of output should look similar to the following (exact numbers may change):
```
>python pytorch\resnet50\train.py
Loading the training dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Train data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Train data Y:
shape=torch.Size([32]),
dtype=torch.int64
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([32]),
dtype=torch.int64
Finished moving resnet50 to device: dml in 0.2560007572174072s.
Epoch 1
-------------------------------
loss: 2.309573 [ 0/50000]
loss: 2.324099 [ 3200/50000]
loss: 2.297763 [ 6400/50000]
loss: 2.292575 [ 9600/50000]
loss: 2.251738 [12800/50000]
loss: 2.183397 [16000/50000]
loss: 2.130508 [19200/50000]
loss: 2.000042 [22400/50000]
loss: 2.183213 [25600/50000]
loss: 2.250935 [28800/50000]
loss: 1.730087 [32000/50000]
loss: 1.999480 [35200/50000]
loss: 1.865684 [38400/50000]
loss: 2.058377 [41600/50000]
loss: 2.059475 [44800/50000]
loss: 2.279521 [48000/50000]
current highest_accuracy: 0.2856000065803528
Test Error:
Accuracy: 28.6%, Avg loss: 1.862064
```
By default, the script will run for 50 epochs with a batch size of 32 and print the accuracy after every 100 batches. The training script can be run multiple times and saves progress after each epoch (by default). The accuracy should increase over time.
> When discrete memory or shared GPU memory is insufficient consider running the same scripts with a smaller batch size (use the --batch_size argument). For example:
```
python pytorch\resnet50\train.py --batch_size 8
```
You can inspect `train.py` (and the real script, `pytorch/classification/train_classification.py`) to see the command line it is invoking or adjust some of the parameters.
You can save the model for testing by passing in the --save_model flag. This will cause checkpoints to be saved to the `pytorch\checkpoints\<device>\<model>\checkpoint.pth` path.
```
python pytorch\resnet50\train.py --save_model
```
## Testing
Once the model is trained and saved we can now test the model using the following steps. The test script will use the latest trained model from the checkpoints folder.
```
python pytorch\resnet50\test.py
```
You should see the result such as this:
```
>python pytorch\resnet50\test.py
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([32]),
dtype=torch.int64
Finished moving resnet50 to device: dml in 0.6159994602203369s.
current highest_accuracy: 0.10559999942779541
Test Error:
Accuracy: 10.0%, Avg loss: 2.321213
```
## Predict
Once the model is trained and saved we can now run the prediction using the following steps. The predict script will use that latest trained model from the checkpoints folder.
```
python pytorch\squeezenet\predict.py --image E:\a.jpeg
```
You should see the result such as this:
```
E:\work\dml>python pytorch\squeezenet\predict.py --image E:\a.jpeg
hammerhead 0.35642221570014954
stingray 0.34619468450546265
electric ray 0.09593362361192703
cock 0.07319413870573044
great white shark 0.06555310636758804
```
## Tracing
It may be useful to get a trace during training or evaluation.
```
python pytorch\resnet50\test.py --trace True
python pytorch\resnet50\train.py --trace True
```
With default settings, you'll see output like the following:
```
>python pytorch\resnet50\train.py --trace Tue
Loading the training dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Train data X [N, C, H, W]:
shape=torch.Size([1, 3, 224, 224]),
dtype=torch.float32
Train data Y:
shape=torch.Size([1]),
dtype=torch.int64
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([1, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([1]),
dtype=torch.int64
Finished moving resnet50 to device: dml in 0.594947338104248s.
Epoch 1
-------------------------------
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
model_inference 34.65% 823.161ms 66.84% 1.588s 1.588s -4 b -20 b 1
ThnnConv2DBackward 0.05% 1.119ms 21.18% 503.098ms 9.492ms 0 b 0 b 53
aten::thnn_conv2d_backward 21.04% 499.936ms 21.13% 501.979ms 9.471ms 0 b 0 b 53
Optimizer.step#SGD.step 0.24% 5.683ms 10.84% 257.530ms 257.530ms -4 b -20 b 1
aten::batch_norm 0.09% 2.118ms 8.96% 212.849ms 4.016ms 0 b 0 b 53
aten::_batch_norm_impl_index 0.08% 1.846ms 8.87% 210.731ms 3.976ms 0 b 0 b 53
aten::native_batch_norm 3.82% 90.859ms 8.73% 207.468ms 3.914ms 0 b 0 b 53
aten::add 6.64% 157.698ms 7.77% 184.523ms 862.258us 0 b 0 b 214
aten::empty 5.60% 133.136ms 5.60% 133.136ms 166.005us 60 b 60 b 802
aten::conv2d 0.08% 1.843ms 5.59% 132.890ms 2.507ms 0 b 0 b 53
aten::convolution 0.07% 1.559ms 5.52% 131.047ms 2.473ms 0 b 0 b 53
aten::_convolution 0.22% 5.117ms 5.45% 129.488ms 2.443ms 0 b 0 b 53
aten::_convolution_nogroup 0.08% 1.810ms 5.24% 124.371ms 2.347ms 0 b 0 b 53
aten::thnn_conv2d 0.07% 1.760ms 5.16% 122.562ms 2.312ms 0 b 0 b 53
aten::thnn_conv2d_forward 4.92% 116.862ms 5.08% 120.802ms 2.279ms 0 b 0 b 53
NativeBatchNormBackward 0.05% 1.202ms 4.86% 115.441ms 2.178ms 0 b 0 b 53
aten::native_batch_norm_backward 3.06% 72.769ms 4.81% 114.239ms 2.155ms 0 b 0 b 53
aten::empty_strided 4.68% 111.158ms 4.68% 111.158ms 295.634us 0 b 0 b 376
aten::clone 0.67% 15.835ms 3.07% 73.035ms 453.637us 0 b 0 b 161
aten::empty_like 0.12% 2.741ms 3.00% 71.267ms 334.588us 0 b 0 b 213
aten::add_ 2.92% 69.436ms 2.92% 69.436ms 392.292us 0 b 0 b 177
struct torch::autograd::AccumulateGrad 0.12% 2.960ms 2.62% 62.349ms 387.258us 0 b 0 b 161
aten::new_empty_strided 0.06% 1.337ms 2.10% 49.896ms 309.912us 0 b 0 b 161
AddmmBackward 0.00% 56.400us 1.84% 43.649ms 43.649ms 0 b 0 b 1
aten::mm 1.79% 42.570ms 1.83% 43.489ms 21.745ms 0 b 0 b 2
ReluBackward1 0.02% 394.800us 1.73% 40.983ms 836.398us 0 b 0 b 49
aten::threshold_backward 1.71% 40.589ms 1.71% 40.589ms 828.341us 0 b 0 b 49
aten::copy_ 1.68% 39.820ms 1.68% 39.820ms 82.787us 0 b 0 b 481
aten::to 0.08% 1.928ms 1.13% 26.825ms 506.126us 0 b 0 b 53
aten::log_softmax 0.00% 42.400us 0.82% 19.532ms 19.532ms 0 b 0 b 1
aten::_log_softmax 0.82% 19.489ms 0.82% 19.489ms 19.489ms 0 b 0 b 1
Optimizer.zero_grad#SGD.zero_grad 0.52% 12.294ms 0.80% 19.066ms 19.066ms -4 b -20 b 1
aten::reshape 0.54% 12.869ms 0.78% 18.629ms 49.811us 0 b 0 b 374
aten::nll_loss 0.03% 645.100us 0.56% 13.385ms 13.385ms 0 b 0 b 1
aten::nll_loss_forward 0.53% 12.600ms 0.54% 12.740ms 12.740ms 0 b 0 b 1
aten::relu_ 0.36% 8.556ms 0.36% 8.556ms 174.618us 0 b 0 b 49
aten::linear 0.00% 49.400us 0.31% 7.462ms 7.462ms 0 b 0 b 1
aten::max_pool2d 0.01% 324.600us 0.31% 7.409ms 7.409ms 0 b 0 b 1
aten::addmm 0.29% 6.982ms 0.31% 7.312ms 7.312ms 0 b 0 b 1
aten::max_pool2d_with_indices 0.30% 7.085ms 0.30% 7.085ms 7.085ms 0 b 0 b 1
aten::zero_ 0.29% 6.806ms 0.29% 6.806ms 41.498us 0 b 0 b 164
MaxPool2DWithIndicesBackward 0.00% 30.200us 0.28% 6.579ms 6.579ms 0 b 0 b 1
aten::max_pool2d_with_indices_backward 0.28% 6.548ms 0.28% 6.548ms 6.548ms 0 b 0 b 1
aten::view 0.24% 5.794ms 0.24% 5.794ms 15.451us 0 b 0 b 375
aten::detach 0.13% 3.044ms 0.24% 5.736ms 35.630us 0 b 0 b 161
LogSoftmaxBackward 0.03% 601.300us 0.24% 5.601ms 5.601ms 0 b 0 b 1
AdaptiveAvgPool2DBackward 0.00% 13.700us 0.23% 5.370ms 5.370ms 0 b 0 b 1
aten::_adaptive_avg_pool2d_backward 0.23% 5.357ms 0.23% 5.357ms 5.357ms 0 b 0 b 1
aten::_log_softmax_backward_data 0.21% 5.000ms 0.21% 5.000ms 5.000ms 0 b 0 b 1
aten::ones_like 0.00% 27.700us 0.20% 4.692ms 4.692ms 0 b 0 b 1
aten::fill_ 0.18% 4.363ms 0.18% 4.363ms 4.363ms 0 b 0 b 1
NllLossBackward 0.04% 917.500us 0.15% 3.485ms 3.485ms 0 b 0 b 1
detach 0.11% 2.692ms 0.11% 2.692ms 16.721us 0 b 0 b 161
aten::nll_loss_backward 0.11% 2.556ms 0.11% 2.567ms 2.567ms 0 b 0 b 1
aten::as_strided 0.05% 1.290ms 0.05% 1.290ms 3.402us 0 b 0 b 379
aten::transpose 0.02% 579.900us 0.04% 898.000us 5.476us 0 b 0 b 164
aten::zeros 0.02% 575.100us 0.04% 865.200us 288.400us 12 b 0 b 3
TBackward 0.02% 376.000us 0.02% 398.500us 398.500us 0 b 0 b 1
aten::broadcast_to 0.01% 281.000us 0.01% 329.400us 329.400us 0 b 0 b 1
aten::adaptive_avg_pool2d 0.00% 30.400us 0.01% 321.500us 321.500us 0 b 0 b 1
aten::_adaptive_avg_pool2d 0.01% 291.100us 0.01% 291.100us 291.100us 0 b 0 b 1
aten::t 0.00% 111.200us 0.01% 204.400us 40.880us 0 b 0 b 5
aten::squeeze 0.00% 61.200us 0.00% 95.700us 47.850us 0 b 0 b 2
aten::flatten 0.00% 26.800us 0.00% 61.000us 61.000us 0 b 0 b 1
AddBackward0 0.00% 56.200us 0.00% 56.200us 3.513us 0 b 0 b 16
aten::expand 0.00% 29.300us 0.00% 48.400us 48.400us 0 b 0 b 1
aten::conj 0.00% 21.900us 0.00% 21.900us 10.950us 0 b 0 b 2
ViewBackward 0.00% 8.200us 0.00% 20.400us 20.400us 0 b 0 b 1
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.376s
Done! with highest_accuracy: 0
```
## External Links
- [Original training data (LSVRC 2012)](http://www.image-net.org/challenges/LSVRC/2012/)
- [Alternative training data (CIFAR-10)](https://www.cs.toronto.edu/~kriz/cifar.html)
Alternative implementations:
- [ONNX](https://github.com/onnx/models/tree/master/vision/classification/resnet)

Просмотреть файл

@ -0,0 +1,26 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from test_classification import predict
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--image", type=str, help="Image to classify.")
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
args = parser.parse_args()
predict(args.image, 'resnet50', args.device)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,7 @@
pandas
tensorboard
matplotlib
tqdm
pyyaml
opencv-python
wget

31
PyTorch/resnet50/test.py Normal file
Просмотреть файл

@ -0,0 +1,31 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from test_classification import main as test
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
batch_size = 1 if args.trace else args.batch_size
test(args.path, batch_size, args.device, 'resnet50', args.trace)
if __name__ == "__main__":
main()

35
PyTorch/resnet50/train.py Normal file
Просмотреть файл

@ -0,0 +1,35 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from train_classification import main as train
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--epochs', type=int, default=50, metavar='N', help='The number of epochs to train for.')
parser.add_argument('--learning_rate', type=float, default=0.001, metavar='LR', help='The learning rate.')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='The percentage of past parameters to store.')
parser.add_argument('--weight_decay', default=0.0001, type=float, help='The parameter to decay weights.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--save_model', action='store_true', help='Save the model state_dict to file')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
train(args.path, args.batch_size, args.epochs, args.learning_rate,
args.momentum, args.weight_decay, args.device, 'resnet50', args.save_model, args.trace)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,260 @@
# SqueezeNet Model <!-- omit in toc -->
Sample scripts for training the SqueezeNet model using PyTorch on DirectML.
These scripts were forked from https://github.com/pytorch/benchmark. The original code is Copyright (c) 2019, pytorch, and is used here under the terms of the BSD 3-Clause License. See [LICENSE](https://github.com/pytorch/benchmark/blob/main/LICENSE) for more information.
The original paper can be found at: https://arxiv.org/abs/1602.07360
- [Setup](#setup)
- [Prepare Data](#prepare-data)
- [Training](#training)
- [Testing](#testing)
- [Predict](#predict)
- [Tracing](#tracing)
- [External Links](#links)
## Setup
Install the following prerequisites:
```
pip install -r pytorch\squeezenet\requirements.txt
```
## Prepare Data
After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data:
```
python pytorch\data\cifar.py
```
Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following:
```
>python pytorch\data\cifar.py
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz
170499072it [00:32, 5250164.09it/s]
Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python
```
## Training
A helper script exists to train SqueezeNet with default data, batch size, and so on:
```
python pytorch\squeezenet\train.py
```
The first few lines of output should look similar to the following (exact numbers may change):
```
>python pytorch\squeezenet\train.py
Loading the training dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Train data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Train data Y:
shape=torch.Size([32]),
dtype=torch.int64
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([32]),
dtype=torch.int64
Finished moving squeezenet1_1 to device: dml in 0.2560007572174072s.
Epoch 1
-------------------------------
loss: 2.309573 [ 0/50000]
loss: 2.324099 [ 3200/50000]
loss: 2.297763 [ 6400/50000]
loss: 2.292575 [ 9600/50000]
loss: 2.251738 [12800/50000]
loss: 2.183397 [16000/50000]
loss: 2.130508 [19200/50000]
loss: 2.000042 [22400/50000]
loss: 2.183213 [25600/50000]
loss: 2.250935 [28800/50000]
loss: 1.730087 [32000/50000]
loss: 1.999480 [35200/50000]
loss: 1.865684 [38400/50000]
loss: 2.058377 [41600/50000]
loss: 2.059475 [44800/50000]
loss: 2.279521 [48000/50000]
current highest_accuracy: 0.2856000065803528
Test Error:
Accuracy: 28.6%, Avg loss: 1.862064
```
By default, the script will run for 50 epochs with a batch size of 32 and print the accuracy after every 100 batches. The training script can be run multiple times and saves progress after each epoch (by default). The accuracy should increase over time.
> When discrete memory or shared GPU memory is insufficient consider running the same scripts with a smaller batch size (use the --batch_size argument). For example:
```
python pytorch\resnet50\train.py --batch_size 8
```
You can inspect `train.py` (and the real script, `pytorch/classification/train_classification.py`) to see the command line it is invoking or adjust some of the parameters.
You can save the model for testing by passing in the --save_model flag. This will cause checkpoints to be saved to the `pytorch\checkpoints\<device>\<model>\checkpoint.pth` path.
```
python pytorch\resnet50\train.py --save_model
```
## Testing
Once the model is trained and saved we can now test the model using the following steps. The test script will use the latest trained model from the checkpoints folder.
```
python pytorch\squeezenet\test.py
```
You should see the result such as this:
```
>python pytorch\squeezenet\test.py
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([32, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([32]),
dtype=torch.int64
Finished moving squeezenet1_1 to device: dml in 0.22499728202819824s.
current highest_accuracy: 0.10000000149011612
Test Error:
Accuracy: 10.0%, Avg loss: 2.321213
```
## Predict
Once the model is trained and saved we can now run the prediction using the following steps. The predict script will use that latest trained model from the checkpoints folder.
```
python pytorch\squeezenet\predict.py --image E:\a.jpeg
```
You should see the result such as this:
```
E:\work\dml>python pytorch\squeezenet\predict.py --image E:\a.jpeg
hammerhead 0.35642221570014954
stingray 0.34619468450546265
electric ray 0.09593362361192703
cock 0.07319413870573044
great white shark 0.06555310636758804
```
## Tracing
It may be useful to get a trace during training or evaluation.
```
python pytorch\squeezenet\test.py --trace True
python pytorch\squeezenet\train.py --trace True
```
With default settings, you'll see output like the following:
```
>python pytorch\squeezenet\train.py --trace True
Loading the training dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Train data X [N, C, H, W]:
shape=torch.Size([1, 3, 224, 224]),
dtype=torch.float32
Train data Y:
shape=torch.Size([1]),
dtype=torch.int64
Loading the testing dataset from: E:\work\dml\PyTorch\data\cifar-10-python
Test data X [N, C, H, W]:
shape=torch.Size([1, 3, 224, 224]),
dtype=torch.float32
Test data Y:
shape=torch.Size([1]),
dtype=torch.int64
Finished moving squeezenet1_1 to device: dml in 0.2282116413116455s.
Epoch 1
-------------------------------
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
model_inference 33.98% 244.942ms 67.93% 489.574ms 489.574ms -4 b -20 b 1
ThnnConv2DBackward 0.06% 435.700us 21.73% 156.616ms 6.024ms 0 b 0 b 26
aten::thnn_conv2d_backward 21.52% 155.095ms 21.67% 156.180ms 6.007ms 0 b 0 b 26
aten::conv2d 0.15% 1.070ms 13.12% 94.566ms 3.637ms 0 b 0 b 26
aten::convolution 0.12% 877.800us 12.97% 93.496ms 3.596ms 0 b 0 b 26
aten::_convolution 0.14% 975.500us 12.85% 92.618ms 3.562ms 0 b 0 b 26
aten::_convolution_nogroup 0.12% 889.600us 12.71% 91.643ms 3.525ms 0 b 0 b 26
aten::thnn_conv2d 0.12% 858.900us 12.59% 90.753ms 3.491ms 0 b 0 b 26
aten::thnn_conv2d_forward 12.01% 86.566ms 12.47% 89.894ms 3.457ms 0 b 0 b 26
Optimizer.step#SGD.step 0.52% 3.769ms 10.38% 74.808ms 74.808ms -4 b -20 b 1
aten::add 4.57% 32.967ms 4.57% 32.967ms 633.988us 0 b 0 b 52
ReluBackward1 0.03% 219.000us 4.01% 28.888ms 1.111ms 0 b 0 b 26
aten::threshold_backward 3.98% 28.669ms 3.98% 28.669ms 1.103ms 0 b 0 b 26
aten::empty_strided 3.82% 27.552ms 3.82% 27.552ms 257.492us 4 b 4 b 107
struct torch::autograd::AccumulateGrad 0.13% 905.400us 3.19% 22.985ms 442.012us 0 b 0 b 52
aten::clone 0.52% 3.726ms 2.79% 20.118ms 386.875us 0 b 0 b 52
aten::add_ 2.23% 16.077ms 2.23% 16.077ms 309.167us 0 b 0 b 52
aten::new_empty_strided 0.06% 450.100us 2.02% 14.575ms 280.285us 0 b 0 b 52
aten::log_softmax 0.00% 31.800us 1.95% 14.039ms 14.039ms 0 b 0 b 1
aten::_log_softmax 1.94% 14.007ms 1.94% 14.007ms 14.007ms 0 b 0 b 1
aten::copy_ 1.59% 11.450ms 1.59% 11.450ms 107.012us 0 b 0 b 107
aten::nll_loss 0.01% 51.200us 1.52% 10.988ms 10.988ms 0 b 0 b 1
aten::cat 0.06% 439.200us 1.52% 10.964ms 1.370ms 0 b 0 b 8
aten::nll_loss_forward 1.50% 10.779ms 1.52% 10.937ms 10.937ms 0 b 0 b 1
aten::dropout 0.01% 97.400us 1.50% 10.809ms 10.809ms 0 b 0 b 1
aten::_cat 1.46% 10.525ms 1.46% 10.525ms 1.316ms 0 b 0 b 8
aten::max_pool2d 0.02% 143.300us 1.10% 7.919ms 2.640ms 0 b 0 b 3
aten::max_pool2d_with_indices 1.08% 7.776ms 1.08% 7.776ms 2.592ms 0 b 0 b 3
aten::relu_ 0.98% 7.045ms 0.98% 7.045ms 270.969us 0 b 0 b 26
MaxPool2DWithIndicesBackward 0.01% 55.600us 0.87% 6.302ms 2.101ms 0 b 0 b 3
aten::max_pool2d_with_indices_backward 0.87% 6.246ms 0.87% 6.246ms 2.082ms 0 b 0 b 3
aten::adaptive_avg_pool2d 0.01% 43.100us 0.85% 6.109ms 6.109ms 0 b 0 b 1
aten::_adaptive_avg_pool2d 0.84% 6.066ms 0.84% 6.066ms 6.066ms 0 b 0 b 1
aten::as_strided 0.82% 5.932ms 0.82% 5.932ms 26.249us 0 b 0 b 226
aten::div_ 0.57% 4.096ms 0.64% 4.628ms 4.628ms 0 b -4 b 1
LogSoftmaxBackward 0.00% 21.700us 0.64% 4.585ms 4.585ms 0 b 0 b 1
aten::mul 0.64% 4.579ms 0.64% 4.579ms 2.290ms 0 b 0 b 2
aten::_log_softmax_backward_data 0.63% 4.563ms 0.63% 4.563ms 4.563ms 0 b 0 b 1
AdaptiveAvgPool2DBackward 0.00% 13.000us 0.62% 4.496ms 4.496ms 0 b 0 b 1
CatBackward 0.01% 63.400us 0.62% 4.486ms 560.712us 0 b 0 b 8
aten::_adaptive_avg_pool2d_backward 0.62% 4.483ms 0.62% 4.483ms 4.483ms 0 b 0 b 1
aten::ones_like 0.00% 29.700us 0.62% 4.478ms 4.478ms 0 b 0 b 1
aten::narrow 0.01% 54.100us 0.61% 4.422ms 276.394us 0 b 0 b 16
aten::slice 0.02% 155.500us 0.61% 4.368ms 273.013us 0 b 0 b 16
aten::fill_ 0.57% 4.120ms 0.57% 4.120ms 4.120ms 0 b 0 b 1
aten::empty 0.43% 3.080ms 0.43% 3.080ms 16.296us 338.06 Kb 338.06 Kb 189
aten::bernoulli_ 0.17% 1.233ms 0.37% 2.636ms 1.318ms 0 b -338.00 Kb 2
Optimizer.zero_grad#SGD.zero_grad 0.06% 400.600us 0.36% 2.623ms 2.623ms -4 b -20 b 1
aten::zero_ 0.31% 2.250ms 0.31% 2.250ms 40.902us 0 b 0 b 55
NllLossBackward 0.01% 48.700us 0.30% 2.180ms 2.180ms 0 b 0 b 1
aten::nll_loss_backward 0.29% 2.118ms 0.30% 2.132ms 2.132ms 0 b 0 b 1
aten::detach 0.14% 1.012ms 0.26% 1.864ms 35.854us 0 b 0 b 52
aten::empty_like 0.02% 110.800us 0.13% 914.800us 304.933us 338.00 Kb 0 b 3
detach 0.12% 852.700us 0.12% 852.700us 16.398us 0 b 0 b 52
MulBackward0 0.00% 16.700us 0.07% 539.700us 539.700us 0 b 0 b 1
aten::to 0.01% 57.700us 0.07% 532.200us 266.100us 4 b 0 b 2
aten::transpose 0.04% 277.300us 0.06% 413.300us 5.299us 0 b 0 b 78
aten::zeros 0.02% 117.300us 0.03% 249.100us 83.033us 12 b 0 b 3
aten::reshape 0.01% 57.100us 0.01% 96.300us 32.100us 0 b 0 b 3
aten::flatten 0.01% 54.100us 0.01% 95.300us 95.300us 0 b 0 b 1
aten::squeeze 0.01% 60.300us 0.01% 93.900us 46.950us 0 b 0 b 2
aten::view 0.01% 80.400us 0.01% 80.400us 20.100us 0 b 0 b 4
ViewBackward 0.00% 7.400us 0.00% 25.400us 25.400us 0 b 0 b 1
aten::conj 0.00% 7.500us 0.00% 7.500us 7.500us 0 b 0 b 1
------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 720.753ms
Done! with highest_accuracy: 0
```
## External Links
- [Original paper](https://arxiv.org/abs/1602.07360)
- [Original training data (LSVRC 2012)](http://www.image-net.org/challenges/LSVRC/2012/)
- [Alternative training data (CIFAR-10)](https://www.cs.toronto.edu/~kriz/cifar.html)
Alternative implementations:
- [ONNX](https://github.com/onnx/models/tree/master/vision/classification/squeezenet)

Просмотреть файл

@ -0,0 +1,26 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from test_classification import predict
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--image", type=str, help="Image to classify.")
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
args = parser.parse_args()
predict(args.image, 'squeezenet1_1', args.device)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,7 @@
pandas
tensorboard
matplotlib
tqdm
pyyaml
opencv-python
wget

Просмотреть файл

@ -0,0 +1,31 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from test_classification import main as test
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
batch_size = 1 if args.trace else args.batch_size
test(args.path, batch_size, args.device, 'squeezenet1_1', args.trace)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,35 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
import argparse
import subprocess
import os
import pathlib
import sys
classification_folder = str(os.path.join(pathlib.Path(__file__).parent.parent.resolve(), 'classification'))
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, classification_folder)
from train_classification import main as train
def main():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--path", type=str, default="cifar-10-python", help="Path to cifar dataset.")
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='Batch size to train with.')
parser.add_argument('--epochs', type=int, default=50, metavar='N', help='The number of epochs to train for.')
parser.add_argument('--learning_rate', type=float, default=0.001, metavar='LR', help='The learning rate.')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='The percentage of past parameters to store.')
parser.add_argument('--weight_decay', default=0.0001, type=float, help='The parameter to decay weights.')
parser.add_argument('--device', type=str, default='dml', help='The device to use for training.')
parser.add_argument('--save_model', action='store_true', help='Save the model state_dict to file')
parser.add_argument('--trace', type=bool, default=False, help='Trace performance.')
args = parser.parse_args()
train(args.path, args.batch_size, args.epochs, args.learning_rate,
args.momentum, args.weight_decay, args.device, 'squeezenet1_1', args.save_model, args.trace)
if __name__ == "__main__":
main()

Просмотреть файл

@ -14,6 +14,7 @@ More information about DirectML can be found in [Introduction to DirectML](https
- [Windows ML on DirectML](#windows-ml-on-directml)
- [ONNX Runtime on DirectML](#onnx-runtime-on-directml)
- [TensorFlow with DirectML](#tensorflow-with-directml-preview)
- [PyTorch with DirectML](#pytorch-with-directml)
- [Feedback](#feedback)
- [External Links](#external-links)
- [Documentation](#documentation)
@ -55,6 +56,7 @@ See the following sections for more information:
* [Windows ML on DirectML](#Windows-ML-on-DirectML)
* [ONNX Runtime on DirectML](#ONNX-Runtime-on-DirectML)
* [TensorFlow with DirectML (Preview)](#TensorFlow-with-DirectML-Preview)
* [PyTorch with DirectML (Preview)](#pytorch-with-DirectML-Preview)
## DirectML Samples
@ -110,12 +112,24 @@ TensorFlow on DirectML is supported on both the latest versions of Windows 10 an
* [TensorFlow GitHub | RFC: TensorFlow on DirectML](https://github.com/tensorflow/community/pull/243)
* [TensorFlow homepage](https://www.tensorflow.org/)
## PyTorch with DirectML
DirectML acceleration for PyTorch 1.8.0 is currently available for Public Preview. PyTorch with DirectML enables training and inference of complex machine learning models on a wide range of DirectX 12-compatible hardware.
PyTorch on DirectML is supported on both the latest versions of Windows 10 and the [Windows Subsystem for Linux](https://docs.microsoft.com/windows/wsl/about), and is available for download as a PyPI package. For more information about getting started, see [GPU accelerated ML training (docs.microsoft.com)](http://aka.ms/gpuinwsldocs)
* [PyTorch on DirectML samples](./PyTorch)
* [pytorch-directml PyPI project](https://pypi.org/project/tensorflow-directml/)
* [PyTorch homepage](https://pytorch.org/)
## Feedback
We look forward to hearing from you!
* For TensorFlow with DirectML issues, bugs, and feedback; or for general DirectML issues and feedback, please [file an issue](https://github.com/microsoft/DirectML-Samples/issues) or contact us directly at askdirectml@microsoft.com.
* For PyTorch with DirectML issues, bugs, and feedback; or for general DirectML issues and feedback, please [file an issue](https://github.com/microsoft/DirectML-Samples/issues) or contact us directly at askdirectml@microsoft.com.
* For Windows ML issues, please file a GitHub issue at [microsoft/Windows-Machine-Learning](https://github.com/Microsoft/Windows-Machine-Learning/issues) or contact us directly at askwindowsml@microsoft.com.
* For ONNX Runtime issues, please file an issue at [microsoft/onnxruntime](https://github.com/microsoft/onnxruntime/issues).