зеркало из https://github.com/microsoft/torchgeo.git
README: update Lightning trainer examples (#1211)
This commit is contained in:
Родитель
0d8d320e86
Коммит
8c98167bda
21
README.md
21
README.md
|
@ -39,8 +39,9 @@ The following sections give basic examples of what you can do with TorchGeo.
|
|||
First we'll import various classes and functions used in the following sections:
|
||||
|
||||
```python
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch import Trainer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
|
||||
from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples
|
||||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
@ -56,8 +57,8 @@ Many remote sensing applications involve working with [*geospatial datasets*](ht
|
|||
In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of [Landsat](https://www.usgs.gov/landsat-missions) and [Cropland Data Layer (CDL)](https://data.nal.usda.gov/dataset/cropscape-cropland-data-layer) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.
|
||||
|
||||
```python
|
||||
landsat7 = Landsat7(root="...")
|
||||
landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2])
|
||||
landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"])
|
||||
landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"])
|
||||
landsat = landsat7 | landsat8
|
||||
```
|
||||
|
||||
|
@ -124,8 +125,18 @@ In order to facilitate direct comparisons between results published in the liter
|
|||
|
||||
```python
|
||||
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
|
||||
task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1)
|
||||
trainer = Trainer(gpus=1, default_root_dir="...")
|
||||
task = SemanticSegmentationTask(
|
||||
model="unet",
|
||||
backbone="resnet50",
|
||||
weights="imagenet",
|
||||
in_channels=3,
|
||||
num_classes=2,
|
||||
loss="ce",
|
||||
ignore_index=None,
|
||||
learning_rate=0.1,
|
||||
learning_rate_schedule_patience=6,
|
||||
)
|
||||
trainer = Trainer(default_root_dir="...")
|
||||
|
||||
trainer.fit(model=task, datamodule=datamodule)
|
||||
```
|
||||
|
|
Загрузка…
Ссылка в новой задаче