diff --git a/README.md b/README.md index 2d9f68a8c..13711cde9 100644 --- a/README.md +++ b/README.md @@ -39,8 +39,9 @@ The following sections give basic examples of what you can do with TorchGeo. First we'll import various classes and functions used in the following sections: ```python -from lightning import Trainer +from lightning.pytorch import Trainer from torch.utils.data import DataLoader + from torchgeo.datamodules import InriaAerialImageLabelingDataModule from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples from torchgeo.samplers import RandomGeoSampler @@ -56,8 +57,8 @@ Many remote sensing applications involve working with [*geospatial datasets*](ht In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of [Landsat](https://www.usgs.gov/landsat-missions) and [Cropland Data Layer (CDL)](https://data.nal.usda.gov/dataset/cropscape-cropland-data-layer) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets. ```python -landsat7 = Landsat7(root="...") -landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2]) +landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"]) +landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"]) landsat = landsat7 | landsat8 ``` @@ -124,8 +125,18 @@ In order to facilitate direct comparisons between results published in the liter ```python datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6) -task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1) -trainer = Trainer(gpus=1, default_root_dir="...") +task = SemanticSegmentationTask( + model="unet", + backbone="resnet50", + weights="imagenet", + in_channels=3, + num_classes=2, + loss="ce", + ignore_index=None, + learning_rate=0.1, + learning_rate_schedule_patience=6, +) +trainer = Trainer(default_root_dir="...") trainer.fit(model=task, datamodule=datamodule) ```