codespaces-getting-started-ml/image-classifier.ipynb

278 строки
9.1 KiB
Plaintext
Исходник Обычный вид История

2022-08-11 01:59:43 +03:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple Image Classifier\n",
"\n",
"Beginner-friendly image classifier built with PyTorch and CIFAR-10.\n",
"\n",
"<img alt=\"A photo of a man on an elephant with an ML-generated overlay showing objects in the frame\" src=\"https://upload.wikimedia.org/wikipedia/commons/a/ae/DenseCap_%28Johnson_et_al.%2C_2016%29_%28cropped%29.png\" width=450px>\n",
"\n",
"An image classifier is an ML model that recognizes objects in images. We can build image classifiers by feeding tens of thousands of labelled images to a neural network. Tools like PyTorch train these networks by evaluating their performance against the dataset.\n",
"\n",
"Let's build an image classifier that detects planes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. We'll download a dataset, configure a neural network, train a model, and evaluate its performance."
2022-08-11 01:59:43 +03:00
]
},
2022-09-30 20:41:34 +03:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2022-10-01 01:21:15 +03:00
"%pip install matplotlib numpy torch torchvision tqdm"
2022-09-30 20:41:34 +03:00
]
},
2022-08-11 01:59:43 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Download a dataset and preview images\n",
"\n",
"A model is only as good as its dataset.\n",
"\n",
"Training tools need lots of high-quality data to build accurate models. We'll use the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) of 60,000 photos to build our image classifier. Get started by downloading the dataset with `torchvision` and previewing a handful of images from it."
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Download the CIFAR-10 dataset to ./data\n",
"batch_size=10\n",
2022-08-11 01:59:43 +03:00
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"print(\"Downloading training data...\")\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
"print(\"Downloading testing data...\")\n",
"testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)\n",
"\n",
"# Our model will recognize these kinds of objects\n",
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"\n",
"# Grab images from our training data\n",
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"for i in range(batch_size):\n",
" # Add new subplot\n",
" plt.subplot(2, int(batch_size/2), i + 1)\n",
2022-08-11 01:59:43 +03:00
" # Plot the image\n",
" img = images[i]\n",
" img = img / 2 + 0.5\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.axis('off')\n",
" # Add the image's label\n",
" plt.title(classes[labels[i]])\n",
"\n",
"plt.suptitle('Preview of Training Data', size=20)\n",
2022-08-11 01:59:43 +03:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 2: Configure the neural network\n",
"\n",
"Now that we have our dataset, we need to set up a neural network for PyTorch. Our neural network will transform an image into a description."
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"# Define a convolutional neural network\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"net = Net()\n",
"\n",
"# Define a loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
"\n",
"print(\"Your network is ready for training!\")"
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 3: Train the network and save model\n",
2022-08-11 01:59:43 +03:00
"\n",
"PyTorch trains our network by adjusting its parameters and evaluating its performance against our labelled dataset."
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"from time import sleep\n",
"\n",
"EPOCHS = 2\n",
"print(\"Training...\")\n",
"for epoch in range(EPOCHS):\n",
" running_loss = 0.0\n",
" for i, data in enumerate(tqdm(trainloader, desc=f\"Epoch {epoch + 1} of {EPOCHS}\", leave=True, ncols=80)):\n",
" inputs, labels = data\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"# Save our trained model\n",
"PATH = './cifar_net.pth'\n",
"torch.save(net.state_dict(), PATH)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 4: Test the trained model\n",
"\n",
"Let's test our model!"
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Pick random photos from training set\n",
2022-08-15 21:46:38 +03:00
"if dataiter == None:\n",
" dataiter = iter(testloader)\n",
2022-08-11 01:59:43 +03:00
"images, labels = dataiter.next()\n",
"\n",
"# Load our model\n",
"net = Net()\n",
"net.load_state_dict(torch.load(PATH))\n",
"\n",
"# Analyze images\n",
"outputs = net(images)\n",
"_, predicted = torch.max(outputs, 1)\n",
"\n",
"# Show results\n",
"for i in range(batch_size):\n",
" # Add new subplot\n",
" plt.subplot(2, int(batch_size/2), i + 1)\n",
2022-08-11 01:59:43 +03:00
" # Plot the image\n",
" img = images[i]\n",
" img = img / 2 + 0.5\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.axis('off')\n",
" # Add the image's label\n",
" color = \"green\"\n",
" label = classes[predicted[i]]\n",
" if classes[labels[i]] != classes[predicted[i]]:\n",
" color = \"red\"\n",
" label = \"(\" + label + \")\"\n",
" plt.title(label, color=color)\n",
"\n",
"plt.suptitle('Objects Found by Model', size=20)\n",
2022-08-11 01:59:43 +03:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 5: Evaluate model accuracy\n",
"\n",
"Let's conclude by evaluating our model's overall performance."
2022-08-11 01:59:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Measure accuracy for each class\n",
"correct_pred = {classname: 0 for classname in classes}\n",
"total_pred = {classname: 0 for classname in classes}\n",
"with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data\n",
" outputs = net(images)\n",
" _, predictions = torch.max(outputs, 1)\n",
" # collect the correct predictions for each class\n",
" for label, prediction in zip(labels, predictions):\n",
" if label == prediction:\n",
" correct_pred[classes[label]] += 1\n",
" total_pred[classes[label]] += 1\n",
"\n",
"# Print accuracy statistics\n",
"for classname, correct_count in correct_pred.items():\n",
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
" print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')"
]
}
],
"metadata": {
"kernelspec": {
2022-10-01 01:21:15 +03:00
"display_name": "Python 3 (ipykernel)",
2022-08-11 01:59:43 +03:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-10-01 01:18:21 +03:00
"version": "3.10.7"
2022-08-11 01:59:43 +03:00
},
"vscode": {
"interpreter": {
2022-10-01 01:18:21 +03:00
"hash": "eb4a0ac80907d7f44e1a5e88d3d3381b33e3dbedd3a24d113e876f30a0c46bee"
2022-08-11 01:59:43 +03:00
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}