* 0.1.0 release

* Train deps needed for release testing

* Update development status

* setup.py should not be run directly

* Test more trainers

* Fix local docs build

* Update installation instructions

* Specify test data dir in config

* Fix tutorial docs

* Trainers should default to num_workers=0, download=False

* Correct location for root_dir

* Try different GDAL name

* Try again

* Various fixes to release tests

* Update pip installs in tutorials

* Fix some bugs

* Config file not being picked up

* Get back to 100% test coverage

* Added correct weight string to UCMerced

* yolo fix

* yolo fix pt 2

* yolo fix 2 pt. 1

* Simplify tests a bit

* Make the trainer notebook look stupid

* UCMerced should download by default in the trainers

* Revert

* Fix logo/author, include LICENSE in upload

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Adam J. Stewart 2021-11-07 22:05:58 -06:00 коммит произвёл GitHub
Родитель e5bbc738a3
Коммит 740d4f87a3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
40 изменённых файлов: 188 добавлений и 146 удалений

14
.github/workflows/release.yaml поставляемый
Просмотреть файл

@ -17,9 +17,19 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install apt dependencies
run: |
sudo apt-add-repository ppa:ubuntugis/ubuntugis-unstable
sudo apt-get update
sudo apt-get install gdal-bin libgdal-dev
- name: Install pip dependencies
run: pip install .[tests]
run: |
pip install gdal tqdm # TODO: these deps shouldn't be needed
pip install .[datasets,tests,train]
pip install -r docs/requirements.txt
- name: Run notebook checks
env:
MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }}
run: pytest --nbmake docs/tutorials
integration:
name: integration
@ -32,6 +42,6 @@ jobs:
with:
python-version: 3.9
- name: Install pip dependencies
run: pip install .[tests]
run: pip install .[datasets,tests,train]
- name: Run integration checks
run: pytest -m slow

Просмотреть файл

@ -1,4 +1,4 @@
<img src="logo/logo-color.svg" width="400" alt="TorchGeo"/>
<img src="https://raw.githubusercontent.com/microsoft/torchgeo/main/logo/logo-color.svg" width="400" alt="TorchGeo"/>
TorchGeo is a [PyTorch](https://pytorch.org/) domain library, similar to [torchvision](https://pytorch.org/vision), that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.
@ -23,7 +23,7 @@ Tests:
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
```console
$ pip install git+https://github.com/microsoft/torchgeo.git
$ pip install torchgeo
```
For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/latest/user/installation.html).

Просмотреть файл

@ -9,8 +9,8 @@ experiment:
in_channels: 14
num_classes: 19
datamodule:
num_classes: 19
batch_size: 128
num_workers: 6
root_dir: "tests/data/bigearthnet"
bands: "all"
num_classes: 19
num_classes: ${experiment.module.num_classes}
batch_size: 128
num_workers: 0

Просмотреть файл

@ -9,11 +9,12 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 64
num_workers: 6
root_dir: "tests/data/chesapeake/cvpr"
train_splits:
- "de-train"
- "de-test"
val_splits:
- "de-val"
- "de-test"
test_splits:
- "de-test"
batch_size: 64
num_workers: 0

Просмотреть файл

@ -12,9 +12,15 @@ experiment:
num_classes: 7
num_filters: 256
datamodule:
train_state: "de"
root_dir: "tests/data/chesapeake/cvpr"
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
patches_per_tile: 200
patch_size: 256
batch_size: 64
num_workers: 4
num_classes: ${experiment.module.num_classes}
num_workers: 0
class_set: ${experiment.module.num_classes}

Просмотреть файл

@ -5,5 +5,6 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
root_dir: "tests/data/cowc_counting"
batch_size: 32
num_workers: 4
num_workers: 0

Просмотреть файл

@ -5,5 +5,6 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
root_dir: "tests/data/cyclone"
batch_size: 32
num_workers: 4
num_workers: 0

