README: update Lightning trainer examples (#1211)

This commit is contained in:
Adam J. Stewart 2023-04-03 13:44:23 -05:00 коммит произвёл GitHub
Родитель 0d8d320e86
Коммит 8c98167bda
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 16 добавлений и 5 удалений

Просмотреть файл

@ -39,8 +39,9 @@ The following sections give basic examples of what you can do with TorchGeo.
First we'll import various classes and functions used in the following sections:
```python
from lightning import Trainer
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples
from torchgeo.samplers import RandomGeoSampler
@ -56,8 +57,8 @@ Many remote sensing applications involve working with [*geospatial datasets*](ht
In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of [Landsat](https://www.usgs.gov/landsat-missions) and [Cropland Data Layer (CDL)](https://data.nal.usda.gov/dataset/cropscape-cropland-data-layer) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.
```python
landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2])
landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"])
landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"])
landsat = landsat7 | landsat8
```
@ -124,8 +125,18 @@ In order to facilitate direct comparisons between results published in the liter
```python
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights="imagenet",
in_channels=3,
num_classes=2,
loss="ce",
ignore_index=None,
learning_rate=0.1,
learning_rate_schedule_patience=6,
)
trainer = Trainer(default_root_dir="...")
trainer.fit(model=task, datamodule=datamodule)
```