Modify the inference script to batch chips when scoring. Code is duplicated between the inference script and ModelSession implementation.

This commit is contained in:
Siyu Yang 2020-12-22 15:06:17 -08:00
Родитель 78b3e0cfeb
Коммит 8aeff2123c
5 изменённых файлов: 216 добавлений и 135 удалений

Просмотреть файл

@ -39,11 +39,11 @@ sh run_training_local.sh
# modify selection of tiles at the top of inference.py
checkpoint_path=/disk/wcs/wcs_coarse_baseline_0_wrong_val_viz/outputs/wcs_coarse_baseline/checkpoints/model_best.pth.tar
checkpoint_path=/home/boto/wcs/mnt/wcs-orinoquia/useful_checkpoints/wcs_coarse_baseline/wcs_coarse_baseline_epoch93_best.pth.tar
out_dir=/home/<username>/wcs/mnt/wcs-orinoquia/delivered/20200715/results_coarse_baseline_201920
out_dir=/home/boto/wcs/mnt/wcs-orinoquia/delivered/20201221_timepoints/2017_2018
python training/inference.py --config_module_path training/experiments/coarse_baseline/coarse_baseline_config_refactored.py --checkpoint_path ${checkpoint_path} --out_dir ${out_dir} --output_softmax
python training_wcs/scripts/inference.py --config_module_path training_wcs/experiments/coarse_baseline/coarse_baseline_config.py --checkpoint_path ${checkpoint_path} --out_dir ${out_dir}
# Post-processing

Просмотреть файл

