Small changes to the geospatial datasets
This commit is contained in:
Родитель
bf02d81add
Коммит
e64ff96e94
|
@ -16,7 +16,7 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
|
||||
|
||||
Args:
|
||||
imagery_fns: A list of filenames (or web addresses -- anything that `rasterio.open()` can read) pointing to imagery tiles.
|
||||
imagery_fns: A list of filenames (or URLS -- anything that `rasterio.open()` can read) pointing to imagery tiles.
|
||||
label_fns: A list of filenames of the same size as `imagery_fns` pointing to label mask tiles or `None` if the Dataset should operate in "imagery only mode". Note that we expect `imagery_fns[i]` and `label_fns[i]` to have the same dimension and coordinate system.
|
||||
groups: Optional: A list of integers of the same size as `imagery_fns` that gives the "group" membership of each tile. This can be used to normalize imagery from different groups differently.
|
||||
chip_size: Desired size of chips (in pixels).
|
||||
|
@ -40,17 +40,17 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
self.chip_size = chip_size
|
||||
self.num_chips_per_tile = num_chips_per_tile
|
||||
self.windowed_sampling = windowed_sampling
|
||||
|
||||
|
||||
self.image_transform = image_transform
|
||||
self.label_transform = label_transform
|
||||
self.nodata_check = nodata_check
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
if self.verbose:
|
||||
print("Constructed StreamingGeospatialDataset")
|
||||
|
||||
def stream_tile_fns(self):
|
||||
|
||||
def stream_tile_fns(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None: # In this case we are not loading through a DataLoader with multiple workers
|
||||
worker_id = 0
|
||||
|
@ -58,7 +58,7 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
else:
|
||||
worker_id = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
|
||||
|
||||
# We only want to shuffle the order we traverse the files if we are the first worker (else, every worker will shuffle the files...)
|
||||
if worker_id == 0:
|
||||
np.random.shuffle(self.fns) # in place
|
||||
|
@ -68,14 +68,14 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
|
||||
if self.verbose:
|
||||
print("Creating a filename stream for worker %d" % (worker_id))
|
||||
|
||||
|
||||
# This logic splits up the list of filenames into `num_workers` chunks. Each worker will recieve ceil(num_filenames / num_workers) filenames to generate chips from. If the number of workers doesn't divide the number of filenames evenly then the last worker will have fewer filenames.
|
||||
N = len(self.fns)
|
||||
num_files_per_worker = int(np.ceil(N / num_workers))
|
||||
lower_idx = worker_id * num_files_per_worker
|
||||
upper_idx = min(N, (worker_id+1) * num_files_per_worker)
|
||||
for idx in range(lower_idx, upper_idx):
|
||||
|
||||
|
||||
label_fn = None
|
||||
if self.use_labels:
|
||||
img_fn, label_fn = self.fns[idx]
|
||||
|
@ -89,12 +89,14 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
|
||||
if self.verbose:
|
||||
print("Worker %d, yielding file %d" % (worker_id, idx))
|
||||
|
||||
|
||||
yield (img_fn, label_fn, group)
|
||||
|
||||
|
||||
def stream_chips(self):
|
||||
for img_fn, label_fn, group in self.stream_tile_fns():
|
||||
|
||||
num_skipped_chips = 0
|
||||
|
||||
# Open file pointers
|
||||
img_fp = rasterio.open(img_fn, "r")
|
||||
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
|
||||
|
@ -103,7 +105,7 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
|
||||
t_height, t_width = label_fp.shape
|
||||
assert height == t_height and width == t_width
|
||||
|
||||
|
||||
try:
|
||||
# If we aren't in windowed sampling mode then we should read the entire tile up front
|
||||
if not self.windowed_sampling:
|
||||
|
@ -137,6 +139,7 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
skip_chip = self.nodata_check(img)
|
||||
|
||||
if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
|
||||
num_skipped_chips += 1
|
||||
continue
|
||||
|
||||
# Transform the imagery
|
||||
|
@ -165,13 +168,16 @@ class StreamingGeospatialDataset(IterableDataset):
|
|||
else:
|
||||
yield img
|
||||
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
|
||||
print("Reading %s failed, skipping..." % (img_fn))
|
||||
|
||||
print("WARNING: Reading %s failed, skipping..." % (img_fn))
|
||||
|
||||
# Close file pointers
|
||||
img_fp.close()
|
||||
if self.use_labels:
|
||||
label_fp.close()
|
||||
|
||||
|
||||
if num_skipped_chips>0 and self.verbose:
|
||||
print("We skipped %d chips on %s" % (img_fn))
|
||||
|
||||
def __iter__(self):
|
||||
if self.verbose:
|
||||
print("Creating a new StreamingGeospatialDataset iterator")
|
||||
|
|
|
@ -43,7 +43,7 @@ class TileInferenceDataset(Dataset):
|
|||
for y in list(range(0, height - self.chip_size, stride)) + [height - self.chip_size]:
|
||||
for x in list(range(0, width - self.chip_size, stride)) + [width - self.chip_size]:
|
||||
self.chip_coordinates.append((y,x))
|
||||
self.num_chips = len(self.chip_coordinates)
|
||||
self.num_chips = len(self.chip_coordinates)
|
||||
|
||||
if self.verbose:
|
||||
print("Constructed TileInferenceDataset -- we have %d by %d file with %d channels with a dtype of %s. We are sampling %d chips from it." % (
|
||||
|
|
Загрузка…
Ссылка в новой задаче