From 26343ddc3610b4f04f6a63f8bf0bf56e22207adc Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 19 Dec 2021 09:23:18 -0600 Subject: [PATCH] Fix GeoDataset pickling (#304) * Fix GeoDataset pickling * mypy fixes --- tests/datasets/test_geo.py | 9 +++++++++ torchgeo/datasets/geo.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 8deb5d030..698590e33 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os +import pickle from pathlib import Path from typing import Dict @@ -90,6 +91,14 @@ class TestGeoDataset: assert "bbox: BoundingBox" in out assert "size: 1" in out + def test_picklable(self, dataset: GeoDataset) -> None: + x = pickle.dumps(dataset) + y = pickle.loads(x) + assert dataset.crs == y.crs + assert dataset.res == y.res + assert len(dataset) == len(y) + assert dataset.bounds == y.bounds + def test_abstract(self) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoDataset() # type: ignore[abstract] diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 826a715e0..345a58a4d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -126,6 +126,41 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC): bbox: {self.bounds} size: {len(self)}""" + # NOTE: This hack should be removed once the following issue is fixed: + # https://github.com/Toblerity/rtree/issues/87 + + def __getstate__( + self, + ) -> Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ]: + """Define how instances are pickled. + + Returns: + the state necessary to unpickle the instance + """ + objects = self.index.intersection(self.index.bounds, objects=True) + tuples = [(item.id, item.bounds, item.object) for item in objects] + return self.__dict__, tuples + + def __setstate__( + self, + state: Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ], + ) -> None: + """Define how to unpickle an instance. + + Args: + state: the state of the instance when it was pickled + """ + attrs, tuples = state + self.__dict__.update(attrs) + for item in tuples: + self.index.insert(*item) + @property def bounds(self) -> BoundingBox: """Bounds of the index.