Added some relevant functionality to StreamingDatasets
This commit is contained in:
Родитель
b1e832245c
Коммит
bf02d81add
|
@ -12,7 +12,7 @@ from torch.utils.data.dataset import IterableDataset
|
|||
|
||||
class StreamingGeospatialDataset(IterableDataset):
|
||||
|
||||
def __init__(self, imagery_fns, label_fns=None, groups=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, transform=None, verbose=False):
|
||||
def __init__(self, imagery_fns, label_fns=None, groups=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, image_transform=None, label_transform=None, nodata_check=None, verbose=False):
|
||||
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
|
||||
|
||||
Args:
|
||||
|
@ -22,7 +22,9 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
chip_size: Desired size of chips (in pixels).
|
||||
num_chips_per_tile: Desired number of chips to sample for each tile.
|
||||
windowed_sampling: Flag indicating whether we should sample each chip with a read using `rasterio.windows.Window` or whether we should read the whole tile into memory, then sample chips.
|
||||
transform: A function to apply to each chip object. If this is `None`, then the only transformation applied to the loaded imagery will be to convert it to a `torch.Tensor`. If this is not `None`, then the function should return a `Torch.tensor`. Further, if `groups` is not `None` then the transform function should expect the imagery as the first argument and the group as the second argument.
|
||||
image_transform: A function to apply to each image chip object. If this is `None`, then the only transformation applied to the loaded imagery will be to convert it to a `torch.Tensor`. If this is not `None`, then the function should return a `Torch.tensor`. Further, if `groups` is not `None` then the transform function should expect the imagery as the first argument and the group as the second argument.
|
||||
label_transform: Similar to image_transform, but applied to label chips.
|
||||
nodata_check: A method that will check an `(image_chip)` or `(image_chip, label_chip)` (if `label_fns` are provided) and return whether or not the chip should be skipped. This can be used, for example, to skip chips that contain nodata values.
|
||||
verbose: If `False` we will be quiet.
|
||||
"""
|
||||
|
||||
|
@ -39,7 +41,10 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
self.num_chips_per_tile = num_chips_per_tile
|
||||
self.windowed_sampling = windowed_sampling
|
||||
|
||||
self.transform = transform
|
||||
self.image_transform = image_transform
|
||||
self.label_transform = label_transform
|
||||
self.nodata_check = nodata_check
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
|
@ -94,10 +99,10 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
img_fp = rasterio.open(img_fn, "r")
|
||||
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
|
||||
|
||||
height, width = img_fp.shape
|
||||
height, width = img_fp.shape
|
||||
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
|
||||
t_height, t_width = label_fp.shape
|
||||
assert height == t_height and width == t_width
|
||||
assert height == t_height and width == t_width
|
||||
|
||||
try:
|
||||
# If we aren't in windowed sampling mode then we should read the entire tile up front
|
||||
|
@ -112,7 +117,9 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
x = np.random.randint(0, width-self.chip_size)
|
||||
y = np.random.randint(0, height-self.chip_size)
|
||||
|
||||
# Read imagery
|
||||
# Read imagery / labels
|
||||
img = None
|
||||
labels = None
|
||||
if self.windowed_sampling:
|
||||
img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
|
||||
if self.use_labels:
|
||||
|
@ -122,25 +129,43 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
if self.use_labels:
|
||||
labels = label_data[y:y+self.chip_size, x:x+self.chip_size]
|
||||
|
||||
# TODO: check for nodata and throw away the chip if necessary. Not sure how to do this in a dataset independent way.
|
||||
|
||||
if self.use_labels:
|
||||
labels = torch.from_numpy(labels).squeeze()
|
||||
|
||||
if self.transform is not None:
|
||||
if self.groups is None:
|
||||
img = self.transform(img)
|
||||
# Check for no data
|
||||
if self.nodata_check is not None:
|
||||
if self.use_labels:
|
||||
skip_chip = self.nodata_check(img, labels)
|
||||
else:
|
||||
img = self.transform(img, group)
|
||||
else:
|
||||
img = torch.from_numpy(img)
|
||||
skip_chip = self.nodata_check(img)
|
||||
|
||||
if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
|
||||
continue
|
||||
|
||||
# Transform the imagery
|
||||
if self.image_transform is not None:
|
||||
if self.groups is None:
|
||||
img = self.image_transform(img)
|
||||
else:
|
||||
img = self.image_transform(img, group)
|
||||
else:
|
||||
img = torch.from_numpy(img).squeeze()
|
||||
|
||||
# Transform the labels
|
||||
if self.use_labels:
|
||||
if self.label_transform is not None:
|
||||
if self.groups is None:
|
||||
labels = self.label_transform(labels)
|
||||
else:
|
||||
labels = self.label_transform(labels, group)
|
||||
else:
|
||||
labels = torch.from_numpy(labels).squeeze()
|
||||
|
||||
|
||||
# Note, that img should be a torch "Double" type (i.e. a np.float32) and labels should be a torch "Long" type (i.e. np.int64)
|
||||
if self.use_labels:
|
||||
yield img, labels
|
||||
else:
|
||||
yield img
|
||||
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
|
||||
print("Reading %s failed, skipping..." % (fn))
|
||||
print("Reading %s failed, skipping..." % (img_fn))
|
||||
|
||||
# Close file pointers
|
||||
img_fp.close()
|
||||
|
|
Загрузка…
Ссылка в новой задаче