Merge branch 'main' of https://github.com/microsoft/torchgeo into feature/ssl_experiments_sampling

This commit is contained in:
mtrazzak 2021-10-21 15:02:21 +01:00
Родитель 7a346b6d9d f68cec5a2e
Коммит bae8d0e979
76 изменённых файлов: 2318 добавлений и 421 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -3,6 +3,7 @@
/logs/
/output/
*.csv
*.pdf
# Spack
.spack-env/

Просмотреть файл

@ -12,9 +12,9 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchgeo.datasets import CDL, BoundingBox, Landsat8
from torchgeo.datasets import CDL, Landsat8
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
@ -83,7 +83,7 @@ def set_up_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-s",
"--stride",
default=2 ** 7,
default=112,
type=int,
help="sampling stride for GridGeoSampler",
)
@ -145,26 +145,15 @@ def main(args: argparse.Namespace) -> None:
length = args.num_batches * args.batch_size
num_batches = args.num_batches
# Workaround for https://github.com/microsoft/torchgeo/issues/149
roi = BoundingBox(
-2000000, 2200000, 280000, 3170000, dataset.bounds.mint, dataset.bounds.maxt
)
# Convert from pixel coords to CRS coords
size = args.patch_size * cdl.res
stride = args.stride * cdl.res
samplers = [
RandomGeoSampler(
landsat.index,
size=args.patch_size,
length=length,
roi=roi,
),
GridGeoSampler(
landsat.index, size=args.patch_size, stride=args.stride, roi=roi
),
RandomGeoSampler(landsat, size=size, length=length),
GridGeoSampler(landsat, size=size, stride=stride),
RandomBatchGeoSampler(
landsat.index,
size=args.patch_size,
batch_size=args.batch_size,
length=length,
roi=roi,
landsat, size=size, batch_size=args.batch_size, length=length
),
]
@ -224,7 +213,7 @@ def main(args: argparse.Namespace) -> None:
)
# Benchmark model
model = resnet18()
model = resnet34()
# Change number of input channels to match Landsat
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
@ -259,7 +248,7 @@ def main(args: argparse.Namespace) -> None:
duration = toc - tic
if args.verbose:
print("\nResNet-18:")
print("\nResNet-34:")
print(f" duration: {duration:.3f} sec")
print(f" count: {num_total_patches} patches")
print(f" rate: {num_total_patches / duration:.3f} patches/sec")
@ -271,7 +260,7 @@ def main(args: argparse.Namespace) -> None:
"duration": duration,
"count": num_total_patches,
"rate": num_total_patches / duration,
"sampler": "resnet18",
"sampler": "ResNet-34",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
}
@ -297,8 +286,6 @@ def main(args: argparse.Namespace) -> None:
if __name__ == "__main__":
os.environ["GDAL_CACHEMAX"] = "50%"
parser = set_up_parser()
args = parser.parse_args()

18
conf/resisc45.yaml Normal file
Просмотреть файл

@ -0,0 +1,18 @@
trainer:
gpus: 1 # single GPU training
min_epochs: 10
max_epochs: 40
benchmark: True
experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 128
num_workers: 6
val_split_pct: 0.2
test_split_pct: 0.2

Просмотреть файл

@ -11,6 +11,8 @@ experiment:
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
datamodule:
batch_size: 128
num_workers: 6
bands: "rgb"

Просмотреть файл

@ -0,0 +1,14 @@
experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
datamodule:
batch_size: 128
num_workers: 6
weights: ${experiment.module.weights}
val_split_pct: 0.2
test_split_pct: 0.2

Просмотреть файл

@ -6,7 +6,8 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 3
datamodule:
batch_size: 128
num_workers: 6
weights: ${experiment.module.weights}
bands: "rgb"

Просмотреть файл

@ -6,7 +6,7 @@
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINX_BUILD ?= sphinx-build
SPHINX_BUILD_OPTS ?= -W --keep-going
SPHINX_BUILD_OPTS ?= -W -j auto --keep-going
SOURCE_DIR = .
BUILD_DIR = _build

Просмотреть файл

@ -80,6 +80,11 @@ Smallholder Cashew Plantations in Benin
.. autoclass:: BeninSmallHolderCashews
BigEarthNet
^^^^^^^^^^^
.. autoclass:: BigEarthNet
Cars Overhead With Context (COWC)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -143,6 +148,7 @@ SpaceNet
.. autoclass:: SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
.. autoclass:: SpaceNet4
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Просмотреть файл

@ -1,4 +1,33 @@
torchgeo.models
=================
.. automodule:: torchgeo.models
.. module:: torchgeo.models
Change Star
^^^^^^^^^^^
.. autoclass:: ChangeStar
.. autoclass:: ChangeStarFarSeg
.. autoclass:: ChangeMixin
Foreground-aware Relation Network (FarSeg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FarSeg
Fully-convolutional Network (FCN)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FCN
Fully Convolutional Siamese Networks for Change Detection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FCEF
.. autoclass:: FCSiamConc
.. autoclass:: FCSiamDiff
Random-convolutional feature (RCF) extractor
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RCF

Просмотреть файл

@ -16,7 +16,7 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
from torchgeo.samplers import RandomGeoSampler
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
sampler = RandomGeoSampler(dataset, size=1000, length=100)
dataloader = DataLoader(dataset, sampler=sampler)
@ -43,7 +43,7 @@ When working with large tile-based datasets, randomly sampling patches from each
from torchgeo.samplers import RandomBatchGeoSampler
dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
dataloader = DataLoader(dataset, batch_sampler=sampler)

Просмотреть файл

@ -78,7 +78,7 @@ html_theme_options = {
"logo_only": True,
"pytorch_project": "docs",
"navigation_with_keys": True,
"analytics_id": "UA-117752657-2",
"analytics_id": "UA-209075005-1",
}
html_favicon = os.path.join("..", "logo", "favicon.ico")

Просмотреть файл

@ -211,7 +211,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = RandomGeoSampler(naip.index, size=1000, length=888)\n",
" sampler = RandomGeoSampler(naip, size=1000, length=888)\n",
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"
@ -262,7 +262,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = GridGeoSampler(naip.index, size=1000, stride=500)\n",
" sampler = GridGeoSampler(naip, size=1000, stride=500)\n",
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"
@ -313,7 +313,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)\n",
" sampler = RandomBatchGeoSampler(naip, size=1000, batch_size=12, length=888)\n",
" dataloader = DataLoader(dataset, batch_sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"

Просмотреть файл

@ -329,7 +329,7 @@
},
"outputs": [],
"source": [
"sampler = RandomGeoSampler(naip.index, size=1000, length=10)"
"sampler = RandomGeoSampler(naip, size=1000, length=10)"
]
},
{

Просмотреть файл

@ -224,7 +224,7 @@
" min = np.percentile(x, 100 - percentile, axis=-1)[:, None, None]\n",
" max = np.percentile(x, percentile, axis=-1)[:, None, None]\n",
" x = x.reshape(c, h, w)\n",
" x = np.clamp(x, min, max)\n",
" x = np.clip(x, min, max)\n",
" return (x - min) / (max - min)"
],
"execution_count": null,
@ -687,4 +687,4 @@
]
}
]
}
}

77
experiments/plot_bar_chart.py Executable file
Просмотреть файл

