зеркало из https://github.com/microsoft/torchgeo.git
0.1.0 release (#226)
* 0.1.0 release * Train deps needed for release testing * Update development status * setup.py should not be run directly * Test more trainers * Fix local docs build * Update installation instructions * Specify test data dir in config * Fix tutorial docs * Trainers should default to num_workers=0, download=False * Correct location for root_dir * Try different GDAL name * Try again * Various fixes to release tests * Update pip installs in tutorials * Fix some bugs * Config file not being picked up * Get back to 100% test coverage * Added correct weight string to UCMerced * yolo fix * yolo fix pt 2 * yolo fix 2 pt. 1 * Simplify tests a bit * Make the trainer notebook look stupid * UCMerced should download by default in the trainers * Revert * Fix logo/author, include LICENSE in upload Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
e5bbc738a3
Коммит
740d4f87a3
|
@ -17,9 +17,19 @@ jobs:
|
|||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-add-repository ppa:ubuntugis/ubuntugis-unstable
|
||||
sudo apt-get update
|
||||
sudo apt-get install gdal-bin libgdal-dev
|
||||
- name: Install pip dependencies
|
||||
run: pip install .[tests]
|
||||
run: |
|
||||
pip install gdal tqdm # TODO: these deps shouldn't be needed
|
||||
pip install .[datasets,tests,train]
|
||||
pip install -r docs/requirements.txt
|
||||
- name: Run notebook checks
|
||||
env:
|
||||
MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }}
|
||||
run: pytest --nbmake docs/tutorials
|
||||
integration:
|
||||
name: integration
|
||||
|
@ -32,6 +42,6 @@ jobs:
|
|||
with:
|
||||
python-version: 3.9
|
||||
- name: Install pip dependencies
|
||||
run: pip install .[tests]
|
||||
run: pip install .[datasets,tests,train]
|
||||
- name: Run integration checks
|
||||
run: pytest -m slow
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
<img src="logo/logo-color.svg" width="400" alt="TorchGeo"/>
|
||||
<img src="https://raw.githubusercontent.com/microsoft/torchgeo/main/logo/logo-color.svg" width="400" alt="TorchGeo"/>
|
||||
|
||||
TorchGeo is a [PyTorch](https://pytorch.org/) domain library, similar to [torchvision](https://pytorch.org/vision), that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.
|
||||
|
||||
|
@ -23,7 +23,7 @@ Tests:
|
|||
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
|
||||
|
||||
```console
|
||||
$ pip install git+https://github.com/microsoft/torchgeo.git
|
||||
$ pip install torchgeo
|
||||
```
|
||||
|
||||
For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/latest/user/installation.html).
|
||||
|
|
|
@ -9,8 +9,8 @@ experiment:
|
|||
in_channels: 14
|
||||
num_classes: 19
|
||||
datamodule:
|
||||
num_classes: 19
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
root_dir: "tests/data/bigearthnet"
|
||||
bands: "all"
|
||||
num_classes: 19
|
||||
num_classes: ${experiment.module.num_classes}
|
||||
batch_size: 128
|
||||
num_workers: 0
|
||||
|
|
|
@ -9,11 +9,12 @@ experiment:
|
|||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
||||
root_dir: "tests/data/chesapeake/cvpr"
|
||||
train_splits:
|
||||
- "de-train"
|
||||
- "de-test"
|
||||
val_splits:
|
||||
- "de-val"
|
||||
- "de-test"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
batch_size: 64
|
||||
num_workers: 0
|
||||
|
|
|
@ -12,9 +12,15 @@ experiment:
|
|||
num_classes: 7
|
||||
num_filters: 256
|
||||
datamodule:
|
||||
train_state: "de"
|
||||
root_dir: "tests/data/chesapeake/cvpr"
|
||||
train_splits:
|
||||
- "de-test"
|
||||
val_splits:
|
||||
- "de-test"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
patches_per_tile: 200
|
||||
patch_size: 256
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
num_classes: ${experiment.module.num_classes}
|
||||
num_workers: 0
|
||||
class_set: ${experiment.module.num_classes}
|
||||
|
|
|
@ -5,5 +5,6 @@ experiment:
|
|||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
root_dir: "tests/data/cowc_counting"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
|
|
|
@ -5,5 +5,6 @@ experiment:
|
|||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
root_dir: "tests/data/cyclone"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
|
|
|
@ -12,5 +12,6 @@ experiment:
|
|||
num_classes: 6
|
||||
num_filters: 256
|
||||
datamodule:
|
||||
root_dir: "tests/data/landcoverai"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
|
|
|
@ -12,5 +12,8 @@ experiment:
|
|||
num_classes: 13
|
||||
num_filters: 64
|
||||
datamodule:
|
||||
naip_root_dir: "tests/data/naip"
|
||||
chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
patch_size: 32
|
||||
|
|
|
@ -9,9 +9,6 @@ experiment:
|
|||
in_channels: 3
|
||||
num_classes: 45
|
||||
datamodule:
|
||||
root_dir: "tests/data/resisc45"
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
weights: ${experiment.module.weights}
|
||||
unsupervised_mode: false
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
||||
num_workers: 0
|
||||
|
|
|
@ -11,5 +11,6 @@ experiment:
|
|||
in_channels: 15
|
||||
num_classes: 11
|
||||
datamodule:
|
||||
root_dir: "tests/data/sen12ms"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
|
|
|
@ -9,6 +9,7 @@ experiment:
|
|||
in_channels: 3
|
||||
num_classes: 17
|
||||
datamodule:
|
||||
root_dir: "tests/data/so2sat"
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
num_workers: 0
|
||||
bands: "rgb"
|
||||
|
|
|
@ -3,14 +3,12 @@ experiment:
|
|||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
weights: null
|
||||
weights: "random"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 3
|
||||
num_classes: 21
|
||||
datamodule:
|
||||
root_dir: "tests/data/ucmerced"
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
unsupervised_mode: false
|
||||
val_split_pct: 0.1
|
||||
test_split_pct: 0.1
|
||||
num_workers: 0
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install git+https://github.com/microsoft/torchgeo.git"
|
||||
"%pip install torchgeo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install git+https://github.com/microsoft/torchgeo.git"
|
||||
"%pip install torchgeo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.11"
|
||||
}
|
||||
},
|
||||
|
@ -86,7 +85,7 @@
|
|||
"id": "wOwsb8KT_uXR"
|
||||
},
|
||||
"source": [
|
||||
"%pip install git+https://github.com/microsoft/torchgeo.git"
|
||||
"%pip install torchgeo"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
|
@ -117,7 +116,6 @@
|
|||
"import rasterio\n",
|
||||
"import rasterio.features\n",
|
||||
"import shapely\n",
|
||||
"import tifffile\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torchvision.transforms as T\n",
|
||||
|
@ -194,9 +192,9 @@
|
|||
" compress=\"DEFLATE\"\n",
|
||||
" ))\n",
|
||||
"\n",
|
||||
" # Write to geotiff\n",
|
||||
" with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
|
||||
" ds_windowed.write(ds.read(1, window=window), 1)\n",
|
||||
" # Write to geotiff\n",
|
||||
" with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
|
||||
" ds_windowed.write(ds.read(1, window=window), 1)\n",
|
||||
"\n",
|
||||
"def download(root: str, url: str, bands: List[str], geometry: shapely.geometry.Polygon) -> None:\n",
|
||||
" \"\"\"Extract windows from each band COG file in s3 and save locally.\"\"\"\n",
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install git+https://github.com/microsoft/torchgeo.git"
|
||||
"%pip install torchgeo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -82,6 +82,18 @@
|
|||
"from torchgeo.trainers import RegressionTask"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8bc0d83",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
|
||||
"# skip the expensive training.\n",
|
||||
"in_tests = (\"PYTEST_CURRENT_TEST\" in os.environ)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6e1d9b6",
|
||||
|
@ -226,12 +238,12 @@
|
|||
],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(\n",
|
||||
" gpus=1,\n",
|
||||
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
|
||||
" logger=[csv_logger],\n",
|
||||
" default_root_dir=experiment_dir,\n",
|
||||
" min_epochs=1,\n",
|
||||
" max_epochs=10\n",
|
||||
" max_epochs=10,\n",
|
||||
" fast_dev_run=in_tests\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -469,25 +481,26 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_steps = []\n",
|
||||
"train_rmse = []\n",
|
||||
"if not in_tests:\n",
|
||||
" train_steps = []\n",
|
||||
" train_rmse = []\n",
|
||||
"\n",
|
||||
"val_steps = []\n",
|
||||
"val_rmse = []\n",
|
||||
"with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
|
||||
" csv_reader = csv.DictReader(f, delimiter=',')\n",
|
||||
" for i, row in enumerate(csv_reader):\n",
|
||||
" try:\n",
|
||||
" train_rmse.append(float(row[\"train_rmse\"]))\n",
|
||||
" train_steps.append(i)\n",
|
||||
" except ValueError: # Ignore rows where train RMSE is empty\n",
|
||||
" pass\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" val_rmse.append(float(row[\"val_rmse\"]))\n",
|
||||
" val_steps.append(i)\n",
|
||||
" except ValueError: # Ignore rows where val RMSE is empty\n",
|
||||
" pass"
|
||||
" val_steps = []\n",
|
||||
" val_rmse = []\n",
|
||||
" with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
|
||||
" csv_reader = csv.DictReader(f, delimiter=',')\n",
|
||||
" for i, row in enumerate(csv_reader):\n",
|
||||
" try:\n",
|
||||
" train_rmse.append(float(row[\"train_rmse\"]))\n",
|
||||
" train_steps.append(i)\n",
|
||||
" except ValueError: # Ignore rows where train RMSE is empty\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" val_rmse.append(float(row[\"val_rmse\"]))\n",
|
||||
" val_steps.append(i)\n",
|
||||
" except ValueError: # Ignore rows where val RMSE is empty\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -510,14 +523,15 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"plt.figure()\n",
|
||||
"plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
|
||||
"plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
|
||||
"plt.legend(fontsize=15)\n",
|
||||
"plt.xlabel(\"Batches\", fontsize=15)\n",
|
||||
"plt.ylabel(\"RMSE\", fontsize=15)\n",
|
||||
"plt.show()\n",
|
||||
"plt.close()"
|
||||
"if not in_tests:\n",
|
||||
" plt.figure()\n",
|
||||
" plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
|
||||
" plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
|
||||
" plt.legend(fontsize=15)\n",
|
||||
" plt.xlabel(\"Batches\", fontsize=15)\n",
|
||||
" plt.ylabel(\"RMSE\", fontsize=15)\n",
|
||||
" plt.show()\n",
|
||||
" plt.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.11"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
|
@ -84,7 +83,7 @@
|
|||
"id": "wOwsb8KT_uXR"
|
||||
},
|
||||
"source": [
|
||||
"%pip install git+https://github.com/microsoft/torchgeo.git"
|
||||
"%pip install torchgeo"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
|
@ -538,7 +537,7 @@
|
|||
"outputId": "fa0443da-8b4d-47f7-e713-93e1e4976e87"
|
||||
},
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
"%nvidia-smi"
|
||||
],
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
|
@ -724,4 +723,4 @@
|
|||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
Installation
|
||||
============
|
||||
|
||||
TorchGeo is simple and easy to install. We support installation using the `pip <https://pip.pypa.io/>`_, `conda <https://docs.conda.io/>`_, and `spack <https://spack.io/>`_ package managers, although you can also install from source if you want to.
|
||||
TorchGeo is simple and easy to install. We support installation using the `pip <https://pip.pypa.io/>`_, `conda <https://docs.conda.io/>`_, and `spack <https://spack.io/>`_ package managers.
|
||||
|
||||
pip
|
||||
---
|
||||
|
||||
..
|
||||
Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
|
||||
Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
|
||||
|
||||
.. code-block:: console
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torchgeo
|
||||
$ pip install torchgeo
|
||||
|
||||
|
||||
If you want to install a development version, you can use a VCS project URL:
|
||||
|
@ -32,16 +31,10 @@ or a local git checkout:
|
|||
|
||||
By default, only required dependencies are installed. TorchGeo has a number of optional dependencies for specific datasets or development. These can be installed with a comma-separated list:
|
||||
|
||||
..
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torchgeo[datasets]
|
||||
$ pip install torchgeo[style,tests]
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install .[datasets]
|
||||
$ pip install .[style,tests]
|
||||
$ pip install torchgeo[datasets]
|
||||
$ pip install torchgeo[style,tests]
|
||||
|
||||
|
||||
See the ``setup.cfg`` for a complete list of options. See the `pip documentation <https://pip.pypa.io/>`_ for more details.
|
||||
|
@ -57,12 +50,11 @@ If you need to install non-Python dependencies like PyTorch, it's better to use
|
|||
$ conda config --set channel_priority strict
|
||||
|
||||
|
||||
..
|
||||
Now, you can install the latest stable release using:
|
||||
Now, you can install the latest stable release using:
|
||||
|
||||
.. code-block:: console
|
||||
.. code-block:: console
|
||||
|
||||
$ conda install torchgeo
|
||||
$ conda install torchgeo
|
||||
|
||||
|
||||
Conda does not directly support installing development versions, but you can use conda to install our dependencies, then use pip to install TorchGeo itself.
|
||||
|
@ -110,15 +102,3 @@ Optional dependencies can be installed by enabling build variants:
|
|||
Run ``spack info py-torchgeo`` for a complete list of variants.
|
||||
|
||||
See the `spack documentation <https://spack.readthedocs.io/>`_ for more details.
|
||||
|
||||
source
|
||||
------
|
||||
|
||||
TorchGeo can also be installed from source using the ``setup.py`` file and setuptools.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ git clone https://github.com/microsoft/torchgeo.git
|
||||
$ cd torchgeo
|
||||
$ python setup.py build
|
||||
$ python setup.py install
|
||||
|
|
|
@ -39,7 +39,7 @@ dependencies:
|
|||
- scikit-learn>=0.18
|
||||
- scipy>=0.9
|
||||
- segmentation-models-pytorch>=0.2
|
||||
- setuptools>=30.4
|
||||
- setuptools>=42
|
||||
- sphinx>=3
|
||||
- timm>=0.2.1
|
||||
- torchmetrics
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=30.4",
|
||||
"setuptools>=42",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
|
11
setup.cfg
11
setup.cfg
|
@ -2,14 +2,15 @@
|
|||
[metadata]
|
||||
name = torchgeo
|
||||
version = attr: torchgeo.__version__
|
||||
author = attr: torchgeo.__author__
|
||||
author = Adam J. Stewart
|
||||
author_email = ajstewart426@gmail.com
|
||||
description = TorchGeo: datasets, transforms, and models for geospatial data
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
url = https://github.com/microsoft/torchgeo
|
||||
license_files = LICENSE
|
||||
classifiers =
|
||||
Development Status :: 1 - Planning
|
||||
Development Status :: 3 - Alpha
|
||||
Intended Audience :: Science/Research
|
||||
Programming Language :: Python :: 3
|
||||
Programming Language :: Python :: 3.6
|
||||
|
@ -20,12 +21,12 @@ classifiers =
|
|||
Operating System :: OS Independent
|
||||
Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||
Topic :: Scientific/Engineering :: GIS
|
||||
keywords = pytorch, deep learning, machine learning
|
||||
keywords = pytorch, deep learning, machine learning, remote sensing, satellite imagery, geospatial
|
||||
|
||||
[options]
|
||||
setup_requires =
|
||||
# setuptools 30.4+ required for options.packages.find section in setup.cfg
|
||||
setuptools>=30.4
|
||||
# setuptools 42+ required for metadata.license_files support in setup.cfg
|
||||
setuptools>=42
|
||||
install_requires =
|
||||
einops
|
||||
# fiona 1.5+ required for fiona.transform module
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
def test_install(tmp_path: Path) -> None:
|
||||
subprocess.run(
|
||||
[sys.executable, "setup.py", "build", "--build-base", str(tmp_path)], check=True
|
||||
)
|
|
@ -134,20 +134,34 @@ trainer:
|
|||
subprocess.run(args, check=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["cowc_counting", "cyclone", "sen12ms", "landcoverai"])
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"bigearthnet",
|
||||
"byol",
|
||||
"chesapeake_cvpr",
|
||||
"cowc_counting",
|
||||
"cyclone",
|
||||
"landcoverai",
|
||||
"naipchesapeake",
|
||||
"resisc45",
|
||||
"sen12ms",
|
||||
"so2sat",
|
||||
"ucmerced",
|
||||
],
|
||||
)
|
||||
def test_tasks(task: str, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
data_dir = os.path.join("tests", "data", task)
|
||||
log_dir = tmp_path / "logs"
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"program.data_dir=" + data_dir,
|
||||
"program.log_dir=" + str(log_dir),
|
||||
"trainer.fast_dev_run=1",
|
||||
"experiment.task=" + task,
|
||||
"program.overwrite=True",
|
||||
"config_file=" + os.path.join("conf", "task_defaults", task + ".yaml"),
|
||||
]
|
||||
subprocess.run(args, check=True)
|
||||
|
|
|
@ -58,3 +58,10 @@ class TestLandCoverAISegmentationTask:
|
|||
batch = next(iter(datamodule.val_dataloader()))
|
||||
task.validation_step(batch, 0)
|
||||
task.validation_epoch_end(0)
|
||||
|
||||
def test_test(
|
||||
self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
task.test_step(batch, 0)
|
||||
task.test_epoch_end(0)
|
||||
|
|
|
@ -11,4 +11,4 @@ common image transformations for geospatial data.
|
|||
"""
|
||||
|
||||
__author__ = "Adam J. Stewart"
|
||||
__version__ = "0.1.0.dev0"
|
||||
__version__ = "0.1.0"
|
||||
|
|
|
@ -582,7 +582,7 @@ class BigEarthNetDataModule(pl.LightningDataModule):
|
|||
bands: str = "all",
|
||||
num_classes: int = 19,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
|
||||
|
|
|
@ -553,7 +553,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
patches_per_tile: int = 200,
|
||||
patch_size: int = 256,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
class_set: int = 7,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -711,7 +711,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
splits=self.train_splits,
|
||||
layers=self.layers,
|
||||
transforms=None,
|
||||
download=True,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
|
||||
|
@ -314,7 +314,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
|
|||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
COWCCounting(self.root_dir, download=True)
|
||||
COWCCounting(self.root_dir, download=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
|
|
@ -227,7 +227,7 @@ class CycloneDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -215,7 +215,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 4, **kwargs: Any
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
|
||||
|
||||
|
@ -250,7 +250,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
|
|||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LandCoverAI(self.root_dir, download=True, checksum=False)
|
||||
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
|
|
@ -64,7 +64,6 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
"""
|
||||
|
||||
# TODO: tune these hyperparams
|
||||
patch_size = 256
|
||||
length = 1000
|
||||
stride = 128
|
||||
|
||||
|
@ -73,7 +72,8 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
naip_root_dir: str,
|
||||
chesapeake_root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
patch_size: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
|
||||
|
@ -83,12 +83,14 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
chesapeake_root_dir: directory containing Chesapeake data
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
patch_size: size of patches to sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.naip_root_dir = naip_root_dir
|
||||
self.chesapeake_root_dir = chesapeake_root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.patch_size = patch_size
|
||||
|
||||
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the NAIP Dataset.
|
||||
|
@ -120,7 +122,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
Chesapeake13(self.chesapeake_root_dir, download=True, checksum=False)
|
||||
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
|
|
@ -219,7 +219,7 @@ class RESISC45DataModule(pl.LightningDataModule):
|
|||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
|
||||
|
|
|
@ -289,7 +289,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
|
|||
seed: int,
|
||||
band_set: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
|
||||
|
|
|
@ -243,7 +243,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
bands: str = "rgb",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -225,7 +225,7 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
|
||||
|
@ -266,7 +266,7 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
UCMerced(self.root_dir, download=True, checksum=False)
|
||||
UCMerced(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
|
|
@ -451,3 +451,6 @@ class BYOLTask(LightningModule):
|
|||
)
|
||||
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
||||
|
||||
def test_step(self, *args: Any) -> None: # type: ignore[override]
|
||||
"""No-op, does nothing."""
|
||||
|
|
|
@ -34,7 +34,7 @@ class ClassificationTask(pl.LightningModule):
|
|||
|
||||
imagenet_pretrained = False
|
||||
custom_pretrained = False
|
||||
if not os.path.exists(self.hparams["weights"]):
|
||||
if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
|
||||
if self.hparams["weights"] == "imagenet":
|
||||
imagenet_pretrained = True
|
||||
elif self.hparams["weights"] == "random":
|
||||
|
|
|
@ -109,3 +109,23 @@ class LandCoverAISegmentationTask(SemanticSegmentationTask):
|
|||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step identical to the validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["mask"].long().squeeze()
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the test and validation steps only log per *epoch*
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
|
|
|
@ -27,7 +27,7 @@ class So2SatClassificationTask(ClassificationTask):
|
|||
in_channels = self.hparams["in_channels"]
|
||||
|
||||
pretrained = False
|
||||
if not os.path.exists(self.hparams["weights"]):
|
||||
if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
|
||||
if self.hparams["weights"] == "imagenet":
|
||||
pretrained = True
|
||||
elif self.hparams["weights"] == "random":
|
||||
|
|
Загрузка…
Ссылка в новой задаче