зеркало из https://github.com/microsoft/torchgeo.git
Merge branch 'main' of https://github.com/microsoft/torchgeo into feature/ssl_experiments_sampling
This commit is contained in:
Коммит
bae8d0e979
|
@ -3,6 +3,7 @@
|
|||
/logs/
|
||||
/output/
|
||||
*.csv
|
||||
*.pdf
|
||||
|
||||
# Spack
|
||||
.spack-env/
|
||||
|
|
39
benchmark.py
39
benchmark.py
|
@ -12,9 +12,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.models import resnet34
|
||||
|
||||
from torchgeo.datasets import CDL, BoundingBox, Landsat8
|
||||
from torchgeo.datasets import CDL, Landsat8
|
||||
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
|
||||
|
||||
|
||||
|
@ -83,7 +83,7 @@ def set_up_parser() -> argparse.ArgumentParser:
|
|||
parser.add_argument(
|
||||
"-s",
|
||||
"--stride",
|
||||
default=2 ** 7,
|
||||
default=112,
|
||||
type=int,
|
||||
help="sampling stride for GridGeoSampler",
|
||||
)
|
||||
|
@ -145,26 +145,15 @@ def main(args: argparse.Namespace) -> None:
|
|||
length = args.num_batches * args.batch_size
|
||||
num_batches = args.num_batches
|
||||
|
||||
# Workaround for https://github.com/microsoft/torchgeo/issues/149
|
||||
roi = BoundingBox(
|
||||
-2000000, 2200000, 280000, 3170000, dataset.bounds.mint, dataset.bounds.maxt
|
||||
)
|
||||
# Convert from pixel coords to CRS coords
|
||||
size = args.patch_size * cdl.res
|
||||
stride = args.stride * cdl.res
|
||||
|
||||
samplers = [
|
||||
RandomGeoSampler(
|
||||
landsat.index,
|
||||
size=args.patch_size,
|
||||
length=length,
|
||||
roi=roi,
|
||||
),
|
||||
GridGeoSampler(
|
||||
landsat.index, size=args.patch_size, stride=args.stride, roi=roi
|
||||
),
|
||||
RandomGeoSampler(landsat, size=size, length=length),
|
||||
GridGeoSampler(landsat, size=size, stride=stride),
|
||||
RandomBatchGeoSampler(
|
||||
landsat.index,
|
||||
size=args.patch_size,
|
||||
batch_size=args.batch_size,
|
||||
length=length,
|
||||
roi=roi,
|
||||
landsat, size=size, batch_size=args.batch_size, length=length
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -224,7 +213,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
)
|
||||
|
||||
# Benchmark model
|
||||
model = resnet18()
|
||||
model = resnet34()
|
||||
# Change number of input channels to match Landsat
|
||||
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
|
||||
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
|
@ -259,7 +248,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
duration = toc - tic
|
||||
|
||||
if args.verbose:
|
||||
print("\nResNet-18:")
|
||||
print("\nResNet-34:")
|
||||
print(f" duration: {duration:.3f} sec")
|
||||
print(f" count: {num_total_patches} patches")
|
||||
print(f" rate: {num_total_patches / duration:.3f} patches/sec")
|
||||
|
@ -271,7 +260,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
"duration": duration,
|
||||
"count": num_total_patches,
|
||||
"rate": num_total_patches / duration,
|
||||
"sampler": "resnet18",
|
||||
"sampler": "ResNet-34",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
}
|
||||
|
@ -297,8 +286,6 @@ def main(args: argparse.Namespace) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["GDAL_CACHEMAX"] = "50%"
|
||||
|
||||
parser = set_up_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "resisc45"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
|
@ -11,6 +11,8 @@ experiment:
|
|||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 3
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
bands: "rgb"
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "resisc45"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
weights: ${experiment.module.weights}
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
|
@ -6,7 +6,8 @@ experiment:
|
|||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 3
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
weights: ${experiment.module.weights}
|
||||
bands: "rgb"
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINX_BUILD ?= sphinx-build
|
||||
SPHINX_BUILD_OPTS ?= -W --keep-going
|
||||
SPHINX_BUILD_OPTS ?= -W -j auto --keep-going
|
||||
SOURCE_DIR = .
|
||||
BUILD_DIR = _build
|
||||
|
||||
|
|
|
@ -80,6 +80,11 @@ Smallholder Cashew Plantations in Benin
|
|||
|
||||
.. autoclass:: BeninSmallHolderCashews
|
||||
|
||||
BigEarthNet
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: BigEarthNet
|
||||
|
||||
Cars Overhead With Context (COWC)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -143,6 +148,7 @@ SpaceNet
|
|||
.. autoclass:: SpaceNet
|
||||
.. autoclass:: SpaceNet1
|
||||
.. autoclass:: SpaceNet2
|
||||
.. autoclass:: SpaceNet4
|
||||
|
||||
Tropical Cyclone Wind Estimation Competition
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -1,4 +1,33 @@
|
|||
torchgeo.models
|
||||
=================
|
||||
|
||||
.. automodule:: torchgeo.models
|
||||
.. module:: torchgeo.models
|
||||
|
||||
Change Star
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ChangeStar
|
||||
.. autoclass:: ChangeStarFarSeg
|
||||
.. autoclass:: ChangeMixin
|
||||
|
||||
Foreground-aware Relation Network (FarSeg)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FarSeg
|
||||
|
||||
Fully-convolutional Network (FCN)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FCN
|
||||
|
||||
Fully Convolutional Siamese Networks for Change Detection
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FCEF
|
||||
.. autoclass:: FCSiamConc
|
||||
.. autoclass:: FCSiamDiff
|
||||
|
||||
Random-convolutional feature (RCF) extractor
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RCF
|
||||
|
|
|
@ -16,7 +16,7 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
|
|||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
|
||||
sampler = RandomGeoSampler(dataset, size=1000, length=100)
|
||||
dataloader = DataLoader(dataset, sampler=sampler)
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ When working with large tile-based datasets, randomly sampling patches from each
|
|||
from torchgeo.samplers import RandomBatchGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
|
||||
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
|
||||
dataloader = DataLoader(dataset, batch_sampler=sampler)
|
||||
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ html_theme_options = {
|
|||
"logo_only": True,
|
||||
"pytorch_project": "docs",
|
||||
"navigation_with_keys": True,
|
||||
"analytics_id": "UA-117752657-2",
|
||||
"analytics_id": "UA-209075005-1",
|
||||
}
|
||||
|
||||
html_favicon = os.path.join("..", "logo", "favicon.ico")
|
||||
|
|
|
@ -211,7 +211,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = RandomGeoSampler(naip.index, size=1000, length=888)\n",
|
||||
" sampler = RandomGeoSampler(naip, size=1000, length=888)\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -262,7 +262,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = GridGeoSampler(naip.index, size=1000, stride=500)\n",
|
||||
" sampler = GridGeoSampler(naip, size=1000, stride=500)\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -313,7 +313,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)\n",
|
||||
" sampler = RandomBatchGeoSampler(naip, size=1000, batch_size=12, length=888)\n",
|
||||
" dataloader = DataLoader(dataset, batch_sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
|
|
@ -329,7 +329,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sampler = RandomGeoSampler(naip.index, size=1000, length=10)"
|
||||
"sampler = RandomGeoSampler(naip, size=1000, length=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -224,7 +224,7 @@
|
|||
" min = np.percentile(x, 100 - percentile, axis=-1)[:, None, None]\n",
|
||||
" max = np.percentile(x, percentile, axis=-1)[:, None, None]\n",
|
||||
" x = x.reshape(c, h, w)\n",
|
||||
" x = np.clamp(x, min, max)\n",
|
||||
" x = np.clip(x, min, max)\n",
|
||||
" return (x - min) / (max - min)"
|
||||
],
|
||||
"execution_count": null,
|
||||
|
@ -687,4 +687,4 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df1 = pd.read_csv("original-benchmark-results.csv")
|
||||
df2 = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
mean1 = df1.groupby("sampler").mean()
|
||||
mean2 = df2.groupby("sampler").mean()
|
||||
|
||||
cached1 = (
|
||||
df1[(df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
cached2 = (
|
||||
df2[(df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
not_cached1 = (
|
||||
df1[(~df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
not_cached2 = (
|
||||
df2[(~df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
|
||||
print("cached, original\n", cached1)
|
||||
print("cached, warped\n", cached2)
|
||||
print("not cached, original\n", not_cached1)
|
||||
print("not cached, warped\n", not_cached2)
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
labels = ["GridGeoSampler", "RandomBatchGeoSampler", "RandomGeoSampler"]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
x = np.arange(3)
|
||||
width = 0.2
|
||||
|
||||
rects1 = ax.bar(
|
||||
x - width * 3 / 2,
|
||||
not_cached1["rate"],
|
||||
width,
|
||||
label="Raw Data, Not Cached",
|
||||
color=cmap[0],
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
x - width * 1 / 2,
|
||||
not_cached2["rate"],
|
||||
width,
|
||||
label="Preprocessed, Not Cached",
|
||||
color=cmap[1],
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
x + width * 1 / 2, cached1["rate"], width, label="Raw Data, Cached", color=cmap[2]
|
||||
)
|
||||
rects3 = ax.bar(
|
||||
x + width * 3 / 2,
|
||||
cached2["rate"],
|
||||
width,
|
||||
label="Preprocessed, Cached",
|
||||
color=cmap[3],
|
||||
)
|
||||
|
||||
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, fontsize=12)
|
||||
ax.tick_params(axis="x", labelrotation=10)
|
||||
ax.legend(fontsize="large")
|
||||
|
||||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
random_cached = df[(df["sampler"] == "RandomGeoSampler") & (df["cached"])]
|
||||
random_batch_cached = df[(df["sampler"] == "RandomBatchGeoSampler") & (df["cached"])]
|
||||
grid_cached = df[(df["sampler"] == "GridGeoSampler") & (df["cached"])]
|
||||
other = [
|
||||
("RandomGeoSampler", random_cached),
|
||||
("RandomBatchGeoSampler", random_batch_cached),
|
||||
("GridGeoSampler", grid_cached),
|
||||
]
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
for i, (label, df) in enumerate(other):
|
||||
df = df.groupby("batch_size")
|
||||
ax.plot(df.mean().index, df.mean()["rate"], color=cmap[i], label=label)
|
||||
ax.fill_between(
|
||||
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
|
||||
)
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
|
||||
ax.set_xlabel("batch size", fontsize=12)
|
||||
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
|
||||
ax.legend(loc="center right", fontsize="large")
|
||||
|
||||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df1 = pd.read_csv("original-benchmark-results.csv")
|
||||
df2 = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
random_cached1 = df1[(df1["sampler"] == "RandomGeoSampler") & (df1["cached"])]
|
||||
random_cached2 = df2[(df2["sampler"] == "RandomGeoSampler") & (df2["cached"])]
|
||||
random_cachedp = random_cached1
|
||||
random_cachedp["rate"] /= random_cached2["rate"]
|
||||
|
||||
random_batch_cached1 = df1[
|
||||
(df1["sampler"] == "RandomBatchGeoSampler") & (df1["cached"])
|
||||
]
|
||||
random_batch_cached2 = df2[
|
||||
(df2["sampler"] == "RandomBatchGeoSampler") & (df2["cached"])
|
||||
]
|
||||
random_batch_cachedp = random_batch_cached1
|
||||
random_batch_cachedp["rate"] /= random_batch_cached2["rate"]
|
||||
|
||||
grid_cached1 = df1[(df1["sampler"] == "GridGeoSampler") & (df1["cached"])]
|
||||
grid_cached2 = df2[(df2["sampler"] == "GridGeoSampler") & (df2["cached"])]
|
||||
grid_cachedp = grid_cached1
|
||||
grid_cachedp["rate"] /= grid_cached2["rate"]
|
||||
|
||||
other = [
|
||||
("RandomGeoSampler (cached)", random_cachedp),
|
||||
("RandomBatchGeoSampler (cached)", random_batch_cachedp),
|
||||
("GridGeoSampler (cached)", grid_cachedp),
|
||||
]
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
for i, (label, df) in enumerate(other):
|
||||
df = df.groupby("batch_size")
|
||||
ax.plot([16, 32, 64, 128, 256], df.mean()["rate"], color=cmap[i], label=label)
|
||||
ax.fill_between(
|
||||
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
|
||||
)
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256])
|
||||
ax.set_xlabel("batch size")
|
||||
ax.set_ylabel("% sampling rate (patches/sec)")
|
||||
ax.legend()
|
||||
plt.show()
|
|
@ -3,12 +3,14 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
"""Script for running the benchmark script over a sweep of different options."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
NUM_BATCHES = 100
|
||||
EPOCH_SIZE = 4096
|
||||
|
||||
SEED_OPTIONS = [0, 1, 2]
|
||||
CACHE_OPTIONS = [True, False]
|
||||
|
@ -23,11 +25,14 @@ CDL_DATA_ROOT = ""
|
|||
total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# With 6 workers, this will use ~60% of available RAM
|
||||
os.environ["GDAL_CACHEMAX"] = "10%"
|
||||
|
||||
tic = time.time()
|
||||
for i, (cache, batch_size, seed) in enumerate(
|
||||
itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS)
|
||||
):
|
||||
print(f"{i}/{total_num_experiments} -- {time.time() - tic}")
|
||||
print(f"\n{i}/{total_num_experiments} -- {time.time() - tic}")
|
||||
tic = time.time()
|
||||
command: List[str] = [
|
||||
"python",
|
||||
|
@ -40,8 +45,8 @@ if __name__ == "__main__":
|
|||
"6",
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--num-batches",
|
||||
str(NUM_BATCHES),
|
||||
"--epoch-size",
|
||||
str(EPOCH_SIZE),
|
||||
"--seed",
|
||||
str(seed),
|
||||
"--verbose",
|
||||
|
|
|
@ -15,11 +15,10 @@ DATA_DIR = "" # path to the LandcoverAI data directory
|
|||
|
||||
# Hyperparameter options
|
||||
model_options = ["unet"]
|
||||
encoder_options = ["resnet50"]
|
||||
lr_options = [1e-4]
|
||||
loss_options = ["ce"]
|
||||
weight_init_options = ["imagenet"]
|
||||
seeds = list(range(15))
|
||||
encoder_options = ["resnet18", "resnet50"]
|
||||
lr_options = [1e-2, 1e-3, 1e-4]
|
||||
loss_options = ["ce", "jaccard"]
|
||||
weight_init_options = ["null", "imagenet"]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
|
@ -36,18 +35,13 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
|||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, encoder, lr, loss, weight_init, seed) in itertools.product(
|
||||
model_options,
|
||||
encoder_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
weight_init_options,
|
||||
seeds,
|
||||
for (model, encoder, lr, loss, weight_init) in itertools.product(
|
||||
model_options, encoder_options, lr_options, loss_options, weight_init_options
|
||||
):
|
||||
|
||||
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}_{seed}"
|
||||
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}"
|
||||
|
||||
output_dir = os.path.join("output", "landcoverai_seed_experiments")
|
||||
output_dir = os.path.join("output", "landcoverai_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "landcoverai.yaml")
|
||||
|
||||
|
@ -63,7 +57,6 @@ if __name__ == "__main__":
|
|||
+ f" experiment.module.encoder_name={encoder}"
|
||||
+ f" experiment.module.encoder_weights={weight_init}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.seed={seed}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = [0]
|
||||
DRY_RUN = False # if False then print out the commands to be run, if True then run
|
||||
DATA_DIR = "" # path to the RESISC45 data directory
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["resnet18", "resnet50"]
|
||||
lr_options = [1e-2, 1e-3, 1e-4]
|
||||
loss_options = ["ce"]
|
||||
weight_options = ["imagenet_only", "random"]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, lr, loss, weights) in itertools.product(
|
||||
model_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
weight_options,
|
||||
):
|
||||
|
||||
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}"
|
||||
|
||||
output_dir = os.path.join("output", "resisc45_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "resisc45.yaml")
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, experiment_name)):
|
||||
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" experiment.name={experiment_name}"
|
||||
+ f" experiment.module.classification_model={model}"
|
||||
+ f" experiment.module.learning_rate={lr}"
|
||||
+ f" experiment.module.loss={loss}"
|
||||
+ f" experiment.module.weights={weights}"
|
||||
+ f" experiment.datamodule.weights={weights}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(
|
||||
target=do_work,
|
||||
args=(
|
||||
work,
|
||||
gpu_idx,
|
||||
),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import List
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = [0, 1, 2, 3, 3]
|
||||
DRY_RUN = False # if False then print out the commands to be run, if True then run
|
||||
DATA_DIR = "" # path to the So2Sat data directory
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["resnet50"]
|
||||
lr_options = [1e-4]
|
||||
loss_options = ["ce"]
|
||||
weight_options: List[str] = [] # set paths to checkpoint files
|
||||
bands_options = ["s2"]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, lr, loss, weights, bands) in itertools.product(
|
||||
model_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
weight_options,
|
||||
bands_options,
|
||||
):
|
||||
|
||||
experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}"
|
||||
|
||||
output_dir = os.path.join("output", "so2sat_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "so2sat.yaml")
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, experiment_name)):
|
||||
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" experiment.name={experiment_name}"
|
||||
+ f" experiment.module.classification_model={model}"
|
||||
+ f" experiment.module.learning_rate={lr}"
|
||||
+ f" experiment.module.loss={loss}"
|
||||
+ f" experiment.module.weights={weights}"
|
||||
+ " experiment.module.in_channels=10"
|
||||
+ f" experiment.datamodule.bands={bands}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(
|
||||
target=do_work,
|
||||
args=(
|
||||
work,
|
||||
gpu_idx,
|
||||
),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -69,8 +69,6 @@ datasets =
|
|||
scipy>=0.9
|
||||
# Optional trainer requirements
|
||||
train =
|
||||
# kornia 0.5.4+ required for AugmentationSequential
|
||||
kornia>=0.5.4
|
||||
# omegaconf 2.1+ required for to_object method
|
||||
omegaconf>=2.1
|
||||
# pytorch-lightning 1.3+ required for gradient_clip_algorithm argument to Trainer
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -9,11 +9,11 @@ from typing import Any, Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ADVANCE
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -40,7 +40,7 @@ class TestADVANCE:
|
|||
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ADVANCE(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import BeninSmallHolderCashews
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -50,7 +50,7 @@ class TestBeninSmallHolderCashews:
|
|||
BeninSmallHolderCashews, "dates", ("2019_11_05",)
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return BeninSmallHolderCashews(
|
||||
root,
|
||||
transforms=transforms,
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BigEarthNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestBigEarthNet:
|
||||
@pytest.fixture(params=["all", "s1", "s2"])
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> BigEarthNet:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.bigearthnet, "download_url", download_url
|
||||
)
|
||||
data_dir = os.path.join("tests", "data", "bigearthnet")
|
||||
metadata = {
|
||||
"s1": {
|
||||
"url": os.path.join(data_dir, "BigEarthNet-S1-v1.0.tar.gz"),
|
||||
"md5": "5a64e9ce38deb036a435a7b59494924c",
|
||||
"filename": "BigEarthNet-S1-v1.0.tar.gz",
|
||||
"directory": "BigEarthNet-S1-v1.0",
|
||||
},
|
||||
"s2": {
|
||||
"url": os.path.join(data_dir, "BigEarthNet-S2-v1.0.tar.gz"),
|
||||
"md5": "ef5f41129b8308ca178b04d7538dbacf",
|
||||
"filename": "BigEarthNet-S2-v1.0.tar.gz",
|
||||
"directory": "BigEarthNet-v1.0",
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
BigEarthNet, "metadata", metadata
|
||||
)
|
||||
bands = request.param
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return BigEarthNet(root, bands, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: BigEarthNet) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["label"].shape == (43,)
|
||||
assert x["image"].dtype == torch.int32 # type: ignore[attr-defined]
|
||||
assert x["label"].dtype == torch.int64 # type: ignore[attr-defined]
|
||||
|
||||
if dataset.bands == "all":
|
||||
assert x["image"].shape == (14, 120, 120)
|
||||
elif dataset.bands == "s1":
|
||||
assert x["image"].shape == (2, 120, 120)
|
||||
else:
|
||||
assert x["image"].shape == (12, 120, 120)
|
||||
|
||||
def test_len(self, dataset: BigEarthNet) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
||||
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: BigEarthNet, tmp_path: Path
|
||||
) -> None:
|
||||
if dataset.bands == "all":
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
|
||||
)
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
|
||||
)
|
||||
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
|
||||
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
|
||||
elif dataset.bands == "s1":
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
|
||||
)
|
||||
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
|
||||
else:
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
|
||||
)
|
||||
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
|
||||
|
||||
BigEarthNet(
|
||||
root=str(tmp_path),
|
||||
bands=dataset.bands,
|
||||
download=False,
|
||||
)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
BigEarthNet(str(tmp_path))
|
|
@ -9,12 +9,12 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -57,7 +57,7 @@ class TestCanadianBuildingFootprints:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CanadianBuildingFootprints(
|
||||
root, res=0.1, transforms=transforms, download=True, checksum=True
|
||||
)
|
||||
|
|
|
@ -11,12 +11,12 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -44,7 +44,7 @@ class TestCDL:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CDL(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: CDL) -> None:
|
||||
|
|
|
@ -9,15 +9,16 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +42,7 @@ class TestChesapeake13:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: Chesapeake13) -> None:
|
||||
|
@ -76,3 +77,84 @@ class TestChesapeake13:
|
|||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
|
||||
class TestChesapeakeCVPR:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("naip-new", "naip-old", "nlcd"),
|
||||
("landsat-leaf-on", "landsat-leaf-off", "lc"),
|
||||
("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"),
|
||||
]
|
||||
)
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> ChesapeakeCVPR:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.chesapeake, "download_url", download_url
|
||||
)
|
||||
md5 = "564b8d944a941b0b65db9f56c92b93a2"
|
||||
monkeypatch.setattr(ChesapeakeCVPR, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join(
|
||||
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
|
||||
)
|
||||
monkeypatch.setattr(ChesapeakeCVPR, "url", url) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
ChesapeakeCVPR,
|
||||
"files",
|
||||
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ChesapeakeCVPR(
|
||||
root,
|
||||
splits=["de-test"],
|
||||
layers=request.param,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
checksum=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: ChesapeakeCVPR) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
|
||||
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ChesapeakeCVPR(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
url = os.path.join(
|
||||
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
|
||||
)
|
||||
root = str(tmp_path)
|
||||
shutil.copy(url, root)
|
||||
ChesapeakeCVPR(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
ChesapeakeCVPR(str(tmp_path), checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = ChesapeakeCVPR(
|
||||
root=dataset.root, splits=["de-train", "de-test"], layers=dataset.layers
|
||||
)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||
):
|
||||
ds[dataset.bounds]
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
@ -15,7 +16,6 @@ from torch.utils.data import ConcatDataset
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import COWCCounting, COWCDetection
|
||||
from torchgeo.datasets.cowc import COWC
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -56,7 +56,7 @@ class TestCOWCCounting:
|
|||
monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return COWCCounting(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: COWC) -> None:
|
||||
|
@ -114,7 +114,7 @@ class TestCOWCDetection:
|
|||
monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = "train"
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return COWCDetection(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: COWC) -> None:
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import CV4AKenyaCropType
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
|
|||
CV4AKenyaCropType, "dates", ["20190606"]
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CV4AKenyaCropType(
|
||||
root,
|
||||
transforms=transforms,
|
||||
|
|
|
@ -9,12 +9,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import TropicalCycloneWindEstimation
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -57,7 +57,7 @@ class TestTropicalCycloneWindEstimation:
|
|||
)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return TropicalCycloneWindEstimation(
|
||||
root, split, transforms, download=True, api_key="", checksum=True
|
||||
)
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ETCI2021
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -55,7 +55,7 @@ class TestETCI2021:
|
|||
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ETCI2021(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: ETCI2021) -> None:
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
@ -21,7 +22,6 @@ from torchgeo.datasets import (
|
|||
VisionDataset,
|
||||
ZipDataset,
|
||||
)
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class CustomGeoDataset(GeoDataset):
|
||||
|
@ -57,11 +57,15 @@ class TestGeoDataset:
|
|||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: GeoDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add_two(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_add_three(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -69,6 +73,7 @@ class TestGeoDataset:
|
|||
ds3 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2 + ds3
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 3
|
||||
|
||||
def test_add_four(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -77,10 +82,13 @@ class TestGeoDataset:
|
|||
ds4 = CustomGeoDataset()
|
||||
dataset = (ds1 + ds2) + (ds3 + ds4)
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_str(self, dataset: GeoDataset) -> None:
|
||||
assert "type: GeoDataset" in str(dataset)
|
||||
assert "bbox: BoundingBox" in str(dataset)
|
||||
out = str(dataset)
|
||||
assert "type: GeoDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 1" in out
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
|
@ -98,7 +106,7 @@ class TestRasterDataset:
|
|||
root = os.path.join("tests", "data", "landsat8")
|
||||
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
|
||||
crs = CRS.from_epsg(3005)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
cache = request.param
|
||||
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)
|
||||
|
||||
|
@ -223,9 +231,14 @@ class TestZipDataset:
|
|||
query = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: ZipDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_str(self, dataset: ZipDataset) -> None:
|
||||
assert "type: ZipDataset" in str(dataset)
|
||||
assert "bbox: BoundingBox" in str(dataset)
|
||||
out = str(dataset)
|
||||
assert "type: ZipDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 2" in out
|
||||
|
||||
def test_vision_dataset(self) -> None:
|
||||
ds1 = CustomVisionDataset()
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import GID15
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -37,7 +37,7 @@ class TestGID15:
|
|||
monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return GID15(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: GID15) -> None:
|
||||
|
|
|
@ -8,13 +8,13 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LandCoverAI
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -40,7 +40,7 @@ class TestLandCoverAI:
|
|||
monkeypatch.setattr(LandCoverAI, "sha256", sha256) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return LandCoverAI(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: LandCoverAI) -> None:
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestLandsat8:
|
||||
|
@ -23,7 +23,7 @@ class TestLandsat8:
|
|||
)
|
||||
root = os.path.join("tests", "data", "landsat8")
|
||||
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Landsat8(root, bands=bands, transforms=transforms)
|
||||
|
||||
def test_separate_files(self, dataset: Landsat8) -> None:
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LEVIRCDPlus
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -37,7 +37,7 @@ class TestLEVIRCDPlus:
|
|||
monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: LEVIRCDPlus) -> None:
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import NAIP, BoundingBox, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestNAIP:
|
||||
|
@ -22,7 +22,7 @@ class TestNAIP:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = os.path.join("tests", "data", "naip")
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return NAIP(root, transforms=transforms)
|
||||
|
||||
def test_getitem(self, dataset: NAIP) -> None:
|
||||
|
|
|
@ -9,13 +9,13 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import VHR10
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
pytest.importorskip("rarfile")
|
||||
pytest.importorskip("pycocotools")
|
||||
|
@ -51,7 +51,7 @@ class TestVHR10:
|
|||
monkeypatch.setitem(VHR10.target_meta, "md5", md5) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return VHR10(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: VHR10) -> None:
|
||||
|
|
|
@ -7,12 +7,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import SEN12MS
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSEN12MS:
|
||||
|
@ -38,7 +38,7 @@ class TestSEN12MS:
|
|||
monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = os.path.join("tests", "data", "sen12ms")
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SEN12MS(root, split, transforms=transforms, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: SEN12MS) -> None:
|
||||
|
|
|
@ -6,10 +6,10 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSentinel2:
|
||||
|
@ -17,7 +17,7 @@ class TestSentinel2:
|
|||
def dataset(self) -> Sentinel2:
|
||||
root = os.path.join("tests", "data", "sentinel2")
|
||||
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Sentinel2(root, bands=bands, transforms=transforms)
|
||||
|
||||
def test_separate_files(self, dataset: Sentinel2) -> None:
|
||||
|
|
|
@ -7,11 +7,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import So2Sat
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSo2Sat:
|
||||
|
@ -28,7 +28,7 @@ class TestSo2Sat:
|
|||
monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = os.path.join("tests", "data", "so2sat")
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return So2Sat(root, split, transforms, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: So2Sat) -> None:
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2
|
||||
from torchgeo.transforms import Identity
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
|
||||
|
||||
TEST_DATA_DIR = "tests/data/spacenet"
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestSpaceNet1:
|
|||
SpaceNet1, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet1(
|
||||
root,
|
||||
image=request.param,
|
||||
|
@ -104,7 +104,7 @@ class TestSpaceNet2:
|
|||
SpaceNet2, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet2(
|
||||
root,
|
||||
image=request.param,
|
||||
|
@ -141,3 +141,67 @@ class TestSpaceNet2:
|
|||
dataset.collection_md5_dict["sn2_AOI_2_Vegas"] = "randommd5hash123"
|
||||
with pytest.raises(RuntimeError, match="Collection sn2_AOI_2_Vegas corrupted"):
|
||||
SpaceNet2(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
|
||||
class TestSpaceNet4:
|
||||
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> SpaceNet4:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Collection, "fetch", fetch_collection
|
||||
)
|
||||
test_md5 = {
|
||||
"sn4_AOI_6_Atlanta": "ea37c2d87e2c3a1d8b2a7c2230080d46",
|
||||
}
|
||||
|
||||
test_angles = ["nadir", "off-nadir", "very-off-nadir"]
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
SpaceNet4, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet4(
|
||||
root,
|
||||
image=request.param,
|
||||
angles=test_angles,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet4) -> None:
|
||||
# Get image-label pair with empty label to
|
||||
# enusre coverage
|
||||
x = dataset[2]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
if dataset.image == "PS-RGBNIR":
|
||||
assert x["image"].shape[0] == 4
|
||||
elif dataset.image == "MS":
|
||||
assert x["image"].shape[0] == 8
|
||||
else:
|
||||
assert x["image"].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet4) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
|
||||
SpaceNet4(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
SpaceNet4(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict["sn4_AOI_6_Atlanta"] = "randommd5hash123"
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted"
|
||||
):
|
||||
SpaceNet4(root=dataset.root, download=True, checksum=True)
|
||||
|
|
|
@ -16,11 +16,13 @@ import pytest
|
|||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
collate_dict,
|
||||
dataset_split,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
download_radiant_mlhub_collection,
|
||||
|
@ -256,6 +258,12 @@ class TestBoundingBox:
|
|||
datetime(2021, 9, 1, 0, 0, 0, 0).timestamp(),
|
||||
datetime(2021, 9, 30, 23, 59, 59, 999999).timestamp(),
|
||||
),
|
||||
(
|
||||
"Dec 21",
|
||||
"%b %y",
|
||||
datetime(2021, 12, 1, 0, 0, 0, 0).timestamp(),
|
||||
datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(),
|
||||
),
|
||||
(
|
||||
"2021-09-13",
|
||||
"%Y-%m-%d",
|
||||
|
@ -335,3 +343,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
|
|||
|
||||
with working_dir(str(subdir), create=True):
|
||||
assert subdir.cwd() == subdir
|
||||
|
||||
|
||||
def test_dataset_split() -> None:
|
||||
num_samples = 24
|
||||
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
|
||||
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
|
||||
ds = TensorDataset(x, y)
|
||||
|
||||
# Test only train/val set split
|
||||
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
|
||||
assert len(train_ds) == num_samples // 2
|
||||
assert len(val_ds) == num_samples // 2
|
||||
|
||||
# Test train/val/test set split
|
||||
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
|
||||
assert len(train_ds) == num_samples // 3
|
||||
assert len(val_ds) == num_samples // 3
|
||||
assert len(test_ds) == num_samples // 3
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Any, Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ZueriCrop
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestZueriCrop:
|
|||
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ZueriCrop(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torchgeo.models import RCF
|
||||
|
||||
|
||||
class TestRCF:
|
||||
def test_in_channels(self) -> None:
|
||||
model = RCF(in_channels=5, features=4, kernel_size=3)
|
||||
x = torch.randn(2, 5, 64, 64)
|
||||
model(x)
|
||||
|
||||
model = RCF(in_channels=3, features=4, kernel_size=3)
|
||||
match = "to have 3 channels, but got 5 channels instead"
|
||||
with pytest.raises(RuntimeError, match=match):
|
||||
model(x)
|
||||
|
||||
def test_num_features(self) -> None:
|
||||
model = RCF(in_channels=5, features=4, kernel_size=3)
|
||||
x = torch.randn(2, 5, 64, 64)
|
||||
y = model(x)
|
||||
assert y.shape[1] == 4
|
||||
|
||||
x = torch.randn(1, 5, 64, 64)
|
||||
y = model(x)
|
||||
assert y.shape[0] == 4
|
||||
|
||||
def test_untrainable(self) -> None:
|
||||
model = RCF(in_channels=5, features=4, kernel_size=3)
|
||||
assert len(list(model.parameters())) == 0
|
||||
|
||||
def test_biases(self) -> None:
|
||||
model = RCF(features=24, bias=10)
|
||||
assert torch.all(model.biases == 10) # type: ignore[attr-defined]
|
|
@ -7,7 +7,6 @@ from typing import Dict, Iterator, List
|
|||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
@ -29,12 +28,10 @@ class CustomBatchGeoSampler(BatchGeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(
|
||||
self,
|
||||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(3005),
|
||||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, bounds)
|
||||
self.crs = crs
|
||||
self.res = res
|
||||
|
||||
|
@ -54,6 +51,7 @@ class TestBatchGeoSampler:
|
|||
def test_len(self, sampler: CustomBatchGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
|
@ -71,11 +69,11 @@ class TestBatchGeoSampler:
|
|||
class TestRandomBatchGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
|
||||
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
for batch in sampler:
|
||||
|
@ -93,6 +91,7 @@ class TestRandomBatchGeoSampler:
|
|||
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
assert len(sampler) == sampler.length // sampler.batch_size
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Dict, Iterator
|
|||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
@ -29,12 +28,10 @@ class CustomGeoSampler(GeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(
|
||||
self,
|
||||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(3005),
|
||||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, bounds)
|
||||
self.crs = crs
|
||||
self.res = res
|
||||
|
||||
|
@ -53,6 +50,7 @@ class TestGeoSampler:
|
|||
def test_len(self, sampler: CustomGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
|
@ -70,11 +68,11 @@ class TestGeoSampler:
|
|||
class TestRandomGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomGeoSampler(index, size, length=10)
|
||||
return RandomGeoSampler(ds, size, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
@ -91,6 +89,7 @@ class TestRandomGeoSampler:
|
|||
def test_len(self, sampler: RandomGeoSampler) -> None:
|
||||
assert len(sampler) == sampler.length
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
|
@ -114,11 +113,11 @@ class TestGridGeoSampler:
|
|||
],
|
||||
)
|
||||
def sampler(self, request: SubRequest) -> GridGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 20, 0, 10, 40, 50))
|
||||
index.insert(1, (0, 20, 0, 10, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 20, 0, 10, 40, 50))
|
||||
ds.index.insert(1, (0, 20, 0, 10, 40, 50))
|
||||
size, stride = request.param
|
||||
return GridGeoSampler(index, size, stride)
|
||||
return GridGeoSampler(ds, size, stride)
|
||||
|
||||
def test_iter(self, sampler: GridGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
@ -132,6 +131,7 @@ class TestGridGeoSampler:
|
|||
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.trainers import RESISC45ClassificationTask
|
||||
|
||||
|
||||
class TestRESISC45Trainer:
|
||||
@pytest.fixture
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load("conf/task_defaults/resisc45.yaml")
|
||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
def test_resnet_ce(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["classification_model"] = "resnet18"
|
||||
default_config["loss"] = "ce"
|
||||
task = RESISC45ClassificationTask(**default_config)
|
||||
assert isinstance(task.model, torchvision.models.ResNet)
|
||||
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
|
||||
|
||||
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["classification_model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
RESISC45ClassificationTask(**default_config)
|
||||
|
||||
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["loss"] = "invalid_loss"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
RESISC45ClassificationTask(**default_config)
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
|
||||
from torchgeo.trainers.utils import extract_encoder, load_state_dict
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Module.__module__ = "nn.Module"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Module:
|
||||
model: Module = torchvision.models.resnet18(pretrained=False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def state_dict(model: Module) -> Dict[str, Tensor]:
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
@pytest.fixture(params=["classification_model", "encoder"])
|
||||
def checkpoint(
|
||||
state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path
|
||||
) -> str:
|
||||
if request.param == "classification_model":
|
||||
state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()})
|
||||
else:
|
||||
state_dict = OrderedDict(
|
||||
{"model.encoder.model." + k: v for k, v in state_dict.items()}
|
||||
)
|
||||
checkpoint = {
|
||||
"hyper_parameters": {request.param: "resnet18"},
|
||||
"state_dict": state_dict,
|
||||
}
|
||||
path = os.path.join(str(tmp_path), f"model_{request.param}.ckpt")
|
||||
torch.save(checkpoint, path)
|
||||
return path
|
||||
|
||||
|
||||
def test_extract_encoder_unsupported_model(tmp_path: Path) -> None:
|
||||
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}}
|
||||
path = os.path.join(str(tmp_path), "dummy.ckpt")
|
||||
torch.save(checkpoint, path)
|
||||
err = """Unknown checkpoint task. Only encoder or classification_model"""
|
||||
"""extraction is supported"""
|
||||
with pytest.raises(ValueError, match=err):
|
||||
extract_encoder(path)
|
||||
|
||||
|
||||
def test_extract_encoder(checkpoint: str) -> None:
|
||||
extract_encoder(checkpoint)
|
||||
|
||||
|
||||
def test_load_state_dict(checkpoint: str, model: Module) -> None:
|
||||
_, state_dict = extract_encoder(checkpoint)
|
||||
model = load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
|
||||
_, state_dict = extract_encoder(checkpoint)
|
||||
expected_in_channels = state_dict["conv1.weight"].shape[1]
|
||||
|
||||
in_channels = 7
|
||||
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
|
||||
in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
warning = f"""input channels {in_channels} != input channels in pretrained"""
|
||||
f"""model {expected_in_channels}. Overriding with new input channels"""
|
||||
with pytest.warns(UserWarning, match=warning):
|
||||
model = load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
|
||||
_, state_dict = extract_encoder(checkpoint)
|
||||
expected_num_classes = state_dict["fc.weight"].shape[0]
|
||||
|
||||
num_classes = 10
|
||||
in_features = model.fc.in_features # type: ignore[union-attr]
|
||||
model.fc = nn.Linear( # type: ignore[attr-defined]
|
||||
in_features, out_features=num_classes
|
||||
)
|
||||
|
||||
warning = f"""num classes {num_classes} != num classes in pretrained model"""
|
||||
f"""{expected_num_classes}. Overriding with new num classes"""
|
||||
with pytest.warns(UserWarning, match=warning):
|
||||
model = load_state_dict(model, state_dict)
|
|
@ -105,46 +105,6 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N
|
|||
assert equal, err
|
||||
|
||||
|
||||
def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.RandomHorizontalFlip(p=1)
|
||||
output = tr(sample)
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]
|
||||
),
|
||||
"mask": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]
|
||||
),
|
||||
"boxes": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 0, 3, 2], [0, 1, 2, 3]]
|
||||
),
|
||||
}
|
||||
assert_matching(output, expected)
|
||||
|
||||
|
||||
def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.RandomVerticalFlip(p=1)
|
||||
output = tr(sample)
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
[[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]
|
||||
),
|
||||
"mask": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]
|
||||
),
|
||||
"boxes": torch.tensor( # type: ignore[attr-defined]
|
||||
[[0, 1, 2, 3], [1, 0, 3, 2]]
|
||||
),
|
||||
}
|
||||
assert_matching(output, expected)
|
||||
|
||||
|
||||
def test_identity(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.Identity()
|
||||
output = tr(sample)
|
||||
assert_matching(output, sample)
|
||||
|
||||
|
||||
def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None:
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from .advance import ADVANCE
|
||||
from .benin_cashews import BeninSmallHolderCashews
|
||||
from .bigearthnet import BigEarthNet
|
||||
from .cbf import CanadianBuildingFootprints
|
||||
from .cdl import CDL
|
||||
from .chesapeake import (
|
||||
|
@ -56,7 +57,7 @@ from .resisc45 import RESISC45
|
|||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4
|
||||
from .ucmerced import UCMerced
|
||||
from .utils import BoundingBox, collate_dict
|
||||
from .zuericrop import ZueriCrop
|
||||
|
@ -93,6 +94,7 @@ __all__ = (
|
|||
# VisionDataset
|
||||
"ADVANCE",
|
||||
"BeninSmallHolderCashews",
|
||||
"BigEarthNet",
|
||||
"COWC",
|
||||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
|
@ -109,6 +111,7 @@ __all__ = (
|
|||
"SpaceNet",
|
||||
"SpaceNet1",
|
||||
"SpaceNet2",
|
||||
"SpaceNet4",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"UCMerced",
|
||||
"VHR10",
|
||||
|
|
|
@ -0,0 +1,380 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""BigEarthNet dataset."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_url, extract_archive
|
||||
|
||||
|
||||
def sort_bands(x: str) -> str:
|
||||
"""Sort Sentinel-2 band files in the correct order."""
|
||||
x = os.path.basename(x).split("_")[-1]
|
||||
x = os.path.splitext(x)[0]
|
||||
if x == "B8A":
|
||||
x = "B08A"
|
||||
return x
|
||||
|
||||
|
||||
class BigEarthNet(VisionDataset):
|
||||
"""BigEarthNet dataset.
|
||||
|
||||
The `BigEarthNet <http://bigearth.net/>`_
|
||||
dataset is a dataset for multilabel remote sensing image scene classification.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* 590,326 patches from 125 Sentinel-1 and Sentinel-2 tiles
|
||||
* Imagery from tiles in Europe between Jun 2017 - May 2018
|
||||
* 12 spectral bands with 10-60 m per pixel resolution (base 120x120 px)
|
||||
* 2 synthetic aperture radar bands (120x120 px)
|
||||
* 43 scene classes from the 2018 CORINE Land Cover database (CLC 2018)
|
||||
|
||||
Dataset format:
|
||||
|
||||
* images are composed of multiple single channel geotiffs
|
||||
* labels are multiclass, stored in a single json file per image
|
||||
* mapping of Sentinel-1 to Sentinel-2 patches are within Sentinel-1 json files
|
||||
* Sentinel-1 bands: (VV, VH)
|
||||
* Sentinel-2 bands: (B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||
* All bands: (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||
* Sentinel-2 bands are of different spatial resolutions and upsampled to 10m
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. Agro-forestry areas
|
||||
1. Airports
|
||||
2. Annual crops associated with permanent crops
|
||||
3. Bare rock
|
||||
4. Beaches, dunes, sands
|
||||
5. Broad-leaved forest
|
||||
6. Burnt areas
|
||||
7. Coastal lagoons
|
||||
8. Complex cultivation patterns
|
||||
9. Coniferous forest
|
||||
10. Construction sites
|
||||
11. Continuous urban fabric
|
||||
12. Discontinuous urban fabric
|
||||
13. Dump sites
|
||||
14. Estuaries
|
||||
15. Fruit trees and berry plantations
|
||||
16. Green urban areas
|
||||
17. Industrial or commercial units
|
||||
18. Inland marshes
|
||||
19. Intertidal flats
|
||||
20. Land principally occupied by agriculture, with
|
||||
significant areas of natural vegetation
|
||||
21. Mineral extraction sites
|
||||
22. Mixed forest
|
||||
23. Moors and heathland
|
||||
24. Natural grassland
|
||||
25. Non-irrigated arable land
|
||||
26. Olive groves
|
||||
27. Pastures
|
||||
28. Peatbogs
|
||||
29. Permanently irrigated land
|
||||
30. Port areas
|
||||
31. Rice fields
|
||||
32. Road and rail networks and associated land
|
||||
33. Salines
|
||||
34. Salt marshes
|
||||
35. Sclerophyllous vegetation
|
||||
36. Sea and ocean
|
||||
37. Sparsely vegetated areas
|
||||
38. Sport and leisure facilities
|
||||
39. Transitional woodland/shrub
|
||||
40. Vineyards
|
||||
41. Water bodies
|
||||
42. Water courses
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.1109/IGARSS.2019.8900532
|
||||
|
||||
"""
|
||||
|
||||
classes = [
|
||||
"Agro-forestry areas",
|
||||
"Airports",
|
||||
"Annual crops associated with permanent crops",
|
||||
"Bare rock",
|
||||
"Beaches, dunes, sands",
|
||||
"Broad-leaved forest",
|
||||
"Burnt areas",
|
||||
"Coastal lagoons",
|
||||
"Complex cultivation patterns",
|
||||
"Coniferous forest",
|
||||
"Construction sites",
|
||||
"Continuous urban fabric",
|
||||
"Discontinuous urban fabric",
|
||||
"Dump sites",
|
||||
"Estuaries",
|
||||
"Fruit trees and berry plantations",
|
||||
"Green urban areas",
|
||||
"Industrial or commercial units",
|
||||
"Inland marshes",
|
||||
"Intertidal flats",
|
||||
"Land principally occupied by agriculture, with significant areas of "
|
||||
"natural vegetation",
|
||||
"Mineral extraction sites",
|
||||
"Mixed forest",
|
||||
"Moors and heathland",
|
||||
"Natural grassland",
|
||||
"Non-irrigated arable land",
|
||||
"Olive groves",
|
||||
"Pastures",
|
||||
"Peatbogs",
|
||||
"Permanently irrigated land",
|
||||
"Port areas",
|
||||
"Rice fields",
|
||||
"Road and rail networks and associated land",
|
||||
"Salines",
|
||||
"Salt marshes",
|
||||
"Sclerophyllous vegetation",
|
||||
"Sea and ocean",
|
||||
"Sparsely vegetated areas",
|
||||
"Sport and leisure facilities",
|
||||
"Transitional woodland/shrub",
|
||||
"Vineyards",
|
||||
"Water bodies",
|
||||
"Water courses",
|
||||
]
|
||||
metadata = {
|
||||
"s1": {
|
||||
"url": "http://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz",
|
||||
"md5": "5a64e9ce38deb036a435a7b59494924c",
|
||||
"filename": "BigEarthNet-S1-v1.0.tar.gz",
|
||||
"directory": "BigEarthNet-S1-v1.0",
|
||||
},
|
||||
"s2": {
|
||||
"url": "http://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz",
|
||||
"md5": "5a64e9ce38deb036a435a7b59494924c",
|
||||
"filename": "BigEarthNet-S2-v1.0.tar.gz",
|
||||
"directory": "BigEarthNet-v1.0",
|
||||
},
|
||||
}
|
||||
image_size = (120, 120)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
bands: str = "all",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new BigEarthNet dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
"""
|
||||
assert bands in ["s1", "s2", "all"]
|
||||
self.root = root
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.class2idx = {c: i for i, c in enumerate(self.classes)}
|
||||
self.num_classes = len(self.classes)
|
||||
self._verify()
|
||||
|
||||
if bands == "s2":
|
||||
self.files = glob.glob(
|
||||
os.path.join(self.root, self.metadata["s2"]["directory"], "*")
|
||||
)
|
||||
else:
|
||||
self.files = glob.glob(
|
||||
os.path.join(self.root, self.metadata["s1"]["directory"], "*")
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
label = self._load_target(index)
|
||||
sample: Dict[str, Tensor] = {
|
||||
"image": image,
|
||||
"label": label,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_paths(self, index: int) -> List[str]:
|
||||
"""Load paths to band files.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
list of file paths
|
||||
"""
|
||||
folder = self.files[index]
|
||||
paths = glob.glob(os.path.join(folder, "*.tif"))
|
||||
# S1->S2 patch mapping is in S1 patch metadata json file
|
||||
if self.bands == "all":
|
||||
paths = sorted(paths)
|
||||
|
||||
metadata_path = glob.glob(os.path.join(folder, "*.json"))[0]
|
||||
with open(metadata_path, "r") as f:
|
||||
name_s2 = json.load(f)["corresponding_s2_patch"]
|
||||
|
||||
folder_s2 = os.path.join(
|
||||
self.root, self.metadata["s2"]["directory"], name_s2
|
||||
)
|
||||
paths_s2 = glob.glob(os.path.join(folder_s2, "*.tif"))
|
||||
paths_s2 = sorted(paths_s2, key=sort_bands)
|
||||
paths.extend(paths_s2)
|
||||
elif self.bands == "s1":
|
||||
paths = sorted(paths)
|
||||
else:
|
||||
paths = sorted(paths, key=sort_bands)
|
||||
return paths
|
||||
|
||||
def _load_image(self, index: int) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the raster image or target
|
||||
"""
|
||||
paths = self._load_paths(index)
|
||||
images = []
|
||||
for path in paths:
|
||||
# Bands are of different spatial resolutions
|
||||
# Resample to (120, 120)
|
||||
with rasterio.open(path) as dataset:
|
||||
array = dataset.read(
|
||||
indexes=1,
|
||||
out_shape=self.image_size,
|
||||
out_dtype="int32",
|
||||
resampling=Resampling.bilinear,
|
||||
)
|
||||
images.append(array)
|
||||
arrays = np.stack(images, axis=0)
|
||||
tensor: Tensor = torch.from_numpy(arrays) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_target(self, index: int) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the target label
|
||||
"""
|
||||
folder = self.files[index]
|
||||
path = glob.glob(os.path.join(folder, "*.json"))[0]
|
||||
with open(path, "r") as f:
|
||||
labels = json.load(f)["labels"]
|
||||
indices = [self.class2idx[label] for label in labels]
|
||||
target: Tensor = torch.zeros( # type: ignore[attr-defined]
|
||||
self.num_classes, dtype=torch.long # type: ignore[attr-defined]
|
||||
)
|
||||
target[indices] = 1
|
||||
return target
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
keys = ["s1", "s2"] if self.bands == "all" else [self.bands]
|
||||
urls = [self.metadata[k]["url"] for k in keys]
|
||||
md5s = [self.metadata[k]["md5"] for k in keys]
|
||||
filenames = [self.metadata[k]["filename"] for k in keys]
|
||||
directories = [self.metadata[k]["directory"] for k in keys]
|
||||
|
||||
# Check if the files already exist
|
||||
exists = [
|
||||
os.path.exists(os.path.join(self.root, directory))
|
||||
for directory in directories
|
||||
]
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if zip file already exists (if so then extract)
|
||||
exists = []
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(self.root, filename)
|
||||
if os.path.exists(filepath):
|
||||
exists.append(True)
|
||||
self._extract(filepath)
|
||||
else:
|
||||
exists.append(False)
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download and extract the dataset
|
||||
for url, filename, md5 in zip(urls, filenames, md5s):
|
||||
self._download(url, filename, md5)
|
||||
filepath = os.path.join(self.root, filename)
|
||||
self._extract(filepath)
|
||||
|
||||
def _download(self, url: str, filename: str, md5: str) -> None:
|
||||
"""Download the dataset.
|
||||
|
||||
Args:
|
||||
url: url to download file
|
||||
filename: output filename to write downloaded file
|
||||
md5: md5 of downloaded file
|
||||
"""
|
||||
download_url(
|
||||
url,
|
||||
self.root,
|
||||
filename=filename,
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def _extract(self, filepath: str) -> None:
|
||||
"""Extract the dataset.
|
||||
|
||||
Args:
|
||||
filepath: path to file to be extracted
|
||||
"""
|
||||
extract_archive(filepath)
|
|
@ -19,7 +19,13 @@ import torch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||
from .utils import (
|
||||
BoundingBox,
|
||||
check_integrity,
|
||||
download_and_extract_archive,
|
||||
download_url,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
|
@ -376,21 +382,15 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
for split in splits:
|
||||
assert split in self.splits
|
||||
assert all([layer in self.valid_layers for layer in layers])
|
||||
super().__init__(transforms) # creates self.index and self.transform
|
||||
self.root = root
|
||||
self.layers = layers
|
||||
self.cache = cache
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
||||
if download and not self._check_structure():
|
||||
self._download()
|
||||
self._verify()
|
||||
|
||||
if checksum:
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
super().__init__(transforms)
|
||||
|
||||
# Add all tiles into the index in epsg:3857 based on the included geojson
|
||||
mint: float = 0
|
||||
|
@ -496,40 +496,45 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
|
||||
return sample
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of the dataset archive.
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Returns:
|
||||
True if dataset archive is found and/or MD5s match, else False
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
integrity: bool = check_integrity(
|
||||
os.path.join(self.root, self.filename),
|
||||
self.md5 if self.checksum else None,
|
||||
)
|
||||
# Check if the extracted files already exist
|
||||
def exists(filename: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, filename))
|
||||
|
||||
return integrity
|
||||
|
||||
def _check_structure(self) -> bool:
|
||||
"""Checks to see if the dataset files exist in the root directory.
|
||||
|
||||
Returns:
|
||||
True if the dataset files are found, else False
|
||||
"""
|
||||
dataset_files = os.listdir(self.root)
|
||||
for file in self.files:
|
||||
if file not in dataset_files:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it."""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
if all(map(exists, self.files)):
|
||||
return
|
||||
|
||||
download_and_extract_archive(
|
||||
# Check if the zip files have already been downloaded
|
||||
if os.path.exists(os.path.join(self.root, self.filename)):
|
||||
self._extract()
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
self._extract()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
download_url(
|
||||
self.url,
|
||||
self.root,
|
||||
filename=self.filename,
|
||||
md5=self.md5,
|
||||
)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
extract_archive(os.path.join(self.root, self.filename))
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -219,7 +218,7 @@ class ETCI2021(VisionDataset):
|
|||
with Image.open(filename) as img:
|
||||
array = np.array(img.convert("L"))
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
|
||||
tensor = torch.clamp(tensor, min=0, max=1) # type: ignore[attr-defined]
|
||||
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
|
@ -251,6 +250,3 @@ class ETCI2021(VisionDataset):
|
|||
filename=self.metadata[self.split]["filename"],
|
||||
md5=self.metadata[self.split]["md5"] if self.checksum else None,
|
||||
)
|
||||
|
||||
if os.path.exists(os.path.join(self.root, "__MACOSX")):
|
||||
shutil.rmtree(os.path.join(self.root, "__MACOSX"))
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch
|
|||
from rasterio.crs import CRS
|
||||
from rasterio.io import DatasetReader
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rasterio.windows import from_bounds
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -104,6 +105,15 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
"""
|
||||
return ZipDataset([self, other])
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
count: int = self.index.count(self.index.bounds)
|
||||
return count
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -113,7 +123,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: GeoDataset
|
||||
bbox: {self.bounds}"""
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def bounds(self) -> BoundingBox:
|
||||
|
@ -286,7 +297,7 @@ class RasterDataset(GeoDataset):
|
|||
filepath = glob.glob(os.path.join(directory, filename))[0]
|
||||
band_filepaths.append(filepath)
|
||||
data_list.append(self._merge_files(band_filepaths, query))
|
||||
data = torch.stack(data_list)
|
||||
data = torch.cat(data_list) # type: ignore[attr-defined]
|
||||
else:
|
||||
data = self._merge_files(filepaths, query)
|
||||
|
||||
|
@ -318,7 +329,16 @@ class RasterDataset(GeoDataset):
|
|||
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
|
||||
|
||||
bounds = (query.minx, query.miny, query.maxx, query.maxy)
|
||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
|
||||
if len(vrt_fhs) == 1:
|
||||
src = vrt_fhs[0]
|
||||
out_width = int(round((query.maxx - query.minx) / self.res))
|
||||
out_height = int(round((query.maxy - query.miny) / self.res))
|
||||
out_shape = (src.count, out_height, out_width)
|
||||
dest = src.read(
|
||||
out_shape=out_shape, window=from_bounds(*bounds, src.transform)
|
||||
)
|
||||
else:
|
||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
|
||||
dest = dest.astype(np.int32)
|
||||
|
||||
tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]
|
||||
|
@ -447,7 +467,7 @@ class VectorDataset(GeoDataset):
|
|||
(minx, maxx), (miny, maxy) = fiona.transform.transform(
|
||||
src.crs, crs.to_dict(), [minx, maxx], [miny, maxy]
|
||||
)
|
||||
except (fiona.errors.DriverError, fiona.errors.FionaValueError):
|
||||
except fiona.errors.FionaValueError:
|
||||
# Skip files that fiona is unable to read
|
||||
continue
|
||||
else:
|
||||
|
@ -719,6 +739,14 @@ class ZipDataset(GeoDataset):
|
|||
sample.update(ds[query])
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return sum(map(len, self.datasets))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -728,7 +756,8 @@ class ZipDataset(GeoDataset):
|
|||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: ZipDataset
|
||||
bbox: {self.bounds}"""
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def bounds(self) -> BoundingBox:
|
||||
|
|
|
@ -170,7 +170,7 @@ class LEVIRCDPlus(VisionDataset):
|
|||
with Image.open(filename) as img:
|
||||
array = np.array(img.convert("L"))
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
|
||||
tensor = torch.clamp(tensor, min=0, max=1) # type: ignore[attr-defined]
|
||||
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import numpy as np
|
|||
import rasterio as rio
|
||||
import torch
|
||||
from affine import Affine
|
||||
from fiona.errors import FionaValueError
|
||||
from rasterio.features import rasterize
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -154,8 +155,11 @@ class SpaceNet(VisionDataset, abc.ABC):
|
|||
Returns:
|
||||
Tensor: label tensor
|
||||
"""
|
||||
with fiona.open(path) as src:
|
||||
labels = [feature["geometry"] for feature in src]
|
||||
try:
|
||||
with fiona.open(path) as src:
|
||||
labels = [feature["geometry"] for feature in src]
|
||||
except FionaValueError:
|
||||
labels = []
|
||||
|
||||
if not labels:
|
||||
mask_data = np.zeros(shape=shape)
|
||||
|
@ -490,3 +494,181 @@ class SpaceNet2(SpaceNet):
|
|||
)
|
||||
files.append({"image_path": imgpath, "label_path": lbl_path})
|
||||
return files
|
||||
|
||||
|
||||
class SpaceNet4(SpaceNet):
|
||||
"""SpaceNet 4: Off-Nadir Buildings Dataset.
|
||||
|
||||
`SpaceNet 4 <https://spacenet.ai/off-nadir-building-detection/>`_ is a
|
||||
dataset of 27 WV-2 imagery captured at varying off-nadir angles and
|
||||
associated building footprints over the city of Atlanta. The off-nadir angle
|
||||
ranges from 7 degrees to 54 degrees.
|
||||
|
||||
|
||||
Dataset features
|
||||
|
||||
* No. of chipped images: 28,728 (PAN/MS/PS-RGBNIR)
|
||||
* No. of label files: 1064
|
||||
* No. of building footprints: >120,000
|
||||
* Area Coverage: 665 sq km
|
||||
* Chip size: 225 x 225 (MS), 900 x 900 (PAN/PS-RGBNIR)
|
||||
|
||||
Dataset format
|
||||
|
||||
* Imagery - Worldview-2 GeoTIFFs
|
||||
* PAN.tif (Panchromatic)
|
||||
* MS.tif (Multispectral)
|
||||
* PS-RGBNIR (Pansharpened RGBNIR)
|
||||
* Labels - GeoJSON
|
||||
* labels.geojson
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/1903.12239
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
|
||||
"""
|
||||
|
||||
dataset_id = "spacenet4"
|
||||
collection_md5_dict = {
|
||||
"sn4_AOI_6_Atlanta": "c597d639cba5257927a97e3eff07b753",
|
||||
}
|
||||
|
||||
imagery = {
|
||||
"MS": "MS.tif",
|
||||
"PAN": "PAN.tif",
|
||||
"PS-RGBNIR": "PS-RGBNIR.tif",
|
||||
}
|
||||
chip_size = {
|
||||
"MS": (225, 225),
|
||||
"PAN": (900, 900),
|
||||
"PS-RGBNIR": (900, 900),
|
||||
}
|
||||
label_glob = "labels.geojson"
|
||||
|
||||
angle_catalog_map = {
|
||||
"nadir": [
|
||||
"1030010003D22F00",
|
||||
"10300100023BC100",
|
||||
"1030010003993E00",
|
||||
"1030010003CAF100",
|
||||
"1030010002B7D800",
|
||||
"10300100039AB000",
|
||||
"1030010002649200",
|
||||
"1030010003C92000",
|
||||
"1030010003127500",
|
||||
"103001000352C200",
|
||||
"103001000307D800",
|
||||
],
|
||||
"off-nadir": [
|
||||
"1030010003472200",
|
||||
"1030010003315300",
|
||||
"10300100036D5200",
|
||||
"103001000392F600",
|
||||
"1030010003697400",
|
||||
"1030010003895500",
|
||||
"1030010003832800",
|
||||
],
|
||||
"very-off-nadir": [
|
||||
"10300100035D1B00",
|
||||
"1030010003CCD700",
|
||||
"1030010003713C00",
|
||||
"10300100033C5200",
|
||||
"1030010003492700",
|
||||
"10300100039E6200",
|
||||
"1030010003BDDC00",
|
||||
"1030010003CD4300",
|
||||
"1030010003193D00",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
image: str = "PS-RGBNIR",
|
||||
angles: List[str] = [],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 4 Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
image: image selection which must be in ["MS", "PAN", "PS-RGBNIR"]
|
||||
angles: angle selection which must be in ["nadir", "off-nadir",
|
||||
"very-off-nadir"]
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing
|
||||
"""
|
||||
collections = ["sn4_AOI_6_Atlanta"]
|
||||
assert image in {"MS", "PAN", "PS-RGBNIR"}
|
||||
self.angles = angles
|
||||
if self.angles:
|
||||
for angle in self.angles:
|
||||
assert angle in self.angle_catalog_map.keys()
|
||||
super().__init__(
|
||||
root, image, collections, transforms, download, api_key, checksum
|
||||
)
|
||||
|
||||
def _load_files(self, root: str) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for each pair of image and label
|
||||
"""
|
||||
files = []
|
||||
nadir = []
|
||||
offnadir = []
|
||||
veryoffnadir = []
|
||||
images = glob.glob(os.path.join(root, self.collections[0], "*", self.filename))
|
||||
images = sorted(images)
|
||||
|
||||
catalog_id_pattern = re.compile(r"(_[A-Z0-9])\w+$")
|
||||
for imgpath in images:
|
||||
imgdir = os.path.basename(os.path.dirname(imgpath))
|
||||
match = catalog_id_pattern.search(imgdir)
|
||||
assert match is not None, "Invalid image directory"
|
||||
catalog_id = match.group()[1:]
|
||||
|
||||
lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
|
||||
|
||||
lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
|
||||
assert os.path.exists(lbl_path)
|
||||
|
||||
_file = {"image_path": imgpath, "label_path": lbl_path}
|
||||
if catalog_id in self.angle_catalog_map["very-off-nadir"]:
|
||||
veryoffnadir.append(_file)
|
||||
elif catalog_id in self.angle_catalog_map["off-nadir"]:
|
||||
offnadir.append(_file)
|
||||
elif catalog_id in self.angle_catalog_map["nadir"]:
|
||||
nadir.append(_file)
|
||||
|
||||
angle_file_map = {
|
||||
"nadir": nadir,
|
||||
"off-nadir": offnadir,
|
||||
"very-off-nadir": veryoffnadir,
|
||||
}
|
||||
|
||||
if not self.angles:
|
||||
files.extend(nadir + offnadir + veryoffnadir)
|
||||
else:
|
||||
for angle in self.angles:
|
||||
files.extend(angle_file_map[angle])
|
||||
return files
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import rasterio
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, Subset, random_split
|
||||
from torchvision.datasets.utils import check_integrity, download_url
|
||||
|
||||
__all__ = (
|
||||
|
@ -30,6 +31,7 @@ __all__ = (
|
|||
"working_dir",
|
||||
"collate_dict",
|
||||
"rasterio_loader",
|
||||
"dataset_split",
|
||||
)
|
||||
|
||||
|
||||
|
@ -394,3 +396,28 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
|
|||
# VisionClassificationDataset expects images returned with channels last (HWC)
|
||||
array = array.transpose(1, 2, 0)
|
||||
return array
|
||||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
|
||||
) -> List[Subset[Any]]:
|
||||
"""Split a torch Dataset into train/val/test sets.
|
||||
|
||||
If ``test_pct`` is not set then only train and validation splits are returned.
|
||||
|
||||
Args:
|
||||
dataset: dataset to be split into train/val or train/val/test subsets
|
||||
val_pct: percentage of samples to be in validation set
|
||||
test_pct: (Optional) percentage of samples to be in test set
|
||||
Returns:
|
||||
a list of the subset datasets. Either [train, val] or [train, val, test]
|
||||
"""
|
||||
if test_pct is None:
|
||||
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
|
||||
train_length = len(dataset) - val_length # type: ignore[arg-type]
|
||||
return random_split(dataset, [train_length, val_length])
|
||||
else:
|
||||
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
|
||||
test_length = int(len(dataset) * test_pct) # type: ignore[arg-type]
|
||||
train_length = len(dataset) - (val_length + test_length) # type: ignore[arg-type] # noqa: E501
|
||||
return random_split(dataset, [train_length, val_length, test_length])
|
||||
|
|
|
@ -7,6 +7,7 @@ from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
|
|||
from .farseg import FarSeg
|
||||
from .fccd import FCEF, FCSiamConc, FCSiamDiff
|
||||
from .fcn import FCN
|
||||
from .rcf import RCF
|
||||
|
||||
__all__ = (
|
||||
"ChangeMixin",
|
||||
|
@ -17,6 +18,7 @@ __all__ = (
|
|||
"FCEF",
|
||||
"FCSiamConc",
|
||||
"FCSiamDiff",
|
||||
"RCF",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Implementation of a random convolutional feature projection model."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Module
|
||||
|
||||
Module.__module__ = "torch.nn"
|
||||
Conv2d.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class RCF(Module):
|
||||
"""This model extracts random convolutional features (RCFs) from its input.
|
||||
|
||||
RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
|
||||
(MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
|
||||
|
||||
.. note::
|
||||
|
||||
This Module is *not* trainable. It is only used as a feature extractor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
features: int = 16,
|
||||
kernel_size: int = 3,
|
||||
bias: float = -1.0,
|
||||
) -> None:
|
||||
"""Initializes the RCF model.
|
||||
|
||||
This is a static model that serves to extract fixed length feature vectors from
|
||||
input patches.
|
||||
|
||||
Args:
|
||||
in_channels: number of input channels
|
||||
features: number of features to compute, must be divisible by 2
|
||||
kernel_size: size of the kernel used to compute the RCFs
|
||||
bias: bias of the convolutional layer
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
|
||||
assert features % 2 == 0
|
||||
|
||||
# We register the weight and bias tensors as "buffers". This does two things:
|
||||
# makes them behave correctly when we call .to(...) on the module, and makes
|
||||
# them explicitely _not_ Parameters of the model (which might get updated) if
|
||||
# a user tries to train with this model.
|
||||
self.register_buffer(
|
||||
"weights",
|
||||
torch.randn(
|
||||
features // 2,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
kernel_size,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"biases",
|
||||
torch.zeros( # type: ignore[attr-defined]
|
||||
features // 2, requires_grad=False
|
||||
)
|
||||
+ bias,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward pass of the RCF model.
|
||||
|
||||
Args:
|
||||
x: a tensor with shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
a tensor of size (B, ``self.num_features``)
|
||||
"""
|
||||
x1a = F.relu(
|
||||
F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
|
||||
inplace=True,
|
||||
)
|
||||
x1b = F.relu(
|
||||
-F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
|
||||
x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
|
||||
|
||||
if len(x1a.shape) == 1: # case where we passed a single input
|
||||
output = torch.cat((x1a, x1b), dim=0) # type: ignore[attr-defined]
|
||||
return cast(Tensor, output)
|
||||
else: # case where we passed a batch of > 1 inputs
|
||||
assert len(x1a.shape) == 2
|
||||
output = torch.cat((x1a, x1b), dim=1) # type: ignore[attr-defined]
|
||||
return cast(Tensor, output)
|
|
@ -7,10 +7,9 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
|
@ -43,13 +42,13 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
This is particularly useful during training when you want to maximize the size of
|
||||
the dataset and return as many random :term:`chips <chip>` as possible.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a tile-based dataset if possible.
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
batch_size: int,
|
||||
length: int,
|
||||
|
@ -65,21 +64,22 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
batch_size: number of samples per batch
|
||||
length: number of samples per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
self.size = _to_tuple(size)
|
||||
self.batch_size = batch_size
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return the indices of a dataset.
|
||||
|
@ -96,7 +96,7 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
batch = []
|
||||
for _ in range(self.batch_size):
|
||||
|
||||
bounding_box = get_random_bounding_box(bounds, self.size)
|
||||
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
|
||||
batch.append(bounding_box)
|
||||
|
||||
yield batch
|
||||
|
|
|
@ -7,10 +7,9 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
|
@ -49,7 +48,7 @@ class RandomGeoSampler(GeoSampler):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
length: int,
|
||||
roi: Optional[BoundingBox] = None,
|
||||
|
@ -64,19 +63,20 @@ class RandomGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
length: number of random samples to draw per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
self.size = _to_tuple(size)
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -90,7 +90,7 @@ class RandomGeoSampler(GeoSampler):
|
|||
bounds = BoundingBox(*hit.bounds)
|
||||
|
||||
# Choose a random index within that tile
|
||||
bounding_box = get_random_bounding_box(bounds, self.size)
|
||||
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
|
||||
|
||||
yield bounding_box
|
||||
|
||||
|
@ -117,13 +117,13 @@ class GridGeoSampler(GeoSampler):
|
|||
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
|
||||
the CNN.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a non-tile-based dataset if possible.
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a non-tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
stride: Union[Tuple[float, float], float],
|
||||
roi: Optional[BoundingBox] = None,
|
||||
|
@ -138,18 +138,19 @@ class GridGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
stride: distance to skip between each patch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.size = _to_tuple(size)
|
||||
self.stride = _to_tuple(stride)
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
self.length: int = 0
|
||||
for hit in self.hits:
|
||||
|
|
|
@ -25,7 +25,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
|
|||
|
||||
|
||||
def get_random_bounding_box(
|
||||
bounds: BoundingBox, size: Union[Tuple[float, float], float]
|
||||
bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float
|
||||
) -> BoundingBox:
|
||||
"""Returns a random bounding box within a given bounding box.
|
||||
|
||||
|
@ -45,13 +45,16 @@ def get_random_bounding_box(
|
|||
"""
|
||||
t_size: Tuple[float, float] = _to_tuple(size)
|
||||
|
||||
minx = random.uniform(bounds.minx, bounds.maxx - t_size[1])
|
||||
width = (bounds.maxx - bounds.minx - t_size[1]) // res
|
||||
minx = random.randrange(int(width)) * res + bounds.minx
|
||||
maxx = minx + t_size[1]
|
||||
|
||||
miny = random.uniform(bounds.miny, bounds.maxy - t_size[0])
|
||||
height = (bounds.maxy - bounds.miny - t_size[0]) // res
|
||||
miny = random.randrange(int(height)) * res + bounds.miny
|
||||
maxy = miny + t_size[0]
|
||||
|
||||
mint = bounds.mint
|
||||
maxt = bounds.maxt
|
||||
|
||||
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
query = BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
return query
|
||||
|
|
|
@ -8,6 +8,7 @@ from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
|||
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
|
||||
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
|
||||
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
||||
|
||||
|
@ -21,6 +22,8 @@ __all__ = (
|
|||
"LandcoverAISegmentationTask",
|
||||
"NAIPChesapeakeDataModule",
|
||||
"NAIPChesapeakeSegmentationTask",
|
||||
"RESISC45ClassificationTask",
|
||||
"RESISC45DataModule",
|
||||
"SEN12MSDataModule",
|
||||
"SEN12MSSegmentationTask",
|
||||
"So2SatDataModule",
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""RESISC45 trainer."""
|
||||
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Linear, Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import Accuracy, IoU, MetricCollection
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import RESISC45
|
||||
from ..datasets.utils import dataset_split
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
Module.__module__ = "torch.nn"
|
||||
Conv2d.__module__ = "nn.Conv2d"
|
||||
Linear.__module__ = "nn.Linear"
|
||||
|
||||
IN_CHANNELS = 3
|
||||
NUM_CLASSES = 45
|
||||
|
||||
|
||||
class RESISC45ClassificationTask(pl.LightningModule):
|
||||
"""LightningModule for training models on the RESISC45 Dataset."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
pretrained = "imagenet" in self.hparams["weights"]
|
||||
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
self.model = getattr(
|
||||
torchvision.models.resnet, self.hparams["classification_model"]
|
||||
)(pretrained=pretrained)
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{self.hparams['classification_model']}' is not valid."
|
||||
)
|
||||
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
|
||||
if self.hparams["weights"] in ["imagenet_only", "random"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
pass # stub for initializing the weights of other models
|
||||
|
||||
if self.hparams["loss"] == "ce":
|
||||
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
else:
|
||||
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Keyword Args:
|
||||
classification_model: Name of the classification model use
|
||||
loss: Name of the loss function
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
|
||||
self.config_task()
|
||||
|
||||
self.train_metrics = MetricCollection(
|
||||
{
|
||||
"OverallAccuracy": Accuracy(num_classes=NUM_CLASSES, average="micro"),
|
||||
"AverageAccuracy": Accuracy(num_classes=NUM_CLASSES, average="macro"),
|
||||
"IoU": IoU(num_classes=NUM_CLASSES),
|
||||
},
|
||||
prefix="train_",
|
||||
)
|
||||
self.val_metrics = self.train_metrics.clone(prefix="val_")
|
||||
self.test_metrics = self.train_metrics.clone(prefix="test_")
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step - reports average accuracy and average IoU.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the train step logs every `log_every_n_steps` steps where
|
||||
# `log_every_n_steps` is a parameter to the `Trainer` object
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=False)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def training_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch-level training metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by training_step
|
||||
"""
|
||||
self.log_dict(self.train_metrics.compute())
|
||||
self.train_metrics.reset()
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step - reports average accuracy and average IoU.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
||||
self.val_metrics(y_hat_hard, y)
|
||||
|
||||
def validation_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level validation metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by validation_step
|
||||
"""
|
||||
self.log_dict(self.val_metrics.compute())
|
||||
self.val_metrics.reset()
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step identical to the validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the test and validation steps only log per *epoch*
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
|
||||
def test_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level test metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by test_step
|
||||
"""
|
||||
self.log_dict(self.test_metrics.compute())
|
||||
self.test_metrics.reset()
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
a "lr dict" according to the pytorch lightning documentation --
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class RESISC45DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the RESISC45 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.36801773, 0.38097873, 0.343583]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.14540215, 0.13558227, 0.13203649]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
weights: str = "random",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.weights = weights
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
self.val_split_pct = kwargs["val_split_pct"]
|
||||
self.test_split_pct = kwargs["test_split_pct"]
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
self.transforms = K.AugmentationSequential(
|
||||
K.RandomAffine(degrees=30),
|
||||
K.RandomHorizontalFlip(),
|
||||
K.RandomVerticalFlip(),
|
||||
data_keys=["input"],
|
||||
)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def kornia_pipeline(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset with Kornia."""
|
||||
sample["image"] = self.transforms(sample["image"]).squeeze()
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
RESISC45(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
dataset = RESISC45(
|
||||
self.root_dir,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
else:
|
||||
|
||||
self.train_dataset = RESISC45(
|
||||
self.root_dir,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
if self.val_dataset is None or len(self.val_dataset) == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
if self.test_dataset is None or len(self.test_dataset) == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -222,12 +222,12 @@ class SEN12MSDataModule(pl.LightningDataModule):
|
|||
sample["image"] = sample["image"].float()
|
||||
|
||||
if self.band_set == "all":
|
||||
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
|
||||
sample["image"][2:] = sample["image"][2:].clip(0, 10000) / 10000
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
|
||||
elif self.band_set == "s1":
|
||||
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
else:
|
||||
sample["image"][:] = sample["image"][:].clip(0, 10000) / 10000
|
||||
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
|
||||
|
||||
sample["mask"] = sample["mask"][0, :, :].long()
|
||||
sample["mask"] = torch.take( # type: ignore[attr-defined]
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""So2Sat trainer."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
@ -19,6 +20,7 @@ from torchmetrics import Accuracy, IoU, MetricCollection
|
|||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import So2Sat
|
||||
from . import utils
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
|
@ -27,7 +29,6 @@ Module.__module__ = "torch.nn"
|
|||
Conv2d.__module__ = "nn.Conv2d"
|
||||
Linear.__module__ = "nn.Linear"
|
||||
|
||||
IN_CHANNELS = 10
|
||||
NUM_CLASSES = 17
|
||||
|
||||
|
||||
|
@ -36,83 +37,65 @@ class So2SatClassificationTask(pl.LightningModule):
|
|||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
pretrained = "imagenet" in self.hparams["weights"]
|
||||
pretrained = ("imagenet" in self.hparams["weights"]) and not os.path.exists(
|
||||
self.hparams["weights"]
|
||||
)
|
||||
in_channels = self.hparams["in_channels"]
|
||||
|
||||
if self.hparams["classification_model"] == "resnet18":
|
||||
self.model = torchvision.models.resnet18(pretrained=pretrained)
|
||||
self.model.fc = Linear(512, out_features=NUM_CLASSES)
|
||||
elif self.hparams["classification_model"] == "resnet50":
|
||||
self.model = torchvision.models.resnet50(pretrained=pretrained)
|
||||
self.model.fc = Linear(2048, out_features=NUM_CLASSES)
|
||||
elif self.hparams["classification_model"] == "resnet152":
|
||||
self.model = torchvision.models.resnet152(pretrained=pretrained)
|
||||
self.model.fc = Linear(2048, out_features=NUM_CLASSES)
|
||||
# Create the model
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
self.model = getattr(
|
||||
torchvision.models.resnet, self.hparams["classification_model"]
|
||||
)(pretrained=pretrained)
|
||||
|
||||
# Update first layer
|
||||
w_old = None
|
||||
if pretrained:
|
||||
w_old = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
# Create the new layer
|
||||
self.model.conv1 = Conv2d(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
w_new = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
w_new[:, :3, :, :] = w_old
|
||||
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined]
|
||||
w_new
|
||||
)
|
||||
|
||||
# Update last layer
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{self.hparams['classification_model']}' is not valid."
|
||||
)
|
||||
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
if os.path.exists(self.hparams["weights"]):
|
||||
name, state_dict = utils.extract_encoder(self.hparams["weights"])
|
||||
|
||||
if self.hparams["weights"] == "imagenet_only":
|
||||
pass
|
||||
elif self.hparams["weights"] == "imagenet_and_random":
|
||||
# save the initial imagenet weights
|
||||
w_old = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
if self.hparams["classification_model"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hparams['classification_model']}"
|
||||
)
|
||||
|
||||
# replace the first conv layer (with random weights)
|
||||
self.model.conv1 = Conv2d(
|
||||
IN_CHANNELS,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
w_new = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
# graft the imagenet weights into the first 3 channels
|
||||
w_new[:, :3, :, :] = w_old
|
||||
|
||||
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined]
|
||||
w_new
|
||||
)
|
||||
|
||||
elif self.hparams["weights"] == "random":
|
||||
self.model.conv1 = Conv2d(
|
||||
IN_CHANNELS,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif self.hparams["weights"] == "random_rgb":
|
||||
self.model.conv1 = Conv2d(
|
||||
3,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
pass # stub for initializing the weights of other models
|
||||
|
||||
|
@ -351,7 +334,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
weights: str = "random",
|
||||
bands: str = "rgb",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -361,8 +344,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
bands: Either "rgb" or "s2"
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
|
@ -370,7 +352,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.weights = weights
|
||||
self.bands = bands
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
self.transforms = K.AugmentationSequential(
|
||||
|
@ -386,7 +368,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
|
||||
|
||||
if self.weights == "imagenet_only" or self.weights == "random_rgb":
|
||||
if self.bands == "rgb":
|
||||
sample["image"] = sample["image"][:3, :, :]
|
||||
|
||||
return sample
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Common trainer utilities."""
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Module.__module__ = "nn.Module"
|
||||
|
||||
|
||||
def extract_encoder(path: str) -> Tuple[str, Dict[str, Tensor]]:
|
||||
"""Extracts an encoder from a pytorch lightning checkpoint file.
|
||||
|
||||
Args:
|
||||
path: path to checkpoint file (.ckpt)
|
||||
|
||||
Returns:
|
||||
name: str representing model name (generally a torchvision resnet model)
|
||||
state_dict: dict of layer names and weight tensors
|
||||
|
||||
Raises:
|
||||
ValueError: if 'classification_model' or 'encoder' not in
|
||||
checkpoint['hyper_parameters']
|
||||
"""
|
||||
checkpoint = torch.load( # type: ignore[no-untyped-call]
|
||||
path, map_location=torch.device("cpu") # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if "classification_model" in checkpoint["hyper_parameters"]:
|
||||
name = checkpoint["hyper_parameters"]["classification_model"]
|
||||
state_dict = checkpoint["state_dict"]
|
||||
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
|
||||
state_dict = OrderedDict(
|
||||
{k.replace("model.", ""): v for k, v in state_dict.items()}
|
||||
)
|
||||
elif "encoder" in checkpoint["hyper_parameters"]:
|
||||
name = checkpoint["hyper_parameters"]["encoder"]
|
||||
state_dict = checkpoint["state_dict"]
|
||||
state_dict = OrderedDict(
|
||||
{k: v for k, v in state_dict.items() if "model.encoder.model" in k}
|
||||
)
|
||||
state_dict = OrderedDict(
|
||||
{k.replace("model.encoder.model.", ""): v for k, v in state_dict.items()}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Unknown checkpoint task. Only encoder or classification_model"""
|
||||
"""extraction is supported"""
|
||||
)
|
||||
|
||||
return name, state_dict
|
||||
|
||||
|
||||
def load_state_dict(model: Module, state_dict: Dict[str, Tensor]) -> Module:
|
||||
"""Load pretrained resnet weights to a model.
|
||||
|
||||
Args:
|
||||
model: model to load the pretrained weights to
|
||||
state_dict: dict containing tensor parameters
|
||||
|
||||
Returns:
|
||||
the model with pretrained weights
|
||||
|
||||
Warns:
|
||||
If input channels in model != pretrained model input channels
|
||||
If num output classes in model != pretrained model num classes
|
||||
"""
|
||||
in_channels = model.conv1.in_channels # type: ignore[union-attr]
|
||||
expected_in_channels = state_dict["conv1.weight"].shape[1]
|
||||
num_classes = model.fc.out_features # type: ignore[union-attr]
|
||||
expected_num_classes = state_dict["fc.weight"].shape[0]
|
||||
|
||||
if in_channels != expected_in_channels:
|
||||
warnings.warn(
|
||||
f"""input channels {in_channels} != input channels in pretrained"""
|
||||
"""model {expected_in_channels}. Overriding with new input channels"""
|
||||
)
|
||||
del state_dict["conv1.weight"]
|
||||
|
||||
if num_classes != expected_num_classes:
|
||||
warnings.warn(
|
||||
f"""num classes {num_classes} != num classes in pretrained model"""
|
||||
"""{expected_num_classes}. Overriding with new num classes"""
|
||||
)
|
||||
del state_dict["fc.weight"], state_dict["fc.bias"]
|
||||
|
||||
model.load_state_dict(state_dict, strict=False) # type: ignore[arg-type]
|
||||
|
||||
return model
|
|
@ -4,12 +4,7 @@
|
|||
"""TorchGeo transforms."""
|
||||
|
||||
from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
|
||||
from .transforms import (
|
||||
AugmentationSequential,
|
||||
Identity,
|
||||
RandomHorizontalFlip,
|
||||
RandomVerticalFlip,
|
||||
)
|
||||
from .transforms import AugmentationSequential
|
||||
|
||||
__all__ = (
|
||||
"AppendNDBI",
|
||||
|
@ -17,9 +12,6 @@ __all__ = (
|
|||
"AppendNDVI",
|
||||
"AppendNDWI",
|
||||
"AugmentationSequential",
|
||||
"Identity",
|
||||
"RandomHorizontalFlip",
|
||||
"RandomVerticalFlip",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
|
|
|
@ -15,91 +15,6 @@ from torch.nn import Module # type: ignore[attr-defined]
|
|||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class RandomHorizontalFlip(Module): # type: ignore[misc,name-defined]
|
||||
"""Horizontally flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Args:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Args:
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
if "image" in sample:
|
||||
sample["image"] = sample["image"].flip(-1)
|
||||
|
||||
if "boxes" in sample:
|
||||
height, width = sample["image"].shape[-2:]
|
||||
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].flip(-1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class RandomVerticalFlip(Module): # type: ignore[misc,name-defined]
|
||||
"""Vertically flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Args:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Args:
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
if "image" in sample:
|
||||
sample["image"] = sample["image"].flip(-2)
|
||||
|
||||
if "boxes" in sample:
|
||||
height, width = sample["image"].shape[-2:]
|
||||
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].flip(-2)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Identity(Module): # type: ignore[misc,name-defined]
|
||||
"""Identity function used for testing purposes."""
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Do nothing.
|
||||
|
||||
Args:
|
||||
sample: the input
|
||||
|
||||
Returns:
|
||||
the unchanged input
|
||||
"""
|
||||
return sample
|
||||
|
||||
|
||||
class AugmentationSequential(Module): # type: ignore[misc]
|
||||
"""Wrapper around kornia AugmentationSequential to handle input dicts."""
|
||||
|
||||
|
|
3
train.py
3
train.py
|
@ -23,6 +23,8 @@ from torchgeo.trainers import (
|
|||
LandcoverAISegmentationTask,
|
||||
NAIPChesapeakeDataModule,
|
||||
NAIPChesapeakeSegmentationTask,
|
||||
RESISC45ClassificationTask,
|
||||
RESISC45DataModule,
|
||||
SEN12MSDataModule,
|
||||
SEN12MSSegmentationTask,
|
||||
So2SatClassificationTask,
|
||||
|
@ -39,6 +41,7 @@ TASK_TO_MODULES_MAPPING: Dict[
|
|||
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat": (So2SatClassificationTask, So2SatDataModule),
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче