Updating the dataset class
This commit is contained in:
Родитель
5331c860e9
Коммит
22cf561b5b
|
@ -10,42 +10,39 @@ import torch
|
|||
from torchvision import transforms
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
|
||||
class ClusterTileStreamingDataset(IterableDataset):
|
||||
class StreamingGeospatialDataset(IterableDataset):
|
||||
|
||||
def __init__(self, fns, means, stdevs, cluster_model, patch_size=256, num_patches_per_tile=200, windowed_sampling=False, infinite_sample=False, transform=None, verbose=False):
|
||||
self.fns = sorted(list(fns))
|
||||
def __init__(self, imagery_fns, label_fns=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, transform=None, verbose=False):
|
||||
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
|
||||
|
||||
Args:
|
||||
imagery_fns: A list of filenames (or web addresses -- anything that `rasterio.open()` can read) pointing to imagery tiles.
|
||||
label_fns: A list of filenames of the same size as `imagery_fns` pointing to label mask tiles or `None` if the Dataset should operate in "imagery only mode". Note that we expect `imagery_fns[i]` and `label_fns[i]` to have the same dimension and coordinate system.
|
||||
chip_size: Desired size of chips (in pixels).
|
||||
num_chips_per_tile: Desired number of chips to sample for each tile.
|
||||
windowed_sampling: Flag indicating whether we should sample each chip with a read using `rasterio.windows.Window` or whether we should read the whole tile into memory, then sample chips.
|
||||
transform: The torchvision.transform object to apply to each chip.
|
||||
verbose: If `False` we will be quiet.
|
||||
"""
|
||||
|
||||
if label_fns is None:
|
||||
self.fns = imagery_fns
|
||||
self.use_labels = False
|
||||
else:
|
||||
self.fns = list(zip(imagery_fns, label_fns))
|
||||
self.use_labels = True
|
||||
|
||||
self.chip_size = chip_size
|
||||
self.num_chips_per_tile = num_chips_per_tile
|
||||
self.windowed_sampling = windowed_sampling
|
||||
|
||||
# I pass the means and stdevs into the Dataset (instead of through a torchvision.transform) as I want
|
||||
# to standardize the images before using the cluster_model. After the image goes through the transform it will
|
||||
# be a torch tensor and not useable in the sklearn model.
|
||||
self.means = means
|
||||
self.stdevs = stdevs
|
||||
|
||||
# This is an sklearn.cluster.MiniBatchKMeans model
|
||||
self.cluster_model = cluster_model
|
||||
self.cluster_model.verbose = False
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.num_patches_per_tile = num_patches_per_tile
|
||||
self.windowed_sampling = windowed_sampling # if windowed_sampling is True, then we will crop chips from the tiles by using
|
||||
# rasterio's Window reader, if False, then we will read the entire raster into memory and crop chips by indexing.
|
||||
|
||||
self.infinite = infinite_sample
|
||||
self.transform = transform
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
print("Constructed ClusterTileStreamingDataset")
|
||||
|
||||
def stream_tile_fns(self):
|
||||
# This method is called every time we call something like an `enumerate(dataloader)` (e.g. which happens every epoch).
|
||||
# Shuffling the filenames at the beginning of this method will ensure that we don't traverse the same tiles in every worker.
|
||||
seed = torch.randint(low=0,high=2**32-1,size=(1,)).item()
|
||||
np.random.seed(seed) # when different workers spawn, they have the same numpy random seed...
|
||||
local_fns = list(self.fns)
|
||||
np.random.shuffle(local_fns)
|
||||
print("Constructed StreamingGeospatialDataset")
|
||||
|
||||
def stream_tile_fns(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None: # In this case we are not loading through a DataLoader with multiple workers
|
||||
worker_id = 0
|
||||
|
@ -54,69 +51,88 @@ class ClusterTileStreamingDataset(IterableDataset):
|
|||
worker_id = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
|
||||
# We only want to shuffle the order we traverse the files if we are the first worker (else, every worker will shuffle the files...)
|
||||
if worker_id == 0:
|
||||
np.random.shuffle(self.fns) # in place
|
||||
# NOTE: A warning, when different workers are created they will all have the same numpy random seed, however will have different torch random seeds. If you want to use numpy random functions, seed appropriately.
|
||||
#seed = torch.randint(low=0,high=2**32-1,size=(1,)).item()
|
||||
#np.random.seed(seed) # when different workers spawn, they have the same numpy random seed...
|
||||
|
||||
if self.verbose:
|
||||
print("Creating a filename stream for worker %d" % (worker_id))
|
||||
|
||||
N = len(local_fns)
|
||||
idx = 0
|
||||
if self.infinite:
|
||||
while True:
|
||||
if self.verbose:
|
||||
print("Worker %d, using %s" % (worker_id, local_fns[idx]))
|
||||
yield local_fns[idx]
|
||||
idx = (idx + 1) % N
|
||||
else:
|
||||
for fn in local_fns:
|
||||
if self.verbose:
|
||||
print("Worker %d, using %s" % (worker_id, fn))
|
||||
yield fn
|
||||
N = len(self.fns)
|
||||
num_files_per_worker = int(np.ceil(N / num_workers))
|
||||
lower_idx = worker_id * num_files_per_worker
|
||||
upper_idx = min(N, (worker_id+1) * num_files_per_worker)
|
||||
for idx in range(lower_idx, upper_idx):
|
||||
|
||||
label_fn = None
|
||||
if self.use_labels:
|
||||
img_fn, label_fn = self.fns[idx]
|
||||
else:
|
||||
img_fn = self.fns[idx]
|
||||
|
||||
if self.verbose:
|
||||
print("Worker %d, yielding file %d" % (worker_id, idx))
|
||||
|
||||
yield (img_fn, label_fn)
|
||||
|
||||
def stream_chips(self):
|
||||
for fn in self.stream_tile_fns():
|
||||
with rasterio.open(fn, "r") as f:
|
||||
height, width = f.shape
|
||||
try:
|
||||
if not self.windowed_sampling:
|
||||
data = np.rollaxis(f.read(), 0, 3)
|
||||
for img_fn, label_fn in self.stream_tile_fns():
|
||||
|
||||
# Open file pointers
|
||||
img_fp = rasterio.open(img_fn, "r")
|
||||
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
|
||||
|
||||
for i in range(self.num_patches_per_tile):
|
||||
x = np.random.randint(0, width-self.patch_size)
|
||||
y = np.random.randint(0, height-self.patch_size)
|
||||
height, width = img_fp.shape
|
||||
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
|
||||
t_height, t_width = label_fp.shape
|
||||
assert height == t_height and width == t_width
|
||||
|
||||
try:
|
||||
# If we aren't in windowed sampling mode then we should read the entire tile up front
|
||||
if not self.windowed_sampling:
|
||||
img_data = np.rollaxis(img_fp.read(), 0, 3)
|
||||
if self.use_labels:
|
||||
label_data = label_fp.read().squeeze() # assume the label geotiff has a single channel
|
||||
|
||||
if self.windowed_sampling:
|
||||
img = np.rollaxis(f.read(window=Window(x, y, self.patch_size, self.patch_size)), 0, 3)
|
||||
else:
|
||||
img = data[y:y+self.patch_size, x:x+self.patch_size, :]
|
||||
|
||||
# Throw away the chip if it has >50% of nodata
|
||||
num_nan = np.sum(np.sum(img == 0, axis=2) == 12)
|
||||
if num_nan / (self.patch_size * self.patch_size) > 0.5:
|
||||
continue
|
||||
img[np.isnan(img)] = 0
|
||||
for i in range(self.num_chips_per_tile):
|
||||
# Select the top left pixel of our chip randomly
|
||||
x = np.random.randint(0, width-self.chip_size)
|
||||
y = np.random.randint(0, height-self.chip_size)
|
||||
|
||||
# standardize
|
||||
img = (img - means) / stdevs
|
||||
img = img.astype(np.float32)
|
||||
if self.windowed_sampling:
|
||||
img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
|
||||
if self.use_labels:
|
||||
labels = label_fp.read(window=Window(x, y, self.chip_size, self.chip_size)).squeeze()
|
||||
else:
|
||||
img = img_data[y:y+self.chip_size, x:x+self.chip_size, :]
|
||||
if self.use_labels:
|
||||
labels = label_data[y:y+self.chip_size, x:x+self.chip_size]
|
||||
|
||||
# assign each pixel in the input a label based on its cluster index as determined by self.cluster_model
|
||||
targets = img.copy().reshape(-1,12)
|
||||
targets = self.cluster_model.predict(targets)
|
||||
targets = targets.reshape(self.patch_size, self.patch_size)
|
||||
targets = targets.astype(np.int64)
|
||||
targets = transforms.ToTensor()(targets).squeeze()
|
||||
# TODO: check for nodata and throw away the chip if necessary. Not sure how to do this in a dataset independent way.
|
||||
|
||||
if self.use_labels:
|
||||
labels = transforms.ToTensor()(labels).squeeze()
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
yield img, targets
|
||||
except RasterioIOError as e:
|
||||
print("Reading %s failed, skipping..." % (fn))
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.use_labels:
|
||||
yield img, labels
|
||||
else:
|
||||
yield img
|
||||
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
|
||||
print("Reading %s failed, skipping..." % (fn))
|
||||
|
||||
# Close file pointers
|
||||
img_fp.close()
|
||||
if self.use_labels:
|
||||
label_fp.close()
|
||||
|
||||
def __iter__(self):
|
||||
if self.verbose:
|
||||
print("Creating a new StreamingGeospatialDataset iterator")
|
||||
return iter(self.stream_chips())
|
||||
|
||||
def __len__(self):
|
||||
if self.infinite:
|
||||
return sys.maxint
|
||||
else:
|
||||
return len(self.fns) * self.num_patches_per_tile
|
||||
|
|
Загрузка…
Ссылка в новой задаче