torchgeo/docs/tutorials/trainers.ipynb

337 строки
8.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b13c2251",
"metadata": {
"id": "b13c2251"
},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"id": "e563313d",
"metadata": {
"id": "e563313d"
},
"source": [
"# Lightning Trainers\n",
"\n",
"In this tutorial, we demonstrate TorchGeo trainers to train and test a model. We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. We will train models to predict land cover classes.\n",
"\n",
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
]
},
{
"cell_type": "markdown",
"id": "8c1f4156",
"metadata": {
"id": "8c1f4156"
},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo and TensorBoard."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f0d31a8",
"metadata": {
"id": "3f0d31a8"
},
"outputs": [],
"source": [
"%pip install torchgeo tensorboard"
]
},
{
"cell_type": "markdown",
"id": "c90c94c7",
"metadata": {
"id": "c90c94c7"
},
"source": [
"## Imports\n",
"\n",
"Next, we import TorchGeo and any other libraries we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd39f485",
"metadata": {
"id": "bd39f485"
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext tensorboard\n",
"\n",
"import os\n",
"import tempfile\n",
"\n",
"import torch\n",
"from lightning.pytorch import Trainer\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch.loggers import TensorBoardLogger\n",
"\n",
"from torchgeo.datamodules import EuroSAT100DataModule\n",
"from torchgeo.models import ResNet18_Weights\n",
"from torchgeo.trainers import ClassificationTask"
]
},
{
"cell_type": "markdown",
"id": "e6e1d9b6",
"metadata": {
"id": "e6e1d9b6"
},
"source": [
"## Lightning modules\n",
"\n",
"Our trainers use [Lightning](https://lightning.ai/docs/pytorch/stable/) to organize both the training code, and the dataloader setup code. This makes it easy to create and share reproducible experiments and results.\n",
"\n",
"First we'll create a `EuroSAT100DataModule` object which is simply a wrapper around the [EuroSAT100](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#eurosat) dataset. This object 1.) ensures that the data is downloaded, 2.) sets up PyTorch `DataLoader` objects for the train, validation, and test splits, and 3.) ensures that data from the same region **is not** shared between the training and validation sets so that you can properly evaluate the generalization performance of your model."
]
},
{
"cell_type": "markdown",
"id": "9f2daa0d",
"metadata": {
"id": "9f2daa0d"
},
"source": [
"The following variables can be modified to control training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e100f8b",
"metadata": {
"id": "8e100f8b",
"nbmake": {
"mock": {
"batch_size": 1,
"fast_dev_run": true,
"max_epochs": 1,
"num_workers": 0
}
}
},
"outputs": [],
"source": [
"batch_size = 10\n",
"num_workers = 2\n",
"max_epochs = 50\n",
"fast_dev_run = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f2a04c7",
"metadata": {
"id": "0f2a04c7"
},
"outputs": [],
"source": [
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n",
"datamodule = EuroSAT100DataModule(\n",
" root=root, batch_size=batch_size, num_workers=num_workers, download=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "056b7b4c",
"metadata": {
"id": "056b7b4c"
},
"source": [
"Next, we create a `ClassificationTask` object that holds the model object, optimizer object, and training logic. We will use a ResNet-18 model that has been pre-trained on Sentinel-2 imagery."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba5c5442",
"metadata": {
"id": "ba5c5442"
},
"outputs": [],
"source": [
"task = ClassificationTask(\n",
" loss='ce',\n",
" model='resnet18',\n",
" weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,\n",
" in_channels=13,\n",
" num_classes=10,\n",
" lr=0.1,\n",
" patience=5,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d4b67f3e",
"metadata": {
"id": "d4b67f3e"
},
"source": [
"## Training\n",
"\n",
"Now that we have the Lightning modules set up, we can use a Lightning [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to run the training and evaluation loops. There are many useful pieces of configuration that can be set in the `Trainer` -- below we set up model checkpointing based on the validation loss, early stopping based on the validation loss, and a TensorBoard based logger. We encourage you to see the [Lightning docs](https://lightning.ai/docs/pytorch/stable/) for other options that can be set here, e.g. CSV logging, automatically selecting your optimizer's learning rate, and easy multi-GPU training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffe26e5c",
"metadata": {
"id": "ffe26e5c"
},
"outputs": [],
"source": [
"accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
"default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True\n",
")\n",
"early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=10)\n",
"logger = TensorBoardLogger(save_dir=default_root_dir, name='tutorial_logs')"
]
},
{
"cell_type": "markdown",
"id": "06afd8c7",
"metadata": {
"id": "06afd8c7"
},
"source": [
"For tutorial purposes we deliberately lower the maximum number of training epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "225a6d36",
"metadata": {
"id": "225a6d36"
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" accelerator=accelerator,\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" fast_dev_run=fast_dev_run,\n",
" log_every_n_steps=1,\n",
" logger=logger,\n",
" min_epochs=1,\n",
" max_epochs=max_epochs,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "44d71e8f",
"metadata": {
"id": "44d71e8f"
},
"source": [
"When we first call `.fit(...)` the dataset will be downloaded and checksummed (if it hasn't already). After this, the training process will kick off, and results will be saved so that TensorBoard can read them."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00e08790",
"metadata": {
"id": "00e08790"
},
"outputs": [],
"source": [
"trainer.fit(model=task, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"id": "73700fb5",
"metadata": {
"id": "73700fb5"
},
"source": [
"We launch TensorBoard to visualize various performance metrics across training and validation epochs. We can see that our model is just starting to converge, and would probably benefit from additional training time and a lower initial learning rate."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e95ee0a",
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir \"$default_root_dir\""
]
},
{
"cell_type": "markdown",
"id": "04cfc7a8",
"metadata": {
"id": "04cfc7a8"
},
"source": [
"Finally, after the model has been trained, we can easily evaluate it on the test set."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "604a3b2f",
"metadata": {
"id": "604a3b2f"
},
"outputs": [],
"source": [
"trainer.test(model=task, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"execution": {
"timeout": 1200
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}