Просмотреть файл

@ -12,5 +12,6 @@ experiment:
num_classes: 6
num_filters: 256
datamodule:
root_dir: "tests/data/landcoverai"
batch_size: 32
num_workers: 4
num_workers: 0

Просмотреть файл

@ -12,5 +12,8 @@ experiment:
num_classes: 13
num_filters: 64
datamodule:
naip_root_dir: "tests/data/naip"
chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
batch_size: 32
num_workers: 4
num_workers: 0
patch_size: 32

Просмотреть файл

@ -9,9 +9,6 @@ experiment:
in_channels: 3
num_classes: 45
datamodule:
root_dir: "tests/data/resisc45"
batch_size: 128
num_workers: 6
weights: ${experiment.module.weights}
unsupervised_mode: false
val_split_pct: 0.2
test_split_pct: 0.2
num_workers: 0

Просмотреть файл

@ -11,5 +11,6 @@ experiment:
in_channels: 15
num_classes: 11
datamodule:
root_dir: "tests/data/sen12ms"
batch_size: 32
num_workers: 4
num_workers: 0

Просмотреть файл

@ -9,6 +9,7 @@ experiment:
in_channels: 3
num_classes: 17
datamodule:
root_dir: "tests/data/so2sat"
batch_size: 128
num_workers: 6
num_workers: 0
bands: "rgb"

Просмотреть файл

@ -3,14 +3,12 @@ experiment:
module:
loss: "ce"
classification_model: "resnet18"
weights: null
weights: "random"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 21
datamodule:
root_dir: "tests/data/ucmerced"
batch_size: 128
num_workers: 6
unsupervised_mode: false
val_split_pct: 0.1
test_split_pct: 0.1
num_workers: 0

Просмотреть файл

@ -37,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install git+https://github.com/microsoft/torchgeo.git"
"%pip install torchgeo"
]
},
{

Просмотреть файл

@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install git+https://github.com/microsoft/torchgeo.git"
"%pip install torchgeo"
]
},
{

Просмотреть файл

@ -23,7 +23,6 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
}
},
@ -86,7 +85,7 @@
"id": "wOwsb8KT_uXR"
},
"source": [
"%pip install git+https://github.com/microsoft/torchgeo.git"
"%pip install torchgeo"
],
"execution_count": null,
"outputs": []
@ -117,7 +116,6 @@
"import rasterio\n",
"import rasterio.features\n",
"import shapely\n",
"import tifffile\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
@ -194,9 +192,9 @@
" compress=\"DEFLATE\"\n",
" ))\n",
"\n",
" # Write to geotiff\n",
" with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
" ds_windowed.write(ds.read(1, window=window), 1)\n",
" # Write to geotiff\n",
" with rasterio.open(path, \"w\", **metadata) as ds_windowed:\n",
" ds_windowed.write(ds.read(1, window=window), 1)\n",
"\n",
"def download(root: str, url: str, bands: List[str], geometry: shapely.geometry.Polygon) -> None:\n",
" \"\"\"Extract windows from each band COG file in s3 and save locally.\"\"\"\n",

Просмотреть файл