@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
df1 = pd.read_csv("original-benchmark-results.csv")
df2 = pd.read_csv("warped-benchmark-results.csv")
mean1 = df1.groupby("sampler").mean()
mean2 = df2.groupby("sampler").mean()
cached1 = (
df1[(df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
)
cached2 = (
df2[(df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
)
not_cached1 = (
df1[(~df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
)
not_cached2 = (
df2[(~df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
)
print("cached, original\n", cached1)
print("cached, warped\n", cached2)
print("not cached, original\n", not_cached1)
print("not cached, warped\n", not_cached2)
cmap = sns.color_palette()
labels = ["GridGeoSampler", "RandomBatchGeoSampler", "RandomGeoSampler"]
fig, ax = plt.subplots()
x = np.arange(3)
width = 0.2
rects1 = ax.bar(
x - width * 3 / 2,
not_cached1["rate"],
width,
label="Raw Data, Not Cached",
color=cmap[0],
)
rects2 = ax.bar(
x - width * 1 / 2,
not_cached2["rate"],
width,
label="Preprocessed, Not Cached",
color=cmap[1],
)
rects2 = ax.bar(
x + width * 1 / 2, cached1["rate"], width, label="Raw Data, Cached", color=cmap[2]
)
rects3 = ax.bar(
x + width * 3 / 2,
cached2["rate"],
width,
label="Preprocessed, Cached",
color=cmap[3],
)
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=12)
ax.tick_params(axis="x", labelrotation=10)
ax.legend(fontsize="large")
plt.gca().spines.right.set_visible(False)
plt.gca().spines.top.set_visible(False)
plt.tight_layout()
plt.show()

Просмотреть файл

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
df = pd.read_csv("warped-benchmark-results.csv")
random_cached = df[(df["sampler"] == "RandomGeoSampler") & (df["cached"])]
random_batch_cached = df[(df["sampler"] == "RandomBatchGeoSampler") & (df["cached"])]
grid_cached = df[(df["sampler"] == "GridGeoSampler") & (df["cached"])]
other = [
("RandomGeoSampler", random_cached),
("RandomBatchGeoSampler", random_batch_cached),
("GridGeoSampler", grid_cached),
]
cmap = sns.color_palette()
ax = plt.gca()
for i, (label, df) in enumerate(other):
df = df.groupby("batch_size")
ax.plot(df.mean().index, df.mean()["rate"], color=cmap[i], label=label)
ax.fill_between(
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
)
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
ax.set_xlabel("batch size", fontsize=12)
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
ax.legend(loc="center right", fontsize="large")
plt.gca().spines.right.set_visible(False)
plt.gca().spines.top.set_visible(False)
plt.tight_layout()
plt.show()

Просмотреть файл

@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
df1 = pd.read_csv("original-benchmark-results.csv")
df2 = pd.read_csv("warped-benchmark-results.csv")
random_cached1 = df1[(df1["sampler"] == "RandomGeoSampler") & (df1["cached"])]
random_cached2 = df2[(df2["sampler"] == "RandomGeoSampler") & (df2["cached"])]
random_cachedp = random_cached1
random_cachedp["rate"] /= random_cached2["rate"]
random_batch_cached1 = df1[
(df1["sampler"] == "RandomBatchGeoSampler") & (df1["cached"])
]
random_batch_cached2 = df2[
(df2["sampler"] == "RandomBatchGeoSampler") & (df2["cached"])
]
random_batch_cachedp = random_batch_cached1
random_batch_cachedp["rate"] /= random_batch_cached2["rate"]
grid_cached1 = df1[(df1["sampler"] == "GridGeoSampler") & (df1["cached"])]
grid_cached2 = df2[(df2["sampler"] == "GridGeoSampler") & (df2["cached"])]
grid_cachedp = grid_cached1
grid_cachedp["rate"] /= grid_cached2["rate"]
other = [
("RandomGeoSampler (cached)", random_cachedp),
("RandomBatchGeoSampler (cached)", random_batch_cachedp),
("GridGeoSampler (cached)", grid_cachedp),
]
cmap = sns.color_palette()
ax = plt.gca()
for i, (label, df) in enumerate(other):
df = df.groupby("batch_size")
ax.plot([16, 32, 64, 128, 256], df.mean()["rate"], color=cmap[i], label=label)
ax.fill_between(
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
)
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
ax.set_xticklabels([16, 32, 64, 128, 256])
ax.set_xlabel("batch size")
ax.set_ylabel("% sampling rate (patches/sec)")
ax.legend()
plt.show()

Просмотреть файл

@ -3,12 +3,14 @@
# Licensed under the MIT License.
"""Script for running the benchmark script over a sweep of different options."""
import itertools
import os
import subprocess
import time
from typing import List
NUM_BATCHES = 100
EPOCH_SIZE = 4096
SEED_OPTIONS = [0, 1, 2]
CACHE_OPTIONS = [True, False]
@ -23,11 +25,14 @@ CDL_DATA_ROOT = ""
total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS)
if __name__ == "__main__":
# With 6 workers, this will use ~60% of available RAM
os.environ["GDAL_CACHEMAX"] = "10%"
tic = time.time()
for i, (cache, batch_size, seed) in enumerate(
itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS)
):
print(f"{i}/{total_num_experiments} -- {time.time() - tic}")
print(f"\n{i}/{total_num_experiments} -- {time.time() - tic}")
tic = time.time()
command: List[str] = [
"python",
@ -40,8 +45,8 @@ if __name__ == "__main__":
"6",
"--batch-size",
str(batch_size),
"--num-batches",
str(NUM_BATCHES),
"--epoch-size",
str(EPOCH_SIZE),
"--seed",
str(seed),
"--verbose",

Просмотреть файл

@ -15,11 +15,10 @@ DATA_DIR = "" # path to the LandcoverAI data directory
# Hyperparameter options
model_options = ["unet"]
encoder_options = ["resnet50"]
lr_options = [1e-4]
loss_options = ["ce"]
weight_init_options = ["imagenet"]
seeds = list(range(15))
encoder_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce", "jaccard"]
weight_init_options = ["null", "imagenet"]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
@ -36,18 +35,13 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for (model, encoder, lr, loss, weight_init, seed) in itertools.product(
model_options,
encoder_options,
lr_options,
loss_options,
weight_init_options,
seeds,
for (model, encoder, lr, loss, weight_init) in itertools.product(
model_options, encoder_options, lr_options, loss_options, weight_init_options
):
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}_{seed}"
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}"
output_dir = os.path.join("output", "landcoverai_seed_experiments")
output_dir = os.path.join("output", "landcoverai_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "landcoverai.yaml")
@ -63,7 +57,6 @@ if __name__ == "__main__":
+ f" experiment.module.encoder_name={encoder}"
+ f" experiment.module.encoder_weights={weight_init}"
+ f" program.output_dir={output_dir}"
+ f" program.seed={seed}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"

Просмотреть файл

@ -0,0 +1,82 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue
# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0]
DRY_RUN = False # if False then print out the commands to be run, if True then run
DATA_DIR = "" # path to the RESISC45 data directory
# Hyperparameter options
model_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce"]
weight_options = ["imagenet_only", "random"]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for (model, lr, loss, weights) in itertools.product(
model_options,
lr_options,
loss_options,
weight_options,
):
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}"
output_dir = os.path.join("output", "resisc45_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "resisc45.yaml")
if not os.path.exists(os.path.join(output_dir, experiment_name)):
command = (
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
+ f" experiment.datamodule.weights={weights}"
+ f" program.output_dir={output_dir}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"
)
command = command.strip()
work.put(command)
processes = []
for gpu_idx in GPUS:
p = Process(
target=do_work,
args=(
work,
gpu_idx,
),
)
processes.append(p)
p.start()
for p in processes:
p.join()

Просмотреть файл

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue
from typing import List
# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0, 1, 2, 3, 3]
DRY_RUN = False # if False then print out the commands to be run, if True then run
DATA_DIR = "" # path to the So2Sat data directory
# Hyperparameter options
model_options = ["resnet50"]
lr_options = [1e-4]
loss_options = ["ce"]
weight_options: List[str] = [] # set paths to checkpoint files
bands_options = ["s2"]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for (model, lr, loss, weights, bands) in itertools.product(
model_options,
lr_options,
loss_options,
weight_options,
bands_options,
):
experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}"
output_dir = os.path.join("output", "so2sat_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "so2sat.yaml")
if not os.path.exists(os.path.join(output_dir, experiment_name)):
command = (
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
+ " experiment.module.in_channels=10"
+ f" experiment.datamodule.bands={bands}"
+ f" program.output_dir={output_dir}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"
)
command = command.strip()
work.put(command)
processes = []
for gpu_idx in GPUS:
p = Process(
target=do_work,
args=(
work,
gpu_idx,
),
)
processes.append(p)
p.start()
for p in processes:
p.join()

Просмотреть файл

@ -69,8 +69,6 @@ datasets =
scipy>=0.9
# Optional trainer requirements
train =
# kornia 0.5.4+ required for AugmentationSequential
kornia>=0.5.4
# omegaconf 2.1+ required for to_object method
omegaconf>=2.1
# pytorch-lightning 1.3+ required for gradient_clip_algorithm argument to Trainer

Двоичные данные
tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/chesapeake/cvpr/cvpr_chesapeake_landcover.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Any, Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -40,7 +40,7 @@ class TestADVANCE:
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ADVANCE(root, transforms, download=True, checksum=True)
@pytest.fixture

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import BeninSmallHolderCashews
from torchgeo.transforms import Identity
class Dataset:
@ -50,7 +50,7 @@ class TestBeninSmallHolderCashews:
BeninSmallHolderCashews, "dates", ("2019_11_05",)
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return BeninSmallHolderCashews(
root,
transforms=transforms,

Просмотреть файл

@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import BigEarthNet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestBigEarthNet:
@pytest.fixture(params=["all", "s1", "s2"])
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> BigEarthNet:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.bigearthnet, "download_url", download_url
)
data_dir = os.path.join("tests", "data", "bigearthnet")
metadata = {
"s1": {
"url": os.path.join(data_dir, "BigEarthNet-S1-v1.0.tar.gz"),
"md5": "5a64e9ce38deb036a435a7b59494924c",
"filename": "BigEarthNet-S1-v1.0.tar.gz",
"directory": "BigEarthNet-S1-v1.0",
},
"s2": {
"url": os.path.join(data_dir, "BigEarthNet-S2-v1.0.tar.gz"),
"md5": "ef5f41129b8308ca178b04d7538dbacf",
"filename": "BigEarthNet-S2-v1.0.tar.gz",
"directory": "BigEarthNet-v1.0",
},
}
monkeypatch.setattr( # type: ignore[attr-defined]
BigEarthNet, "metadata", metadata
)
bands = request.param
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return BigEarthNet(root, bands, transforms, download=True, checksum=True)
def test_getitem(self, dataset: BigEarthNet) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["label"].shape == (43,)
assert x["image"].dtype == torch.int32 # type: ignore[attr-defined]
assert x["label"].dtype == torch.int64 # type: ignore[attr-defined]
if dataset.bands == "all":
assert x["image"].shape == (14, 120, 120)
elif dataset.bands == "s1":
assert x["image"].shape == (2, 120, 120)
else:
assert x["image"].shape == (12, 120, 120)
def test_len(self, dataset: BigEarthNet) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
def test_already_downloaded_not_extracted(
self, dataset: BigEarthNet, tmp_path: Path
) -> None:
if dataset.bands == "all":
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
)
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
)
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
elif dataset.bands == "s1":
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
)
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
else:
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
)
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
BigEarthNet(
root=str(tmp_path),
bands=dataset.bands,
download=False,
)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
BigEarthNet(str(tmp_path))

Просмотреть файл

@ -9,12 +9,12 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -57,7 +57,7 @@ class TestCanadianBuildingFootprints:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CanadianBuildingFootprints(
root, res=0.1, transforms=transforms, download=True, checksum=True
)

Просмотреть файл

@ -11,12 +11,12 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -44,7 +44,7 @@ class TestCDL:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CDL(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: CDL) -> None:

Просмотреть файл

@ -9,15 +9,16 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset
from torchgeo.transforms import Identity
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -41,7 +42,7 @@ class TestChesapeake13:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: Chesapeake13) -> None:
@ -76,3 +77,84 @@ class TestChesapeake13:
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
class TestChesapeakeCVPR:
@pytest.fixture(
params=[
("naip-new", "naip-old", "nlcd"),
("landsat-leaf-on", "landsat-leaf-off", "lc"),
("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"),
]
)
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> ChesapeakeCVPR:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.chesapeake, "download_url", download_url
)
md5 = "564b8d944a941b0b65db9f56c92b93a2"
monkeypatch.setattr(ChesapeakeCVPR, "md5", md5) # type: ignore[attr-defined]
url = os.path.join(
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
)
monkeypatch.setattr(ChesapeakeCVPR, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
ChesapeakeCVPR,
"files",
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return ChesapeakeCVPR(
root,
splits=["de-test"],
layers=request.param,
transforms=transforms,
download=True,
checksum=True,
)
def test_getitem(self, dataset: ChesapeakeCVPR) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: ChesapeakeCVPR) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
ChesapeakeCVPR(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
)
root = str(tmp_path)
shutil.copy(url, root)
ChesapeakeCVPR(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
ChesapeakeCVPR(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None:
ds = ChesapeakeCVPR(
root=dataset.root, splits=["de-train", "de-test"], layers=dataset.layers
)
with pytest.raises(
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
@ -15,7 +16,6 @@ from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import COWCCounting, COWCDetection
from torchgeo.datasets.cowc import COWC
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -56,7 +56,7 @@ class TestCOWCCounting:
monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return COWCCounting(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: COWC) -> None:
@ -114,7 +114,7 @@ class TestCOWCDetection:
monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
split = "train"
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return COWCDetection(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: COWC) -> None:

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import CV4AKenyaCropType
from torchgeo.transforms import Identity
class Dataset:
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
CV4AKenyaCropType, "dates", ["20190606"]
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CV4AKenyaCropType(
root,
transforms=transforms,

Просмотреть файл

@ -9,12 +9,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import TropicalCycloneWindEstimation
from torchgeo.transforms import Identity
class Dataset:
@ -57,7 +57,7 @@ class TestTropicalCycloneWindEstimation:
)
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return TropicalCycloneWindEstimation(
root, split, transforms, download=True, api_key="", checksum=True
)

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -55,7 +55,7 @@ class TestETCI2021:
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ETCI2021(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: ETCI2021) -> None:

Просмотреть файл

@ -7,6 +7,7 @@ from typing import Dict
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import ConcatDataset
@ -21,7 +22,6 @@ from torchgeo.datasets import (
VisionDataset,
ZipDataset,
)
from torchgeo.transforms import Identity
class CustomGeoDataset(GeoDataset):
@ -57,11 +57,15 @@ class TestGeoDataset:
query = BoundingBox(0, 0, 0, 0, 0, 0)
assert dataset[query] == {"index": query}
def test_len(self, dataset: GeoDataset) -> None:
assert len(dataset) == 1
def test_add_two(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
dataset = ds1 + ds2
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
def test_add_three(self) -> None:
ds1 = CustomGeoDataset()
@ -69,6 +73,7 @@ class TestGeoDataset:
ds3 = CustomGeoDataset()
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 3
def test_add_four(self) -> None:
ds1 = CustomGeoDataset()
@ -77,10 +82,13 @@ class TestGeoDataset:
ds4 = CustomGeoDataset()
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 4
def test_str(self, dataset: GeoDataset) -> None:
assert "type: GeoDataset" in str(dataset)
assert "bbox: BoundingBox" in str(dataset)
out = str(dataset)
assert "type: GeoDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 1" in out
def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
@ -98,7 +106,7 @@ class TestRasterDataset:
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
crs = CRS.from_epsg(3005)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
cache = request.param
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)
@ -223,9 +231,14 @@ class TestZipDataset:
query = BoundingBox(0, 1, 2, 3, 4, 5)
assert dataset[query] == {"index": query}
def test_len(self, dataset: ZipDataset) -> None:
assert len(dataset) == 2
def test_str(self, dataset: ZipDataset) -> None:
assert "type: ZipDataset" in str(dataset)
assert "bbox: BoundingBox" in str(dataset)
out = str(dataset)
assert "type: ZipDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 2" in out
def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import GID15
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -37,7 +37,7 @@ class TestGID15:
monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return GID15(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: GID15) -> None:

Просмотреть файл

@ -8,13 +8,13 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import LandCoverAI
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -40,7 +40,7 @@ class TestLandCoverAI:
monkeypatch.setattr(LandCoverAI, "sha256", sha256) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return LandCoverAI(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LandCoverAI) -> None:

Просмотреть файл

@ -8,11 +8,11 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
from torchgeo.transforms import Identity
class TestLandsat8:
@ -23,7 +23,7 @@ class TestLandsat8:
)
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Landsat8(root, bands=bands, transforms=transforms)
def test_separate_files(self, dataset: Landsat8) -> None:

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import LEVIRCDPlus
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -37,7 +37,7 @@ class TestLEVIRCDPlus:
monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LEVIRCDPlus) -> None:

Просмотреть файл

@ -8,11 +8,11 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import NAIP, BoundingBox, ZipDataset
from torchgeo.transforms import Identity
class TestNAIP:
@ -22,7 +22,7 @@ class TestNAIP:
plt, "show", lambda *args: None
)
root = os.path.join("tests", "data", "naip")
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return NAIP(root, transforms=transforms)
def test_getitem(self, dataset: NAIP) -> None:

