Added some relevant functionality to StreamingDatasets

This commit is contained in:
Caleb Robinson 2020-12-17 23:08:25 +00:00
Родитель b1e832245c
Коммит bf02d81add
1 изменённых файлов: 43 добавлений и 18 удалений

Просмотреть файл

@ -12,7 +12,7 @@ from torch.utils.data.dataset import IterableDataset
class StreamingGeospatialDataset(IterableDataset):
def __init__(self, imagery_fns, label_fns=None, groups=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, transform=None, verbose=False):
def __init__(self, imagery_fns, label_fns=None, groups=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, image_transform=None, label_transform=None, nodata_check=None, verbose=False):
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
Args:
@ -22,7 +22,9 @@ class StreamingGeospatialDataset(IterableDataset):
chip_size: Desired size of chips (in pixels).
num_chips_per_tile: Desired number of chips to sample for each tile.
windowed_sampling: Flag indicating whether we should sample each chip with a read using `rasterio.windows.Window` or whether we should read the whole tile into memory, then sample chips.
transform: A function to apply to each chip object. If this is `None`, then the only transformation applied to the loaded imagery will be to convert it to a `torch.Tensor`. If this is not `None`, then the function should return a `Torch.tensor`. Further, if `groups` is not `None` then the transform function should expect the imagery as the first argument and the group as the second argument.
image_transform: A function to apply to each image chip object. If this is `None`, then the only transformation applied to the loaded imagery will be to convert it to a `torch.Tensor`. If this is not `None`, then the function should return a `Torch.tensor`. Further, if `groups` is not `None` then the transform function should expect the imagery as the first argument and the group as the second argument.
label_transform: Similar to image_transform, but applied to label chips.
nodata_check: A method that will check an `(image_chip)` or `(image_chip, label_chip)` (if `label_fns` are provided) and return whether or not the chip should be skipped. This can be used, for example, to skip chips that contain nodata values.
verbose: If `False` we will be quiet.
"""
@ -39,7 +41,10 @@ class StreamingGeospatialDataset(IterableDataset):
self.num_chips_per_tile = num_chips_per_tile
self.windowed_sampling = windowed_sampling
self.transform = transform
self.image_transform = image_transform
self.label_transform = label_transform
self.nodata_check = nodata_check
self.verbose = verbose
if self.verbose:
@ -94,10 +99,10 @@ class StreamingGeospatialDataset(IterableDataset):
img_fp = rasterio.open(img_fn, "r")
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
height, width = img_fp.shape
height, width = img_fp.shape
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
t_height, t_width = label_fp.shape
assert height == t_height and width == t_width
assert height == t_height and width == t_width
try:
# If we aren't in windowed sampling mode then we should read the entire tile up front
@ -112,7 +117,9 @@ class StreamingGeospatialDataset(IterableDataset):
x = np.random.randint(0, width-self.chip_size)
y = np.random.randint(0, height-self.chip_size)
# Read imagery
# Read imagery / labels
img = None
labels = None
if self.windowed_sampling:
img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
if self.use_labels:
@ -122,25 +129,43 @@ class StreamingGeospatialDataset(IterableDataset):
if self.use_labels:
labels = label_data[y:y+self.chip_size, x:x+self.chip_size]
# TODO: check for nodata and throw away the chip if necessary. Not sure how to do this in a dataset independent way.
if self.use_labels:
labels = torch.from_numpy(labels).squeeze()
if self.transform is not None:
if self.groups is None:
img = self.transform(img)
# Check for no data
if self.nodata_check is not None:
if self.use_labels:
skip_chip = self.nodata_check(img, labels)
else:
img = self.transform(img, group)
else:
img = torch.from_numpy(img)
skip_chip = self.nodata_check(img)
if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
continue
# Transform the imagery
if self.image_transform is not None:
if self.groups is None:
img = self.image_transform(img)
else:
img = self.image_transform(img, group)
else:
img = torch.from_numpy(img).squeeze()
# Transform the labels
if self.use_labels:
if self.label_transform is not None:
if self.groups is None:
labels = self.label_transform(labels)
else:
labels = self.label_transform(labels, group)
else:
labels = torch.from_numpy(labels).squeeze()
# Note, that img should be a torch "Double" type (i.e. a np.float32) and labels should be a torch "Long" type (i.e. np.int64)
if self.use_labels:
yield img, labels
else:
yield img
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
print("Reading %s failed, skipping..." % (fn))
print("Reading %s failed, skipping..." % (img_fn))
# Close file pointers
img_fp.close()