@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install git+https://github.com/microsoft/torchgeo.git"
"%pip install torchgeo"
]
},
{
@ -82,6 +82,18 @@
"from torchgeo.trainers import RegressionTask"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8bc0d83",
"metadata": {},
"outputs": [],
"source": [
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
"# skip the expensive training.\n",
"in_tests = (\"PYTEST_CURRENT_TEST\" in os.environ)"
]
},
{
"cell_type": "markdown",
"id": "e6e1d9b6",
@ -226,12 +238,12 @@
],
"source": [
"trainer = pl.Trainer(\n",
" gpus=1,\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[csv_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=1,\n",
" max_epochs=10\n",
" max_epochs=10,\n",
" fast_dev_run=in_tests\n",
")"
]
},
@ -469,25 +481,26 @@
"metadata": {},
"outputs": [],
"source": [
"train_steps = []\n",
"train_rmse = []\n",
"if not in_tests:\n",
" train_steps = []\n",
" train_rmse = []\n",
"\n",
"val_steps = []\n",
"val_rmse = []\n",
"with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
" csv_reader = csv.DictReader(f, delimiter=',')\n",
" for i, row in enumerate(csv_reader):\n",
" try:\n",
" train_rmse.append(float(row[\"train_rmse\"]))\n",
" train_steps.append(i)\n",
" except ValueError: # Ignore rows where train RMSE is empty\n",
" pass\n",
" \n",
" try:\n",
" val_rmse.append(float(row[\"val_rmse\"]))\n",
" val_steps.append(i)\n",
" except ValueError: # Ignore rows where val RMSE is empty\n",
" pass"
" val_steps = []\n",
" val_rmse = []\n",
" with open(os.path.join(experiment_dir, \"tutorial_logs\", \"version_0\", \"metrics.csv\"), \"r\") as f:\n",
" csv_reader = csv.DictReader(f, delimiter=',')\n",
" for i, row in enumerate(csv_reader):\n",
" try:\n",
" train_rmse.append(float(row[\"train_rmse\"]))\n",
" train_steps.append(i)\n",
" except ValueError: # Ignore rows where train RMSE is empty\n",
" pass\n",
"\n",
" try:\n",
" val_rmse.append(float(row[\"val_rmse\"]))\n",
" val_steps.append(i)\n",
" except ValueError: # Ignore rows where val RMSE is empty\n",
" pass"
]
},
{
@ -510,14 +523,15 @@
}
],
"source": [
"plt.figure()\n",
"plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
"plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
"plt.legend(fontsize=15)\n",
"plt.xlabel(\"Batches\", fontsize=15)\n",
"plt.ylabel(\"RMSE\", fontsize=15)\n",
"plt.show()\n",
"plt.close()"
"if not in_tests:\n",
" plt.figure()\n",
" plt.plot(train_steps, train_rmse, label=\"Train RMSE\")\n",
" plt.plot(val_steps, val_rmse, label=\"Validation RMSE\")\n",
" plt.legend(fontsize=15)\n",
" plt.xlabel(\"Batches\", fontsize=15)\n",
" plt.ylabel(\"RMSE\", fontsize=15)\n",
" plt.show()\n",
" plt.close()"
]
},
{

Просмотреть файл

@ -23,7 +23,6 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
},
"accelerator": "GPU"
@ -84,7 +83,7 @@
"id": "wOwsb8KT_uXR"
},
"source": [
"%pip install git+https://github.com/microsoft/torchgeo.git"
"%pip install torchgeo"
],
"execution_count": null,
"outputs": []
@ -538,7 +537,7 @@
"outputId": "fa0443da-8b4d-47f7-e713-93e1e4976e87"
},
"source": [
"!nvidia-smi"
"%nvidia-smi"
],
"execution_count": 8,
"outputs": [
@ -724,4 +723,4 @@
"outputs": []
}
]
}
}

Просмотреть файл