Просмотреть файл

@ -9,13 +9,13 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import VHR10
from torchgeo.transforms import Identity
pytest.importorskip("rarfile")
pytest.importorskip("pycocotools")
@ -51,7 +51,7 @@ class TestVHR10:
monkeypatch.setitem(VHR10.target_meta, "md5", md5) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return VHR10(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: VHR10) -> None:

Просмотреть файл

@ -7,12 +7,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import SEN12MS
from torchgeo.transforms import Identity
class TestSEN12MS:
@ -38,7 +38,7 @@ class TestSEN12MS:
monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined]
root = os.path.join("tests", "data", "sen12ms")
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SEN12MS(root, split, transforms=transforms, checksum=True)
def test_getitem(self, dataset: SEN12MS) -> None:

Просмотреть файл

@ -6,10 +6,10 @@ from pathlib import Path
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
from torchgeo.transforms import Identity
class TestSentinel2:
@ -17,7 +17,7 @@ class TestSentinel2:
def dataset(self) -> Sentinel2:
root = os.path.join("tests", "data", "sentinel2")
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Sentinel2(root, bands=bands, transforms=transforms)
def test_separate_files(self, dataset: Sentinel2) -> None:

Просмотреть файл

@ -7,11 +7,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import So2Sat
from torchgeo.transforms import Identity
class TestSo2Sat:
@ -28,7 +28,7 @@ class TestSo2Sat:
monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined]
root = os.path.join("tests", "data", "so2sat")
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True)
def test_getitem(self, dataset: So2Sat) -> None:

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import SpaceNet1, SpaceNet2
from torchgeo.transforms import Identity
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
TEST_DATA_DIR = "tests/data/spacenet"
@ -51,7 +51,7 @@ class TestSpaceNet1:
SpaceNet1, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet1(
root,
image=request.param,
@ -104,7 +104,7 @@ class TestSpaceNet2:
SpaceNet2, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet2(
root,
image=request.param,
@ -141,3 +141,67 @@ class TestSpaceNet2:
dataset.collection_md5_dict["sn2_AOI_2_Vegas"] = "randommd5hash123"
with pytest.raises(RuntimeError, match="Collection sn2_AOI_2_Vegas corrupted"):
SpaceNet2(root=dataset.root, download=True, checksum=True)
class TestSpaceNet4:
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> SpaceNet4:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Collection, "fetch", fetch_collection
)
test_md5 = {
"sn4_AOI_6_Atlanta": "ea37c2d87e2c3a1d8b2a7c2230080d46",
}
test_angles = ["nadir", "off-nadir", "very-off-nadir"]
monkeypatch.setattr( # type: ignore[attr-defined]
SpaceNet4, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet4(
root,
image=request.param,
angles=test_angles,
transforms=transforms,
download=True,
api_key="",
)
def test_getitem(self, dataset: SpaceNet4) -> None:
# Get image-label pair with empty label to
# enusre coverage
x = dataset[2]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "PS-RGBNIR":
assert x["image"].shape[0] == 4
elif dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1
def test_len(self, dataset: SpaceNet4) -> None:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet4(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SpaceNet4(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict["sn4_AOI_6_Atlanta"] = "randommd5hash123"
with pytest.raises(
RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted"
):
SpaceNet4(root=dataset.root, download=True, checksum=True)

Просмотреть файл

@ -16,11 +16,13 @@ import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import TensorDataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
collate_dict,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
@ -256,6 +258,12 @@ class TestBoundingBox:
datetime(2021, 9, 1, 0, 0, 0, 0).timestamp(),
datetime(2021, 9, 30, 23, 59, 59, 999999).timestamp(),
),
(
"Dec 21",
"%b %y",
datetime(2021, 12, 1, 0, 0, 0, 0).timestamp(),
datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(),
),
(
"2021-09-13",
"%Y-%m-%d",
@ -335,3 +343,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
with working_dir(str(subdir), create=True):
assert subdir.cwd() == subdir
def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)
# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2
# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Any, Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ZueriCrop
from torchgeo.transforms import Identity
pytest.importorskip("h5py")
@ -41,7 +41,7 @@ class TestZueriCrop:
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ZueriCrop(root, transforms, download=True, checksum=True)
@pytest.fixture

37
tests/models/test_rcf.py Normal file
Просмотреть файл

@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import pytest
import torch
from torchgeo.models import RCF
class TestRCF:
def test_in_channels(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
x = torch.randn(2, 5, 64, 64)
model(x)
model = RCF(in_channels=3, features=4, kernel_size=3)
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)
def test_num_features(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4
x = torch.randn(1, 5, 64, 64)
y = model(x)
assert y.shape[0] == 4
def test_untrainable(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
assert len(list(model.parameters())) == 0
def test_biases(self) -> None:
model = RCF(features=24, bias=10)
assert torch.all(model.biases == 10) # type: ignore[attr-defined]

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Dict, Iterator, List
import pytest
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rtree.index import Index, Property
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset
@ -29,12 +28,10 @@ class CustomBatchGeoSampler(BatchGeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
self.index.insert(0, bounds)
self.crs = crs
self.res = res
@ -54,6 +51,7 @@ class TestBatchGeoSampler:
def test_len(self, sampler: CustomBatchGeoSampler) -> None:
assert len(sampler) == 2
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
@ -71,11 +69,11 @@ class TestBatchGeoSampler:
class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
@ -93,6 +91,7 @@ class TestRandomBatchGeoSampler:
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Dict, Iterator
import pytest
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rtree.index import Index, Property
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset
@ -29,12 +28,10 @@ class CustomGeoSampler(GeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
self.index.insert(0, bounds)
self.crs = crs
self.res = res
@ -53,6 +50,7 @@ class TestGeoSampler:
def test_len(self, sampler: CustomGeoSampler) -> None:
assert len(sampler) == 2
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
@ -70,11 +68,11 @@ class TestGeoSampler:
class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomGeoSampler(index, size, length=10)
return RandomGeoSampler(ds, size, length=10)
def test_iter(self, sampler: RandomGeoSampler) -> None:
for query in sampler:
@ -91,6 +89,7 @@ class TestRandomGeoSampler:
def test_len(self, sampler: RandomGeoSampler) -> None:
assert len(sampler) == sampler.length
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
@ -114,11 +113,11 @@ class TestGridGeoSampler:
],
)
def sampler(self, request: SubRequest) -> GridGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 20, 0, 10, 40, 50))
index.insert(1, (0, 20, 0, 10, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 20, 0, 10, 40, 50))
ds.index.insert(1, (0, 20, 0, 10, 40, 50))
size, stride = request.param
return GridGeoSampler(index, size, stride)
return GridGeoSampler(ds, size, stride)
def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
@ -132,6 +131,7 @@ class TestGridGeoSampler:
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()

Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import Any, Dict, cast
import pytest
import torch.nn as nn
import torchvision
from omegaconf import OmegaConf
from torchgeo.trainers import RESISC45ClassificationTask
class TestRESISC45Trainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/resisc45.yaml")
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args
def test_resnet_ce(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "resnet18"
default_config["loss"] = "ce"
task = RESISC45ClassificationTask(**default_config)
assert isinstance(task.model, torchvision.models.ResNet)
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
default_config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)

Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
from collections import OrderedDict
from pathlib import Path
from typing import Dict
import pytest
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from torch import Tensor
from torch.nn.modules import Module
from torchgeo.trainers.utils import extract_encoder, load_state_dict
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "nn.Module"
@pytest.fixture
def model() -> Module:
model: Module = torchvision.models.resnet18(pretrained=False)
return model
@pytest.fixture
def state_dict(model: Module) -> Dict[str, Tensor]:
return model.state_dict()
@pytest.fixture(params=["classification_model", "encoder"])
def checkpoint(
state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path
) -> str:
if request.param == "classification_model":
state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()})
else:
state_dict = OrderedDict(
{"model.encoder.model." + k: v for k, v in state_dict.items()}
)
checkpoint = {
"hyper_parameters": {request.param: "resnet18"},
"state_dict": state_dict,
}
path = os.path.join(str(tmp_path), f"model_{request.param}.ckpt")
torch.save(checkpoint, path)
return path
def test_extract_encoder_unsupported_model(tmp_path: Path) -> None:
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}}
path = os.path.join(str(tmp_path), "dummy.ckpt")
torch.save(checkpoint, path)
err = """Unknown checkpoint task. Only encoder or classification_model"""
"""extraction is supported"""
with pytest.raises(ValueError, match=err):
extract_encoder(path)
def test_extract_encoder(checkpoint: str) -> None:
extract_encoder(checkpoint)
def test_load_state_dict(checkpoint: str, model: Module) -> None:
_, state_dict = extract_encoder(checkpoint)
model = load_state_dict(model, state_dict)
def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
_, state_dict = extract_encoder(checkpoint)
expected_in_channels = state_dict["conv1.weight"].shape[1]
in_channels = 7
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
in_channels,
out_channels=64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
warning = f"""input channels {in_channels} != input channels in pretrained"""
f"""model {expected_in_channels}. Overriding with new input channels"""
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
_, state_dict = extract_encoder(checkpoint)
expected_num_classes = state_dict["fc.weight"].shape[0]
num_classes = 10
in_features = model.fc.in_features # type: ignore[union-attr]
model.fc = nn.Linear( # type: ignore[attr-defined]
in_features, out_features=num_classes
)
warning = f"""num classes {num_classes} != num classes in pretrained model"""
f"""{expected_num_classes}. Overriding with new num classes"""
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)

Просмотреть файл

@ -105,46 +105,6 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N
assert equal, err
def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomHorizontalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor( # type: ignore[attr-defined]
[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]
),
"mask": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 3, 2], [0, 1, 2, 3]]
),
}
assert_matching(output, expected)
def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomVerticalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor( # type: ignore[attr-defined]
[[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]
),
"mask": torch.tensor( # type: ignore[attr-defined]
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[0, 1, 2, 3], [1, 0, 3, 2]]
),
}
assert_matching(output, expected)
def test_identity(sample: Dict[str, Tensor]) -> None:
tr = transforms.Identity()
output = tr(sample)
assert_matching(output, sample)
def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None:
expected = {
"image": torch.tensor( # type: ignore[attr-defined]

Просмотреть файл

@ -5,6 +5,7 @@
from .advance import ADVANCE
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chesapeake import (
@ -56,7 +57,7 @@ from .resisc45 import RESISC45
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4
from .ucmerced import UCMerced
from .utils import BoundingBox, collate_dict
from .zuericrop import ZueriCrop
@ -93,6 +94,7 @@ __all__ = (
# VisionDataset
"ADVANCE",
"BeninSmallHolderCashews",
"BigEarthNet",
"COWC",
"COWCCounting",
"COWCDetection",
@ -109,6 +111,7 @@ __all__ = (
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
"SpaceNet4",
"TropicalCycloneWindEstimation",
"UCMerced",
"VHR10",

Просмотреть файл

@ -0,0 +1,380 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""BigEarthNet dataset."""
import glob
import json
import os
from typing import Callable, Dict, List, Optional
import numpy as np
import rasterio
import torch
from rasterio.enums import Resampling
from torch import Tensor
from .geo import VisionDataset
from .utils import download_url, extract_archive
def sort_bands(x: str) -> str:
"""Sort Sentinel-2 band files in the correct order."""
x = os.path.basename(x).split("_")[-1]
x = os.path.splitext(x)[0]
if x == "B8A":
x = "B08A"
return x
class BigEarthNet(VisionDataset):
"""BigEarthNet dataset.
The `BigEarthNet <http://bigearth.net/>`_
dataset is a dataset for multilabel remote sensing image scene classification.
Dataset features:
* 590,326 patches from 125 Sentinel-1 and Sentinel-2 tiles
* Imagery from tiles in Europe between Jun 2017 - May 2018
* 12 spectral bands with 10-60 m per pixel resolution (base 120x120 px)
* 2 synthetic aperture radar bands (120x120 px)
* 43 scene classes from the 2018 CORINE Land Cover database (CLC 2018)
Dataset format:
* images are composed of multiple single channel geotiffs
* labels are multiclass, stored in a single json file per image
* mapping of Sentinel-1 to Sentinel-2 patches are within Sentinel-1 json files
* Sentinel-1 bands: (VV, VH)
* Sentinel-2 bands: (B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
* All bands: (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
* Sentinel-2 bands are of different spatial resolutions and upsampled to 10m
Dataset classes:
0. Agro-forestry areas
1. Airports
2. Annual crops associated with permanent crops
3. Bare rock
4. Beaches, dunes, sands
5. Broad-leaved forest
6. Burnt areas
7. Coastal lagoons
8. Complex cultivation patterns
9. Coniferous forest
10. Construction sites
11. Continuous urban fabric
12. Discontinuous urban fabric
13. Dump sites
14. Estuaries
15. Fruit trees and berry plantations
16. Green urban areas
17. Industrial or commercial units
18. Inland marshes
19. Intertidal flats
20. Land principally occupied by agriculture, with
significant areas of natural vegetation
21. Mineral extraction sites
22. Mixed forest
23. Moors and heathland
24. Natural grassland
25. Non-irrigated arable land
26. Olive groves
27. Pastures
28. Peatbogs
29. Permanently irrigated land
30. Port areas
31. Rice fields
32. Road and rail networks and associated land
33. Salines
34. Salt marshes
35. Sclerophyllous vegetation
36. Sea and ocean
37. Sparsely vegetated areas
38. Sport and leisure facilities
39. Transitional woodland/shrub
40. Vineyards
41. Water bodies
42. Water courses
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1109/IGARSS.2019.8900532
"""
classes = [
"Agro-forestry areas",
"Airports",
"Annual crops associated with permanent crops",
"Bare rock",
"Beaches, dunes, sands",
"Broad-leaved forest",
"Burnt areas",
"Coastal lagoons",
"Complex cultivation patterns",
"Coniferous forest",
"Construction sites",
"Continuous urban fabric",
"Discontinuous urban fabric",
"Dump sites",
"Estuaries",
"Fruit trees and berry plantations",
"Green urban areas",
"Industrial or commercial units",
"Inland marshes",
"Intertidal flats",
"Land principally occupied by agriculture, with significant areas of "
"natural vegetation",
"Mineral extraction sites",
"Mixed forest",
"Moors and heathland",
"Natural grassland",
"Non-irrigated arable land",
"Olive groves",
"Pastures",
"Peatbogs",
"Permanently irrigated land",
"Port areas",
"Rice fields",
"Road and rail networks and associated land",
"Salines",
"Salt marshes",
"Sclerophyllous vegetation",
"Sea and ocean",
"Sparsely vegetated areas",
"Sport and leisure facilities",
"Transitional woodland/shrub",
"Vineyards",
"Water bodies",
"Water courses",
]
metadata = {
"s1": {
"url": "http://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz",
"md5": "5a64e9ce38deb036a435a7b59494924c",
"filename": "BigEarthNet-S1-v1.0.tar.gz",
"directory": "BigEarthNet-S1-v1.0",
},
"s2": {
"url": "http://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz",
"md5": "5a64e9ce38deb036a435a7b59494924c",
"filename": "BigEarthNet-S2-v1.0.tar.gz",
"directory": "BigEarthNet-v1.0",
},
}
image_size = (120, 120)
def __init__(
self,
root: str = "data",
bands: str = "all",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new BigEarthNet dataset instance.
Args:
root: root directory where dataset can be found
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
"""
assert bands in ["s1", "s2", "all"]
self.root = root
self.bands = bands
self.transforms = transforms
self.download = download
self.checksum = checksum
self.class2idx = {c: i for i, c in enumerate(self.classes)}
self.num_classes = len(self.classes)
self._verify()
if bands == "s2":
self.files = glob.glob(
os.path.join(self.root, self.metadata["s2"]["directory"], "*")
)
else:
self.files = glob.glob(
os.path.join(self.root, self.metadata["s1"]["directory"], "*")
)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image = self._load_image(index)
label = self._load_target(index)
sample: Dict[str, Tensor] = {
"image": image,
"label": label,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.files)
def _load_paths(self, index: int) -> List[str]:
"""Load paths to band files.
Args:
index: index to return
Returns:
list of file paths
"""
folder = self.files[index]
paths = glob.glob(os.path.join(folder, "*.tif"))
# S1->S2 patch mapping is in S1 patch metadata json file
if self.bands == "all":
paths = sorted(paths)
metadata_path = glob.glob(os.path.join(folder, "*.json"))[0]
with open(metadata_path, "r") as f:
name_s2 = json.load(f)["corresponding_s2_patch"]
folder_s2 = os.path.join(
self.root, self.metadata["s2"]["directory"], name_s2
)
paths_s2 = glob.glob(os.path.join(folder_s2, "*.tif"))
paths_s2 = sorted(paths_s2, key=sort_bands)
paths.extend(paths_s2)
elif self.bands == "s1":
paths = sorted(paths)
else:
paths = sorted(paths, key=sort_bands)
return paths
def _load_image(self, index: int) -> Tensor:
"""Load a single image.
Args:
index: index to return
Returns:
the raster image or target
"""
paths = self._load_paths(index)
images = []
for path in paths:
# Bands are of different spatial resolutions
# Resample to (120, 120)
with rasterio.open(path) as dataset:
array = dataset.read(
indexes=1,
out_shape=self.image_size,
out_dtype="int32",
resampling=Resampling.bilinear,
)
images.append(array)
arrays = np.stack(images, axis=0)
tensor: Tensor = torch.from_numpy(arrays) # type: ignore[attr-defined]
return tensor
def _load_target(self, index: int) -> Tensor:
"""Load the target mask for a single image.
Args:
index: index to return
Returns:
the target label
"""
folder = self.files[index]
path = glob.glob(os.path.join(folder, "*.json"))[0]
with open(path, "r") as f:
labels = json.load(f)["labels"]
indices = [self.class2idx[label] for label in labels]
target: Tensor = torch.zeros( # type: ignore[attr-defined]
self.num_classes, dtype=torch.long # type: ignore[attr-defined]
)
target[indices] = 1
return target
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
keys = ["s1", "s2"] if self.bands == "all" else [self.bands]
urls = [self.metadata[k]["url"] for k in keys]
md5s = [self.metadata[k]["md5"] for k in keys]
filenames = [self.metadata[k]["filename"] for k in keys]
directories = [self.metadata[k]["directory"] for k in keys]
# Check if the files already exist
exists = [
os.path.exists(os.path.join(self.root, directory))
for directory in directories
]
if all(exists):
return
# Check if zip file already exists (if so then extract)
exists = []
for filename in filenames:
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
exists.append(True)
self._extract(filepath)
else:
exists.append(False)
if all(exists):
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download and extract the dataset
for url, filename, md5 in zip(urls, filenames, md5s):
self._download(url, filename, md5)
filepath = os.path.join(self.root, filename)
self._extract(filepath)
def _download(self, url: str, filename: str, md5: str) -> None:
"""Download the dataset.
Args:
url: url to download file
filename: output filename to write downloaded file
md5: md5 of downloaded file
"""
download_url(
url,
self.root,
filename=filename,
md5=md5 if self.checksum else None,
)
def _extract(self, filepath: str) -> None:
"""Extract the dataset.
Args:
filepath: path to file to be extracted
"""
extract_archive(filepath)

Просмотреть файл

@ -19,7 +19,13 @@ import torch
from rasterio.crs import CRS
from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, check_integrity, download_and_extract_archive
from .utils import (
BoundingBox,
check_integrity,
download_and_extract_archive,
download_url,
extract_archive,
)
class Chesapeake(RasterDataset, abc.ABC):
@ -376,21 +382,15 @@ class ChesapeakeCVPR(GeoDataset):
for split in splits:
assert split in self.splits
assert all([layer in self.valid_layers for layer in layers])
super().__init__(transforms) # creates self.index and self.transform
self.root = root
self.layers = layers
self.cache = cache
self.download = download
self.checksum = checksum
if download and not self._check_structure():
self._download()
self._verify()
if checksum:
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
super().__init__(transforms)
# Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0
@ -496,40 +496,45 @@ class ChesapeakeCVPR(GeoDataset):
return sample
def _check_integrity(self) -> bool:
"""Check integrity of the dataset archive.
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Returns:
True if dataset archive is found and/or MD5s match, else False
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
integrity: bool = check_integrity(
os.path.join(self.root, self.filename),
self.md5 if self.checksum else None,
)
# Check if the extracted files already exist
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))
return integrity
def _check_structure(self) -> bool:
"""Checks to see if the dataset files exist in the root directory.
Returns:
True if the dataset files are found, else False
"""
dataset_files = os.listdir(self.root)
for file in self.files:
if file not in dataset_files:
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
if all(map(exists, self.files)):
return
download_and_extract_archive(
# Check if the zip files have already been downloaded
if os.path.exists(os.path.join(self.root, self.filename)):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5,
)
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.filename))

Просмотреть файл

@ -5,7 +5,6 @@
import glob
import os
import shutil
from typing import Callable, Dict, List, Optional
import numpy as np
@ -219,7 +218,7 @@ class ETCI2021(VisionDataset):
with Image.open(filename) as img:
array = np.array(img.convert("L"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = torch.clamp(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
return tensor
@ -251,6 +250,3 @@ class ETCI2021(VisionDataset):
filename=self.metadata[self.split]["filename"],
md5=self.metadata[self.split]["md5"] if self.checksum else None,
)
if os.path.exists(os.path.join(self.root, "__MACOSX")):
shutil.rmtree(os.path.join(self.root, "__MACOSX"))

Просмотреть файл

@ -22,6 +22,7 @@ import torch
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rasterio.windows import from_bounds
from rtree.index import Index, Property
from torch import Tensor
from torch.utils.data import Dataset
@ -104,6 +105,15 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""
return ZipDataset([self, other])
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns:
length of the dataset
"""
count: int = self.index.count(self.index.bounds)
return count
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -113,7 +123,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
return f"""\
{self.__class__.__name__} Dataset
type: GeoDataset
bbox: {self.bounds}"""
bbox: {self.bounds}
size: {len(self)}"""
@property
def bounds(self) -> BoundingBox:
@ -286,7 +297,7 @@ class RasterDataset(GeoDataset):
filepath = glob.glob(os.path.join(directory, filename))[0]
band_filepaths.append(filepath)
data_list.append(self._merge_files(band_filepaths, query))
data = torch.stack(data_list)
data = torch.cat(data_list) # type: ignore[attr-defined]
else:
data = self._merge_files(filepaths, query)
@ -318,7 +329,16 @@ class RasterDataset(GeoDataset):
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
bounds = (query.minx, query.miny, query.maxx, query.maxy)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
if len(vrt_fhs) == 1:
src = vrt_fhs[0]
out_width = int(round((query.maxx - query.minx) / self.res))
out_height = int(round((query.maxy - query.miny) / self.res))
out_shape = (src.count, out_height, out_width)
dest = src.read(
out_shape=out_shape, window=from_bounds(*bounds, src.transform)
)
else:
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
dest = dest.astype(np.int32)
tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]
@ -447,7 +467,7 @@ class VectorDataset(GeoDataset):
(minx, maxx), (miny, maxy) = fiona.transform.transform(
src.crs, crs.to_dict(), [minx, maxx], [miny, maxy]
)
except (fiona.errors.DriverError, fiona.errors.FionaValueError):
except fiona.errors.FionaValueError:
# Skip files that fiona is unable to read
continue
else:
@ -719,6 +739,14 @@ class ZipDataset(GeoDataset):
sample.update(ds[query])
return sample
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns:
length of the dataset
"""
return sum(map(len, self.datasets))
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -728,7 +756,8 @@ class ZipDataset(GeoDataset):
return f"""\
{self.__class__.__name__} Dataset
type: ZipDataset
bbox: {self.bounds}"""
bbox: {self.bounds}
size: {len(self)}"""
@property
def bounds(self) -> BoundingBox:

Просмотреть файл

@ -170,7 +170,7 @@ class LEVIRCDPlus(VisionDataset):
with Image.open(filename) as img:
array = np.array(img.convert("L"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = torch.clamp(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
return tensor

Просмотреть файл

@ -14,6 +14,7 @@ import numpy as np
import rasterio as rio
import torch
from affine import Affine
from fiona.errors import FionaValueError
from rasterio.features import rasterize
from torch import Tensor
@ -154,8 +155,11 @@ class SpaceNet(VisionDataset, abc.ABC):
Returns:
Tensor: label tensor
"""
with fiona.open(path) as src:
labels = [feature["geometry"] for feature in src]
try:
with fiona.open(path) as src:
labels = [feature["geometry"] for feature in src]
except FionaValueError:
labels = []
if not labels:
mask_data = np.zeros(shape=shape)
@ -490,3 +494,181 @@ class SpaceNet2(SpaceNet):
)
files.append({"image_path": imgpath, "label_path": lbl_path})
return files
class SpaceNet4(SpaceNet):
"""SpaceNet 4: Off-Nadir Buildings Dataset.
`SpaceNet 4 <https://spacenet.ai/off-nadir-building-detection/>`_ is a
dataset of 27 WV-2 imagery captured at varying off-nadir angles and
associated building footprints over the city of Atlanta. The off-nadir angle
ranges from 7 degrees to 54 degrees.
Dataset features
* No. of chipped images: 28,728 (PAN/MS/PS-RGBNIR)
* No. of label files: 1064
* No. of building footprints: >120,000
* Area Coverage: 665 sq km
* Chip size: 225 x 225 (MS), 900 x 900 (PAN/PS-RGBNIR)
Dataset format
* Imagery - Worldview-2 GeoTIFFs
* PAN.tif (Panchromatic)
* MS.tif (Multispectral)
* PS-RGBNIR (Pansharpened RGBNIR)
* Labels - GeoJSON
* labels.geojson
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/1903.12239
.. note::
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
"""
dataset_id = "spacenet4"
collection_md5_dict = {
"sn4_AOI_6_Atlanta": "c597d639cba5257927a97e3eff07b753",
}
imagery = {
"MS": "MS.tif",
"PAN": "PAN.tif",
"PS-RGBNIR": "PS-RGBNIR.tif",
}
chip_size = {
"MS": (225, 225),
"PAN": (900, 900),
"PS-RGBNIR": (900, 900),
}
label_glob = "labels.geojson"
angle_catalog_map = {
"nadir": [
"1030010003D22F00",
"10300100023BC100",
"1030010003993E00",
"1030010003CAF100",
"1030010002B7D800",
"10300100039AB000",
"1030010002649200",
"1030010003C92000",
"1030010003127500",
"103001000352C200",
"103001000307D800",
],
"off-nadir": [
"1030010003472200",
"1030010003315300",
"10300100036D5200",
"103001000392F600",
"1030010003697400",
"1030010003895500",
"1030010003832800",
],
"very-off-nadir": [
"10300100035D1B00",
"1030010003CCD700",
"1030010003713C00",
"10300100033C5200",
"1030010003492700",
"10300100039E6200",
"1030010003BDDC00",
"1030010003CD4300",
"1030010003193D00",
],
}
def __init__(
self,
root: str,
image: str = "PS-RGBNIR",
angles: List[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 4 Dataset instance.
Args:
root: root directory where dataset can be found
image: image selection which must be in ["MS", "PAN", "PS-RGBNIR"]
angles: angle selection which must be in ["nadir", "off-nadir",
"very-off-nadir"]
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
collections = ["sn4_AOI_6_Atlanta"]
assert image in {"MS", "PAN", "PS-RGBNIR"}
self.angles = angles
if self.angles:
for angle in self.angles:
assert angle in self.angle_catalog_map.keys()
super().__init__(
root, image, collections, transforms, download, api_key, checksum
)
def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for each pair of image and label
"""
files = []
nadir = []
offnadir = []
veryoffnadir = []
images = glob.glob(os.path.join(root, self.collections[0], "*", self.filename))
images = sorted(images)
catalog_id_pattern = re.compile(r"(_[A-Z0-9])\w+$")
for imgpath in images:
imgdir = os.path.basename(os.path.dirname(imgpath))
match = catalog_id_pattern.search(imgdir)
assert match is not None, "Invalid image directory"
catalog_id = match.group()[1:]
lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
assert os.path.exists(lbl_path)
_file = {"image_path": imgpath, "label_path": lbl_path}
if catalog_id in self.angle_catalog_map["very-off-nadir"]:
veryoffnadir.append(_file)
elif catalog_id in self.angle_catalog_map["off-nadir"]:
offnadir.append(_file)
elif catalog_id in self.angle_catalog_map["nadir"]:
nadir.append(_file)
angle_file_map = {
"nadir": nadir,
"off-nadir": offnadir,
"very-off-nadir": veryoffnadir,
}
if not self.angles:
files.extend(nadir + offnadir + veryoffnadir)
else:
for angle in self.angles:
files.extend(angle_file_map[angle])
return files

Просмотреть файл

@ -18,6 +18,7 @@ import numpy as np
import rasterio
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split
from torchvision.datasets.utils import check_integrity, download_url
__all__ = (
@ -30,6 +31,7 @@ __all__ = (
"working_dir",
"collate_dict",
"rasterio_loader",
"dataset_split",
)
@ -394,3 +396,28 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
# VisionClassificationDataset expects images returned with channels last (HWC)
array = array.transpose(1, 2, 0)
return array
def dataset_split(
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
) -> List[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.
If ``test_pct`` is not set then only train and validation splits are returned.
Args:
dataset: dataset to be split into train/val or train/val/test subsets
val_pct: percentage of samples to be in validation set
test_pct: (Optional) percentage of samples to be in test set
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
train_length = len(dataset) - val_length # type: ignore[arg-type]
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
test_length = int(len(dataset) * test_pct) # type: ignore[arg-type]
train_length = len(dataset) - (val_length + test_length) # type: ignore[arg-type] # noqa: E501
return random_split(dataset, [train_length, val_length, test_length])

Просмотреть файл

@ -7,6 +7,7 @@ from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .farseg import FarSeg
from .fccd import FCEF, FCSiamConc, FCSiamDiff
from .fcn import FCN
from .rcf import RCF
__all__ = (
"ChangeMixin",
@ -17,6 +18,7 @@ __all__ = (
"FCEF",
"FCSiamConc",
"FCSiamDiff",
"RCF",
)
# https://stackoverflow.com/questions/40018681

99
torchgeo/models/rcf.py Normal file
Просмотреть файл

@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Implementation of a random convolutional feature projection model."""
from typing import cast
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Conv2d, Module
Module.__module__ = "torch.nn"
Conv2d.__module__ = "torch.nn"
class RCF(Module):
"""This model extracts random convolutional features (RCFs) from its input.
RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
.. note::
This Module is *not* trainable. It is only used as a feature extractor.
"""
def __init__(
self,
in_channels: int = 4,
features: int = 16,
kernel_size: int = 3,
bias: float = -1.0,
) -> None:
"""Initializes the RCF model.
This is a static model that serves to extract fixed length feature vectors from
input patches.
Args:
in_channels: number of input channels
features: number of features to compute, must be divisible by 2
kernel_size: size of the kernel used to compute the RCFs
bias: bias of the convolutional layer
"""
super().__init__() # type: ignore[no-untyped-call]
assert features % 2 == 0
# We register the weight and bias tensors as "buffers". This does two things:
# makes them behave correctly when we call .to(...) on the module, and makes
# them explicitely _not_ Parameters of the model (which might get updated) if
# a user tries to train with this model.
self.register_buffer(
"weights",
torch.randn(
features // 2,
in_channels,
kernel_size,
kernel_size,
requires_grad=False,
),
)
self.register_buffer(
"biases",
torch.zeros( # type: ignore[attr-defined]
features // 2, requires_grad=False
)
+ bias,
)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the RCF model.
Args:
x: a tensor with shape (B, C, H, W)
Returns:
a tensor of size (B, ``self.num_features``)
"""
x1a = F.relu(
F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=True,
)
x1b = F.relu(
-F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=False,
)
x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
if len(x1a.shape) == 1: # case where we passed a single input
output = torch.cat((x1a, x1b), dim=0) # type: ignore[attr-defined]
return cast(Tensor, output)
else: # case where we passed a batch of > 1 inputs
assert len(x1a.shape) == 2
output = torch.cat((x1a, x1b), dim=1) # type: ignore[attr-defined]
return cast(Tensor, output)

Просмотреть файл

@ -7,10 +7,9 @@ import abc
import random
from typing import Iterator, List, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from torchgeo.datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
@ -43,13 +42,13 @@ class RandomBatchGeoSampler(BatchGeoSampler):
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a tile-based dataset if possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a tile-based dataset if possible.
"""
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
batch_size: int,
length: int,
@ -65,21 +64,22 @@ class RandomBatchGeoSampler(BatchGeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
batch_size: number of samples per batch
length: number of samples per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.res = dataset.res
self.size = _to_tuple(size)
self.batch_size = batch_size
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.
@ -96,7 +96,7 @@ class RandomBatchGeoSampler(BatchGeoSampler):
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size)
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
batch.append(bounding_box)
yield batch

Просмотреть файл

@ -7,10 +7,9 @@ import abc
import random
from typing import Iterator, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from torchgeo.datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
@ -49,7 +48,7 @@ class RandomGeoSampler(GeoSampler):
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
length: int,
roi: Optional[BoundingBox] = None,
@ -64,19 +63,20 @@ class RandomGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
length: number of random samples to draw per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.res = dataset.res
self.size = _to_tuple(size)
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -90,7 +90,7 @@ class RandomGeoSampler(GeoSampler):
bounds = BoundingBox(*hit.bounds)
# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size)
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
yield bounding_box
@ -117,13 +117,13 @@ class GridGeoSampler(GeoSampler):
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a non-tile-based dataset if possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a non-tile-based dataset if possible.
"""
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
stride: Union[Tuple[float, float], float],
roi: Optional[BoundingBox] = None,
@ -138,18 +138,19 @@ class GridGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
stride: distance to skip between each patch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
self.length: int = 0
for hit in self.hits:

Просмотреть файл

@ -25,7 +25,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
def get_random_bounding_box(
bounds: BoundingBox, size: Union[Tuple[float, float], float]
bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
@ -45,13 +45,16 @@ def get_random_bounding_box(
"""
t_size: Tuple[float, float] = _to_tuple(size)
minx = random.uniform(bounds.minx, bounds.maxx - t_size[1])
width = (bounds.maxx - bounds.minx - t_size[1]) // res
minx = random.randrange(int(width)) * res + bounds.minx
maxx = minx + t_size[1]
miny = random.uniform(bounds.miny, bounds.maxy - t_size[0])
height = (bounds.maxy - bounds.miny - t_size[0]) // res
miny = random.randrange(int(height)) * res + bounds.miny
maxy = miny + t_size[0]
mint = bounds.mint
maxt = bounds.maxt
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
query = BoundingBox(minx, maxx, miny, maxy, mint, maxt)
return query

Просмотреть файл

@ -8,6 +8,7 @@ from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule
@ -21,6 +22,8 @@ __all__ = (
"LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule",
"NAIPChesapeakeSegmentationTask",
"RESISC45ClassificationTask",
"RESISC45DataModule",
"SEN12MSDataModule",
"SEN12MSSegmentationTask",
"So2SatDataModule",

Просмотреть файл

@ -0,0 +1,341 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""RESISC45 trainer."""
from typing import Any, Dict, Optional, cast
import kornia.augmentation as K
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision.models
from torch import Tensor
from torch.nn.modules import Conv2d, Linear, Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, IoU, MetricCollection
from torchvision.transforms import Compose, Normalize
from ..datasets import RESISC45
from ..datasets.utils import dataset_split
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
Module.__module__ = "torch.nn"
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"
IN_CHANNELS = 3
NUM_CLASSES = 45
class RESISC45ClassificationTask(pl.LightningModule):
"""LightningModule for training models on the RESISC45 Dataset."""
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
pretrained = "imagenet" in self.hparams["weights"]
if "resnet" in self.hparams["classification_model"]:
self.model = getattr(
torchvision.models.resnet, self.hparams["classification_model"]
)(pretrained=pretrained)
in_features = self.model.fc.in_features
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
else:
raise ValueError(
f"Model type '{self.hparams['classification_model']}' is not valid."
)
if "resnet" in self.hparams["classification_model"]:
if self.hparams["weights"] in ["imagenet_only", "random"]:
pass
else:
raise ValueError(
f"Weight type '{self.hparams['weights']}' is not valid."
)
else:
pass # stub for initializing the weights of other models
if self.hparams["loss"] == "ce":
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
else:
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
def __init__(
self,
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task()
self.train_metrics = MetricCollection(
{
"OverallAccuracy": Accuracy(num_classes=NUM_CLASSES, average="micro"),
"AverageAccuracy": Accuracy(num_classes=NUM_CLASSES, average="macro"),
"IoU": IoU(num_classes=NUM_CLASSES),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the train step logs every `log_every_n_steps` steps where
# `log_every_n_steps` is a parameter to the `Trainer` object
self.log("train_loss", loss, on_step=True, on_epoch=False)
self.train_metrics(y_hat_hard, y)
return cast(Tensor, loss)
def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch-level training metrics.
Args:
outputs: list of items returned by training_step
"""
self.log_dict(self.train_metrics.compute())
self.train_metrics.reset()
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.val_metrics(y_hat_hard, y)
def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics.
Args:
outputs: list of items returned by validation_step
"""
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)
def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics.
Args:
outputs: list of items returned by test_step
"""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation --
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"],
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
},
}
class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.36801773, 0.38097873, 0.343583]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.14540215, 0.13558227, 0.13203649]
)
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
weights: str = "random",
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.weights = weights
self.unsupervised_mode = unsupervised_mode
self.val_split_pct = kwargs["val_split_pct"]
self.test_split_pct = kwargs["test_split_pct"]
self.norm = Normalize(self.band_means, self.band_stds)
self.transforms = K.AugmentationSequential(
K.RandomAffine(degrees=30),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(),
data_keys=["input"],
)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
return sample
def kornia_pipeline(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset with Kornia."""
sample["image"] = self.transforms(sample["image"]).squeeze()
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
RESISC45(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])
if not self.unsupervised_mode:
dataset = RESISC45(
self.root_dir,
transforms=transforms,
)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
else:
self.train_dataset = RESISC45(
self.root_dir,
transforms=transforms,
)
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
if self.val_dataset is None or len(self.val_dataset) == 0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
if self.test_dataset is None or len(self.test_dataset) == 0:
return self.train_dataloader()
else:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -222,12 +222,12 @@ class SEN12MSDataModule(pl.LightningDataModule):
sample["image"] = sample["image"].float()
if self.band_set == "all":
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
sample["image"][2:] = sample["image"][2:].clip(0, 10000) / 10000
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
elif self.band_set == "s1":
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
else:
sample["image"][:] = sample["image"][:].clip(0, 10000) / 10000
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
sample["mask"] = sample["mask"][0, :, :].long()
sample["mask"] = torch.take( # type: ignore[attr-defined]

Просмотреть файл

@ -3,6 +3,7 @@
"""So2Sat trainer."""
import os
from typing import Any, Dict, Optional, cast
import kornia.augmentation as K
@ -19,6 +20,7 @@ from torchmetrics import Accuracy, IoU, MetricCollection
from torchvision.transforms import Compose
from ..datasets import So2Sat
from . import utils
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
@ -27,7 +29,6 @@ Module.__module__ = "torch.nn"
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"
IN_CHANNELS = 10
NUM_CLASSES = 17
@ -36,83 +37,65 @@ class So2SatClassificationTask(pl.LightningModule):
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
pretrained = "imagenet" in self.hparams["weights"]
pretrained = ("imagenet" in self.hparams["weights"]) and not os.path.exists(
self.hparams["weights"]
)
in_channels = self.hparams["in_channels"]
if self.hparams["classification_model"] == "resnet18":
self.model = torchvision.models.resnet18(pretrained=pretrained)
self.model.fc = Linear(512, out_features=NUM_CLASSES)
elif self.hparams["classification_model"] == "resnet50":
self.model = torchvision.models.resnet50(pretrained=pretrained)
self.model.fc = Linear(2048, out_features=NUM_CLASSES)
elif self.hparams["classification_model"] == "resnet152":
self.model = torchvision.models.resnet152(pretrained=pretrained)
self.model.fc = Linear(2048, out_features=NUM_CLASSES)
# Create the model
if "resnet" in self.hparams["classification_model"]:
self.model = getattr(
torchvision.models.resnet, self.hparams["classification_model"]
)(pretrained=pretrained)
# Update first layer
w_old = None
if pretrained:
w_old = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
# Create the new layer
self.model.conv1 = Conv2d(
in_channels,
64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
)
if pretrained:
w_new = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
w_new[:, :3, :, :] = w_old
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined]
w_new
)
# Update last layer
in_features = self.model.fc.in_features
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
else:
raise ValueError(
f"Model type '{self.hparams['classification_model']}' is not valid."
)
if "resnet" in self.hparams["classification_model"]:
if os.path.exists(self.hparams["weights"]):
name, state_dict = utils.extract_encoder(self.hparams["weights"])
if self.hparams["weights"] == "imagenet_only":
pass
elif self.hparams["weights"] == "imagenet_and_random":
# save the initial imagenet weights
w_old = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
if self.hparams["classification_model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hparams['classification_model']}"
)
# replace the first conv layer (with random weights)
self.model.conv1 = Conv2d(
IN_CHANNELS,
64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
)
self.model = utils.load_state_dict(self.model, state_dict)
w_new = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
# graft the imagenet weights into the first 3 channels
w_new[:, :3, :, :] = w_old
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined]
w_new
)
elif self.hparams["weights"] == "random":
self.model.conv1 = Conv2d(
IN_CHANNELS,
64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
)
elif self.hparams["weights"] == "random_rgb":
self.model.conv1 = Conv2d(
3,
64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
)
else:
raise ValueError(
f"Weight type '{self.hparams['weights']}' is not valid."
)
else:
pass # stub for initializing the weights of other models
@ -351,7 +334,7 @@ class So2SatDataModule(pl.LightningDataModule):
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
weights: str = "random",
bands: str = "rgb",
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
@ -361,8 +344,7 @@ class So2SatDataModule(pl.LightningDataModule):
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
bands: Either "rgb" or "s2"
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
@ -370,7 +352,7 @@ class So2SatDataModule(pl.LightningDataModule):
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.weights = weights
self.bands = bands
self.unsupervised_mode = unsupervised_mode
self.transforms = K.AugmentationSequential(
@ -386,7 +368,7 @@ class So2SatDataModule(pl.LightningDataModule):
sample["image"] = sample["image"].float()
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
if self.weights == "imagenet_only" or self.weights == "random_rgb":
if self.bands == "rgb":
sample["image"] = sample["image"][:3, :, :]
return sample

Просмотреть файл

@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Common trainer utilities."""
import warnings
from collections import OrderedDict
from typing import Dict, Tuple
import torch
from torch import Tensor
from torch.nn.modules import Module
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "nn.Module"
def extract_encoder(path: str) -> Tuple[str, Dict[str, Tensor]]:
"""Extracts an encoder from a pytorch lightning checkpoint file.
Args:
path: path to checkpoint file (.ckpt)
Returns:
name: str representing model name (generally a torchvision resnet model)
state_dict: dict of layer names and weight tensors
Raises:
ValueError: if 'classification_model' or 'encoder' not in
checkpoint['hyper_parameters']
"""
checkpoint = torch.load( # type: ignore[no-untyped-call]
path, map_location=torch.device("cpu") # type: ignore[attr-defined]
)
if "classification_model" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["classification_model"]
state_dict = checkpoint["state_dict"]
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
state_dict = OrderedDict(
{k.replace("model.", ""): v for k, v in state_dict.items()}
)
elif "encoder" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["encoder"]
state_dict = checkpoint["state_dict"]
state_dict = OrderedDict(
{k: v for k, v in state_dict.items() if "model.encoder.model" in k}
)
state_dict = OrderedDict(
{k.replace("model.encoder.model.", ""): v for k, v in state_dict.items()}
)
else:
raise ValueError(
"""Unknown checkpoint task. Only encoder or classification_model"""
"""extraction is supported"""
)
return name, state_dict
def load_state_dict(model: Module, state_dict: Dict[str, Tensor]) -> Module:
"""Load pretrained resnet weights to a model.
Args:
model: model to load the pretrained weights to
state_dict: dict containing tensor parameters
Returns:
the model with pretrained weights
Warns:
If input channels in model != pretrained model input channels
If num output classes in model != pretrained model num classes
"""
in_channels = model.conv1.in_channels # type: ignore[union-attr]
expected_in_channels = state_dict["conv1.weight"].shape[1]
num_classes = model.fc.out_features # type: ignore[union-attr]
expected_num_classes = state_dict["fc.weight"].shape[0]
if in_channels != expected_in_channels:
warnings.warn(
f"""input channels {in_channels} != input channels in pretrained"""
"""model {expected_in_channels}. Overriding with new input channels"""
)
del state_dict["conv1.weight"]
if num_classes != expected_num_classes:
warnings.warn(
f"""num classes {num_classes} != num classes in pretrained model"""
"""{expected_num_classes}. Overriding with new num classes"""
)
del state_dict["fc.weight"], state_dict["fc.bias"]
model.load_state_dict(state_dict, strict=False) # type: ignore[arg-type]
return model

Просмотреть файл

@ -4,12 +4,7 @@
"""TorchGeo transforms."""
from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
from .transforms import (
AugmentationSequential,
Identity,
RandomHorizontalFlip,
RandomVerticalFlip,
)
from .transforms import AugmentationSequential
__all__ = (
"AppendNDBI",
@ -17,9 +12,6 @@ __all__ = (
"AppendNDVI",
"AppendNDWI",
"AugmentationSequential",
"Identity",
"RandomHorizontalFlip",
"RandomVerticalFlip",
)
# https://stackoverflow.com/questions/40018681

Просмотреть файл

@ -15,91 +15,6 @@ from torch.nn import Module # type: ignore[attr-defined]
Module.__module__ = "torch.nn"
class RandomHorizontalFlip(Module): # type: ignore[misc,name-defined]
"""Horizontally flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Args:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Randomly flip the image and target tensors.
Args:
sample: a single data sample
Returns:
a possibly flipped sample
"""
if torch.rand(1) < self.p:
if "image" in sample:
sample["image"] = sample["image"].flip(-1)
if "boxes" in sample:
height, width = sample["image"].shape[-2:]
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
if "mask" in sample:
sample["mask"] = sample["mask"].flip(-1)
return sample
class RandomVerticalFlip(Module): # type: ignore[misc,name-defined]
"""Vertically flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Args:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Randomly flip the image and target tensors.
Args:
sample: a single data sample
Returns:
a possibly flipped sample
"""
if torch.rand(1) < self.p:
if "image" in sample:
sample["image"] = sample["image"].flip(-2)
if "boxes" in sample:
height, width = sample["image"].shape[-2:]
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
if "mask" in sample:
sample["mask"] = sample["mask"].flip(-2)
return sample
class Identity(Module): # type: ignore[misc,name-defined]
"""Identity function used for testing purposes."""
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Do nothing.
Args:
sample: the input
Returns:
the unchanged input
"""
return sample
class AugmentationSequential(Module): # type: ignore[misc]
"""Wrapper around kornia AugmentationSequential to handle input dicts."""

Просмотреть файл

@ -23,6 +23,8 @@ from torchgeo.trainers import (
LandcoverAISegmentationTask,
NAIPChesapeakeDataModule,
NAIPChesapeakeSegmentationTask,
RESISC45ClassificationTask,
RESISC45DataModule,
SEN12MSDataModule,
SEN12MSSegmentationTask,
So2SatClassificationTask,
@ -39,6 +41,7 @@ TASK_TO_MODULES_MAPPING: Dict[
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
"so2sat": (So2SatClassificationTask, So2SatDataModule),
}