From 3cc63def02f6cbf40ab7d84f6afe13284e044651 Mon Sep 17 00:00:00 2001 From: isaac <22203655+isaaccorley@users.noreply.github.com> Date: Tue, 2 Nov 2021 10:45:38 -0500 Subject: [PATCH] BigEarthNet Trainers (#211) * add additional bigearthnet test data for train/val/test split * update bigearthnet dataset length test * add MultiLabelClassificationTask * add BigEarthNet trainer and datamodule * add bigearthnet and multilabelclassificationtask tests * mypy and format * add estimated band min/max values for normalization * softmax outputs to correctly compute metrics * update min/max stats for 100k samples * organize imports in torchgeo.trainers.__init__.py * clean up fixtures in test_tasks.py * added bigearthnet to train.py * format * move fixtures into class methods * consolidate bigearthnet fixtures * refactor tasks tests * add scope=class * style/mypy fixes * mypy fixes --- conf/bigearthnet.yaml | 18 ++ conf/task_defaults/bigearthnet.yaml | 13 ++ .../bigearthnet/BigEarthNet-S1-v1.0.tar.gz | Bin 1086 -> 1358 bytes .../bigearthnet/BigEarthNet-S2-v1.0.tar.gz | Bin 1110 -> 1762 bytes tests/datasets/test_bigearthnet.py | 2 +- tests/trainers/test_bigearthnet.py | 39 ++++ tests/trainers/test_tasks.py | 177 +++++++++++++--- torchgeo/trainers/__init__.py | 23 ++- torchgeo/trainers/bigearthnet.py | 193 ++++++++++++++++++ torchgeo/trainers/tasks.py | 122 ++++++++++- train.py | 3 + 11 files changed, 554 insertions(+), 36 deletions(-) create mode 100644 conf/bigearthnet.yaml create mode 100644 conf/task_defaults/bigearthnet.yaml create mode 100644 tests/trainers/test_bigearthnet.py create mode 100644 torchgeo/trainers/bigearthnet.py diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml new file mode 100644 index 000000000..7f8b52f9d --- /dev/null +++ b/conf/bigearthnet.yaml @@ -0,0 +1,18 @@ +trainer: + gpus: 1 # single GPU training + min_epochs: 10 + max_epochs: 40 + benchmark: True + +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet.yaml new file mode 100644 index 000000000..723d4e7d9 --- /dev/null +++ b/conf/task_defaults/bigearthnet.yaml @@ -0,0 +1,13 @@ +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz b/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz index 9169a8a932c58a3e93789b8f209d8ff6a590043a..d9df455f105d44fc1fab8b90f739c1a585089166 100644 GIT binary patch literal 1358 zcmV-U1+n@ciwFP!000001MOT*ZzDArp0w3&Sz5FlSS}!0;fDG#p0BW6XiL#a&RT-p62{ER*G6R;`#zFm$A0~h zc z+0!1#Fr;xp7(yELA{LxucTxEj#nzDduj5U3XWxDB<@Q4t*SGgvgU^@lVDm9{Y~35>_jK#&( zcP;1`J37|T?j?<<58e4^QIwo#cTxEj#l-7>9AnqQ?%u<-gf?gWu>m<`%zt7158eOm zn$yq!hHcgO`qyQA0XKaA-`V+4TjcRwO{>gn6`q>op;AcC;)ppHJWe#~3{N%X_ZHa5l^-(>*!ANXG`|L+0+ zD+DF@-_c9_?-;=UGEnCKEELjR{ujLcy8JIfm;VKK595Ckt~>vqjyE6)P=Wsq-2wie z5v22fKpK7!yGMRP4oE`Qp2kr)>o4I3Lka#j3;b{A{BHvP%fRXK9mG-XEX(T0=<7~K zmiZ_OS1Dt(MH2r25k?w;j{*`NR96aCpGF~V`YcLg)Qp%PC)G~jH@Q*Oiz0SFL*9t! z;IvRR4SwJ``buBc$zWE%N>)L~s+04q+fkcEPkk?;QOFCc)n5OM533!8d=M=+HrzHz zyhCA9H#XhR>pQzV>Rq-v)=k{Gx?eRBdw%GV7kGBEM_AIBKPALSrp{zuk- zUhV(3#k0WwlK{|v%Jo41nJ$3;)rDmK2k}QS0RIF31OE@&1OHDG7wbRfApeUk0r+16 zQvDt$^k|y|!3l~yFKts^p>cvpi#M4DNy_{c^gT@uQA}GQZBmbf32)IS zF^ZZfB;0}o=-6-Z?IRO;3O#JGh#ag2J~`%%<}y^FcqbG?LPvzfhkjfv+2`YC$cs`I z^MYKs8me2BWvjAkRTiyEnpG}Y4Ogrt6s+`a3uV3P9J%W1s8xVzW3v6 zDFFW;{Qrs8|2W3r|0@R>^MB#*e~#7l{~cRA3nG611OHzNrf~jisO!D_e;)QVIU8o) zzaD49RN#Nm|7Hid{15t{YEXm!2k-wl|JQ*3WdNT4D%XSOzorY|e{~_5|Ka?vVo-wr z%jbV=9nSwq0i6FS*SksQe@cWN6fd5ynca!`-;D6wMiTn@xk77`Q|DD16KihJE|D^!*pK?9Wf2NCz_#e)H z&JJ?5{kx17%XzYY8^1LvOqJnspxsUH))zJDW#ug6hQ0000000000 Q0C2_eA1;maY5;fu0KvP>q5uE@ literal 1086 zcmV-E1i||siwFP!000001MQmcZ`(u|$DOv?%Ba!yg7F5VQ+S8k)W6O4hO{c$B83T& ztwln}x;RdH!?nZNE>vX_0;KSYm;42Yw`~6if5u+(C-9u(EN&8-mbQ+!J|F3F=ex&u zXP@uBPki2Hy-z8R_aAxjMoZl|RBOtnbjc`yZJHuhZBxm|BP7*Ob^Tp<*m>xvt({#*gYi-*D{4awuw@vd@t%wo zc&bi33yP{*iYbxi74_!?qw(iq7@uZ0ulzOT!t?(;#?J1;TFknOzGh@h5#yhm|1;Nr z)3zqof75_0q7-ohuKD`kXnZ8CLVR12D$7y@QcDmjx%5(&CJ>ghv5Wj@**I+qq|1KClAFQS4?Lzv+ zU@7f=x01i6e?B=GpOvwFCGF8oNm>!p}ec2Y>K7Bh<&W?OboLq&@2YWin6yHASD)f3YN>{ue;8`tQ@W=SR+g7t;=nY3*qg28+G}*P9aPzY3+} z_y3xj>OYvH{ujXUEdo@B6SI1se3c%jc{?m}!ysls5+)Hzc$)^KOWT~e)%CH}qcGTD zoU@Es>XSpS2g^sy!XV$mHV7sdJFMj*OG4k(ZI zy(nL@2mR*64HF(gL7M;du(;}mA@49~*AsWGHvZ5i&3KkI z(dXe)&rL7pKB-n;HO%1x`e6XA_nb&``ZRX;g+)u-az1M{nh=fW8O$JE9Zttl4Uy{w zo)f`Uxn90z>VBKBh=~TBm>mquZarJ{Q-~~zi>?3qy{p>)vvL2g5YYce|Nlbs|2)R% z|Cf#;#((bnKg$^T|CTAf{}J#1(El%lV|nrqfP5z3UY_H6%m3wJ%bANU&-Ah7X(wBr z_Oj(^H(Q?av*paumS=m~^2Jk#Zj#Omz8B2>7re9kw~+%X`*?fY#p`vtzt{F6$1ABneKe#`!*X6n>|^P%E1Hp6x2% z+vAp>YdxU&eIQ7=+CMaKx8W9#Lh)=ocPM4VNZ@qH{^!iA^f3Lr=!Na%*1p`!=d0l8 zUeP|pd^xNURJ+&oH|D&EIM{fQynM$>)>uY;OqcQ~KO3qC?m5G1LCq?(+iidCj}{O^L&vV0mB-+s^^T9BR0<&d9hZ?+)wD45(xSyH*G9i z)plnzZ7^44-1%4?9c#y{h;m%k80zxmXn}{Ti2P#Y-{7Dg|1(kMkweSe)1plWy%QeD zuKTe@$pMu;%d!n@XyabxC;P_wO@{|d1AOqF+W%b8L3cBxF-&IRCs-8Q4#j|b+ymKUX9?Sr=FzDO7L8%{tlpNDkrG_?zy>M` zqx{h#+yIt$zakp8WpaJZk%=l-0A4h>!Nk9!xhQcLWmiDC&hgUkk5jcEOX24=}ex`vH#Hq1T zqF5uC`JW9-{)5-=>1&HzEa)lLj8)F4br^Md0jU`)*TQ=5K9Y>O9i^p?!C%m;o3R-7 zoP%2>e2B={iwr0rqU*`1`y!%mAF`WsjO*MUG0QLx)(y4s*Ld=zPTW^>!Ldq+*cMP) z4J(jQ<_0ar%wS+1hkKHuD5x}=G6+jbtcglMw07AA2~p(GM_uWPA;J3ZEul%i0+|$l z&_z3zF3FNmYu{T#%Y&G0=o6Y}C7cpMr1_|dp9~97*3ePFU-23-0GG zX?NoDTmmLo*SD(oUeRb5*wu}EwMn`LTCK{Urjfdv02?re;~%6ytJ3aN4DKJD%7Sas ztO@E2_Chu?U?G9+jB_YF0hyiCdNC!wHo5(r3@J&DnNBBDz9Wfm)x8Ul{^d|EojE%7xoioJj_(`2xGJu5AHCvKK8~bU9{L`(Puts*}uKD%zIc! zd9%#HCnT|haI>9grF8qVcDNdzUi+2vXR$7e^8Lsg^~+Trug>7NY}dvV&aG^{{J>U5 zej?R4amfQdtK`@uLA?eCX3MA6ki+I~=vPs#ArZX~XF0(rR)=G{c$RZi^Y9~GDcqk& zwH1?TErEK@L;(DFv@q0qGURpgK)dD81UxEBz}k-S_S zi0?OCY_XLKgS6`(#7Pz19tA@M=^BZ$?5`Wl6Ncwes_dAK@Jdi6hiMLa0YdL6T4LS@ zRa`l&Se6lyGGcWJxUct?@7UzluuFsRfT7&~DL>z~9UgAg zDjA_Jn<|cpAH5Cc^a1Za`A(~JDxqe40TPUgueI1+MrD zNp|%uA9LT-&47gWj1dBwO>2(9OIg=r8TOosT$CoYZc`fN`f#C;mAQ8Fc30EFlh+kt zoy{v9+Oz3prn&dVV*T9@iGE4%I>clATDD1_uXq~dj+#)#@`e-KrY3P3$0y9KH6czH zcR+UuBZ3Gl_1jY5`!yg|(|&^?3>=FB1dBaE{5HsdiyXZx%FBoON}2wV!%jR9S^Q3g ztCls`Q%hPbTd2i~nU!MZX%KAv$(dpu-cAMeev)|GfaWTG02dDK-f1U*CqOFjcEsJ8 zSp4z|5WR-Dp5+RBnjl-QLEG*6^uZUj*{I9V4Cf4Y+40K?H5$_fuz28Xfa2%?%MOVE z%!V!IwxV?r>IE@H;5!$6QyPd;*Q)TT2)HUkM(i4;kZI=>qk%00d5{=4{}UO%3G!5{kEh;C%`FVY`s A1ONa4 diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 89416828f..f30d7c7d1 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -71,7 +71,7 @@ class TestBigEarthNet: assert x["image"].shape == (12, 120, 120) def test_len(self, dataset: BigEarthNet) -> None: - assert len(dataset) == 2 + assert len(dataset) == 4 def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True) diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py new file mode 100644 index 000000000..add7d8861 --- /dev/null +++ b/tests/trainers/test_bigearthnet.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.trainers import BigEarthNetDataModule + + +class TestBigEarthNetDataModule: + @pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False])) + def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: + bands, unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + bands, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 0f1973cb0..869cb230c 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -2,50 +2,114 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Tuple, cast +from typing import Any, Dict, Generator, Optional, cast import pytest +import pytorch_lightning as pl +import torch +import torch.nn.functional as F from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, TensorDataset from torchgeo.trainers import ( ClassificationTask, CycloneDataModule, + MultiLabelClassificationTask, RegressionTask, - So2SatDataModule, ) from .test_utils import mocked_log -@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) -def bands(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) +class DummyDataset(Dataset): # type: ignore[type-arg] + def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None: + x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w) + y = torch.randint( # type: ignore[attr-defined] + 0, num_classes, size=(10,) + ) # (b,) + + if multilabel: + y = F.one_hot(y, num_classes=num_classes) # (b, classes) + + self.dataset = TensorDataset(x, y) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Tensor]: + x, y = self.dataset[idx] + sample = {"image": x, "label": y} + return sample -@pytest.fixture(scope="module", params=[True, False]) -def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm +class DummyDataModule(pl.LightningDataModule): + def __init__( + self, + num_channels: int, + num_classes: int, + multilabel: bool, + batch_size: int = 1, + num_workers: int = 0, + ) -> None: + super().__init__() # type: ignore[no-untyped-call] + self.num_channels = num_channels + self.num_classes = num_classes + self.multilabel = multilabel + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: Optional[str] = None) -> None: + self.dataset = DummyDataset( + num_channels=self.num_channels, + num_classes=self.num_classes, + multilabel=self.multilabel, + ) + + def train_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def test_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) class TestClassificationTask: + @pytest.fixture(scope="class", params=[2, 3, 5]) + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=request.param, + num_classes=45, + multilabel=False, + batch_size=2, + num_workers=0, + ) + dm.prepare_data() + dm.setup() + return dm + @pytest.fixture( - params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) + scope="class", + params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]), ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] + def config( + self, request: SubRequest, datamodule: DummyDataModule + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -65,20 +129,20 @@ class TestClassificationTask: assert "lr_scheduler" in out def test_training( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) @@ -99,6 +163,7 @@ class TestClassificationTask: def test_invalid_loss(self, config: Dict[str, Any]) -> None: config["loss"] = "invalid_loss" + config["classification_model"] = "resnet18" error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): ClassificationTask(**config) @@ -117,6 +182,68 @@ class TestClassificationTask: ClassificationTask(**config) +class TestMultiLabelClassificationTask: + @pytest.fixture(scope="class") + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=3, + num_classes=43, + multilabel=True, + batch_size=2, + num_workers=0, + ) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"])) + def config( + self, datamodule: DummyDataModule, request: SubRequest + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] + loss, weights = request.param + task_args["loss"] = loss + task_args["weights"] = weights + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> MultiLabelClassificationTask: + task = MultiLabelClassificationTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_training( + self, datamodule: DummyDataModule, task: ClassificationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: DummyDataModule, task: ClassificationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + MultiLabelClassificationTask(**config) + + class TestRegressionTask: @pytest.fixture(scope="class") def datamodule(self) -> CycloneDataModule: diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index a07eee5ef..20148f752 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule from .byol import BYOLTask from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask from .cyclone import CycloneDataModule @@ -11,29 +12,35 @@ from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentation from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask from .so2sat import So2SatClassificationTask, So2SatDataModule -from .tasks import ClassificationTask, RegressionTask +from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( # Tasks - "ClassificationTask", - "RegressionTask", - # Trainers + "BigEarthNetClassificationTask", "BYOLTask", "ChesapeakeCVPRSegmentationTask", "ChesapeakeCVPRDataModule", + "ClassificationTask", "CycloneDataModule", "LandcoverAIDataModule", "LandcoverAISegmentationTask", - "NAIPChesapeakeDataModule", + "MultiLabelClassificationTask", "NAIPChesapeakeSegmentationTask", "RESISC45ClassificationTask", - "RESISC45DataModule", - "SEN12MSDataModule", + "RegressionTask", "SEN12MSSegmentationTask", - "So2SatDataModule", "So2SatClassificationTask", "UCMercedClassificationTask", + # DataModules + "BigEarthNetDataModule", + "ChesapeakeCVPRDataModule", + "CycloneDataModule", + "LandcoverAIDataModule", + "NAIPChesapeakeDataModule", + "RESISC45DataModule", + "SEN12MSDataModule", + "So2SatDataModule", "UCMercedDataModule", ) diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py new file mode 100644 index 000000000..a0abaadd2 --- /dev/null +++ b/torchgeo/trainers/bigearthnet.py @@ -0,0 +1,193 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BigEarthNet trainer.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import BigEarthNet +from ..datasets.utils import dataset_split +from .tasks import MultiLabelClassificationTask + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class BigEarthNetClassificationTask(MultiLabelClassificationTask): + """LightningModule for training models on the BigEarthNet Dataset.""" + + num_classes = 43 + + +class BigEarthNetDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the BigEarthNet dataset. + + Uses the train/val/test splits from the dataset. + """ + + # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) + # min/max band statistics computed on 100k random samples + band_mins_raw = torch.tensor( # type: ignore[attr-defined] + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] + ) + band_maxs_raw = torch.tensor( # type: ignore[attr-defined] + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] + ) + + # min/max band statistics computed by percentile clipping the + # above to samples to [2, 98] + band_mins = torch.tensor( # type: ignore[attr-defined] + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + band_maxs = torch.tensor( # type: ignore[attr-defined] + [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + batch_size: int = 64, + num_workers: int = 4, + unsupervised_mode: bool = False, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes + bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.batch_size = batch_size + self.num_workers = num_workers + self.unsupervised_mode = unsupervised_mode + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + if bands == "all": + self.mins = self.band_mins[:, None, None] + self.maxs = self.band_maxs[:, None, None] + elif bands == "s1": + self.mins = self.band_mins[:2, None, None] + self.maxs = self.band_maxs[:2, None, None] + else: + self.mins = self.band_mins[2:, None, None] + self.maxs = self.band_maxs[2:, None, None] + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) + sample["image"] = torch.clip( # type: ignore[attr-defined] + sample["image"], min=0.0, max=1.0 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + BigEarthNet(self.root_dir, bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + + if not self.unsupervised_mode: + + dataset = BigEarthNet( + self.root_dir, bands=self.bands, transforms=transforms + ) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + else: + self.train_dataset = BigEarthNet( # type: ignore[assignment] + self.root_dir, bands=self.bands, transforms=transforms + ) + self.val_dataset, self.test_dataset = None, None # type: ignore[assignment] + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + if self.unsupervised_mode or self.val_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + if self.unsupervised_mode or self.test_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index 3554ce832..037633445 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -57,7 +57,7 @@ class ClassificationTask(pl.LightningModule): # Update first layer if in_channels != 3: - w_old = None + w_old = torch.empty(0) # type: ignore[attr-defined] if pretrained: w_old = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight @@ -75,7 +75,11 @@ class ClassificationTask(pl.LightningModule): w_new = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight ).detach() - w_new[:, :3, :, :] = w_old + if in_channels > 3: + w_new[:, :3, :, :] = w_old + else: + w_old = w_old[:, :in_channels, :, :] + w_new[:, :in_channels, :, :] = w_old self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 w_new ) @@ -266,6 +270,120 @@ class ClassificationTask(pl.LightningModule): } +class MultiLabelClassificationTask(ClassificationTask): + """Abstract base class for multi label image classification LightningModules.""" + + #: number of classes in dataset + num_classes: int = 43 + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + self.config_model() + + if self.hparams["loss"] == "bce": + self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined] + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + classification_model: Name of the classification model use + loss: Name of the loss function + weights: Either "random", "imagenet_only", "imagenet_and_random", or + "random_rgb" + """ + super().__init__(**kwargs) + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_metrics = MetricCollection( + { + "OverallAccuracy": Accuracy( + num_classes=self.num_classes, average="micro", multiclass=False + ), + "AverageAccuracy": Accuracy( + num_classes=self.num_classes, average="macro", multiclass=False + ), + "F1Score": FBeta( + num_classes=self.num_classes, + beta=1.0, + average="micro", + multiclass=False, + ), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step. + + Args: + batch: Current batch + batch_idx: Index of current batch + Returns: + training loss + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) + + class RegressionTask(pl.LightningModule): """LightningModule for training models on regression datasets.""" diff --git a/train.py b/train.py index 2d646b30b..dc17a1b1f 100755 --- a/train.py +++ b/train.py @@ -14,6 +14,8 @@ from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torchgeo.trainers import ( + BigEarthNetClassificationTask, + BigEarthNetDataModule, BYOLTask, ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask, @@ -36,6 +38,7 @@ from torchgeo.trainers import ( TASK_TO_MODULES_MAPPING: Dict[ str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] ] = { + "bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), "cyclone": (RegressionTask, CycloneDataModule),