Add __len__ method to GeoDataset and ZipDataset (#175)

* Add __len__ method to GeoDataset and ZipDataset

* Fix type hints
This commit is contained in:
Adam J. Stewart 2021-09-30 22:36:53 -05:00 коммит произвёл GitHub
Родитель bc6d0358ca
Коммит 7ac915b765
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 38 добавлений и 6 удалений

Просмотреть файл

@ -57,11 +57,15 @@ class TestGeoDataset:
query = BoundingBox(0, 0, 0, 0, 0, 0)
assert dataset[query] == {"index": query}
def test_len(self, dataset: GeoDataset) -> None:
assert len(dataset) == 1
def test_add_two(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
dataset = ds1 + ds2
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
def test_add_three(self) -> None:
ds1 = CustomGeoDataset()
@ -69,6 +73,7 @@ class TestGeoDataset:
ds3 = CustomGeoDataset()
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 3
def test_add_four(self) -> None:
ds1 = CustomGeoDataset()
@ -77,10 +82,13 @@ class TestGeoDataset:
ds4 = CustomGeoDataset()
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 4
def test_str(self, dataset: GeoDataset) -> None:
assert "type: GeoDataset" in str(dataset)
assert "bbox: BoundingBox" in str(dataset)
out = str(dataset)
assert "type: GeoDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 1" in out
def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
@ -223,9 +231,14 @@ class TestZipDataset:
query = BoundingBox(0, 1, 2, 3, 4, 5)
assert dataset[query] == {"index": query}
def test_len(self, dataset: ZipDataset) -> None:
assert len(dataset) == 2
def test_str(self, dataset: ZipDataset) -> None:
assert "type: ZipDataset" in str(dataset)
assert "bbox: BoundingBox" in str(dataset)
out = str(dataset)
assert "type: ZipDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 2" in out
def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()

Просмотреть файл

@ -104,6 +104,15 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""
return ZipDataset([self, other])
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns:
length of the dataset
"""
count: int = self.index.count(self.index.bounds)
return count
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -113,7 +122,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
return f"""\
{self.__class__.__name__} Dataset
type: GeoDataset
bbox: {self.bounds}"""
bbox: {self.bounds}
size: {len(self)}"""
@property
def bounds(self) -> BoundingBox:
@ -719,6 +729,14 @@ class ZipDataset(GeoDataset):
sample.update(ds[query])
return sample
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns:
length of the dataset
"""
return sum(map(len, self.datasets))
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -728,7 +746,8 @@ class ZipDataset(GeoDataset):
return f"""\
{self.__class__.__name__} Dataset
type: ZipDataset
bbox: {self.bounds}"""
bbox: {self.bounds}
size: {len(self)}"""
@property
def bounds(self) -> BoundingBox: