зеркало из https://github.com/microsoft/torchgeo.git
Add __len__ method to GeoDataset and ZipDataset (#175)
* Add __len__ method to GeoDataset and ZipDataset * Fix type hints
This commit is contained in:
Родитель
bc6d0358ca
Коммит
7ac915b765
|
@ -57,11 +57,15 @@ class TestGeoDataset:
|
|||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: GeoDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add_two(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_add_three(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -69,6 +73,7 @@ class TestGeoDataset:
|
|||
ds3 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2 + ds3
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 3
|
||||
|
||||
def test_add_four(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -77,10 +82,13 @@ class TestGeoDataset:
|
|||
ds4 = CustomGeoDataset()
|
||||
dataset = (ds1 + ds2) + (ds3 + ds4)
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_str(self, dataset: GeoDataset) -> None:
|
||||
assert "type: GeoDataset" in str(dataset)
|
||||
assert "bbox: BoundingBox" in str(dataset)
|
||||
out = str(dataset)
|
||||
assert "type: GeoDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 1" in out
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
|
@ -223,9 +231,14 @@ class TestZipDataset:
|
|||
query = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: ZipDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_str(self, dataset: ZipDataset) -> None:
|
||||
assert "type: ZipDataset" in str(dataset)
|
||||
assert "bbox: BoundingBox" in str(dataset)
|
||||
out = str(dataset)
|
||||
assert "type: ZipDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 2" in out
|
||||
|
||||
def test_vision_dataset(self) -> None:
|
||||
ds1 = CustomVisionDataset()
|
||||
|
|
|
@ -104,6 +104,15 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
"""
|
||||
return ZipDataset([self, other])
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
count: int = self.index.count(self.index.bounds)
|
||||
return count
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -113,7 +122,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: GeoDataset
|
||||
bbox: {self.bounds}"""
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def bounds(self) -> BoundingBox:
|
||||
|
@ -719,6 +729,14 @@ class ZipDataset(GeoDataset):
|
|||
sample.update(ds[query])
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return sum(map(len, self.datasets))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -728,7 +746,8 @@ class ZipDataset(GeoDataset):
|
|||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: ZipDataset
|
||||
bbox: {self.bounds}"""
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def bounds(self) -> BoundingBox:
|
||||
|
|
Загрузка…
Ссылка в новой задаче