278 строки
9.1 KiB
Plaintext
278 строки
9.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Simple Image Classifier\n",
|
|
"\n",
|
|
"Beginner-friendly image classifier built with PyTorch and CIFAR-10.\n",
|
|
"\n",
|
|
"<img alt=\"A photo of a man on an elephant with an ML-generated overlay showing objects in the frame\" src=\"https://upload.wikimedia.org/wikipedia/commons/a/ae/DenseCap_%28Johnson_et_al.%2C_2016%29_%28cropped%29.png\" width=450px>\n",
|
|
"\n",
|
|
"An image classifier is an ML model that recognizes objects in images. We can build image classifiers by feeding tens of thousands of labelled images to a neural network. Tools like PyTorch train these networks by evaluating their performance against the dataset.\n",
|
|
"\n",
|
|
"Let's build an image classifier that detects planes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. We'll download a dataset, configure a neural network, train a model, and evaluate its performance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install matplotlib numpy torch torchvision tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 1: Download a dataset and preview images\n",
|
|
"\n",
|
|
"A model is only as good as its dataset.\n",
|
|
"\n",
|
|
"Training tools need lots of high-quality data to build accurate models. We'll use the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) of 60,000 photos to build our image classifier. Get started by downloading the dataset with `torchvision` and previewing a handful of images from it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torchvision\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# Download the CIFAR-10 dataset to ./data\n",
|
|
"batch_size=10\n",
|
|
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
|
|
"print(\"Downloading training data...\")\n",
|
|
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
|
|
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
|
|
"print(\"Downloading testing data...\")\n",
|
|
"testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
|
|
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)\n",
|
|
"\n",
|
|
"# Our model will recognize these kinds of objects\n",
|
|
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
|
|
"\n",
|
|
"# Grab images from our training data\n",
|
|
"dataiter = iter(trainloader)\n",
|
|
"images, labels = dataiter.next()\n",
|
|
"\n",
|
|
"for i in range(batch_size):\n",
|
|
" # Add new subplot\n",
|
|
" plt.subplot(2, int(batch_size/2), i + 1)\n",
|
|
" # Plot the image\n",
|
|
" img = images[i]\n",
|
|
" img = img / 2 + 0.5\n",
|
|
" npimg = img.numpy()\n",
|
|
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
|
|
" plt.axis('off')\n",
|
|
" # Add the image's label\n",
|
|
" plt.title(classes[labels[i]])\n",
|
|
"\n",
|
|
"plt.suptitle('Preview of Training Data', size=20)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Step 2: Configure the neural network\n",
|
|
"\n",
|
|
"Now that we have our dataset, we need to set up a neural network for PyTorch. Our neural network will transform an image into a description."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"# Define a convolutional neural network\n",
|
|
"class Net(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
|
" self.pool = nn.MaxPool2d(2, 2)\n",
|
|
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
|
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
|
|
" self.fc2 = nn.Linear(120, 84)\n",
|
|
" self.fc3 = nn.Linear(84, 10)\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.pool(F.relu(self.conv1(x)))\n",
|
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
|
" x = torch.flatten(x, 1)\n",
|
|
" x = F.relu(self.fc1(x))\n",
|
|
" x = F.relu(self.fc2(x))\n",
|
|
" x = self.fc3(x)\n",
|
|
" return x\n",
|
|
"net = Net()\n",
|
|
"\n",
|
|
"# Define a loss function and optimizer\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
|
|
"\n",
|
|
"print(\"Your network is ready for training!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Step 3: Train the network and save model\n",
|
|
"\n",
|
|
"PyTorch trains our network by adjusting its parameters and evaluating its performance against our labelled dataset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"from time import sleep\n",
|
|
"\n",
|
|
"EPOCHS = 2\n",
|
|
"print(\"Training...\")\n",
|
|
"for epoch in range(EPOCHS):\n",
|
|
" running_loss = 0.0\n",
|
|
" for i, data in enumerate(tqdm(trainloader, desc=f\"Epoch {epoch + 1} of {EPOCHS}\", leave=True, ncols=80)):\n",
|
|
" inputs, labels = data\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" outputs = net(inputs)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
"# Save our trained model\n",
|
|
"PATH = './cifar_net.pth'\n",
|
|
"torch.save(net.state_dict(), PATH)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Step 4: Test the trained model\n",
|
|
"\n",
|
|
"Let's test our model!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Pick random photos from training set\n",
|
|
"if dataiter == None:\n",
|
|
" dataiter = iter(testloader)\n",
|
|
"images, labels = dataiter.next()\n",
|
|
"\n",
|
|
"# Load our model\n",
|
|
"net = Net()\n",
|
|
"net.load_state_dict(torch.load(PATH))\n",
|
|
"\n",
|
|
"# Analyze images\n",
|
|
"outputs = net(images)\n",
|
|
"_, predicted = torch.max(outputs, 1)\n",
|
|
"\n",
|
|
"# Show results\n",
|
|
"for i in range(batch_size):\n",
|
|
" # Add new subplot\n",
|
|
" plt.subplot(2, int(batch_size/2), i + 1)\n",
|
|
" # Plot the image\n",
|
|
" img = images[i]\n",
|
|
" img = img / 2 + 0.5\n",
|
|
" npimg = img.numpy()\n",
|
|
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
|
|
" plt.axis('off')\n",
|
|
" # Add the image's label\n",
|
|
" color = \"green\"\n",
|
|
" label = classes[predicted[i]]\n",
|
|
" if classes[labels[i]] != classes[predicted[i]]:\n",
|
|
" color = \"red\"\n",
|
|
" label = \"(\" + label + \")\"\n",
|
|
" plt.title(label, color=color)\n",
|
|
"\n",
|
|
"plt.suptitle('Objects Found by Model', size=20)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Step 5: Evaluate model accuracy\n",
|
|
"\n",
|
|
"Let's conclude by evaluating our model's overall performance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Measure accuracy for each class\n",
|
|
"correct_pred = {classname: 0 for classname in classes}\n",
|
|
"total_pred = {classname: 0 for classname in classes}\n",
|
|
"with torch.no_grad():\n",
|
|
" for data in testloader:\n",
|
|
" images, labels = data\n",
|
|
" outputs = net(images)\n",
|
|
" _, predictions = torch.max(outputs, 1)\n",
|
|
" # collect the correct predictions for each class\n",
|
|
" for label, prediction in zip(labels, predictions):\n",
|
|
" if label == prediction:\n",
|
|
" correct_pred[classes[label]] += 1\n",
|
|
" total_pred[classes[label]] += 1\n",
|
|
"\n",
|
|
"# Print accuracy statistics\n",
|
|
"for classname, correct_count in correct_pred.items():\n",
|
|
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
|
|
" print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.7"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "eb4a0ac80907d7f44e1a5e88d3d3381b33e3dbedd3a24d113e876f30a0c46bee"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|