fix sample so it works on the GPU (#505)

This commit is contained in:
Chris Lovett 2022-11-17 09:23:33 -08:00 коммит произвёл GitHub
Родитель 673410d94f
Коммит b5157f8441
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 37 добавлений и 40 удалений

Просмотреть файл

@ -77,6 +77,7 @@ class Network(nn.Module):
# Instantiate a neural network model
model = Network()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define the loss function with Classification Cross-Entropy loss and an optimizer with Adam optimizer
loss_fn = nn.CrossEntropyLoss()
@ -99,7 +100,8 @@ def testAccuracy():
for data in test_loader:
images, labels = data
# Run the model on the test set to predict labels
outputs = model(images)
outputs = model(images.to(device))
outputs = outputs.cpu()
# The label with the highest energy will be our prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
@ -116,7 +118,6 @@ def train(num_epochs):
best_accuracy = 0.0
# Define your execution device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("The model will be running on", device, "device")
# Convert model parameters and buffers to CPU or Cuda
model.to(device)
@ -257,7 +258,3 @@ if __name__ == "__main__":
# Conversion to ONNX
Convert_ONNX()