Check to make sure bands is a tuple, fixed some typing issues

This commit is contained in:
Caleb Robinson 2021-06-05 22:22:33 +00:00 коммит произвёл Adam J. Stewart
Родитель 3f13606837
Коммит 7edc3d4bf9
1 изменённых файлов: 6 добавлений и 5 удалений

Просмотреть файл

@ -97,7 +97,7 @@ class CV4AKenyaCropType(VisionDataset):
root: str = "data",
chip_size: int = 256,
stride: int = 128,
bands: Optional[Tuple[str]] = None,
bands: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable[[Image.Image], Any]] = None,
target_transform: Optional[Callable[[Image.Image], Any]] = None,
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
@ -153,7 +153,7 @@ class CV4AKenyaCropType(VisionDataset):
]:
self.chips_metadata.append((tile_index, y, x))
def __getitem__(self, index: int) -> Dict:
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
@ -223,12 +223,13 @@ class CV4AKenyaCropType(VisionDataset):
return (labels, field_ids)
def _validate_bands(self, bands: Optional[Tuple[str]]) -> Tuple[str]:
def _validate_bands(self, bands: Optional[Tuple[str, ...]]) -> Tuple[str, ...]:
"""Routine for validating a list of bands / filling in a default value"""
if bands is None:
return self.band_names
else:
assert isinstance(bands, tuple), "The list of bands must be a tuple"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")
@ -236,7 +237,7 @@ class CV4AKenyaCropType(VisionDataset):
@lru_cache
def _load_all_image_tiles(
self, tile_name_: str, bands: Optional[Tuple[str]] = None
self, tile_name_: str, bands: Optional[Tuple[str, ...]] = None
) -> np.ndarray:
"""Load all the imagery (across time) for a single _tile_. Optionally allows
for subsetting of the bands that are loaded.
@ -263,7 +264,7 @@ class CV4AKenyaCropType(VisionDataset):
@lru_cache
def _load_single_image_tile(
self, tile_name_: str, date_: str, bands: Optional[Tuple[str]] = None
self, tile_name_: str, date_: str, bands: Optional[Tuple[str, ...]] = None
) -> np.ndarray:
"""Loads the imagery for a single tile for a single date. Optionally allows
for subsetting of the bands that are loaded."""