* Add info

* Address comments
This commit is contained in:
Robin Cole 2023-10-06 15:04:56 +01:00 коммит произвёл Nils Lehmann
Родитель fb706b0b78
Коммит 77940a137d
1 изменённых файлов: 56 добавлений и 2 удалений

Просмотреть файл

@ -146,12 +146,66 @@ trainer.fit(model=task, datamodule=datamodule)
<img src="https://raw.githubusercontent.com/microsoft/torchgeo/main/images/inria.png" alt="Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset"/>
In our GitHub repo, we provide `train.py` and `evaluate.py` scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the [conf](https://github.com/microsoft/torchgeo/blob/main/conf) directory for example configuration files that can be customized for different training runs.
TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways:
```console
$ python train.py config_file=conf/landcoverai.yaml
# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
python3 -m torchgeo
```
It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:
```console
# See valid stages
torchgeo --help
# See valid trainer options
torchgeo fit --help
# See valid model options
torchgeo fit --model.help ClassificationTask
# See valid data options
torchgeo fit --data.help EuroSAT100DataModule
```
Using the following config file:
```yaml
trainer:
max_epochs: 20
model:
class_path: ClassificationTask
init_args:
model: "resnet18"
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
download: true
```
we can see the script in action:
```console
# Train and validate a model
torchgeo fit --config config.yaml
# Validate-only
torchgeo validate --config config.yaml
# Calculate and report test accuracy
torchgeo test --config config.yaml --trainer.ckpt_path=...
```
It can also be imported and used in a Python script if you need to extend it to add new features:
```python
from torchgeo.main import main
main(["fit", "--config", "config.yaml"])
```
See the [Lightning documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for more details.
## Citation
If you use this software in your work, please cite our [paper](https://dl.acm.org/doi/10.1145/3557915.3560953):