Add files via upload
This commit is contained in:
Родитель
63dd898061
Коммит
8792466be8
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Authors: Mary Wahl, Kolya Malkin, Nebojsa Jojic
|
||||
#
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os, argparse, cntk, tifffile, warnings, osr
|
||||
from osgeo import gdal
|
||||
from gdalconst import *
|
||||
from mpl_toolkits.basemap import Basemap
|
||||
from collections import namedtuple
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Maps land use labels to colors
|
||||
color_map = np.asarray([[0,0,0],
|
||||
[0,0,1],
|
||||
[0,0.5,0],
|
||||
[0.5,1,0.5],
|
||||
[0.5,0.375,0.375]], dtype=np.float32)
|
||||
|
||||
|
||||
def load_image_pair(tile_name):
|
||||
''' Load the corresponding NAIP and LandCover images '''
|
||||
#with warnings.filterwarnings('ignore'):
|
||||
# With the currently-available training data, the tifffile package
|
||||
# generates these RuntimeWarnings and UserWarnings under normal
|
||||
# operating conditions:
|
||||
# - RuntimeWarning: py_decodelzw encountered unexpected end of stream
|
||||
# - UserWarning: unpack: string size must be a multiple of element size
|
||||
# - UserWarning: invalid tile data
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
naip_image = np.transpose(tifffile.imread(
|
||||
'{}_NAIP.tif'.format(tile_name))) / 256.0
|
||||
landcover_image = np.transpose(tifffile.imread(
|
||||
'{}_LandCover.tif'.format(tile_name)))
|
||||
landcover_image[landcover_image > 4] = 4
|
||||
return (naip_image, landcover_image)
|
||||
|
||||
|
||||
def find_pixel_from_latlon(img_filename, lat, lon):
|
||||
''' Find the indices for a point of interest '''
|
||||
img = gdal.Open(img_filename, GA_ReadOnly)
|
||||
img_proj = osr.SpatialReference()
|
||||
img_proj.ImportFromWkt(img.GetProjection())
|
||||
ulcrnrx, xstep, _, ulcrnry, _, ystep = img.GetGeoTransform()
|
||||
|
||||
world_map = Basemap(lat_0=0,
|
||||
lon_0=0,
|
||||
llcrnrlat=-90, urcrnrlat=90,
|
||||
llcrnrlon=-180, urcrnrlon=180,
|
||||
resolution='c', projection='stere')
|
||||
world_proj = osr.SpatialReference()
|
||||
world_proj.ImportFromProj4(world_map.proj4string)
|
||||
ct_to_img = osr.CoordinateTransformation(world_proj, img_proj)
|
||||
|
||||
xpos, ypos = world_map(lon, lat, inverse=False)
|
||||
xpos, ypos, _ = ct_to_img.TransformPoint(xpos, ypos)
|
||||
x = int((xpos - ulcrnrx) / xstep)
|
||||
y = int((ypos - ulcrnry) / ystep)
|
||||
|
||||
return(x,y)
|
||||
|
||||
|
||||
def save_naip_image(input_image, output_filename):
|
||||
#color_last = np.round(np.transpose(input_image, (1, 2, 0)) * 255, 0) \
|
||||
# .astype(np.uint8)
|
||||
color_last = np.transpose(input_image)
|
||||
tifffile.imsave(output_filename, color_last)
|
||||
return
|
||||
|
||||
|
||||
def save_label_image(input_image, output_filename, hard=True):
|
||||
num_labels, height, width = input_image.shape
|
||||
label_image = np.zeros((3, height, width))
|
||||
if hard:
|
||||
my_label_indices = input_image.argmax(axis=0)
|
||||
for label_idx in range(num_labels):
|
||||
for rgb_idx in range(3):
|
||||
label_image[rgb_idx, :, :] += (my_label_indices == label_idx) *\
|
||||
color_map[label_idx, rgb_idx]
|
||||
else:
|
||||
for label_idx in range(num_labels):
|
||||
for rgb_idx in range(3):
|
||||
label_image[rgb_idx, :, :] += input_image[label_idx, :, :] * \
|
||||
color_map[label_idx, rgb_idx]
|
||||
label_image = np.transpose(label_image).astype(np.float32)
|
||||
tifffile.imsave(output_filename, label_image)
|
||||
return
|
||||
|
||||
|
||||
def eval(input_filename, model_filename, output_dir, center_lat, center_lon,
|
||||
region_dim):
|
||||
''' Coordinates model evaluation '''
|
||||
model = cntk.load_model(model_filename)
|
||||
naip_image, true_lc_image = load_image_pair(
|
||||
input_filename.replace('_NAIP.tif', ''))
|
||||
|
||||
# Crop the input image and its true labels to the ROI. Include padding on
|
||||
# the NAIP image so that we have enough info to label the whole ROI.
|
||||
delta = int(region_dim / 2)
|
||||
padding = 64
|
||||
|
||||
center_x, center_y = find_pixel_from_latlon(input_filename, center_lat,
|
||||
center_lon)
|
||||
true_lc_image = true_lc_image[center_x - delta:center_x + delta,
|
||||
center_y - delta:center_y + delta].astype(np.float32)
|
||||
naip_image = naip_image[:,
|
||||
center_x - (delta + padding):center_x + delta + padding,
|
||||
center_y - (delta + padding):center_y + delta + padding].astype(
|
||||
np.float32)
|
||||
|
||||
# Iterate over the squares
|
||||
n_rows = int(region_dim / 128) # = n_cols, since the region is square
|
||||
pred_lc_image = np.zeros((5, true_lc_image.shape[0],
|
||||
true_lc_image.shape[1]))
|
||||
for row_idx in range(n_rows):
|
||||
for col_idx in range(n_rows):
|
||||
# Extract a 256 x 256 region from the NAIP image, to feed into the
|
||||
# model.
|
||||
sq_naip = naip_image[:,
|
||||
row_idx * 128:(row_idx * 128) + 256,
|
||||
col_idx * 128:(col_idx * 128) + 256]
|
||||
sq_pred_lc = np.squeeze(model.eval({model.arguments[0]: [sq_naip]}))
|
||||
pred_lc_image[:,
|
||||
row_idx * 128:(row_idx * 128) + 128,
|
||||
col_idx * 128:(col_idx * 128) + 128] = sq_pred_lc
|
||||
|
||||
# Save the extracted images in human-viewable form. Will drop the near-
|
||||
# infrared channel from the NAIP imagery so that it won't wind up being
|
||||
# rendered as transparency. Note that the true labels must be expanded up
|
||||
# to one-hot before using the same function to save them.
|
||||
save_naip_image(naip_image[:3, padding:-padding, padding:-padding],
|
||||
os.path.join(output_dir, 'NAIP.tif'))
|
||||
save_label_image(pred_lc_image, os.path.join(output_dir, 'pred_labels.tif'),
|
||||
hard=True)
|
||||
|
||||
temp = np.transpose(np.eye(5)[true_lc_image.astype(np.int32)], [2, 0, 1])
|
||||
save_label_image(temp, os.path.join(output_dir, 'true_labels.tif'),
|
||||
hard=True)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='''
|
||||
Applies a trained model to segment a subregion of an input NAIP image
|
||||
according to land use in five categories:
|
||||
- 0: No data
|
||||
- 1: Water
|
||||
- 2: Trees and shrubs
|
||||
- 3: Herbaceous vegetation
|
||||
- 4: Barren or Impervious (roads and other)
|
||||
Expects an input directory containing pairs of images with naming convention
|
||||
"[filename_base]_NAIP.tif" and "[filename_base]_LandCover.tif". Outputs the
|
||||
trained model and training checkpoints to a specified model directory (will
|
||||
load a checkpoint from this directory if a checkpoint is found there).
|
||||
''')
|
||||
parser.add_argument('-i', '--input_filename', type=str, required=True,
|
||||
help='Filepath to the input NAIP image')
|
||||
parser.add_argument('-m', '--model_filename', type=str, required=True,
|
||||
help='Filepath to the trained model')
|
||||
parser.add_argument('-o', '--output_dir', type=str, required=True,
|
||||
help='Directory where output will be written')
|
||||
parser.add_argument('-t', '--center_lat', type=float, required=True,
|
||||
help='The latitude at the center of the ROI')
|
||||
parser.add_argument('-n', '--center_lon', type=float, required=True,
|
||||
help='The longitude at the center of the ROI')
|
||||
parser.add_argument('-r', '--region_dim', type=int, required=False,
|
||||
default=1024,
|
||||
help='The side length of the ROI in pixels (meters)')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.exists(args.input_filename), \
|
||||
'Input file {} could not be accessed.'.format(args.input_filename)
|
||||
assert os.path.exists(args.model_filename), \
|
||||
'Model file {} could not be accessed.'.format(args.model_filename)
|
||||
assert args.region_dim % 128 == 0, \
|
||||
'Region dimension must be divisible by 128.'
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
eval(args.input_filename, args.model_filename, args.output_dir,
|
||||
args.center_lat, args.center_lon, args.region_dim)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Authors: Kolya Malkin, Nebojsa Jojic
|
||||
#
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import io
|
||||
import cntk
|
||||
from cntk.layers import *
|
||||
from cntk.initializer import *
|
||||
from cntk.ops import *
|
||||
|
||||
def conv_bn(input, filter_size, num_filters, strides=(1,1), init=uniform(0.00001)):
|
||||
c = Convolution(filter_size, num_filters, activation=None, init=init, pad=True, strides=strides, bias=False)(input)
|
||||
r = BatchNormalization(map_rank=1, normalization_time_constant=4096, use_cntk_engine=True)(c)
|
||||
return r
|
||||
|
||||
def conv_bn_relu(input, filter_size, num_filters, strides=(1,1), init=he_normal()):
|
||||
r = conv_bn(input, filter_size, num_filters, strides, init)
|
||||
return relu(r)
|
||||
|
||||
def resnet_basic(input, num_filters):
|
||||
c1 = conv_bn_relu(input, (3,3), num_filters)
|
||||
c2 = conv_bn(c1, (3,3), num_filters)
|
||||
p = c2 #+ input
|
||||
return relu(p)
|
||||
|
||||
def resnet_basic_inc(input, num_filters, strides=(2,2)):
|
||||
c1 = conv_bn_relu(input, (3,3), num_filters, strides)
|
||||
c2 = conv_bn(c1, (3,3), num_filters)
|
||||
c3 = conv_bn(input, (1,1), num_filters, strides)
|
||||
p = c2#c3 + c2
|
||||
return relu(p)
|
||||
|
||||
def resnet_basic_stack(input, num_stack_layers, num_filters):
|
||||
assert (num_stack_layers >= 0)
|
||||
l = input
|
||||
for i in range(num_stack_layers):
|
||||
l = resnet_basic(l, num_filters)
|
||||
return l
|
||||
|
||||
def resu_model(input, num_stack_layers, c_map, num_classes, block_size):
|
||||
r = cntk.slice(input, 0, 0, 1)
|
||||
g = cntk.slice(input, 0, 1, 2)
|
||||
b = cntk.slice(input, 0, 2, 3)
|
||||
i = cntk.slice(input, 0, 3, 4)
|
||||
|
||||
r -= reduce_mean(r)
|
||||
g -= reduce_mean(g)
|
||||
b -= reduce_mean(b)
|
||||
#i -= reduce_mean(i)
|
||||
|
||||
input_do = splice(splice(splice(r, g, axis=0), b, axis=0), i, axis=0)
|
||||
|
||||
conv = conv_bn(input_do, (3, 3), c_map[0])
|
||||
|
||||
r1 = resnet_basic_stack(conv, num_stack_layers, c_map[0])
|
||||
|
||||
r2_1 = resnet_basic_inc(r1, c_map[1])
|
||||
r2_2 = resnet_basic_stack(r2_1, num_stack_layers-1, c_map[1])
|
||||
|
||||
r3_1 = resnet_basic_inc(r2_2, c_map[2])
|
||||
r3_2 = resnet_basic_stack(r3_1, num_stack_layers-1, c_map[2])
|
||||
|
||||
r4_1 = resnet_basic_inc(r3_2, c_map[3])
|
||||
r4_2 = resnet_basic_stack(r4_1, num_stack_layers-1, c_map[3])
|
||||
|
||||
r4_us = layers.ConvolutionTranspose((3, 3), c_map[3], strides=2, output_shape=(block_size/4, block_size/4), pad=True, bias=False, init=bilinear(3, 3))(r4_2)
|
||||
|
||||
o3 = relu(layers.Convolution((1, 1), c_map[2])(r3_2) + layers.Convolution((1, 1), c_map[2])(r4_us))
|
||||
o3_us = layers.ConvolutionTranspose((3, 3), c_map[2], strides=2, output_shape=(block_size/2, block_size/2), pad=True, bias=False, init=bilinear(3, 3))(o3)
|
||||
|
||||
o2 = relu(layers.Convolution((1, 1), c_map[1])(r2_2) + layers.Convolution((1, 1), c_map[1])(o3_us))
|
||||
o2_us = layers.ConvolutionTranspose((3, 3), c_map[1], strides=2, output_shape=(block_size, block_size), pad=True, bias=False, init=bilinear(3, 3))(o2)
|
||||
|
||||
o1 = relu(layers.Convolution((3, 3), c_map[0], pad=True)(input_do) + layers.Convolution((1, 1), c_map[0])(r1) + layers.Convolution((1, 1), c_map[0])(o2_us))
|
||||
|
||||
return layers.Convolution((3, 3), num_classes, pad=True, activation=relu)(o1)
|
||||
|
||||
def model(c_classes, block_size, num_stack_layers, c_map):
|
||||
def tower(input):
|
||||
return resu_model(input, num_stack_layers, c_map, c_classes, block_size)
|
||||
return tower
|
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Authors: Mary Wahl, Kolya Malkin, Nebojsa Jojic
|
||||
#
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os, argparse, cntk, tifffile, model_mini_pub, warnings
|
||||
import cntk.train.distributed as distributed
|
||||
from cntk.train.training_session import CheckpointConfig, training_session
|
||||
|
||||
|
||||
def load_image_pair(tile_name):
|
||||
''' Load the corresponding NAIP and LandCover images '''
|
||||
#with warnings.filterwarnings('ignore'):
|
||||
# With the currently-available training data, the tifffile package
|
||||
# generates these RuntimeWarnings and UserWarnings under normal
|
||||
# operating conditions:
|
||||
# - RuntimeWarning: py_decodelzw encountered unexpected end of stream
|
||||
# - UserWarning: unpack: string size must be a multiple of element size
|
||||
# - UserWarning: invalid tile data
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
naip_image = np.transpose(tifffile.imread(
|
||||
'{}_NAIP.tif'.format(tile_name))) / 256.0
|
||||
landcover_image = np.transpose(tifffile.imread(
|
||||
'{}_LandCover.tif'.format(tile_name)))
|
||||
landcover_image[landcover_image > 4] = 4
|
||||
return (naip_image, landcover_image)
|
||||
|
||||
|
||||
def get_cropped_data(image, bounds, rescale=False):
|
||||
''' Crop out a subsection of an NAIP or LandCover image. Note that NAIP
|
||||
images have an extra axis (for color), use rescale=True. '''
|
||||
a, b, c, d = bounds
|
||||
if rescale:
|
||||
return(image[:, a : (a + c), b : (b + d)].astype(np.float32))
|
||||
else:
|
||||
return(image[a : (a + c), b : (b + d)].astype(np.int32))
|
||||
|
||||
|
||||
def interesting_patch(label_slice):
|
||||
''' Upsample less common labels '''
|
||||
w, h = label_slice.shape
|
||||
return ((label_slice == 1).sum() + \
|
||||
(label_slice == 4).sum() > 0.003 * w * h) \
|
||||
or (np.random.random_sample() > 0.5)
|
||||
|
||||
|
||||
class MyDataSource(cntk.io.UserMinibatchSource):
|
||||
''' A minibatch source for NAIP and label data '''
|
||||
def __init__(self, f_dim, l_dim, number_of_workers, input_dir,
|
||||
minibatches_per_image):
|
||||
''' Divvy up images between workers at initialization '''
|
||||
# Record the image dimensions for later
|
||||
self.f_dim, self.l_dim = f_dim, l_dim
|
||||
self.minibatches_per_image = minibatches_per_image
|
||||
self.num_color_channels, self.block_size, _ = self.f_dim
|
||||
self.num_landcover_classes, _, _ = self.l_dim
|
||||
|
||||
# Record the stream information
|
||||
self.fsi = cntk.io.StreamInformation(
|
||||
'features', 0, 'dense', np.float32, self.f_dim)
|
||||
self.lsi = cntk.io.StreamInformation(
|
||||
'labels', 1, 'dense', np.float32, self.l_dim)
|
||||
|
||||
# Create a transform for converting labels to one-hot
|
||||
self.x = cntk.input_variable((self.block_size, self.block_size))
|
||||
self.oh_tf = cntk.one_hot(self.x, self.num_landcover_classes, False,
|
||||
axis=0)
|
||||
|
||||
# Decide which tiles each worker will process
|
||||
self.tile_names = {}
|
||||
all_tiles = np.sort(
|
||||
[os.path.join(input_dir, i.replace('_NAIP.tif', '')) \
|
||||
for i in os.listdir(input_dir) if i.endswith('_NAIP.tif')])
|
||||
if number_of_workers > len(all_tiles):
|
||||
for i in range(number_of_workers):
|
||||
self.tile_names[i] = all_tiles[np.random.randint(
|
||||
len(all_tiles), size=(2))]
|
||||
else:
|
||||
for i, tile_subset in enumerate(np.array_split(np.sort(all_tiles),
|
||||
number_of_workers)):
|
||||
if len(tile_subset) > 5:
|
||||
tile_subset = tile_subset[:5]
|
||||
self.tile_names[i] = tile_subset
|
||||
|
||||
|
||||
self.current_mb_indices = dict(zip(range(number_of_workers),
|
||||
[0] * number_of_workers))
|
||||
self.current_image_indices = dict(zip(range(number_of_workers),
|
||||
[0] * number_of_workers))
|
||||
self.naip_images = [[]] * number_of_workers
|
||||
self.landcover_images = [[]] * number_of_workers
|
||||
self.already_loaded_images = [False] * number_of_workers
|
||||
|
||||
super(MyDataSource, self).__init__()
|
||||
|
||||
def stream_infos(self):
|
||||
return [self.fsi, self.lsi]
|
||||
|
||||
def next_minibatch(self, mb_size_in_samples, number_of_workers=1, worker_rank=0,
|
||||
device=None):
|
||||
''' Worker loads TIF images and extracts samples from them '''
|
||||
|
||||
if not self.already_loaded_images[worker_rank]:
|
||||
# It's time to load all images into memory. This can take time, so
|
||||
# we log our progress to stdout
|
||||
self.already_loaded_images[worker_rank] = True
|
||||
for i, tile_name in enumerate(self.tile_names[worker_rank]):
|
||||
try:
|
||||
naip_image, landcover_image = load_image_pair(tile_name)
|
||||
self.naip_images[worker_rank].append(naip_image)
|
||||
self.landcover_images[worker_rank].append(landcover_image)
|
||||
print('Worker {} loaded its {}th image'.format(
|
||||
worker_rank, i))
|
||||
except ValueError:
|
||||
print('Failed to load TIF pair: {}'.format(tile_name))
|
||||
pass
|
||||
print('Worker {} completed image loading'.format(worker_rank))
|
||||
|
||||
if self.current_mb_indices[worker_rank] == 0:
|
||||
# It's time to advance the image index
|
||||
self.current_image_indices[worker_rank] = (
|
||||
self.current_image_indices[worker_rank] + 1) % len(
|
||||
self.naip_images[worker_rank])
|
||||
idx = self.current_image_indices[worker_rank]
|
||||
|
||||
# Feature data have dimensions: num_color_channels x block size
|
||||
# x block size
|
||||
# Label data have dimensions: block_size x block_size
|
||||
features = np.zeros((mb_size_in_samples, self.num_color_channels,
|
||||
self.block_size, self.block_size),
|
||||
dtype=np.float32)
|
||||
labels = np.zeros((mb_size_in_samples, self.block_size,
|
||||
self.block_size), dtype=np.float32)
|
||||
|
||||
# Randomly select subsets of the image for training
|
||||
h, w = self.naip_images[worker_rank][idx].shape[1:]
|
||||
samples_retained = 0
|
||||
while samples_retained < mb_size_in_samples:
|
||||
i = np.random.randint(0, w - self.block_size)
|
||||
j = np.random.randint(0, h - self.block_size)
|
||||
bounds = (i, j, self.block_size, self.block_size)
|
||||
label_slice = get_cropped_data(
|
||||
self.landcover_images[worker_rank][idx], bounds, False)
|
||||
if interesting_patch(label_slice):
|
||||
features[samples_retained, :, :, :] = get_cropped_data(
|
||||
self.naip_images[worker_rank][idx], bounds, True)
|
||||
labels[samples_retained, :, :] = label_slice
|
||||
samples_retained += 1
|
||||
|
||||
# Convert the label data to one-hot, then convert arrays to Values
|
||||
f_data = cntk.Value(batch=features)
|
||||
l_data = cntk.Value(batch=self.oh_tf.eval({self.x: labels}))
|
||||
|
||||
result = {self.fsi: cntk.io.MinibatchData(
|
||||
f_data, mb_size_in_samples, mb_size_in_samples, False),
|
||||
self.lsi: cntk.io.MinibatchData(
|
||||
l_data, mb_size_in_samples, mb_size_in_samples, False)}
|
||||
|
||||
# Minibatch collection complete: update minibatch index so we know
|
||||
# how many more minibatches to collect using this TIFF pair
|
||||
self.current_mb_indices[worker_rank] = (1 +
|
||||
self.current_mb_indices[worker_rank]) % self.minibatches_per_image
|
||||
return(result)
|
||||
|
||||
|
||||
def center_square(output, block_size, padding):
|
||||
return(cntk.slice(cntk.slice(output, 1, padding, block_size - padding),
|
||||
2, padding, block_size - padding))
|
||||
|
||||
|
||||
def criteria(label, output, block_size, c_classes, weights):
|
||||
''' Define the loss function and metric '''
|
||||
probs = cntk.softmax(output, axis=0)
|
||||
log_probs = cntk.log(probs)
|
||||
ce = cntk.times(weights, -cntk.element_times(log_probs, label),
|
||||
output_rank=2)
|
||||
mean_ce = cntk.reduce_mean(ce)
|
||||
_, w, h = label.shape
|
||||
pe = cntk.classification_error(probs, label, axis=0) - \
|
||||
cntk.reduce_sum(cntk.slice(label, 0, 0, 1)) / cntk.reduce_sum(label)
|
||||
return(mean_ce, pe)
|
||||
|
||||
|
||||
def train(input_dir, output_dir, num_epochs):
|
||||
''' Coordinates model creation and training; minibatch creation '''
|
||||
num_landcover_classes = 5
|
||||
num_color_channels = 4
|
||||
block_size = 256
|
||||
padding = int(block_size / 4)
|
||||
|
||||
my_rank = distributed.Communicator.rank()
|
||||
number_of_workers = distributed.Communicator.num_workers()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# We extract 160 sample regions from an input image before moving along to
|
||||
# the next image file. Our epoch size is 16,000 samples.
|
||||
minibatch_size = 10
|
||||
minibatches_per_image = 160
|
||||
minibatches_per_epoch = 1600
|
||||
epoch_size = minibatch_size * minibatches_per_epoch
|
||||
|
||||
# Define the input variables
|
||||
f_dim = (num_color_channels, block_size, block_size)
|
||||
l_dim = (num_landcover_classes, block_size, block_size)
|
||||
feature = cntk.input_variable(f_dim, np.float32)
|
||||
label = cntk.input_variable(l_dim, np.float32)
|
||||
|
||||
# Define the minibatch source
|
||||
minibatch_source = MyDataSource(f_dim, l_dim, number_of_workers, input_dir,
|
||||
minibatches_per_image)
|
||||
input_map = {feature: minibatch_source.streams.features,
|
||||
label: minibatch_source.streams.labels}
|
||||
|
||||
# Define the model
|
||||
model = model_mini_pub.model(num_landcover_classes, block_size,
|
||||
2, [64, 32, 32, 32])(feature)
|
||||
|
||||
# Define the loss function and metric. Note that loss is not computed
|
||||
# directly on the model's output; the edges are first dropped.
|
||||
output = center_square(cntk.reshape(model,
|
||||
(num_landcover_classes, block_size,
|
||||
block_size)),
|
||||
block_size, padding)
|
||||
label_center = center_square(label, block_size, padding)
|
||||
mean_ce, pe = criteria(label_center, output, block_size,
|
||||
num_landcover_classes, [0.0, 1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Create the progress writer, learner, and trainer (which will be a
|
||||
# distributed trainer if number_of_workers > 1)
|
||||
progress_writers = [cntk.logging.progress_print.ProgressPrinter(
|
||||
tag='Training',
|
||||
num_epochs=num_epochs,
|
||||
freq=epoch_size,
|
||||
rank=my_rank)]
|
||||
|
||||
lr_per_mb = [0.0001] * 30 + [0.00001] * 30 + [0.000001]
|
||||
lr_per_sample = [lr / minibatch_size for lr in lr_per_mb]
|
||||
lr_schedule = cntk.learning_rate_schedule(lr_per_sample,
|
||||
epoch_size=epoch_size,
|
||||
unit=cntk.UnitType.sample)
|
||||
learner = cntk.rmsprop(model.parameters, lr_schedule, 0.95, 1.1, 0.9, 1.1,
|
||||
0.9, l2_regularization_weight=0.00001)
|
||||
|
||||
if number_of_workers > 1:
|
||||
parameter_learner = distributed.data_parallel_distributed_learner(
|
||||
learner, num_quantization_bits=32)
|
||||
trainer = cntk.Trainer(output, (mean_ce, pe), parameter_learner,
|
||||
progress_writers)
|
||||
else:
|
||||
trainer = cntk.Trainer(output, (mean_ce, pe), learner, progress_writers)
|
||||
|
||||
# Perform the training! Note that some progress output will be generated by
|
||||
# each of the workers.
|
||||
if my_rank == 0:
|
||||
print('Retraining model for {} epochs.'.format(num_epochs))
|
||||
print('Found {} workers'.format(number_of_workers))
|
||||
print('Printing progress every {} minibatches'.format(
|
||||
minibatches_per_epoch))
|
||||
cntk.logging.progress_print.log_number_of_parameters(model)
|
||||
training_session(
|
||||
trainer=trainer,
|
||||
max_samples=num_epochs * epoch_size,
|
||||
mb_source=minibatch_source,
|
||||
mb_size=minibatch_size,
|
||||
model_inputs_to_streams=input_map,
|
||||
checkpoint_config=CheckpointConfig(
|
||||
frequency=epoch_size,
|
||||
filename=os.path.join(output_dir, 'trained_checkpoint.model'),
|
||||
preserve_all=True),
|
||||
progress_frequency=epoch_size
|
||||
).train()
|
||||
|
||||
distributed.Communicator.finalize()
|
||||
if my_rank == 0:
|
||||
trainer.model.save(os.path.join(output_dir,
|
||||
'trained.model'))
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='''
|
||||
Trains a model to segment NAIP images according to land use in five categories:
|
||||
- 0: No data
|
||||
- 1: Water
|
||||
- 2: Trees and shrubs
|
||||
- 3: Herbaceous vegetation
|
||||
- 4: Barren or Impervious (roads and other)
|
||||
Expects an input directory containing pairs of images with naming convention
|
||||
"[filename_base]_NAIP.tif" and "[filename_base]_LandCover.tif". Outputs the
|
||||
trained model and training checkpoints to a specified model directory (will
|
||||
load a checkpoint from this directory if a checkpoint is found there).
|
||||
''')
|
||||
parser.add_argument('-i', '--input_dir', type=str, required=True,
|
||||
help='Directory containing all training image files.')
|
||||
parser.add_argument('-o', '--model_dir', type=str, required=True,
|
||||
help='Directory where model outputs will be stored.')
|
||||
parser.add_argument('-n', '--num_epochs', type=int, required=False,
|
||||
default=1,
|
||||
help='Specifies the number of epochs of training to ' +
|
||||
'be performed.')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.exists(args.input_dir), \
|
||||
'Input directory {} could not be accessed.'.format(args.input_dir)
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
assert args.num_epochs > 0, \
|
||||
'The number of epochs must be greater than zero'
|
||||
|
||||
train(args.input_dir, args.model_dir, args.num_epochs)
|
Загрузка…
Ссылка в новой задаче