@ -1,17 +1,16 @@
Installation
============
TorchGeo is simple and easy to install. We support installation using the `pip <https://pip.pypa.io/>`_, `conda <https://docs.conda.io/>`_, and `spack <https://spack.io/>`_ package managers, although you can also install from source if you want to.
TorchGeo is simple and easy to install. We support installation using the `pip <https://pip.pypa.io/>`_, `conda <https://docs.conda.io/>`_, and `spack <https://spack.io/>`_ package managers.
pip
---
..
Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
Since TorchGeo is written in pure-Python, the easiest way to install it is using pip:
.. code-block:: console
.. code-block:: console
$ pip install torchgeo
$ pip install torchgeo
If you want to install a development version, you can use a VCS project URL:
@ -32,16 +31,10 @@ or a local git checkout:
By default, only required dependencies are installed. TorchGeo has a number of optional dependencies for specific datasets or development. These can be installed with a comma-separated list:
..
.. code-block:: console
$ pip install torchgeo[datasets]
$ pip install torchgeo[style,tests]
.. code-block:: console
$ pip install .[datasets]
$ pip install .[style,tests]
$ pip install torchgeo[datasets]
$ pip install torchgeo[style,tests]
See the ``setup.cfg`` for a complete list of options. See the `pip documentation <https://pip.pypa.io/>`_ for more details.
@ -57,12 +50,11 @@ If you need to install non-Python dependencies like PyTorch, it's better to use
$ conda config --set channel_priority strict
..
Now, you can install the latest stable release using:
Now, you can install the latest stable release using:
.. code-block:: console
.. code-block:: console
$ conda install torchgeo
$ conda install torchgeo
Conda does not directly support installing development versions, but you can use conda to install our dependencies, then use pip to install TorchGeo itself.
@ -110,15 +102,3 @@ Optional dependencies can be installed by enabling build variants:
Run ``spack info py-torchgeo`` for a complete list of variants.
See the `spack documentation <https://spack.readthedocs.io/>`_ for more details.
source
------
TorchGeo can also be installed from source using the ``setup.py`` file and setuptools.
.. code-block:: console
$ git clone https://github.com/microsoft/torchgeo.git
$ cd torchgeo
$ python setup.py build
$ python setup.py install

Просмотреть файл

@ -39,7 +39,7 @@ dependencies:
- scikit-learn>=0.18
- scipy>=0.9
- segmentation-models-pytorch>=0.2
- setuptools>=30.4
- setuptools>=42
- sphinx>=3
- timm>=0.2.1
- torchmetrics

Просмотреть файл

@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools>=30.4",
"setuptools>=42",
"wheel",
]
build-backend = "setuptools.build_meta"

Просмотреть файл

@ -2,14 +2,15 @@
[metadata]
name = torchgeo
version = attr: torchgeo.__version__
author = attr: torchgeo.__author__
author = Adam J. Stewart
author_email = ajstewart426@gmail.com
description = TorchGeo: datasets, transforms, and models for geospatial data
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/microsoft/torchgeo
license_files = LICENSE
classifiers =
Development Status :: 1 - Planning
Development Status :: 3 - Alpha
Intended Audience :: Science/Research
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
@ -20,12 +21,12 @@ classifiers =
Operating System :: OS Independent
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: GIS
keywords = pytorch, deep learning, machine learning
keywords = pytorch, deep learning, machine learning, remote sensing, satellite imagery, geospatial
[options]
setup_requires =
# setuptools 30.4+ required for options.packages.find section in setup.cfg
setuptools>=30.4
# setuptools 42+ required for metadata.license_files support in setup.cfg
setuptools>=42
install_requires =
einops
# fiona 1.5+ required for fiona.transform module

Просмотреть файл

@ -1,16 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import subprocess
import sys
from pathlib import Path
import pytest
pytestmark = pytest.mark.slow
def test_install(tmp_path: Path) -> None:
subprocess.run(
[sys.executable, "setup.py", "build", "--build-base", str(tmp_path)], check=True
)

Просмотреть файл

@ -134,20 +134,34 @@ trainer:
subprocess.run(args, check=True)
@pytest.mark.parametrize("task", ["cowc_counting", "cyclone", "sen12ms", "landcoverai"])
@pytest.mark.parametrize(
"task",
[
"bigearthnet",
"byol",
"chesapeake_cvpr",
"cowc_counting",
"cyclone",
"landcoverai",
"naipchesapeake",
"resisc45",
"sen12ms",
"so2sat",
"ucmerced",
],
)
def test_tasks(task: str, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
data_dir = os.path.join("tests", "data", task)
log_dir = tmp_path / "logs"
args = [
sys.executable,
"train.py",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"trainer.fast_dev_run=1",
"experiment.task=" + task,
"program.overwrite=True",
"config_file=" + os.path.join("conf", "task_defaults", task + ".yaml"),
]
subprocess.run(args, check=True)

Просмотреть файл

@ -58,3 +58,10 @@ class TestLandCoverAISegmentationTask:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)
def test_test(
self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask
) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)

