Small changes to the geospatial datasets

This commit is contained in:
Caleb Robinson 2020-12-18 00:59:09 +00:00
Родитель bf02d81add
Коммит e64ff96e94
2 изменённых файлов: 21 добавлений и 15 удалений

Просмотреть файл

@ -16,7 +16,7 @@ class StreamingGeospatialDataset(IterableDataset):
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
Args:
imagery_fns: A list of filenames (or web addresses -- anything that `rasterio.open()` can read) pointing to imagery tiles.
imagery_fns: A list of filenames (or URLS -- anything that `rasterio.open()` can read) pointing to imagery tiles.
label_fns: A list of filenames of the same size as `imagery_fns` pointing to label mask tiles or `None` if the Dataset should operate in "imagery only mode". Note that we expect `imagery_fns[i]` and `label_fns[i]` to have the same dimension and coordinate system.
groups: Optional: A list of integers of the same size as `imagery_fns` that gives the "group" membership of each tile. This can be used to normalize imagery from different groups differently.
chip_size: Desired size of chips (in pixels).
@ -40,17 +40,17 @@ class StreamingGeospatialDataset(IterableDataset):
self.chip_size = chip_size
self.num_chips_per_tile = num_chips_per_tile
self.windowed_sampling = windowed_sampling
self.image_transform = image_transform
self.label_transform = label_transform
self.nodata_check = nodata_check
self.verbose = verbose
if self.verbose:
print("Constructed StreamingGeospatialDataset")
def stream_tile_fns(self):
def stream_tile_fns(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # In this case we are not loading through a DataLoader with multiple workers
worker_id = 0
@ -58,7 +58,7 @@ class StreamingGeospatialDataset(IterableDataset):
else:
worker_id = worker_info.id
num_workers = worker_info.num_workers
# We only want to shuffle the order we traverse the files if we are the first worker (else, every worker will shuffle the files...)
if worker_id == 0:
np.random.shuffle(self.fns) # in place
@ -68,14 +68,14 @@ class StreamingGeospatialDataset(IterableDataset):
if self.verbose:
print("Creating a filename stream for worker %d" % (worker_id))
# This logic splits up the list of filenames into `num_workers` chunks. Each worker will recieve ceil(num_filenames / num_workers) filenames to generate chips from. If the number of workers doesn't divide the number of filenames evenly then the last worker will have fewer filenames.
N = len(self.fns)
num_files_per_worker = int(np.ceil(N / num_workers))
lower_idx = worker_id * num_files_per_worker
upper_idx = min(N, (worker_id+1) * num_files_per_worker)
for idx in range(lower_idx, upper_idx):
label_fn = None
if self.use_labels:
img_fn, label_fn = self.fns[idx]
@ -89,12 +89,14 @@ class StreamingGeospatialDataset(IterableDataset):
if self.verbose:
print("Worker %d, yielding file %d" % (worker_id, idx))
yield (img_fn, label_fn, group)
def stream_chips(self):
for img_fn, label_fn, group in self.stream_tile_fns():
num_skipped_chips = 0
# Open file pointers
img_fp = rasterio.open(img_fn, "r")
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
@ -103,7 +105,7 @@ class StreamingGeospatialDataset(IterableDataset):
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
t_height, t_width = label_fp.shape
assert height == t_height and width == t_width
try:
# If we aren't in windowed sampling mode then we should read the entire tile up front
if not self.windowed_sampling:
@ -137,6 +139,7 @@ class StreamingGeospatialDataset(IterableDataset):
skip_chip = self.nodata_check(img)
if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
num_skipped_chips += 1
continue
# Transform the imagery
@ -165,13 +168,16 @@ class StreamingGeospatialDataset(IterableDataset):
else:
yield img
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
print("Reading %s failed, skipping..." % (img_fn))
print("WARNING: Reading %s failed, skipping..." % (img_fn))
# Close file pointers
img_fp.close()
if self.use_labels:
label_fp.close()
if num_skipped_chips>0 and self.verbose:
print("We skipped %d chips on %s" % (img_fn))
def __iter__(self):
if self.verbose:
print("Creating a new StreamingGeospatialDataset iterator")

Просмотреть файл

@ -43,7 +43,7 @@ class TileInferenceDataset(Dataset):
for y in list(range(0, height - self.chip_size, stride)) + [height - self.chip_size]:
for x in list(range(0, width - self.chip_size, stride)) + [width - self.chip_size]:
self.chip_coordinates.append((y,x))
self.num_chips = len(self.chip_coordinates)
self.num_chips = len(self.chip_coordinates)
if self.verbose:
print("Constructed TileInferenceDataset -- we have %d by %d file with %d channels with a dtype of %s. We are sampling %d chips from it." % (