Modify the inference script to batch chips when scoring. Code is duplicated between the inference script and ModelSession implementation.
This commit is contained in:
Родитель
78b3e0cfeb
Коммит
8aeff2123c
|
@ -39,11 +39,11 @@ sh run_training_local.sh
|
|||
|
||||
# modify selection of tiles at the top of inference.py
|
||||
|
||||
checkpoint_path=/disk/wcs/wcs_coarse_baseline_0_wrong_val_viz/outputs/wcs_coarse_baseline/checkpoints/model_best.pth.tar
|
||||
checkpoint_path=/home/boto/wcs/mnt/wcs-orinoquia/useful_checkpoints/wcs_coarse_baseline/wcs_coarse_baseline_epoch93_best.pth.tar
|
||||
|
||||
out_dir=/home/<username>/wcs/mnt/wcs-orinoquia/delivered/20200715/results_coarse_baseline_201920
|
||||
out_dir=/home/boto/wcs/mnt/wcs-orinoquia/delivered/20201221_timepoints/2017_2018
|
||||
|
||||
python training/inference.py --config_module_path training/experiments/coarse_baseline/coarse_baseline_config_refactored.py --checkpoint_path ${checkpoint_path} --out_dir ${out_dir} --output_softmax
|
||||
python training_wcs/scripts/inference.py --config_module_path training_wcs/experiments/coarse_baseline/coarse_baseline_config.py --checkpoint_path ${checkpoint_path} --out_dir ${out_dir}
|
||||
|
||||
|
||||
# Post-processing
|
||||
|
|
|
@ -136,6 +136,66 @@ Export.image.toDrive({
|
|||
});
|
||||
|
||||
|
||||
/* All months 2015, 2016. Surface reflectance, median composite */
|
||||
|
||||
var sr_images = ee.ImageCollection((landsat_8_sr))
|
||||
.select(['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'B11', 'pixel_qa'])
|
||||
// Filter to get only images in the rough region outline
|
||||
.filterBounds(full_region)
|
||||
// Filter to get images within the first three years of Landsat 8
|
||||
.filterDate('2015-01-01', '2016-12-31')
|
||||
// Sort by scene cloudiness, ascending.
|
||||
.sort('CLOUD_COVER', false)
|
||||
// map is only available for ImageCollection; mosaic() or a composite reducer makes into an Image
|
||||
.map(maskL8sr);
|
||||
|
||||
print(sr_images); // we can only export Image, not ImageCollection
|
||||
|
||||
var sr_images = sr_images.median();
|
||||
print(sr_images);
|
||||
|
||||
Map.addLayer(sr_images, {bands: ['B4', 'B3', 'B2'], min: 0, max: 3000, gamma: 1.4}, 'L8 SR');
|
||||
|
||||
// Export over full region
|
||||
Export.image.toDrive({
|
||||
image: sr_images,
|
||||
description: 'wcs_orinoquia_sr_median_2015_2016',
|
||||
folder: 'wcs_orinoquia',
|
||||
scale: 30,
|
||||
region: full_region,
|
||||
maxPixels: 651523504
|
||||
});
|
||||
|
||||
|
||||
/* All months 2017, 2018. Surface reflectance, median composite */
|
||||
|
||||
var sr_images = ee.ImageCollection((landsat_8_sr))
|
||||
.select(['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B10', 'B11', 'pixel_qa'])
|
||||
// Filter to get only images in the rough region outline
|
||||
.filterBounds(full_region)
|
||||
// Filter to get images within the first three years of Landsat 8
|
||||
.filterDate('2017-01-01', '2018-12-31')
|
||||
// Sort by scene cloudiness, ascending.
|
||||
.sort('CLOUD_COVER', false)
|
||||
// map is only available for ImageCollection; mosaic() or a composite reducer makes into an Image
|
||||
.map(maskL8sr);
|
||||
|
||||
print(sr_images); // we can only export Image, not ImageCollection
|
||||
|
||||
var sr_images = sr_images.median();
|
||||
print(sr_images);
|
||||
|
||||
Map.addLayer(sr_images, {bands: ['B4', 'B3', 'B2'], min: 0, max: 3000, gamma: 1.4}, 'L8 SR');
|
||||
|
||||
// Export over full region
|
||||
Export.image.toDrive({
|
||||
image: sr_images,
|
||||
description: 'wcs_orinoquia_sr_median_2017_2018',
|
||||
folder: 'wcs_orinoquia',
|
||||
scale: 30,
|
||||
region: full_region,
|
||||
maxPixels: 651523504
|
||||
});
|
||||
|
||||
|
||||
/* Elevation at 30m resolution from SRTM */
|
||||
|
@ -201,7 +261,6 @@ Export.image.toDrive({
|
|||
|
||||
/* Show outlines of regions */
|
||||
|
||||
// imports
|
||||
var trial_region =
|
||||
/* color: #d63000 */
|
||||
/* shown: false */
|
||||
|
|
|
@ -107,7 +107,7 @@ class TorchFineTuningOrinoquia(ModelSession):
|
|||
|
||||
# form batches
|
||||
batch_size = self.config.batch_size if self.config.batch_size is not None else 32
|
||||
batch = []
|
||||
batches = []
|
||||
batch_indices = [] # cache these to save recalculating when filling in model predictions
|
||||
|
||||
chip_size = self.config.chip_size
|
||||
|
@ -133,21 +133,21 @@ class TorchFineTuningOrinoquia(ModelSession):
|
|||
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
|
||||
chip = chip * sat_mask
|
||||
|
||||
batch.append(chip)
|
||||
batches.append(chip)
|
||||
|
||||
valid_row_end = row_start + min(prediction_window_size, height - row_idx * prediction_window_size)
|
||||
valid_col_end = col_start + min(prediction_window_size, width - col_idx * prediction_window_size)
|
||||
batch_indices.append(
|
||||
(row_start, valid_row_end, col_start, valid_col_end)) # explicit to be less confusing
|
||||
batch = np.array(batch) # (num_chips, channels, height, width)
|
||||
batches = np.array(batches) # (num_chips, channels, height, width)
|
||||
|
||||
# score chips in batches
|
||||
model_output = []
|
||||
model_features = [] # same spatial dims as model_output, but has 64 or 32 channels
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(batch), batch_size):
|
||||
t_batch = batch[i:i + batch_size]
|
||||
for i in range(0, len(batches), batch_size):
|
||||
t_batch = batches[i:i + batch_size]
|
||||
t_batch = torch.from_numpy(t_batch).to(self.device)
|
||||
|
||||
scores, features = self.model.forward(t_batch,
|
||||
|
|
|
@ -7,6 +7,7 @@ mapped June 29 2020.
|
|||
|
||||
import os
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
@ -315,12 +316,12 @@ def get_chip(tile_reader, chip_window, chip_for_display=True):
|
|||
return stacked
|
||||
|
||||
|
||||
def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
|
||||
def preprocess_tile(tile: Union[rasterio.DatasetReader, np.ndarray]) -> np.ndarray:
|
||||
"""Same functionality as get_chip(), but applies to a numpy array tile of arbitrary shape.
|
||||
Currently only used with the landcover tool.
|
||||
|
||||
Args:
|
||||
tile_array: A numpy array of dims (height, width, channels). Expect elevation to be the eleventh channel
|
||||
tile: A numpy array of dims (height, width, channels) or the rasterio reader.
|
||||
Expect elevation to be the eleventh channel
|
||||
|
||||
Returns:
|
||||
Numpy array representing of the preprocessed chip of dims (6, height, width) - note that channels is
|
||||
|
@ -330,7 +331,7 @@ def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
|
|||
for b in normal_bands:
|
||||
# getting one band at a time because min, max and gamma may be different
|
||||
band = ImageryVisualizer.show_landsat8_patch(
|
||||
tile_array,
|
||||
tile,
|
||||
bands=[b], # pass in a list to get the batch dimension in the results
|
||||
band_min=bands_normalization_params['min'][b],
|
||||
band_max=bands_normalization_params['max'][b],
|
||||
|
@ -339,16 +340,16 @@ def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
|
|||
)
|
||||
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
|
||||
|
||||
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_array) # 2D, dtype is float32
|
||||
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile) # 2D, dtype is float32
|
||||
bands_to_stack.append(ndvi)
|
||||
|
||||
# for the interactive tool, elevation is band 11 (1-indexed) or 10 (0-indexed), and already normalized
|
||||
# by elevation_standardization_params (could have done the normalization here too)
|
||||
elevation = tile_array[:, :, 10]
|
||||
elevation = tile[:, :, 10]
|
||||
bands_to_stack.append(elevation)
|
||||
|
||||
stacked = np.stack(bands_to_stack)
|
||||
|
||||
assert stacked.shape == (
|
||||
6, tile_array.shape[0], tile_array.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
|
||||
6, tile.shape[0], tile.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
|
||||
return stacked
|
||||
|
|
|
@ -17,6 +17,7 @@ from datetime import datetime
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# SPECIFY input_tiles as a list of absolute paths to imagery .tif files
|
||||
|
||||
|
@ -26,9 +27,9 @@ import torch
|
|||
# input_tiles = [os.path.join('/boto_disk_0/wcs_data/tiles/full_sr_median_2013_2014/tiles', i) for i in
|
||||
# val_tiles]
|
||||
|
||||
# all 2019 - 2020 tiles
|
||||
tiles_dir = '/home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2019_202004'
|
||||
input_tiles = [os.path.join(tiles_dir, i) for i in os.listdir(tiles_dir)]
|
||||
# all tiles for the two-year composite
|
||||
tiles_dir = '/home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2013_2014_dem'
|
||||
input_tiles = [os.path.join(tiles_dir, i) for i in os.listdir(tiles_dir) if i.endswith('.tif')]
|
||||
|
||||
|
||||
def write_colormap(tile_writer, config):
|
||||
|
@ -53,6 +54,12 @@ def main():
|
|||
default='./outputs',
|
||||
help='Path to a dir to put the prediction tiles.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
default=32,
|
||||
type=int,
|
||||
help='How many chips form one batch during scoring.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_softmax',
|
||||
action='store_true',
|
||||
|
@ -61,6 +68,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
assert os.path.exists(args.checkpoint_path), f'Checkpoint at {args.checkpoint_path} does not exist.'
|
||||
assert args.batch_size > 0, f'Invalid batch size: {args.batch_size}'
|
||||
|
||||
out_dir = args.out_dir
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
@ -99,164 +107,176 @@ def main():
|
|||
f"val accuracy is {checkpoint.get('val_acc', 'Not Available')}")
|
||||
|
||||
model = model.to(device=device)
|
||||
model.eval() # eval mode: norm or dropout layers will work in eval mode instead of training_wcs mode
|
||||
|
||||
with torch.no_grad(): # with autograd engine deactivated
|
||||
for i_file, input_tile_path in enumerate(input_tiles):
|
||||
# for each tile, form batches and score them; after scoring, write the output of each chip in the batch
|
||||
# to the output geotiff
|
||||
for i_file, input_tile_path in enumerate(input_tiles):
|
||||
|
||||
out_path_hardmax = os.path.join(out_dir, 'res_' + os.path.basename(input_tile_path))
|
||||
if os.path.exists(out_path_hardmax):
|
||||
print(f'Skipping already scored tile {out_path_hardmax}')
|
||||
continue
|
||||
out_path_hardmax = os.path.join(out_dir, 'res_' + os.path.basename(input_tile_path))
|
||||
if os.path.exists(out_path_hardmax):
|
||||
print(f'Skipping already scored tile {out_path_hardmax}')
|
||||
continue
|
||||
|
||||
print(f'Scoring input tile {i_file} out of {len(input_tiles)}, {input_tile_path}')
|
||||
print(f'Scoring input tile {i_file} out of {len(input_tiles)}, {input_tile_path}')
|
||||
|
||||
# dict_scores = {} # dict of window tuple to numpy array of scores
|
||||
# dict_scores = {} # dict of window tuple to numpy array of scores
|
||||
|
||||
# load entire tile into memory
|
||||
tile_reader = rasterio.open(input_tile_path)
|
||||
# load entire tile into memory
|
||||
tile_reader = rasterio.open(input_tile_path)
|
||||
|
||||
# use the get_chip function (to normalize the bands in a way that's consistent with training_wcs)
|
||||
# but get a chip that's the size of the tile - all at once
|
||||
whole_tile_window = (0, 0, tile_reader.width, tile_reader.height)
|
||||
data_array: np.ndarray = config.get_chip(tile_reader, whole_tile_window, chip_for_display=False)
|
||||
# set up hardmax prediction output tile
|
||||
tile_writer_hardmax = rasterio.open(
|
||||
out_path_hardmax,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=tile_reader.height,
|
||||
width=tile_reader.width,
|
||||
count=1, # only 1 "band", the hardmax predicted label at the pixel
|
||||
dtype=np.uint8,
|
||||
crs=tile_reader.crs,
|
||||
transform=tile_reader.transform,
|
||||
nodata=0,
|
||||
compress='lzw',
|
||||
blockxsize=prediction_window_size,
|
||||
# reads and writes are most efficient when the windows match the dataset’s own block structure
|
||||
blockysize=prediction_window_size
|
||||
)
|
||||
write_colormap(tile_writer_hardmax, config)
|
||||
|
||||
# pad by mirroring at the edges to facilitate predicting only on the center crop
|
||||
data_array = np.pad(data_array,
|
||||
[
|
||||
(0, 0), # only pad height and width
|
||||
(prediction_window_offset, prediction_window_offset), # height / rows
|
||||
(prediction_window_offset, prediction_window_offset) # width / cols
|
||||
],
|
||||
mode='symmetric')
|
||||
# set up the softmax output tile
|
||||
if args.output_softmax:
|
||||
out_path_softmax = os.path.join(out_dir, 'prob_' + os.path.basename(input_tile_path))
|
||||
# probabilities projected into RGB for intuitive viewing
|
||||
out_path_softmax_viz = os.path.join(out_dir, 'prob_viz_' + os.path.basename(input_tile_path))
|
||||
|
||||
# set up hardmax prediction output tile
|
||||
tile_writer_hardmax = rasterio.open(
|
||||
out_path_hardmax,
|
||||
tile_writer_softmax = rasterio.open(
|
||||
out_path_softmax,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=tile_reader.height,
|
||||
width=tile_reader.width,
|
||||
count=1, # only 1 "band", the hardmax predicted label at the pixel
|
||||
count=config.num_classes, # as many "bands" as there are classes to house the softmax probabilities
|
||||
# quantize probabilities scores so each can be stored as one byte instead of 4-byte float32
|
||||
dtype=np.uint8,
|
||||
crs=tile_reader.crs,
|
||||
transform=tile_reader.transform,
|
||||
nodata=0,
|
||||
compress='lzw',
|
||||
blockxsize=prediction_window_size,
|
||||
blockysize=prediction_window_size
|
||||
)
|
||||
tile_writer_softmax_viz = rasterio.open(
|
||||
out_path_softmax_viz,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=tile_reader.height,
|
||||
width=tile_reader.width,
|
||||
count=3, # RGB
|
||||
dtype=np.uint8,
|
||||
crs=tile_reader.crs,
|
||||
transform=tile_reader.transform,
|
||||
nodata=0,
|
||||
compress='lzw',
|
||||
blockxsize=prediction_window_size,
|
||||
# reads and writes are most efficient when the windows match the dataset’s own block structure
|
||||
blockysize=prediction_window_size
|
||||
)
|
||||
write_colormap(tile_writer_hardmax, config)
|
||||
|
||||
# set up the softmax output tile
|
||||
if args.output_softmax:
|
||||
out_path_softmax = os.path.join(out_dir, 'prob_' + os.path.basename(input_tile_path))
|
||||
# probabilities projected into RGB for intuitive viewing
|
||||
out_path_softmax_viz = os.path.join(out_dir, 'prob_viz_' + os.path.basename(input_tile_path))
|
||||
# use the get_chip function (to normalize the bands in a way that's consistent with training_wcs)
|
||||
# but get a chip that's the size of the tile - all at once
|
||||
whole_tile_window = (0, 0, tile_reader.width, tile_reader.height)
|
||||
data_array: np.ndarray = config.get_chip(tile_reader, whole_tile_window, chip_for_display=False)
|
||||
# TODO - could simplify to
|
||||
# data_array: np.ndarray = config.preprocess_tile(tile_reader)
|
||||
|
||||
tile_writer_softmax = rasterio.open(
|
||||
out_path_softmax,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=tile_reader.height,
|
||||
width=tile_reader.width,
|
||||
count=config.num_classes, # as many "bands" as there are classes to house the softmax probabilities
|
||||
# quantize probabilities scores so each can be stored as one byte instead of 4-byte float32
|
||||
dtype=np.uint8,
|
||||
crs=tile_reader.crs,
|
||||
transform=tile_reader.transform,
|
||||
nodata=0,
|
||||
compress='lzw',
|
||||
blockxsize=prediction_window_size,
|
||||
blockysize=prediction_window_size
|
||||
)
|
||||
tile_writer_softmax_viz = rasterio.open(
|
||||
out_path_softmax_viz,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=tile_reader.height,
|
||||
width=tile_reader.width,
|
||||
count=3, # RGB
|
||||
dtype=np.uint8,
|
||||
crs=tile_reader.crs,
|
||||
transform=tile_reader.transform,
|
||||
nodata=0,
|
||||
compress='lzw',
|
||||
blockxsize=prediction_window_size,
|
||||
blockysize=prediction_window_size
|
||||
)
|
||||
# pad by mirroring at the edges to facilitate predicting only on the center crop
|
||||
data_array = np.pad(data_array,
|
||||
[
|
||||
(0, 0), # only pad height and width
|
||||
(prediction_window_offset, prediction_window_offset), # height / rows
|
||||
(prediction_window_offset, prediction_window_offset) # width / cols
|
||||
],
|
||||
mode='symmetric')
|
||||
|
||||
# score the tile in windows
|
||||
num_rows = math.ceil(tile_reader.height / prediction_window_size)
|
||||
num_cols = math.ceil(tile_reader.width / prediction_window_size)
|
||||
# form batches
|
||||
batch_size = args.batch_size
|
||||
batches = []
|
||||
batch_indices = [] # cache these to save recalculating when filling in model predictions
|
||||
|
||||
for col_idx in range(num_cols):
|
||||
# score the tile with overlapping sliding windows
|
||||
num_rows = math.ceil(tile_reader.height / prediction_window_size)
|
||||
num_cols = math.ceil(tile_reader.width / prediction_window_size)
|
||||
|
||||
col_start = col_idx * prediction_window_size
|
||||
col_end = col_start + chip_size
|
||||
print('Forming batches...')
|
||||
for col_idx in tqdm(range(num_cols)):
|
||||
col_start = col_idx * prediction_window_size
|
||||
col_end = col_start + chip_size
|
||||
|
||||
for row_idx in range(num_rows):
|
||||
for row_idx in range(num_rows):
|
||||
row_start = row_idx * prediction_window_size
|
||||
row_end = row_start + chip_size
|
||||
|
||||
row_start = row_idx * prediction_window_size
|
||||
row_end = row_start + chip_size
|
||||
chip = data_array[:, row_start:row_end, col_start: col_end]
|
||||
# pad to (chip_size, chip_size)
|
||||
chip = np.pad(chip,
|
||||
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
|
||||
|
||||
chip = data_array[:, row_start:row_end, col_start: col_end]
|
||||
# processing it as the dataset loader _get_chip does
|
||||
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
|
||||
chip = chip * sat_mask
|
||||
|
||||
# pad to (chip_size, chip_size)
|
||||
chip = np.pad(chip,
|
||||
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
|
||||
batches.append(chip)
|
||||
|
||||
# processing it as the dataset loader _get_chip does
|
||||
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
|
||||
chip = chip * sat_mask
|
||||
valid_height = min(prediction_window_size, tile_reader.height - row_idx * prediction_window_size)
|
||||
valid_width = min(prediction_window_size, tile_reader.width - col_idx * prediction_window_size)
|
||||
batch_indices.append(
|
||||
(row_start, valid_height, col_start, valid_width)) # explicit to be less confusing
|
||||
|
||||
chip = np.expand_dims(chip, axis=0)
|
||||
chip = torch.FloatTensor(chip).to(device=device)
|
||||
batches = np.array(batches)
|
||||
|
||||
try:
|
||||
scores = model(chip) # these are scores before the final softmax
|
||||
except Exception as e:
|
||||
print(f'Exception in scoring loop model() application: {e}')
|
||||
print(f'Chip has shape {chip.shape}')
|
||||
sys.exit(1)
|
||||
# score chips in batches and write the results to the output geotiff
|
||||
print('Scoring batches...')
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(batches), batch_size)):
|
||||
t_batch = batches[i:i + batch_size]
|
||||
t_batch = torch.from_numpy(t_batch).to(device=device, dtype=config.dtype)
|
||||
|
||||
_, preds = scores.max(1)
|
||||
softmax_scores = torch.nn.functional.softmax(scores, dim=1)
|
||||
try:
|
||||
scores = model(t_batch) # these are scores before the final softmax
|
||||
except Exception as e:
|
||||
print(f'Exception in scoring loop model() application: {e}')
|
||||
print(f't_batch has shape {t_batch.shape}')
|
||||
sys.exit(1)
|
||||
|
||||
softmax_scores = softmax_scores.cpu().numpy() # (batch_size, num_classes, H, W)
|
||||
preds = preds.cpu().numpy().astype(np.uint8)
|
||||
_, preds = scores.max(1)
|
||||
preds = preds.cpu().numpy().astype(np.uint8)
|
||||
assert np.max(preds) < config.num_classes
|
||||
|
||||
assert np.max(preds) < config.num_classes
|
||||
softmax_scores = torch.nn.functional.softmax(scores,
|
||||
dim=1).cpu().numpy() # (batch_size, num_classes, H, W)
|
||||
|
||||
# write the output one chip at a time
|
||||
cur_batch_indices = batch_indices[i:i + batch_size]
|
||||
for chip_i, (row_start, valid_height, col_start, valid_width) in enumerate(cur_batch_indices):
|
||||
# model output needs to be cropped to the window so they can be written correctly into the tile
|
||||
|
||||
# same order as rasterio window: (col_off x, row_off y, width delta_x, height delta_y)
|
||||
valid_window_tup = (
|
||||
col_start,
|
||||
row_start,
|
||||
min(prediction_window_size, tile_reader.width - col_idx * prediction_window_size),
|
||||
min(prediction_window_size, tile_reader.height - row_idx * prediction_window_size)
|
||||
)
|
||||
|
||||
# preds has a batch dim here
|
||||
preds1 = preds[:,
|
||||
prediction_window_offset:prediction_window_offset + valid_window_tup[3],
|
||||
prediction_window_offset:prediction_window_offset + valid_window_tup[
|
||||
2]] # last dim is the inner most, x, width
|
||||
|
||||
# preds has a batch dim here; last dim is the inner most, x, width
|
||||
preds_chip = preds[chip_i,
|
||||
prediction_window_offset:prediction_window_offset + valid_height,
|
||||
prediction_window_offset:prediction_window_offset + valid_width]
|
||||
preds_chip = np.expand_dims(preds_chip, axis=0)
|
||||
# debug - print(f'col is {col_idx}, row is {row_idx}, valid_window_tup is {valid_window_tup}, preds shape: {preds.shape}, preds1 shape: {preds1.shape}')
|
||||
|
||||
window = rasterio.windows.Window(valid_window_tup[0], valid_window_tup[1],
|
||||
valid_window_tup[2], valid_window_tup[3])
|
||||
tile_writer_hardmax.write(preds1, window=window)
|
||||
window = rasterio.windows.Window(col_start, row_start,
|
||||
valid_width, valid_height)
|
||||
# print(f'window is {window}')
|
||||
tile_writer_hardmax.write(preds_chip, window=window)
|
||||
|
||||
if args.output_softmax:
|
||||
# cropping, e.g. from (1, 14, 256, 256) to (1, 14, 128, 128)
|
||||
softmax_scores = softmax_scores[:, :,
|
||||
prediction_window_offset:prediction_window_offset + valid_window_tup[3],
|
||||
prediction_window_offset:prediction_window_offset + valid_window_tup[2]]
|
||||
prediction_window_offset:prediction_window_offset + valid_height,
|
||||
prediction_window_offset:prediction_window_offset + valid_width]
|
||||
|
||||
# get rid of batch dim. First dim for TIFF writer needs to be number of bands to write
|
||||
softmax_scores = softmax_scores.squeeze()
|
||||
|
@ -278,7 +298,8 @@ def main():
|
|||
|
||||
tile_writer_softmax_viz.write(softmax_scores_proj, window=window)
|
||||
|
||||
tile_writer_hardmax.close()
|
||||
tile_writer_hardmax.close()
|
||||
if args.output_softmax:
|
||||
tile_writer_softmax.close()
|
||||
tile_writer_softmax_viz.close()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче