diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index a3ed9b83c..a322036de 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -17,9 +17,19 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.9
+ - name: Install apt dependencies
+ run: |
+ sudo apt-add-repository ppa:ubuntugis/ubuntugis-unstable
+ sudo apt-get update
+ sudo apt-get install gdal-bin libgdal-dev
- name: Install pip dependencies
- run: pip install .[tests]
+ run: |
+ pip install gdal tqdm # TODO: these deps shouldn't be needed
+ pip install .[datasets,tests,train]
+ pip install -r docs/requirements.txt
- name: Run notebook checks
+ env:
+ MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }}
run: pytest --nbmake docs/tutorials
integration:
name: integration
@@ -32,6 +42,6 @@ jobs:
with:
python-version: 3.9
- name: Install pip dependencies
- run: pip install .[tests]
+ run: pip install .[datasets,tests,train]
- name: Run integration checks
run: pytest -m slow
diff --git a/README.md b/README.md
index 5d6950706..80adc5ea8 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-
+
TorchGeo is a [PyTorch](https://pytorch.org/) domain library, similar to [torchvision](https://pytorch.org/vision), that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.
@@ -23,7 +23,7 @@ Tests:
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
```console
-$ pip install git+https://github.com/microsoft/torchgeo.git
+$ pip install torchgeo
```
For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/latest/user/installation.html).
diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet.yaml
index 745fb41a5..c5c352861 100644
--- a/conf/task_defaults/bigearthnet.yaml
+++ b/conf/task_defaults/bigearthnet.yaml
@@ -9,8 +9,8 @@ experiment:
in_channels: 14
num_classes: 19
datamodule:
- num_classes: 19
- batch_size: 128
- num_workers: 6
+ root_dir: "tests/data/bigearthnet"
bands: "all"
- num_classes: 19
+ num_classes: ${experiment.module.num_classes}
+ batch_size: 128
+ num_workers: 0
diff --git a/conf/task_defaults/byol.yaml b/conf/task_defaults/byol.yaml
index 2a4aa84bd..95066aa28 100644
--- a/conf/task_defaults/byol.yaml
+++ b/conf/task_defaults/byol.yaml
@@ -9,11 +9,12 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
- batch_size: 64
- num_workers: 6
+ root_dir: "tests/data/chesapeake/cvpr"
train_splits:
- - "de-train"
+ - "de-test"
val_splits:
- - "de-val"
+ - "de-test"
test_splits:
- "de-test"
+ batch_size: 64
+ num_workers: 0
diff --git a/conf/task_defaults/chesapeake_cvpr.yaml b/conf/task_defaults/chesapeake_cvpr.yaml
index 1087c96f9..6cb0b7a90 100644
--- a/conf/task_defaults/chesapeake_cvpr.yaml
+++ b/conf/task_defaults/chesapeake_cvpr.yaml
@@ -12,9 +12,15 @@ experiment:
num_classes: 7
num_filters: 256
datamodule:
- train_state: "de"
+ root_dir: "tests/data/chesapeake/cvpr"
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
patches_per_tile: 200
patch_size: 256
batch_size: 64
- num_workers: 4
- num_classes: ${experiment.module.num_classes}
+ num_workers: 0
+ class_set: ${experiment.module.num_classes}
diff --git a/conf/task_defaults/cowc_counting.yaml b/conf/task_defaults/cowc_counting.yaml
index b99d20728..a0d43a234 100644
--- a/conf/task_defaults/cowc_counting.yaml
+++ b/conf/task_defaults/cowc_counting.yaml
@@ -5,5 +5,6 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
+ root_dir: "tests/data/cowc_counting"
batch_size: 32
- num_workers: 4
+ num_workers: 0
diff --git a/conf/task_defaults/cyclone.yaml b/conf/task_defaults/cyclone.yaml
index 1b50c8f25..c42cdb4c2 100644
--- a/conf/task_defaults/cyclone.yaml
+++ b/conf/task_defaults/cyclone.yaml
@@ -5,5 +5,6 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
+ root_dir: "tests/data/cyclone"
batch_size: 32
- num_workers: 4
+ num_workers: 0
diff --git a/conf/task_defaults/landcoverai.yaml b/conf/task_defaults/landcoverai.yaml
index d3efd7bc6..d471df596 100644
--- a/conf/task_defaults/landcoverai.yaml
+++ b/conf/task_defaults/landcoverai.yaml
@@ -12,5 +12,6 @@ experiment:
num_classes: 6
num_filters: 256
datamodule:
+ root_dir: "tests/data/landcoverai"
batch_size: 32
- num_workers: 4
+ num_workers: 0
diff --git a/conf/task_defaults/naipchesapeake.yaml b/conf/task_defaults/naipchesapeake.yaml
index f793fd56d..1720a9bef 100644
--- a/conf/task_defaults/naipchesapeake.yaml
+++ b/conf/task_defaults/naipchesapeake.yaml
@@ -12,5 +12,8 @@ experiment:
num_classes: 13
num_filters: 64
datamodule:
+ naip_root_dir: "tests/data/naip"
+ chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
batch_size: 32
- num_workers: 4
+ num_workers: 0
+ patch_size: 32
diff --git a/conf/task_defaults/resisc45.yaml b/conf/task_defaults/resisc45.yaml
index 1f2089f78..a657cc5d0 100644
--- a/conf/task_defaults/resisc45.yaml
+++ b/conf/task_defaults/resisc45.yaml
@@ -9,9 +9,6 @@ experiment:
in_channels: 3
num_classes: 45
datamodule:
+ root_dir: "tests/data/resisc45"
batch_size: 128
- num_workers: 6
- weights: ${experiment.module.weights}
- unsupervised_mode: false
- val_split_pct: 0.2
- test_split_pct: 0.2
+ num_workers: 0
diff --git a/conf/task_defaults/sen12ms.yaml b/conf/task_defaults/sen12ms.yaml
index bd4ceb029..c0a36ac0d 100644
--- a/conf/task_defaults/sen12ms.yaml
+++ b/conf/task_defaults/sen12ms.yaml
@@ -11,5 +11,6 @@ experiment:
in_channels: 15
num_classes: 11
datamodule:
+ root_dir: "tests/data/sen12ms"
batch_size: 32
- num_workers: 4
+ num_workers: 0
diff --git a/conf/task_defaults/so2sat.yaml b/conf/task_defaults/so2sat.yaml
index ef72dc9de..be6a4e0f5 100644
--- a/conf/task_defaults/so2sat.yaml
+++ b/conf/task_defaults/so2sat.yaml
@@ -9,6 +9,7 @@ experiment:
in_channels: 3
num_classes: 17
datamodule:
+ root_dir: "tests/data/so2sat"
batch_size: 128
- num_workers: 6
+ num_workers: 0
bands: "rgb"
diff --git a/conf/task_defaults/ucmerced.yaml b/conf/task_defaults/ucmerced.yaml
index 993e36759..2742fa329 100644
--- a/conf/task_defaults/ucmerced.yaml
+++ b/conf/task_defaults/ucmerced.yaml
@@ -3,14 +3,12 @@ experiment:
module:
loss: "ce"
classification_model: "resnet18"
- weights: null
+ weights: "random"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 21
datamodule:
+ root_dir: "tests/data/ucmerced"
batch_size: 128
- num_workers: 6
- unsupervised_mode: false
- val_split_pct: 0.1
- test_split_pct: 0.1
+ num_workers: 0
diff --git a/docs/tutorials/benchmarking.ipynb b/docs/tutorials/benchmarking.ipynb
index d3bf3aa32..20befc7c1 100644
--- a/docs/tutorials/benchmarking.ipynb
+++ b/docs/tutorials/benchmarking.ipynb
@@ -37,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
- "%pip install git+https://github.com/microsoft/torchgeo.git"
+ "%pip install torchgeo"
]
},
{
diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb
index 347a1f495..fc25832c4 100644
--- a/docs/tutorials/getting_started.ipynb
+++ b/docs/tutorials/getting_started.ipynb
@@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
- "%pip install git+https://github.com/microsoft/torchgeo.git"
+ "%pip install torchgeo"
]
},
{
diff --git a/docs/tutorials/indices.ipynb b/docs/tutorials/indices.ipynb
index 96d2a0c1b..fd710d394 100644
--- a/docs/tutorials/indices.ipynb
+++ b/docs/tutorials/indices.ipynb
@@ -23,7 +23,6 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
"version": "3.8.11"
}
},
@@ -86,7 +85,7 @@
"id": "wOwsb8KT_uXR"
},
"source": [
- "%pip install git+https://github.com/microsoft/torchgeo.git"
+ "%pip install torchgeo"
],
"execution_count": null,
"outputs": []
@@ -117,7 +116,6 @@
"import rasterio\n",
"import rasterio.features\n",
"import shapely\n",
- "import tifffile\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
@@ -194,9 +192,9 @@
" compress=\"DEFLATE\"\n",
" ))\n",
"\n",
- " # Write to geotiff\n",
- " with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
- " ds_windowed.write(ds.read(1, window=window), 1)\n",
+ " # Write to geotiff\n",
+ " with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
+ " ds_windowed.write(ds.read(1, window=window), 1)\n",
"\n",
"def download(root: str, url: str, bands: List[str], geometry: shapely.geometry.Polygon) -> None:\n",
" \"\"\"Extract windows from each band COG file in s3 and save locally.\"\"\"\n",
diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb
index 790c7495e..4359f0241 100644
--- a/docs/tutorials/trainers.ipynb
+++ b/docs/tutorials/trainers.ipynb
@@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
- "%pip install git+https://github.com/microsoft/torchgeo.git"
+ "%pip install torchgeo"
]
},
{
@@ -82,6 +82,18 @@
"from torchgeo.trainers import RegressionTask"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e8bc0d83",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
+ "# skip the expensive training.\n",
+ "in_tests = (\"PYTEST_CURRENT_TEST\" in os.environ)"
+ ]
+ },
{
"cell_type": "markdown",
"id": "e6e1d9b6",
@@ -226,12 +238,12 @@
],
"source": [
"trainer = pl.Trainer(\n",
- " gpus=1,\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[csv_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=1,\n",
- " max_epochs=10\n",
+ " max_epochs=10,\n",
+ " fast_dev_run=in_tests\n",
")"
]
},
@@ -469,25 +481,26 @@
"metadata": {},
"outputs": [],
"source": [
- "train_steps = []\n",
- "train_rmse = []\n",
+ "if not in_tests:\n",
+ " train_steps = []\n",
+ " train_rmse = []\n",
"\n",
- "val_steps = []\n",
- "val_rmse = []\n",
- "with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
- " csv_reader = csv.DictReader(f, delimiter=',')\n",
- " for i, row in enumerate(csv_reader):\n",
- " try:\n",
- " train_rmse.append(float(row[\"train_rmse\"]))\n",
- " train_steps.append(i)\n",
- " except ValueError: # Ignore rows where train RMSE is empty\n",
- " pass\n",
- " \n",
- " try:\n",
- " val_rmse.append(float(row[\"val_rmse\"]))\n",
- " val_steps.append(i)\n",
- " except ValueError: # Ignore rows where val RMSE is empty\n",
- " pass"
+ " val_steps = []\n",
+ " val_rmse = []\n",
+ " with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
+ " csv_reader = csv.DictReader(f, delimiter=',')\n",
+ " for i, row in enumerate(csv_reader):\n",
+ " try:\n",
+ " train_rmse.append(float(row[\"train_rmse\"]))\n",
+ " train_steps.append(i)\n",
+ " except ValueError: # Ignore rows where train RMSE is empty\n",
+ " pass\n",
+ "\n",
+ " try:\n",
+ " val_rmse.append(float(row[\"val_rmse\"]))\n",
+ " val_steps.append(i)\n",
+ " except ValueError: # Ignore rows where val RMSE is empty\n",
+ " pass"
]
},
{
@@ -510,14 +523,15 @@
}
],
"source": [
- "plt.figure()\n",
- "plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
- "plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
- "plt.legend(fontsize=15)\n",
- "plt.xlabel(\"Batches\", fontsize=15)\n",
- "plt.ylabel(\"RMSE\", fontsize=15)\n",
- "plt.show()\n",
- "plt.close()"
+ "if not in_tests:\n",
+ " plt.figure()\n",
+ " plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
+ " plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
+ " plt.legend(fontsize=15)\n",
+ " plt.xlabel(\"Batches\", fontsize=15)\n",
+ " plt.ylabel(\"RMSE\", fontsize=15)\n",
+ " plt.show()\n",
+ " plt.close()"
]
},
{
diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb
index e7685a3d8..7e347c02d 100644
--- a/docs/tutorials/transforms.ipynb
+++ b/docs/tutorials/transforms.ipynb
@@ -23,7 +23,6 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
"version": "3.8.11"
},
"accelerator": "GPU"
@@ -84,7 +83,7 @@
"id": "wOwsb8KT_uXR"
},
"source": [
- "%pip install git+https://github.com/microsoft/torchgeo.git"
+ "%pip install torchgeo"
],
"execution_count": null,
"outputs": []
@@ -538,7 +537,7 @@
"outputId": "fa0443da-8b4d-47f7-e713-93e1e4976e87"
},
"source": [
- "!nvidia-smi"
+ "%nvidia-smi"
],
"execution_count": 8,
"outputs": [
@@ -724,4 +723,4 @@
"outputs": []
}
]
-}
\ No newline at end of file
+}
diff --git a/docs/user/installation.rst b/docs/user/installation.rst
index 5c89a1a28..6b23ba7d7 100644
--- a/docs/user/installation.rst
+++ b/docs/user/installation.rst
@@ -1,17 +1,16 @@
Installation
============
-TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers, although you can also install from source if you want to.
+TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers.
pip
---
-..
- Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
+Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
- .. code-block:: console
+.. code-block:: console
- $ pip install torchgeo
+ $ pip install torchgeo
If you want to install a development version, you can use a VCS project URL:
@@ -32,16 +31,10 @@ or a local git checkout:
By default, only required dependencies are installed. TorchGeo has a number of optional dependencies for specific datasets or development. These can be installed with a comma-separated list:
-..
- .. code-block:: console
-
- $ pip install torchgeo[datasets]
- $ pip install torchgeo[style,tests]
-
.. code-block:: console
- $ pip install .[datasets]
- $ pip install .[style,tests]
+ $ pip install torchgeo[datasets]
+ $ pip install torchgeo[style,tests]
See the ``setup.cfg`` for a complete list of options. See the `pip documentation `_ for more details.
@@ -57,12 +50,11 @@ If you need to install non-Python dependencies like PyTorch, it's better to use
$ conda config --set channel_priority strict
-..
- Now, you can install the latest stable release using:
+Now, you can install the latest stable release using:
- .. code-block:: console
+.. code-block:: console
- $ conda install torchgeo
+ $ conda install torchgeo
Conda does not directly support installing development versions, but you can use conda to install our dependencies, then use pip to install TorchGeo itself.
@@ -110,15 +102,3 @@ Optional dependencies can be installed by enabling build variants:
Run ``spack info py-torchgeo`` for a complete list of variants.
See the `spack documentation `_ for more details.
-
-source
-------
-
-TorchGeo can also be installed from source using the ``setup.py`` file and setuptools.
-
-.. code-block:: console
-
- $ git clone https://github.com/microsoft/torchgeo.git
- $ cd torchgeo
- $ python setup.py build
- $ python setup.py install
diff --git a/environment.yml b/environment.yml
index 263a748a2..aef403329 100644
--- a/environment.yml
+++ b/environment.yml
@@ -39,7 +39,7 @@ dependencies:
- scikit-learn>=0.18
- scipy>=0.9
- segmentation-models-pytorch>=0.2
- - setuptools>=30.4
+ - setuptools>=42
- sphinx>=3
- timm>=0.2.1
- torchmetrics
diff --git a/pyproject.toml b/pyproject.toml
index 1dd8029e0..f98f5c8ad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[build-system]
requires = [
- "setuptools>=30.4",
+ "setuptools>=42",
"wheel",
]
build-backend = "setuptools.build_meta"
diff --git a/setup.cfg b/setup.cfg
index 715e19667..1636a342d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,14 +2,15 @@
[metadata]
name = torchgeo
version = attr: torchgeo.__version__
-author = attr: torchgeo.__author__
+author = Adam J. Stewart
author_email = ajstewart426@gmail.com
description = TorchGeo: datasets, transforms, and models for geospatial data
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/microsoft/torchgeo
+license_files = LICENSE
classifiers =
- Development Status :: 1 - Planning
+ Development Status :: 3 - Alpha
Intended Audience :: Science/Research
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
@@ -20,12 +21,12 @@ classifiers =
Operating System :: OS Independent
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: GIS
-keywords = pytorch, deep learning, machine learning
+keywords = pytorch, deep learning, machine learning, remote sensing, satellite imagery, geospatial
[options]
setup_requires =
- # setuptools 30.4+ required for options.packages.find section in setup.cfg
- setuptools>=30.4
+ # setuptools 42+ required for metadata.license_files support in setup.cfg
+ setuptools>=42
install_requires =
einops
# fiona 1.5+ required for fiona.transform module
diff --git a/tests/test_setup.py b/tests/test_setup.py
deleted file mode 100644
index ed104721f..000000000
--- a/tests/test_setup.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-import subprocess
-import sys
-from pathlib import Path
-
-import pytest
-
-pytestmark = pytest.mark.slow
-
-
-def test_install(tmp_path: Path) -> None:
- subprocess.run(
- [sys.executable, "setup.py", "build", "--build-base", str(tmp_path)], check=True
- )
diff --git a/tests/test_train.py b/tests/test_train.py
index ee7b2f6bb..786dd6ad9 100644
--- a/tests/test_train.py
+++ b/tests/test_train.py
@@ -134,20 +134,34 @@ trainer:
subprocess.run(args, check=True)
-@pytest.mark.parametrize("task", ["cowc_counting", "cyclone", "sen12ms", "landcoverai"])
+@pytest.mark.parametrize(
+ "task",
+ [
+ "bigearthnet",
+ "byol",
+ "chesapeake_cvpr",
+ "cowc_counting",
+ "cyclone",
+ "landcoverai",
+ "naipchesapeake",
+ "resisc45",
+ "sen12ms",
+ "so2sat",
+ "ucmerced",
+ ],
+)
def test_tasks(task: str, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
- data_dir = os.path.join("tests", "data", task)
log_dir = tmp_path / "logs"
args = [
sys.executable,
"train.py",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
- "program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"trainer.fast_dev_run=1",
"experiment.task=" + task,
"program.overwrite=True",
+ "config_file=" + os.path.join("conf", "task_defaults", task + ".yaml"),
]
subprocess.run(args, check=True)
diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py
index ed8d4448a..b14d5b466 100644
--- a/tests/trainers/test_landcoverai.py
+++ b/tests/trainers/test_landcoverai.py
@@ -58,3 +58,10 @@ class TestLandCoverAISegmentationTask:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)
+
+ def test_test(
+ self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask
+ ) -> None:
+ batch = next(iter(datamodule.test_dataloader()))
+ task.test_step(batch, 0)
+ task.test_epoch_end(0)
diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py
index dd5c053b5..5da98111a 100644
--- a/torchgeo/__init__.py
+++ b/torchgeo/__init__.py
@@ -11,4 +11,4 @@ common image transformations for geospatial data.
"""
__author__ = "Adam J. Stewart"
-__version__ = "0.1.0.dev0"
+__version__ = "0.1.0"
diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py
index fc09ee71c..1f629ad8e 100644
--- a/torchgeo/datasets/bigearthnet.py
+++ b/torchgeo/datasets/bigearthnet.py
@@ -582,7 +582,7 @@ class BigEarthNetDataModule(pl.LightningDataModule):
bands: str = "all",
num_classes: int = 19,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py
index bd96a739f..79f28bd4a 100644
--- a/torchgeo/datasets/chesapeake.py
+++ b/torchgeo/datasets/chesapeake.py
@@ -553,7 +553,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
patches_per_tile: int = 200,
patch_size: int = 256,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
class_set: int = 7,
**kwargs: Any,
) -> None:
@@ -711,7 +711,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
splits=self.train_splits,
layers=self.layers,
transforms=None,
- download=True,
+ download=False,
checksum=False,
)
diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py
index de993861b..35bbdc54b 100644
--- a/torchgeo/datasets/cowc.py
+++ b/torchgeo/datasets/cowc.py
@@ -278,7 +278,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
root_dir: str,
seed: int,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
@@ -314,7 +314,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
- COWCCounting(self.root_dir, download=True)
+ COWCCounting(self.root_dir, download=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py
index 357b314b3..487677e83 100644
--- a/torchgeo/datasets/cyclone.py
+++ b/torchgeo/datasets/cyclone.py
@@ -227,7 +227,7 @@ class CycloneDataModule(pl.LightningDataModule):
root_dir: str,
seed: int,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py
index 2e9b58e77..771e87f5c 100644
--- a/torchgeo/datasets/landcoverai.py
+++ b/torchgeo/datasets/landcoverai.py
@@ -215,7 +215,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
"""
def __init__(
- self, root_dir: str, batch_size: int = 64, num_workers: int = 4, **kwargs: Any
+ self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
@@ -250,7 +250,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
- _ = LandCoverAI(self.root_dir, download=True, checksum=False)
+ _ = LandCoverAI(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py
index 44bdac951..ef299988a 100644
--- a/torchgeo/datasets/naip.py
+++ b/torchgeo/datasets/naip.py
@@ -64,7 +64,6 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
"""
# TODO: tune these hyperparams
- patch_size = 256
length = 1000
stride = 128
@@ -73,7 +72,8 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
naip_root_dir: str,
chesapeake_root_dir: str,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
+ patch_size: int = 256,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
@@ -83,12 +83,14 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
chesapeake_root_dir: directory containing Chesapeake data
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
+ patch_size: size of patches to sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.naip_root_dir = naip_root_dir
self.chesapeake_root_dir = chesapeake_root_dir
self.batch_size = batch_size
self.num_workers = num_workers
+ self.patch_size = patch_size
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.
@@ -120,7 +122,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
- Chesapeake13(self.chesapeake_root_dir, download=True, checksum=False)
+ Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py
index 5b6cc1cce..5a68715a8 100644
--- a/torchgeo/datasets/resisc45.py
+++ b/torchgeo/datasets/resisc45.py
@@ -219,7 +219,7 @@ class RESISC45DataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py
index 5dfdeed48..f1cf8b2ad 100644
--- a/torchgeo/datasets/sen12ms.py
+++ b/torchgeo/datasets/sen12ms.py
@@ -289,7 +289,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
seed: int,
band_set: str = "all",
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py
index 46b907c20..d0e03cef2 100644
--- a/torchgeo/datasets/so2sat.py
+++ b/torchgeo/datasets/so2sat.py
@@ -243,7 +243,7 @@ class So2SatDataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
bands: str = "rgb",
unsupervised_mode: bool = False,
**kwargs: Any,
diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py
index 853e3df20..5a07c24b8 100644
--- a/torchgeo/datasets/ucmerced.py
+++ b/torchgeo/datasets/ucmerced.py
@@ -225,7 +225,7 @@ class UCMercedDataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
- num_workers: int = 4,
+ num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
@@ -266,7 +266,7 @@ class UCMercedDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
- UCMerced(self.root_dir, download=True, checksum=False)
+ UCMerced(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py
index 23690cd29..fb3565a78 100644
--- a/torchgeo/trainers/byol.py
+++ b/torchgeo/trainers/byol.py
@@ -451,3 +451,6 @@ class BYOLTask(LightningModule):
)
self.log("val_loss", loss, on_step=False, on_epoch=True)
+
+ def test_step(self, *args: Any) -> None: # type: ignore[override]
+ """No-op, does nothing."""
diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py
index c0b447b74..f5494cabb 100644
--- a/torchgeo/trainers/classification.py
+++ b/torchgeo/trainers/classification.py
@@ -34,7 +34,7 @@ class ClassificationTask(pl.LightningModule):
imagenet_pretrained = False
custom_pretrained = False
- if not os.path.exists(self.hparams["weights"]):
+ if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] == "imagenet":
imagenet_pretrained = True
elif self.hparams["weights"] == "random":
diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py
index 85d4fd813..c95f03135 100644
--- a/torchgeo/trainers/landcoverai.py
+++ b/torchgeo/trainers/landcoverai.py
@@ -109,3 +109,23 @@ class LandCoverAISegmentationTask(SemanticSegmentationTask):
)
plt.close()
+
+ def test_step( # type: ignore[override]
+ self, batch: Dict[str, Any], batch_idx: int
+ ) -> None:
+ """Test step identical to the validation step.
+
+ Args:
+ batch: Current batch
+ batch_idx: Index of current batch
+ """
+ x = batch["image"]
+ y = batch["mask"].long().squeeze()
+ y_hat = self.forward(x)
+ y_hat_hard = y_hat.argmax(dim=1)
+
+ loss = self.loss(y_hat, y)
+
+ # by default, the test and validation steps only log per *epoch*
+ self.log("test_loss", loss, on_step=False, on_epoch=True)
+ self.test_metrics(y_hat_hard, y)
diff --git a/torchgeo/trainers/so2sat.py b/torchgeo/trainers/so2sat.py
index 3bda62acb..6de070851 100644
--- a/torchgeo/trainers/so2sat.py
+++ b/torchgeo/trainers/so2sat.py
@@ -27,7 +27,7 @@ class So2SatClassificationTask(ClassificationTask):
in_channels = self.hparams["in_channels"]
pretrained = False
- if not os.path.exists(self.hparams["weights"]):
+ if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] == "imagenet":
pretrained = True
elif self.hparams["weights"] == "random":