Add empirical sampling mode to the RCF model (#1339)

* Initial commit

* Fix practically all the problems

* Add docs

* Bruh

* np.typing in quotes because that makes more sense

* make really sure that mosaiks does stuff

* Consolidating RCFs

* Work

* formatting

* coverage

* Update torchgeo/models/rcf.py

* Minor style changes

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2023-09-29 06:55:11 -07:00 коммит произвёл GitHub
Родитель c51014c656
Коммит 51ffb698ee
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 130 добавлений и 17 удалений

Просмотреть файл

@ -1,25 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
import torch
from torchgeo.datasets import EuroSAT
from torchgeo.models import RCF
class TestRCF:
def test_in_channels(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
model(x)
model = RCF(in_channels=3, features=4, kernel_size=3)
model = RCF(in_channels=3, features=4, kernel_size=3, mode="gaussian")
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)
def test_num_features(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4
@ -29,14 +32,27 @@ class TestRCF:
assert y.shape[0] == 4
def test_untrainable(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
assert len(list(model.parameters())) == 0
def test_biases(self) -> None:
model = RCF(features=24, bias=10)
model = RCF(features=24, bias=10, mode="gaussian")
assert torch.all(model.biases == 10)
def test_seed(self) -> None:
weights1 = RCF(seed=1).weights
weights2 = RCF(seed=1).weights
weights1 = RCF(seed=1, mode="gaussian").weights
weights2 = RCF(seed=1, mode="gaussian").weights
assert torch.allclose(weights1, weights2)
def test_empirical(self) -> None:
root = os.path.join("tests", "data", "eurosat")
ds = EuroSAT(root=root, bands=EuroSAT.rgb_bands, split="train")
model = RCF(
in_channels=3, features=4, kernel_size=3, mode="empirical", dataset=ds
)
model(torch.randn(2, 3, 8, 8))
def test_empirical_no_dataset(self) -> None:
match = "dataset must be provided when mode is 'empirical'"
with pytest.raises(ValueError, match=match):
RCF(mode="empirical", dataset=None)

Просмотреть файл

@ -5,21 +5,33 @@
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module
from ..datasets import NonGeoDataset
class RCF(Module):
"""This model extracts random convolutional features (RCFs) from its input.
RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
RCFs are used in the Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in "A generalizable and accessible approach to machine
learning with global satellite imagery".
This class can operate in two modes, "gaussian" and "empirical". In "gaussian" mode,
the filters will be sampled from a Gaussian distribution, while in "empirical" mode,
the filters will be sampled from a dataset.
If you use this model in your research, please cite the following paper:
* https://www.nature.com/articles/s41467-021-24638-z
.. note::
This Module is *not* trainable. It is only used as a feature extractor.
This Module is *not* trainable. It is only used as a feature extractor.
"""
weights: Tensor
@ -32,6 +44,8 @@ class RCF(Module):
kernel_size: int = 3,
bias: float = -1.0,
seed: Optional[int] = None,
mode: str = "gaussian",
dataset: Optional[NonGeoDataset] = None,
) -> None:
"""Initializes the RCF model.
@ -41,21 +55,28 @@ class RCF(Module):
.. versionadded:: 0.2
The *seed* parameter.
.. versionadded:: 0.5
The *mode* and *dataset* parameters.
Args:
in_channels: number of input channels
features: number of features to compute, must be divisible by 2
kernel_size: size of the kernel used to compute the RCFs
bias: bias of the convolutional layer
seed: random seed used to initialize the convolutional layer
mode: "empirical" or "gaussian"
dataset: a NonGeoDataset to sample from when mode is "empirical"
"""
super().__init__()
assert mode in ["empirical", "gaussian"]
if mode == "empirical" and dataset is None:
raise ValueError("dataset must be provided when mode is 'empirical'")
assert features % 2 == 0
num_patches = features // 2
if seed is None:
generator = None
else:
generator = torch.Generator().manual_seed(seed)
generator = torch.Generator()
if seed:
generator = generator.manual_seed(seed)
# We register the weight and bias tensors as "buffers". This does two things:
# makes them behave correctly when we call .to(...) on the module, and makes
@ -64,7 +85,7 @@ class RCF(Module):
self.register_buffer(
"weights",
torch.randn(
features // 2,
num_patches,
in_channels,
kernel_size,
kernel_size,
@ -73,9 +94,85 @@ class RCF(Module):
),
)
self.register_buffer(
"biases", torch.zeros(features // 2, requires_grad=False) + bias
"biases", torch.zeros(num_patches, requires_grad=False) + bias
)
if mode == "empirical":
assert dataset is not None
num_channels, height, width = dataset[0]["image"].shape
assert num_channels == in_channels
patches = np.zeros(
(num_patches, num_channels, kernel_size, kernel_size), dtype=np.float32
)
idxs = torch.randint(
0, len(dataset), (num_patches,), generator=generator
).numpy()
ys = torch.randint(
0, height - kernel_size, (num_patches,), generator=generator
).numpy()
xs = torch.randint(
0, width - kernel_size, (num_patches,), generator=generator
).numpy()
for i in range(num_patches):
img = dataset[idxs[i]]["image"]
patches[i] = img[
:, ys[i] : ys[i] + kernel_size, xs[i] : xs[i] + kernel_size
]
patches = self._normalize(patches)
self.weights = torch.tensor(patches)
def _normalize(
self,
patches: "np.typing.NDArray[np.float32]",
min_divisor: float = 1e-8,
zca_bias: float = 0.001,
) -> "np.typing.NDArray[np.float32]":
"""Does ZCA whitening on a set of input patches.
Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120
Args:
patches: a numpy array of size (N, C, H, W)
min_divisor: a small number to guard against division by zero
zca_bias: bias term for ZCA whitening
Returns
a numpy array of size (N, C, H, W) containing the normalized patches
.. versionadded:: 0.5
""" # noqa: E501
n_patches = patches.shape[0]
orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1)
# Zero mean every feature
patches = patches - np.mean(patches, axis=1, keepdims=True)
# Normalize
patch_norms = np.linalg.norm(patches, axis=1)
# Get rid of really small norms
patch_norms[np.where(patch_norms < min_divisor)] = 1
# Make features unit norm
patches = patches / patch_norms[:, np.newaxis]
patchesCovMat = 1.0 / n_patches * patches.T.dot(patches)
(E, V) = np.linalg.eig(patchesCovMat)
E += zca_bias
sqrt_zca_eigs = np.sqrt(E)
inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)
patches_normalized: "np.typing.NDArray[np.float32]" = (
(patches).dot(global_ZCA).dot(global_ZCA.T)
)
return patches_normalized.reshape(orig_shape).astype("float32")
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the RCF model.