diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a3ed9b83c..a322036de 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -17,9 +17,19 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.9 + - name: Install apt dependencies + run: | + sudo apt-add-repository ppa:ubuntugis/ubuntugis-unstable + sudo apt-get update + sudo apt-get install gdal-bin libgdal-dev - name: Install pip dependencies - run: pip install .[tests] + run: | + pip install gdal tqdm # TODO: these deps shouldn't be needed + pip install .[datasets,tests,train] + pip install -r docs/requirements.txt - name: Run notebook checks + env: + MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }} run: pytest --nbmake docs/tutorials integration: name: integration @@ -32,6 +42,6 @@ jobs: with: python-version: 3.9 - name: Install pip dependencies - run: pip install .[tests] + run: pip install .[datasets,tests,train] - name: Run integration checks run: pytest -m slow diff --git a/README.md b/README.md index 5d6950706..80adc5ea8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -TorchGeo +TorchGeo TorchGeo is a [PyTorch](https://pytorch.org/) domain library, similar to [torchvision](https://pytorch.org/vision), that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data. @@ -23,7 +23,7 @@ Tests: The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/): ```console -$ pip install git+https://github.com/microsoft/torchgeo.git +$ pip install torchgeo ``` For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/latest/user/installation.html). diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet.yaml index 745fb41a5..c5c352861 100644 --- a/conf/task_defaults/bigearthnet.yaml +++ b/conf/task_defaults/bigearthnet.yaml @@ -9,8 +9,8 @@ experiment: in_channels: 14 num_classes: 19 datamodule: - num_classes: 19 - batch_size: 128 - num_workers: 6 + root_dir: "tests/data/bigearthnet" bands: "all" - num_classes: 19 + num_classes: ${experiment.module.num_classes} + batch_size: 128 + num_workers: 0 diff --git a/conf/task_defaults/byol.yaml b/conf/task_defaults/byol.yaml index 2a4aa84bd..95066aa28 100644 --- a/conf/task_defaults/byol.yaml +++ b/conf/task_defaults/byol.yaml @@ -9,11 +9,12 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 datamodule: - batch_size: 64 - num_workers: 6 + root_dir: "tests/data/chesapeake/cvpr" train_splits: - - "de-train" + - "de-test" val_splits: - - "de-val" + - "de-test" test_splits: - "de-test" + batch_size: 64 + num_workers: 0 diff --git a/conf/task_defaults/chesapeake_cvpr.yaml b/conf/task_defaults/chesapeake_cvpr.yaml index 1087c96f9..6cb0b7a90 100644 --- a/conf/task_defaults/chesapeake_cvpr.yaml +++ b/conf/task_defaults/chesapeake_cvpr.yaml @@ -12,9 +12,15 @@ experiment: num_classes: 7 num_filters: 256 datamodule: - train_state: "de" + root_dir: "tests/data/chesapeake/cvpr" + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" patches_per_tile: 200 patch_size: 256 batch_size: 64 - num_workers: 4 - num_classes: ${experiment.module.num_classes} + num_workers: 0 + class_set: ${experiment.module.num_classes} diff --git a/conf/task_defaults/cowc_counting.yaml b/conf/task_defaults/cowc_counting.yaml index b99d20728..a0d43a234 100644 --- a/conf/task_defaults/cowc_counting.yaml +++ b/conf/task_defaults/cowc_counting.yaml @@ -5,5 +5,6 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 2 datamodule: + root_dir: "tests/data/cowc_counting" batch_size: 32 - num_workers: 4 + num_workers: 0 diff --git a/conf/task_defaults/cyclone.yaml b/conf/task_defaults/cyclone.yaml index 1b50c8f25..c42cdb4c2 100644 --- a/conf/task_defaults/cyclone.yaml +++ b/conf/task_defaults/cyclone.yaml @@ -5,5 +5,6 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 2 datamodule: + root_dir: "tests/data/cyclone" batch_size: 32 - num_workers: 4 + num_workers: 0 diff --git a/conf/task_defaults/landcoverai.yaml b/conf/task_defaults/landcoverai.yaml index d3efd7bc6..d471df596 100644 --- a/conf/task_defaults/landcoverai.yaml +++ b/conf/task_defaults/landcoverai.yaml @@ -12,5 +12,6 @@ experiment: num_classes: 6 num_filters: 256 datamodule: + root_dir: "tests/data/landcoverai" batch_size: 32 - num_workers: 4 + num_workers: 0 diff --git a/conf/task_defaults/naipchesapeake.yaml b/conf/task_defaults/naipchesapeake.yaml index f793fd56d..1720a9bef 100644 --- a/conf/task_defaults/naipchesapeake.yaml +++ b/conf/task_defaults/naipchesapeake.yaml @@ -12,5 +12,8 @@ experiment: num_classes: 13 num_filters: 64 datamodule: + naip_root_dir: "tests/data/naip" + chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE" batch_size: 32 - num_workers: 4 + num_workers: 0 + patch_size: 32 diff --git a/conf/task_defaults/resisc45.yaml b/conf/task_defaults/resisc45.yaml index 1f2089f78..a657cc5d0 100644 --- a/conf/task_defaults/resisc45.yaml +++ b/conf/task_defaults/resisc45.yaml @@ -9,9 +9,6 @@ experiment: in_channels: 3 num_classes: 45 datamodule: + root_dir: "tests/data/resisc45" batch_size: 128 - num_workers: 6 - weights: ${experiment.module.weights} - unsupervised_mode: false - val_split_pct: 0.2 - test_split_pct: 0.2 + num_workers: 0 diff --git a/conf/task_defaults/sen12ms.yaml b/conf/task_defaults/sen12ms.yaml index bd4ceb029..c0a36ac0d 100644 --- a/conf/task_defaults/sen12ms.yaml +++ b/conf/task_defaults/sen12ms.yaml @@ -11,5 +11,6 @@ experiment: in_channels: 15 num_classes: 11 datamodule: + root_dir: "tests/data/sen12ms" batch_size: 32 - num_workers: 4 + num_workers: 0 diff --git a/conf/task_defaults/so2sat.yaml b/conf/task_defaults/so2sat.yaml index ef72dc9de..be6a4e0f5 100644 --- a/conf/task_defaults/so2sat.yaml +++ b/conf/task_defaults/so2sat.yaml @@ -9,6 +9,7 @@ experiment: in_channels: 3 num_classes: 17 datamodule: + root_dir: "tests/data/so2sat" batch_size: 128 - num_workers: 6 + num_workers: 0 bands: "rgb" diff --git a/conf/task_defaults/ucmerced.yaml b/conf/task_defaults/ucmerced.yaml index 993e36759..2742fa329 100644 --- a/conf/task_defaults/ucmerced.yaml +++ b/conf/task_defaults/ucmerced.yaml @@ -3,14 +3,12 @@ experiment: module: loss: "ce" classification_model: "resnet18" - weights: null + weights: "random" learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 num_classes: 21 datamodule: + root_dir: "tests/data/ucmerced" batch_size: 128 - num_workers: 6 - unsupervised_mode: false - val_split_pct: 0.1 - test_split_pct: 0.1 + num_workers: 0 diff --git a/docs/tutorials/benchmarking.ipynb b/docs/tutorials/benchmarking.ipynb index d3bf3aa32..20befc7c1 100644 --- a/docs/tutorials/benchmarking.ipynb +++ b/docs/tutorials/benchmarking.ipynb @@ -37,7 +37,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install git+https://github.com/microsoft/torchgeo.git" + "%pip install torchgeo" ] }, { diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb index 347a1f495..fc25832c4 100644 --- a/docs/tutorials/getting_started.ipynb +++ b/docs/tutorials/getting_started.ipynb @@ -43,7 +43,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install git+https://github.com/microsoft/torchgeo.git" + "%pip install torchgeo" ] }, { diff --git a/docs/tutorials/indices.ipynb b/docs/tutorials/indices.ipynb index 96d2a0c1b..fd710d394 100644 --- a/docs/tutorials/indices.ipynb +++ b/docs/tutorials/indices.ipynb @@ -23,7 +23,6 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", "version": "3.8.11" } }, @@ -86,7 +85,7 @@ "id": "wOwsb8KT_uXR" }, "source": [ - "%pip install git+https://github.com/microsoft/torchgeo.git" + "%pip install torchgeo" ], "execution_count": null, "outputs": [] @@ -117,7 +116,6 @@ "import rasterio\n", "import rasterio.features\n", "import shapely\n", - "import tifffile\n", "import torch\n", "import torch.nn as nn\n", "import torchvision.transforms as T\n", @@ -194,9 +192,9 @@ " compress=\"DEFLATE\"\n", " ))\n", "\n", - " # Write to geotiff\n", - " with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n", - " ds_windowed.write(ds.read(1, window=window), 1)\n", + " # Write to geotiff\n", + " with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n", + " ds_windowed.write(ds.read(1, window=window), 1)\n", "\n", "def download(root: str, url: str, bands: List[str], geometry: shapely.geometry.Polygon) -> None:\n", " \"\"\"Extract windows from each band COG file in s3 and save locally.\"\"\"\n", diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb index 790c7495e..4359f0241 100644 --- a/docs/tutorials/trainers.ipynb +++ b/docs/tutorials/trainers.ipynb @@ -43,7 +43,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install git+https://github.com/microsoft/torchgeo.git" + "%pip install torchgeo" ] }, { @@ -82,6 +82,18 @@ "from torchgeo.trainers import RegressionTask" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8bc0d83", + "metadata": {}, + "outputs": [], + "source": [ + "# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n", + "# skip the expensive training.\n", + "in_tests = (\"PYTEST_CURRENT_TEST\" in os.environ)" + ] + }, { "cell_type": "markdown", "id": "e6e1d9b6", @@ -226,12 +238,12 @@ ], "source": [ "trainer = pl.Trainer(\n", - " gpus=1,\n", " callbacks=[checkpoint_callback, early_stopping_callback],\n", " logger=[csv_logger],\n", " default_root_dir=experiment_dir,\n", " min_epochs=1,\n", - " max_epochs=10\n", + " max_epochs=10,\n", + " fast_dev_run=in_tests\n", ")" ] }, @@ -469,25 +481,26 @@ "metadata": {}, "outputs": [], "source": [ - "train_steps = []\n", - "train_rmse = []\n", + "if not in_tests:\n", + " train_steps = []\n", + " train_rmse = []\n", "\n", - "val_steps = []\n", - "val_rmse = []\n", - "with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n", - " csv_reader = csv.DictReader(f, delimiter=',')\n", - " for i, row in enumerate(csv_reader):\n", - " try:\n", - " train_rmse.append(float(row[\"train_rmse\"]))\n", - " train_steps.append(i)\n", - " except ValueError: # Ignore rows where train RMSE is empty\n", - " pass\n", - " \n", - " try:\n", - " val_rmse.append(float(row[\"val_rmse\"]))\n", - " val_steps.append(i)\n", - " except ValueError: # Ignore rows where val RMSE is empty\n", - " pass" + " val_steps = []\n", + " val_rmse = []\n", + " with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n", + " csv_reader = csv.DictReader(f, delimiter=',')\n", + " for i, row in enumerate(csv_reader):\n", + " try:\n", + " train_rmse.append(float(row[\"train_rmse\"]))\n", + " train_steps.append(i)\n", + " except ValueError: # Ignore rows where train RMSE is empty\n", + " pass\n", + "\n", + " try:\n", + " val_rmse.append(float(row[\"val_rmse\"]))\n", + " val_steps.append(i)\n", + " except ValueError: # Ignore rows where val RMSE is empty\n", + " pass" ] }, { @@ -510,14 +523,15 @@ } ], "source": [ - "plt.figure()\n", - "plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n", - "plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n", - "plt.legend(fontsize=15)\n", - "plt.xlabel(\"Batches\", fontsize=15)\n", - "plt.ylabel(\"RMSE\", fontsize=15)\n", - "plt.show()\n", - "plt.close()" + "if not in_tests:\n", + " plt.figure()\n", + " plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n", + " plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n", + " plt.legend(fontsize=15)\n", + " plt.xlabel(\"Batches\", fontsize=15)\n", + " plt.ylabel(\"RMSE\", fontsize=15)\n", + " plt.show()\n", + " plt.close()" ] }, { diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index e7685a3d8..7e347c02d 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -23,7 +23,6 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", "version": "3.8.11" }, "accelerator": "GPU" @@ -84,7 +83,7 @@ "id": "wOwsb8KT_uXR" }, "source": [ - "%pip install git+https://github.com/microsoft/torchgeo.git" + "%pip install torchgeo" ], "execution_count": null, "outputs": [] @@ -538,7 +537,7 @@ "outputId": "fa0443da-8b4d-47f7-e713-93e1e4976e87" }, "source": [ - "!nvidia-smi" + "%nvidia-smi" ], "execution_count": 8, "outputs": [ @@ -724,4 +723,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/docs/user/installation.rst b/docs/user/installation.rst index 5c89a1a28..6b23ba7d7 100644 --- a/docs/user/installation.rst +++ b/docs/user/installation.rst @@ -1,17 +1,16 @@ Installation ============ -TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers, although you can also install from source if you want to. +TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers. pip --- -.. - Since TorchGeo is written in pure-Python, the easiest way to install it is using pip: +Since TorchGeo is written in pure-Python, the easiest way to install it is using pip: - .. code-block:: console +.. code-block:: console - $ pip install torchgeo + $ pip install torchgeo If you want to install a development version, you can use a VCS project URL: @@ -32,16 +31,10 @@ or a local git checkout: By default, only required dependencies are installed. TorchGeo has a number of optional dependencies for specific datasets or development. These can be installed with a comma-separated list: -.. - .. code-block:: console - - $ pip install torchgeo[datasets] - $ pip install torchgeo[style,tests] - .. code-block:: console - $ pip install .[datasets] - $ pip install .[style,tests] + $ pip install torchgeo[datasets] + $ pip install torchgeo[style,tests] See the ``setup.cfg`` for a complete list of options. See the `pip documentation `_ for more details. @@ -57,12 +50,11 @@ If you need to install non-Python dependencies like PyTorch, it's better to use $ conda config --set channel_priority strict -.. - Now, you can install the latest stable release using: +Now, you can install the latest stable release using: - .. code-block:: console +.. code-block:: console - $ conda install torchgeo + $ conda install torchgeo Conda does not directly support installing development versions, but you can use conda to install our dependencies, then use pip to install TorchGeo itself. @@ -110,15 +102,3 @@ Optional dependencies can be installed by enabling build variants: Run ``spack info py-torchgeo`` for a complete list of variants. See the `spack documentation `_ for more details. - -source ------- - -TorchGeo can also be installed from source using the ``setup.py`` file and setuptools. - -.. code-block:: console - - $ git clone https://github.com/microsoft/torchgeo.git - $ cd torchgeo - $ python setup.py build - $ python setup.py install diff --git a/environment.yml b/environment.yml index 263a748a2..aef403329 100644 --- a/environment.yml +++ b/environment.yml @@ -39,7 +39,7 @@ dependencies: - scikit-learn>=0.18 - scipy>=0.9 - segmentation-models-pytorch>=0.2 - - setuptools>=30.4 + - setuptools>=42 - sphinx>=3 - timm>=0.2.1 - torchmetrics diff --git a/pyproject.toml b/pyproject.toml index 1dd8029e0..f98f5c8ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=30.4", + "setuptools>=42", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg index 715e19667..1636a342d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,14 +2,15 @@ [metadata] name = torchgeo version = attr: torchgeo.__version__ -author = attr: torchgeo.__author__ +author = Adam J. Stewart author_email = ajstewart426@gmail.com description = TorchGeo: datasets, transforms, and models for geospatial data long_description = file: README.md long_description_content_type = text/markdown url = https://github.com/microsoft/torchgeo +license_files = LICENSE classifiers = - Development Status :: 1 - Planning + Development Status :: 3 - Alpha Intended Audience :: Science/Research Programming Language :: Python :: 3 Programming Language :: Python :: 3.6 @@ -20,12 +21,12 @@ classifiers = Operating System :: OS Independent Topic :: Scientific/Engineering :: Artificial Intelligence Topic :: Scientific/Engineering :: GIS -keywords = pytorch, deep learning, machine learning +keywords = pytorch, deep learning, machine learning, remote sensing, satellite imagery, geospatial [options] setup_requires = - # setuptools 30.4+ required for options.packages.find section in setup.cfg - setuptools>=30.4 + # setuptools 42+ required for metadata.license_files support in setup.cfg + setuptools>=42 install_requires = einops # fiona 1.5+ required for fiona.transform module diff --git a/tests/test_setup.py b/tests/test_setup.py deleted file mode 100644 index ed104721f..000000000 --- a/tests/test_setup.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import subprocess -import sys -from pathlib import Path - -import pytest - -pytestmark = pytest.mark.slow - - -def test_install(tmp_path: Path) -> None: - subprocess.run( - [sys.executable, "setup.py", "build", "--build-base", str(tmp_path)], check=True - ) diff --git a/tests/test_train.py b/tests/test_train.py index ee7b2f6bb..786dd6ad9 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -134,20 +134,34 @@ trainer: subprocess.run(args, check=True) -@pytest.mark.parametrize("task", ["cowc_counting", "cyclone", "sen12ms", "landcoverai"]) +@pytest.mark.parametrize( + "task", + [ + "bigearthnet", + "byol", + "chesapeake_cvpr", + "cowc_counting", + "cyclone", + "landcoverai", + "naipchesapeake", + "resisc45", + "sen12ms", + "so2sat", + "ucmerced", + ], +) def test_tasks(task: str, tmp_path: Path) -> None: output_dir = tmp_path / "output" - data_dir = os.path.join("tests", "data", task) log_dir = tmp_path / "logs" args = [ sys.executable, "train.py", "experiment.name=test", "program.output_dir=" + str(output_dir), - "program.data_dir=" + data_dir, "program.log_dir=" + str(log_dir), "trainer.fast_dev_run=1", "experiment.task=" + task, "program.overwrite=True", + "config_file=" + os.path.join("conf", "task_defaults", task + ".yaml"), ] subprocess.run(args, check=True) diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py index ed8d4448a..b14d5b466 100644 --- a/tests/trainers/test_landcoverai.py +++ b/tests/trainers/test_landcoverai.py @@ -58,3 +58,10 @@ class TestLandCoverAISegmentationTask: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) + + def test_test( + self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index dd5c053b5..5da98111a 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -11,4 +11,4 @@ common image transformations for geospatial data. """ __author__ = "Adam J. Stewart" -__version__ = "0.1.0.dev0" +__version__ = "0.1.0" diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index fc09ee71c..1f629ad8e 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -582,7 +582,7 @@ class BigEarthNetDataModule(pl.LightningDataModule): bands: str = "all", num_classes: int = 19, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for BigEarthNet based DataLoaders. diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index bd96a739f..79f28bd4a 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -553,7 +553,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule): patches_per_tile: int = 200, patch_size: int = 256, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, class_set: int = 7, **kwargs: Any, ) -> None: @@ -711,7 +711,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule): splits=self.train_splits, layers=self.layers, transforms=None, - download=True, + download=False, checksum=False, ) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index de993861b..35bbdc54b 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -278,7 +278,7 @@ class COWCCountingDataModule(pl.LightningDataModule): root_dir: str, seed: int, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for COWC Counting based DataLoaders. @@ -314,7 +314,7 @@ class COWCCountingDataModule(pl.LightningDataModule): This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. """ - COWCCounting(self.root_dir, download=True) + COWCCounting(self.root_dir, download=False) def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 357b314b3..487677e83 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -227,7 +227,7 @@ class CycloneDataModule(pl.LightningDataModule): root_dir: str, seed: int, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, api_key: Optional[str] = None, **kwargs: Any, ) -> None: diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 2e9b58e77..771e87f5c 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -215,7 +215,7 @@ class LandCoverAIDataModule(pl.LightningDataModule): """ def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 4, **kwargs: Any + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: """Initialize a LightningDataModule for LandCover.ai based DataLoaders. @@ -250,7 +250,7 @@ class LandCoverAIDataModule(pl.LightningDataModule): This method is only called once per run. """ - _ = LandCoverAI(self.root_dir, download=True, checksum=False) + _ = LandCoverAI(self.root_dir, download=False, checksum=False) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 44bdac951..ef299988a 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -64,7 +64,6 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): """ # TODO: tune these hyperparams - patch_size = 256 length = 1000 stride = 128 @@ -73,7 +72,8 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): naip_root_dir: str, chesapeake_root_dir: str, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, + patch_size: int = 256, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. @@ -83,12 +83,14 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): chesapeake_root_dir: directory containing Chesapeake data batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders + patch_size: size of patches to sample """ super().__init__() # type: ignore[no-untyped-call] self.naip_root_dir = naip_root_dir self.chesapeake_root_dir = chesapeake_root_dir self.batch_size = batch_size self.num_workers = num_workers + self.patch_size = patch_size def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the NAIP Dataset. @@ -120,7 +122,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): This method is only called once per run. """ - Chesapeake13(self.chesapeake_root_dir, download=True, checksum=False) + Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 5b6cc1cce..5a68715a8 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -219,7 +219,7 @@ class RESISC45DataModule(pl.LightningDataModule): self, root_dir: str, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for RESISC45 based DataLoaders. diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 5dfdeed48..f1cf8b2ad 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -289,7 +289,7 @@ class SEN12MSDataModule(pl.LightningDataModule): seed: int, band_set: str = "all", batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for SEN12MS based DataLoaders. diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 46b907c20..d0e03cef2 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -243,7 +243,7 @@ class So2SatDataModule(pl.LightningDataModule): self, root_dir: str, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, bands: str = "rgb", unsupervised_mode: bool = False, **kwargs: Any, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 853e3df20..5a07c24b8 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -225,7 +225,7 @@ class UCMercedDataModule(pl.LightningDataModule): self, root_dir: str, batch_size: int = 64, - num_workers: int = 4, + num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for UCMerced based DataLoaders. @@ -266,7 +266,7 @@ class UCMercedDataModule(pl.LightningDataModule): This method is only called once per run. """ - UCMerced(self.root_dir, download=True, checksum=False) + UCMerced(self.root_dir, download=False, checksum=False) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 23690cd29..fb3565a78 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -451,3 +451,6 @@ class BYOLTask(LightningModule): ) self.log("val_loss", loss, on_step=False, on_epoch=True) + + def test_step(self, *args: Any) -> None: # type: ignore[override] + """No-op, does nothing.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index c0b447b74..f5494cabb 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -34,7 +34,7 @@ class ClassificationTask(pl.LightningModule): imagenet_pretrained = False custom_pretrained = False - if not os.path.exists(self.hparams["weights"]): + if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]): if self.hparams["weights"] == "imagenet": imagenet_pretrained = True elif self.hparams["weights"] == "random": diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py index 85d4fd813..c95f03135 100644 --- a/torchgeo/trainers/landcoverai.py +++ b/torchgeo/trainers/landcoverai.py @@ -109,3 +109,23 @@ class LandCoverAISegmentationTask(SemanticSegmentationTask): ) plt.close() + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step identical to the validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"].long().squeeze() + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) diff --git a/torchgeo/trainers/so2sat.py b/torchgeo/trainers/so2sat.py index 3bda62acb..6de070851 100644 --- a/torchgeo/trainers/so2sat.py +++ b/torchgeo/trainers/so2sat.py @@ -27,7 +27,7 @@ class So2SatClassificationTask(ClassificationTask): in_channels = self.hparams["in_channels"] pretrained = False - if not os.path.exists(self.hparams["weights"]): + if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]): if self.hparams["weights"] == "imagenet": pretrained = True elif self.hparams["weights"] == "random":