Просмотреть файл

@ -11,4 +11,4 @@ common image transformations for geospatial data.
"""
__author__ = "Adam J. Stewart"
__version__ = "0.1.0.dev0"
__version__ = "0.1.0"

Просмотреть файл

@ -582,7 +582,7 @@ class BigEarthNetDataModule(pl.LightningDataModule):
bands: str = "all",
num_classes: int = 19,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.

Просмотреть файл

@ -553,7 +553,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
patches_per_tile: int = 200,
patch_size: int = 256,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
class_set: int = 7,
**kwargs: Any,
) -> None:
@ -711,7 +711,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
splits=self.train_splits,
layers=self.layers,
transforms=None,
download=True,
download=False,
checksum=False,
)

Просмотреть файл

@ -278,7 +278,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
@ -314,7 +314,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
COWCCounting(self.root_dir, download=True)
COWCCounting(self.root_dir, download=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.

Просмотреть файл

@ -227,7 +227,7 @@ class CycloneDataModule(pl.LightningDataModule):
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -215,7 +215,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
"""
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 4, **kwargs: Any
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
@ -250,7 +250,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
_ = LandCoverAI(self.root_dir, download=True, checksum=False)
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.

Просмотреть файл

@ -64,7 +64,6 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
"""
# TODO: tune these hyperparams
patch_size = 256
length = 1000
stride = 128
@ -73,7 +72,8 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
naip_root_dir: str,
chesapeake_root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
patch_size: int = 256,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
@ -83,12 +83,14 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
chesapeake_root_dir: directory containing Chesapeake data
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
patch_size: size of patches to sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.naip_root_dir = naip_root_dir
self.chesapeake_root_dir = chesapeake_root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.patch_size = patch_size
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.
@ -120,7 +122,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
Chesapeake13(self.chesapeake_root_dir, download=True, checksum=False)
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.

Просмотреть файл

@ -219,7 +219,7 @@ class RESISC45DataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.

Просмотреть файл

@ -289,7 +289,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
seed: int,
band_set: str = "all",
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.

Просмотреть файл

@ -243,7 +243,7 @@ class So2SatDataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
bands: str = "rgb",
unsupervised_mode: bool = False,
**kwargs: Any,

Просмотреть файл

@ -225,7 +225,7 @@ class UCMercedDataModule(pl.LightningDataModule):
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
@ -266,7 +266,7 @@ class UCMercedDataModule(pl.LightningDataModule):
This method is only called once per run.
"""
UCMerced(self.root_dir, download=True, checksum=False)
UCMerced(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.

Просмотреть файл

@ -451,3 +451,6 @@ class BYOLTask(LightningModule):
)
self.log("val_loss", loss, on_step=False, on_epoch=True)
def test_step(self, *args: Any) -> None: # type: ignore[override]
"""No-op, does nothing."""

Просмотреть файл

@ -34,7 +34,7 @@ class ClassificationTask(pl.LightningModule):
imagenet_pretrained = False
custom_pretrained = False
if not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] == "imagenet":
imagenet_pretrained = True
elif self.hparams["weights"] == "random":

Просмотреть файл

@ -109,3 +109,23 @@ class LandCoverAISegmentationTask(SemanticSegmentationTask):
)
plt.close()
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["mask"].long().squeeze()
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)

Просмотреть файл

@ -27,7 +27,7 @@ class So2SatClassificationTask(ClassificationTask):
in_channels = self.hparams["in_channels"]
pretrained = False
if not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] and not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] == "imagenet":
pretrained = True
elif self.hparams["weights"] == "random":