@ -136,6 +136,66 @@ Export.image.toDrive({
});
/* All months 2015, 2016. Surface reflectance, median composite */
var sr_images = ee.ImageCollection((landsat_8_sr))
.select(['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'B11', 'pixel_qa'])
// Filter to get only images in the rough region outline
.filterBounds(full_region)
// Filter to get images within the first three years of Landsat 8
.filterDate('2015-01-01', '2016-12-31')
// Sort by scene cloudiness, ascending.
.sort('CLOUD_COVER', false)
// map is only available for ImageCollection; mosaic() or a composite reducer makes into an Image
.map(maskL8sr);
print(sr_images); // we can only export Image, not ImageCollection
var sr_images = sr_images.median();
print(sr_images);
Map.addLayer(sr_images, {bands: ['B4', 'B3', 'B2'], min: 0, max: 3000, gamma: 1.4}, 'L8 SR');
// Export over full region
Export.image.toDrive({
image: sr_images,
description: 'wcs_orinoquia_sr_median_2015_2016',
folder: 'wcs_orinoquia',
scale: 30,
region: full_region,
maxPixels: 651523504
});
/* All months 2017, 2018. Surface reflectance, median composite */
var sr_images = ee.ImageCollection((landsat_8_sr))
.select(['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'B11', 'pixel_qa'])
// Filter to get only images in the rough region outline
.filterBounds(full_region)
// Filter to get images within the first three years of Landsat 8
.filterDate('2017-01-01', '2018-12-31')
// Sort by scene cloudiness, ascending.
.sort('CLOUD_COVER', false)
// map is only available for ImageCollection; mosaic() or a composite reducer makes into an Image
.map(maskL8sr);
print(sr_images); // we can only export Image, not ImageCollection
var sr_images = sr_images.median();
print(sr_images);
Map.addLayer(sr_images, {bands: ['B4', 'B3', 'B2'], min: 0, max: 3000, gamma: 1.4}, 'L8 SR');
// Export over full region
Export.image.toDrive({
image: sr_images,
description: 'wcs_orinoquia_sr_median_2017_2018',
folder: 'wcs_orinoquia',
scale: 30,
region: full_region,
maxPixels: 651523504
});
/* Elevation at 30m resolution from SRTM */
@ -201,7 +261,6 @@ Export.image.toDrive({
/* Show outlines of regions */
// imports
var trial_region =
/* color: #d63000 */
/* shown: false */

Просмотреть файл

@ -107,7 +107,7 @@ class TorchFineTuningOrinoquia(ModelSession):
# form batches
batch_size = self.config.batch_size if self.config.batch_size is not None else 32
batch = []
batches = []
batch_indices = [] # cache these to save recalculating when filling in model predictions
chip_size = self.config.chip_size
@ -133,21 +133,21 @@ class TorchFineTuningOrinoquia(ModelSession):
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
chip = chip * sat_mask
batch.append(chip)
batches.append(chip)
valid_row_end = row_start + min(prediction_window_size, height - row_idx * prediction_window_size)
valid_col_end = col_start + min(prediction_window_size, width - col_idx * prediction_window_size)
batch_indices.append(
(row_start, valid_row_end, col_start, valid_col_end)) # explicit to be less confusing
batch = np.array(batch) # (num_chips, channels, height, width)
batches = np.array(batches) # (num_chips, channels, height, width)
# score chips in batches
model_output = []
model_features = [] # same spatial dims as model_output, but has 64 or 32 channels
self.model.eval()
with torch.no_grad():
for i in range(0, len(batch), batch_size):
t_batch = batch[i:i + batch_size]
for i in range(0, len(batches), batch_size):
t_batch = batches[i:i + batch_size]
t_batch = torch.from_numpy(t_batch).to(self.device)
scores, features = self.model.forward(t_batch,

Просмотреть файл

@ -7,6 +7,7 @@ mapped June 29 2020.
import os
import sys
from typing import Union
import numpy as np
import rasterio
@ -315,12 +316,12 @@ def get_chip(tile_reader, chip_window, chip_for_display=True):
return stacked
def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
def preprocess_tile(tile: Union[rasterio.DatasetReader, np.ndarray]) -> np.ndarray:
"""Same functionality as get_chip(), but applies to a numpy array tile of arbitrary shape.
Currently only used with the landcover tool.
Args:
tile_array: A numpy array of dims (height, width, channels). Expect elevation to be the eleventh channel
tile: A numpy array of dims (height, width, channels) or the rasterio reader.
Expect elevation to be the eleventh channel
Returns:
Numpy array representing of the preprocessed chip of dims (6, height, width) - note that channels is
@ -330,7 +331,7 @@ def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
for b in normal_bands:
# getting one band at a time because min, max and gamma may be different
band = ImageryVisualizer.show_landsat8_patch(
tile_array,
tile,
bands=[b], # pass in a list to get the batch dimension in the results
band_min=bands_normalization_params['min'][b],
band_max=bands_normalization_params['max'][b],
@ -339,16 +340,16 @@ def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
)
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_array) # 2D, dtype is float32
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile) # 2D, dtype is float32
bands_to_stack.append(ndvi)
# for the interactive tool, elevation is band 11 (1-indexed) or 10 (0-indexed), and already normalized
# by elevation_standardization_params (could have done the normalization here too)
elevation = tile_array[:, :, 10]
elevation = tile[:, :, 10]
bands_to_stack.append(elevation)
stacked = np.stack(bands_to_stack)
assert stacked.shape == (
6, tile_array.shape[0], tile_array.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
6, tile.shape[0], tile.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
return stacked

Просмотреть файл

@ -17,6 +17,7 @@ from datetime import datetime
import numpy as np
import rasterio
import torch
from tqdm import tqdm
# SPECIFY input_tiles as a list of absolute paths to imagery .tif files
@ -26,9 +27,9 @@ import torch
# input_tiles = [os.path.join('/boto_disk_0/wcs_data/tiles/full_sr_median_2013_2014/tiles', i) for i in
# val_tiles]
# all 2019 - 2020 tiles
tiles_dir = '/home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2019_202004'
input_tiles = [os.path.join(tiles_dir, i) for i in os.listdir(tiles_dir)]
# all tiles for the two-year composite
tiles_dir = '/home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2013_2014_dem'
input_tiles = [os.path.join(tiles_dir, i) for i in os.listdir(tiles_dir) if i.endswith('.tif')]
def write_colormap(tile_writer, config):
@ -53,6 +54,12 @@ def main():
default='./outputs',
help='Path to a dir to put the prediction tiles.'
)
parser.add_argument(
'--batch_size',
default=32,
type=int,
help='How many chips form one batch during scoring.'
)
parser.add_argument(
'--output_softmax',
action='store_true',
@ -61,6 +68,7 @@ def main():
args = parser.parse_args()
assert os.path.exists(args.checkpoint_path), f'Checkpoint at {args.checkpoint_path} does not exist.'
assert args.batch_size > 0, f'Invalid batch size: {args.batch_size}'
out_dir = args.out_dir
os.makedirs(out_dir, exist_ok=True)
@ -99,164 +107,176 @@ def main():
f"val accuracy is {checkpoint.get('val_acc', 'Not Available')}")
model = model.to(device=device)
model.eval() # eval mode: norm or dropout layers will work in eval mode instead of training_wcs mode
with torch.no_grad(): # with autograd engine deactivated
for i_file, input_tile_path in enumerate(input_tiles):
# for each tile, form batches and score them; after scoring, write the output of each chip in the batch
# to the output geotiff
for i_file, input_tile_path in enumerate(input_tiles):
out_path_hardmax = os.path.join(out_dir, 'res_' + os.path.basename(input_tile_path))
if os.path.exists(out_path_hardmax):
print(f'Skipping already scored tile {out_path_hardmax}')
continue
out_path_hardmax = os.path.join(out_dir, 'res_' + os.path.basename(input_tile_path))
if os.path.exists(out_path_hardmax):
print(f'Skipping already scored tile {out_path_hardmax}')
continue
print(f'Scoring input tile {i_file} out of {len(input_tiles)}, {input_tile_path}')
print(f'Scoring input tile {i_file} out of {len(input_tiles)}, {input_tile_path}')
# dict_scores = {} # dict of window tuple to numpy array of scores
# dict_scores = {} # dict of window tuple to numpy array of scores
# load entire tile into memory
tile_reader = rasterio.open(input_tile_path)
# load entire tile into memory
tile_reader = rasterio.open(input_tile_path)
# use the get_chip function (to normalize the bands in a way that's consistent with training_wcs)
# but get a chip that's the size of the tile - all at once
whole_tile_window = (0, 0, tile_reader.width, tile_reader.height)
data_array: np.ndarray = config.get_chip(tile_reader, whole_tile_window, chip_for_display=False)
# set up hardmax prediction output tile
tile_writer_hardmax = rasterio.open(
out_path_hardmax,
'w',
driver='GTiff',
height=tile_reader.height,
width=tile_reader.width,
count=1, # only 1 "band", the hardmax predicted label at the pixel
dtype=np.uint8,
crs=tile_reader.crs,
transform=tile_reader.transform,
nodata=0,
compress='lzw',
blockxsize=prediction_window_size,
# reads and writes are most efficient when the windows match the datasets own block structure
blockysize=prediction_window_size
)
write_colormap(tile_writer_hardmax, config)
# pad by mirroring at the edges to facilitate predicting only on the center crop
data_array = np.pad(data_array,
[
(0, 0), # only pad height and width
(prediction_window_offset, prediction_window_offset), # height / rows
(prediction_window_offset, prediction_window_offset) # width / cols
],
mode='symmetric')
# set up the softmax output tile
if args.output_softmax:
out_path_softmax = os.path.join(out_dir, 'prob_' + os.path.basename(input_tile_path))
# probabilities projected into RGB for intuitive viewing
out_path_softmax_viz = os.path.join(out_dir, 'prob_viz_' + os.path.basename(input_tile_path))
# set up hardmax prediction output tile
tile_writer_hardmax = rasterio.open(
out_path_hardmax,
tile_writer_softmax = rasterio.open(
out_path_softmax,
'w',
driver='GTiff',
height=tile_reader.height,
width=tile_reader.width,
count=1, # only 1 "band", the hardmax predicted label at the pixel
count=config.num_classes, # as many "bands" as there are classes to house the softmax probabilities
# quantize probabilities scores so each can be stored as one byte instead of 4-byte float32
dtype=np.uint8,
crs=tile_reader.crs,
transform=tile_reader.transform,
nodata=0,
compress='lzw',
blockxsize=prediction_window_size,
blockysize=prediction_window_size
)
tile_writer_softmax_viz = rasterio.open(
out_path_softmax_viz,
'w',
driver='GTiff',
height=tile_reader.height,
width=tile_reader.width,
count=3, # RGB
dtype=np.uint8,
crs=tile_reader.crs,
transform=tile_reader.transform,
nodata=0,
compress='lzw',
blockxsize=prediction_window_size,
# reads and writes are most efficient when the windows match the datasets own block structure
blockysize=prediction_window_size
)
write_colormap(tile_writer_hardmax, config)
# set up the softmax output tile
if args.output_softmax:
out_path_softmax = os.path.join(out_dir, 'prob_' + os.path.basename(input_tile_path))
# probabilities projected into RGB for intuitive viewing
out_path_softmax_viz = os.path.join(out_dir, 'prob_viz_' + os.path.basename(input_tile_path))
# use the get_chip function (to normalize the bands in a way that's consistent with training_wcs)
# but get a chip that's the size of the tile - all at once
whole_tile_window = (0, 0, tile_reader.width, tile_reader.height)
data_array: np.ndarray = config.get_chip(tile_reader, whole_tile_window, chip_for_display=False)
# TODO - could simplify to
# data_array: np.ndarray = config.preprocess_tile(tile_reader)
tile_writer_softmax = rasterio.open(
out_path_softmax,
'w',
driver='GTiff',
height=tile_reader.height,
width=tile_reader.width,
count=config.num_classes, # as many "bands" as there are classes to house the softmax probabilities
# quantize probabilities scores so each can be stored as one byte instead of 4-byte float32
dtype=np.uint8,
crs=tile_reader.crs,
transform=tile_reader.transform,
nodata=0,
compress='lzw',
blockxsize=prediction_window_size,
blockysize=prediction_window_size
)
tile_writer_softmax_viz = rasterio.open(
out_path_softmax_viz,
'w',
driver='GTiff',
height=tile_reader.height,
width=tile_reader.width,
count=3, # RGB
dtype=np.uint8,
crs=tile_reader.crs,
transform=tile_reader.transform,
nodata=0,
compress='lzw',
blockxsize=prediction_window_size,
blockysize=prediction_window_size
)
# pad by mirroring at the edges to facilitate predicting only on the center crop
data_array = np.pad(data_array,
[
(0, 0), # only pad height and width
(prediction_window_offset, prediction_window_offset), # height / rows
(prediction_window_offset, prediction_window_offset) # width / cols
],
mode='symmetric')
# score the tile in windows
num_rows = math.ceil(tile_reader.height / prediction_window_size)
num_cols = math.ceil(tile_reader.width / prediction_window_size)
# form batches
batch_size = args.batch_size
batches = []
batch_indices = [] # cache these to save recalculating when filling in model predictions
for col_idx in range(num_cols):
# score the tile with overlapping sliding windows
num_rows = math.ceil(tile_reader.height / prediction_window_size)
num_cols = math.ceil(tile_reader.width / prediction_window_size)
col_start = col_idx * prediction_window_size
col_end = col_start + chip_size
print('Forming batches...')
for col_idx in tqdm(range(num_cols)):
col_start = col_idx * prediction_window_size
col_end = col_start + chip_size
for row_idx in range(num_rows):
for row_idx in range(num_rows):
row_start = row_idx * prediction_window_size
row_end = row_start + chip_size
row_start = row_idx * prediction_window_size
row_end = row_start + chip_size
chip = data_array[:, row_start:row_end, col_start: col_end]
# pad to (chip_size, chip_size)
chip = np.pad(chip,
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
chip = data_array[:, row_start:row_end, col_start: col_end]
# processing it as the dataset loader _get_chip does
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
chip = chip * sat_mask
# pad to (chip_size, chip_size)
chip = np.pad(chip,
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
batches.append(chip)
# processing it as the dataset loader _get_chip does
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
chip = chip * sat_mask
valid_height = min(prediction_window_size, tile_reader.height - row_idx * prediction_window_size)
valid_width = min(prediction_window_size, tile_reader.width - col_idx * prediction_window_size)
batch_indices.append(
(row_start, valid_height, col_start, valid_width)) # explicit to be less confusing
chip = np.expand_dims(chip, axis=0)
chip = torch.FloatTensor(chip).to(device=device)
batches = np.array(batches)
try:
scores = model(chip) # these are scores before the final softmax
except Exception as e:
print(f'Exception in scoring loop model() application: {e}')
print(f'Chip has shape {chip.shape}')
sys.exit(1)
# score chips in batches and write the results to the output geotiff
print('Scoring batches...')
model.eval()
with torch.no_grad():
for i in tqdm(range(0, len(batches), batch_size)):
t_batch = batches[i:i + batch_size]
t_batch = torch.from_numpy(t_batch).to(device=device, dtype=config.dtype)
_, preds = scores.max(1)
softmax_scores = torch.nn.functional.softmax(scores, dim=1)
try:
scores = model(t_batch) # these are scores before the final softmax
except Exception as e:
print(f'Exception in scoring loop model() application: {e}')
print(f't_batch has shape {t_batch.shape}')
sys.exit(1)
softmax_scores = softmax_scores.cpu().numpy() # (batch_size, num_classes, H, W)
preds = preds.cpu().numpy().astype(np.uint8)
_, preds = scores.max(1)
preds = preds.cpu().numpy().astype(np.uint8)
assert np.max(preds) < config.num_classes
assert np.max(preds) < config.num_classes
softmax_scores = torch.nn.functional.softmax(scores,
dim=1).cpu().numpy() # (batch_size, num_classes, H, W)
# write the output one chip at a time
cur_batch_indices = batch_indices[i:i + batch_size]
for chip_i, (row_start, valid_height, col_start, valid_width) in enumerate(cur_batch_indices):
# model output needs to be cropped to the window so they can be written correctly into the tile
# same order as rasterio window: (col_off x, row_off y, width delta_x, height delta_y)
valid_window_tup = (
col_start,
row_start,
min(prediction_window_size, tile_reader.width - col_idx * prediction_window_size),
min(prediction_window_size, tile_reader.height - row_idx * prediction_window_size)
)
# preds has a batch dim here
preds1 = preds[:,
prediction_window_offset:prediction_window_offset + valid_window_tup[3],
prediction_window_offset:prediction_window_offset + valid_window_tup[
2]] # last dim is the inner most, x, width
# preds has a batch dim here; last dim is the inner most, x, width
preds_chip = preds[chip_i,
prediction_window_offset:prediction_window_offset + valid_height,
prediction_window_offset:prediction_window_offset + valid_width]
preds_chip = np.expand_dims(preds_chip, axis=0)
# debug - print(f'col is {col_idx}, row is {row_idx}, valid_window_tup is {valid_window_tup}, preds shape: {preds.shape}, preds1 shape: {preds1.shape}')
window = rasterio.windows.Window(valid_window_tup[0], valid_window_tup[1],
valid_window_tup[2], valid_window_tup[3])
tile_writer_hardmax.write(preds1, window=window)
window = rasterio.windows.Window(col_start, row_start,
valid_width, valid_height)
# print(f'window is {window}')
tile_writer_hardmax.write(preds_chip, window=window)
if args.output_softmax:
# cropping, e.g. from (1, 14, 256, 256) to (1, 14, 128, 128)
softmax_scores = softmax_scores[:, :,
prediction_window_offset:prediction_window_offset + valid_window_tup[3],
prediction_window_offset:prediction_window_offset + valid_window_tup[2]]
prediction_window_offset:prediction_window_offset + valid_height,
prediction_window_offset:prediction_window_offset + valid_width]
# get rid of batch dim. First dim for TIFF writer needs to be number of bands to write
softmax_scores = softmax_scores.squeeze()
@ -278,7 +298,8 @@ def main():
tile_writer_softmax_viz.write(softmax_scores_proj, window=window)
tile_writer_hardmax.close()
tile_writer_hardmax.close()
if args.output_softmax:
tile_writer_softmax.close()
tile_writer_softmax_viz.close()