Add finetuning and training code.

This commit is contained in:
Siyu Yang 2020-09-28 13:33:11 -07:00
Родитель 19cd9dbb4c
Коммит f19568dc53
11 изменённых файлов: 2435 добавлений и 0 удалений

3
.gitignore поставляемый
Просмотреть файл

@ -130,3 +130,6 @@ dmypy.json
# Mac file system
.DS_Store
# IDE / PyCharm
.idea/

545
evaluation/evaluate.ipynb Normal file
Просмотреть файл

@ -0,0 +1,545 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.interactiveshell import InteractiveShell\n",
"InteractiveShell.ast_node_interactivity = 'all' # default is last_expr\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('/Source/repos/GitHub_MSFT/landcover-orinoquiaa')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import pickle\n",
"from collections import defaultdict\n",
"\n",
"import rasterio\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
"from sklearn.calibration import calibration_curve\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"\n",
"from geospatial.visualization.raster_label_visualizer import RasterLabelVisualizer\n",
"\n",
"plt.rcParams['figure.figsize'] = (10.0, 10.0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# from data/tile_and_mask.py - which needs to be run in the Solaris env\n",
"\n",
"def get_lon_lat_from_tile_name(tile_name):\n",
" \"\"\"Returns _lon_lat\"\"\"\n",
" parts = tile_name.split('_')\n",
" lon_lat = f'_{parts[-2]}_{parts[-1].split(\".tif\")[0]}'\n",
" return lon_lat"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate a tiles of model predictions"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"viz_util= RasterLabelVisualizer('../constants/class_lists/wcs_coarse_label_map.json')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Empty of data',\n",
" 'Urban and infrastructure',\n",
" 'Agriculture',\n",
" 'Arboreal and forestry crops',\n",
" 'Pasture',\n",
" 'Vegetation',\n",
" 'Forest',\n",
" 'Savanna',\n",
" 'Sand, rocks and bare land',\n",
" 'Unavailable',\n",
" 'Swamp',\n",
" 'Water',\n",
" 'Seasonal savanna',\n",
" 'Seasonally flooded savanna']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label_names = sorted(viz_util.num_to_name.items(), key=lambda x: int(x[0]))\n",
"label_names = [i[1] for i in label_names]\n",
"label_names"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_paths = '/Data/WCS_land_use/delivered/20200701/results_coarse_baseline_201314'\n",
"\n",
"mask_paths = '/Data/WCS_land_use/train_full_region_median/tiles_masks_coarse'\n",
"\n",
"eval_saved_to = '/Data/WCS_land_use/train_full_region_median/result_val_analysis_coarse_baseline'\n",
"\n",
"num_classes = viz_util.num_classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tile_accuracies = {}\n",
"\n",
"cm = np.zeros((num_classes, num_classes), dtype=np.int64)\n",
"\n",
"true_counts = np.zeros((num_classes), dtype=np.int64)\n",
"pred_counts = np.zeros((num_classes), dtype=np.int64)\n",
"\n",
"classes_present_in_gt = set()\n",
"\n",
"for output_tile_fn in os.listdir(output_paths):\n",
" if not output_tile_fn.endswith('.tif'):\n",
" continue\n",
"# for output_tile_fn in ['res_wcs_orinoquia_sr_median_2013_2014-0000000000-0000022272_-68.962_6.593.tif']:\n",
" \n",
" output_tile_path = os.path.join(output_paths, output_tile_fn)\n",
" out_reader = rasterio.open(output_tile_path)\n",
" output_tile = np.array(Image.open(output_tile_path), dtype=np.uint8)\n",
" \n",
" # mask_-68.423_6.054.png\n",
" lon_lat = get_lon_lat_from_tile_name(output_tile_path)\n",
" label_mask_path = os.path.join(mask_paths, f'mask{lon_lat}.tif')\n",
" label_mask = np.array(Image.open(label_mask_path), dtype=np.uint8)\n",
" \n",
" output = output_tile.flatten()\n",
" labels = label_mask.flatten()\n",
" \n",
" # mask out where labels is 0, which is outside of boundary of region\n",
" # and also where output is 0, which is where no imagery is available on the tile\n",
" # now get rid of such entries\n",
" labels_masked = labels * (output != 0)\n",
" no_label_entries = np.where(labels_masked == 0)\n",
" \n",
" labels = np.delete(labels, no_label_entries)\n",
" output = np.delete(output, no_label_entries)\n",
" \n",
" classes_present_in_gt.update(labels)\n",
" \n",
" tile_accuracy = accuracy_score(labels, output, normalize=True)\n",
" tile_accuracies[lon_lat] = tile_accuracy\n",
"\n",
" for y_true, y_pred in tqdm(zip(labels, output)):\n",
" cm[y_true][y_pred] += 1\n",
" true_counts[y_true] += 1\n",
" pred_counts[y_pred] += 1\n",
" \n",
"overall_accuracy = sum(tile_accuracies.values())/len(tile_accuracies)\n",
"print(f'Overall accuracy is {overall_accuracy}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tile_accuracies"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accurate distribution of land types\n",
"The shapefile's area attribute did not look correct"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"true_counts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# normalize by ground truth label counts\n",
"cm_norm = np.zeros((num_classes, num_classes), dtype=np.float)\n",
"for y_true in range(num_classes):\n",
" for y_pred in range(num_classes):\n",
" if true_counts[y_true] == 0:\n",
" cm_norm[y_true][y_pred] = 0.0\n",
" else:\n",
" cm_norm[y_true][y_pred] = cm[y_true][y_pred] / true_counts[y_true]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# docs: https://matplotlib.org/3.1.3/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py\n",
"\n",
"cm_to_plot = cm_norm\n",
"\n",
"\n",
"fig = plt.figure(figsize=(10, 10), dpi=200) # set dpi to 300 to look good\n",
"ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])\n",
"im = ax.matshow(cm_to_plot, cmap=plt.cm.YlGnBu)\n",
"_ = ax.set_xticks(np.array(range(num_classes)))\n",
"_ = ax.set_yticks(np.array(range(num_classes)))\n",
"_ = ax.set_xticklabels(label_names)\n",
"_ = ax.set_yticklabels(label_names)\n",
"_ = ax.set_ylabel('Provided labels')\n",
"_ = ax.set_xlabel('Predicted by model')\n",
"ax.xaxis.tick_top()\n",
"\n",
"# Rotate the tick labels\n",
"_ = plt.setp(ax.get_xticklabels(), rotation=90)\n",
"\n",
"_ = ax.set_xticks(np.array(range(num_classes)) - 0.5, minor=True)\n",
"_ = ax.set_yticks(np.array(range(num_classes)) - 0.5, minor=True)\n",
"ax.grid(which='minor', color='white', linestyle='-', linewidth=3)\n",
"\n",
"cbar = ax.figure.colorbar(im, ax=ax)\n",
"\n",
"# no border\n",
"for edge, spine in ax.spines.items():\n",
" spine.set_visible(False)\n",
"\n",
"# right-click save - layout isn't right otherwise\n",
" \n",
"#fig.tight_layout()\n",
"#plt.savefig('/Users/siyuyang/Source/temp_data/WCS_land_use/train_200218/result_val/evaluation/cm.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cm_norm[32][32]\n",
"cm_norm[33][33]\n",
"\n",
"cm_norm[30, 33] # row, col - ground truth, predicted\n",
"cm_norm[33][30] "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Side by side label and output counts, in log scale"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Per-class accuracy, precision and recall"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# per-class accuracy\n",
"total_obs = cm.sum()\n",
"\n",
"per_class_accuracy = {}\n",
"per_class_recall = {}\n",
"per_class_precision = {}\n",
"\n",
"for cls in range(num_classes):\n",
" if cls not in classes_present_in_gt:\n",
" continue\n",
" \n",
" true_pos = cm[cls, cls]\n",
" \n",
" true_neg = total_obs - cm[cls, :].sum() - cm[:, cls].sum() + true_pos\n",
" \n",
" false_pos = cm[:, cls].sum() - true_pos\n",
" \n",
" false_neg = cm[cls, :].sum() - true_pos\n",
" \n",
" per_class_accuracy[cls] = (true_pos + true_neg) / total_obs\n",
" \n",
" per_class_precision[cls] = true_pos / (true_pos + false_pos)\n",
" \n",
" per_class_recall[cls] = true_pos / (true_pos + false_neg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Category, Accuracy, Precision, Recall')\n",
"for cls, acc in per_class_accuracy.items():\n",
" prec = per_class_precision[cls]\n",
" recall = per_class_recall[cls]\n",
" print(f'{cls} {viz_util.num_to_name[str(cls)]},{acc},{prec},{recall}')\n",
" \n",
"# paste the result into Pages, and fix the row for \"27 Lakes, lagoons, and natural cienaga\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the dataset is so unbalanced (mostly 12 - dense forest) and accuracy counts \"true negatives\" as a win, this is not a good measure of performance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save the evaluation findings - not yet done"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"saved = {\n",
" 'overall_accuracy': overall_accuracy,\n",
" 'per_class_accuracy': per_class_accuracy,\n",
" # 'calibration_summary': calibration_summary\n",
"}\n",
"\n",
"with open(eval_saved_to, 'w') as f:\n",
" json.dump(saved, f, indent=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Is the model well-calibrated?\n",
"\n",
"We can also just record a 2D shape - each cell is the confidence of the most confident class?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(output_scores_path, 'rb') as f:\n",
" dict_scores = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"classes_to_plot = [0, 11, 12, 17, 19, 26, 32]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_true = defaultdict(list)\n",
"y_prob = defaultdict(list)\n",
"\n",
"for window, chip_scores in tqdm(dict_scores.items()):\n",
" # rasterio window is (col_off x, row_off y, width, height)\n",
" \n",
" chip_scores = chip_scores.squeeze() # chip_scores have shape (1, 33, 256, 256)\n",
" chip_scores = chip_scores.reshape((33, -1))\n",
"\n",
" chip_labels = label_mask[window[0]:window[0] + 256, window[1]:window[1] + 256]\n",
" chip_labels = chip_labels.reshape((1, -1))\n",
" # we pad 0 to the end of chips after the tile ends\n",
" chip_labels = np.pad(chip_labels, ((0, 0), (0, 256*256 - chip_labels.shape[1]))).squeeze()\n",
" \n",
" assert chip_scores.shape == (33, 256*256), chip_scores.shape\n",
" assert chip_labels.shape == (256*256,), chip_labels.shape\n",
" \n",
" for cls in classes_to_plot:\n",
" cls_y_true = chip_labels == cls\n",
" cls_y_prob = chip_scores[cls]\n",
" assert len(list(cls_y_true)) == len(list(cls_y_prob)), '{}, {}'.format(\n",
" len(list(cls_y_true)), len(list(cls_y_prob))\n",
" )\n",
" y_true[cls].extend(list(cls_y_true))\n",
" y_prob[cls].extend(list(cls_y_prob))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(y_true[12])\n",
"len(y_prob[12])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_ = plt.plot([0.0, 1.0], color='grey', linestyle=':')\n",
"\n",
"for cls in classes_to_plot:\n",
" _ = frac_positives, mean_prob_in_bin = calibration_curve(y_true[cls], y_prob[cls], n_bins=10)\n",
" _ = plt.plot(mean_prob_in_bin, frac_positives, label=cls, color=viz_util.num_to_color[str(cls)])\n",
"_ = plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean_prob_in_bin\n",
"frac_positives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Expected number of pixels for the whole validation area"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"probability_sum = np.zeros(num_classes, dtype=np.float)\n",
"\n",
"for window, chip_scores in dict_scores.items():\n",
" # print(chip_scores.shape) # (1, 33, 256, 256)\n",
" chip_scores = chip_scores.squeeze()\n",
" chip_scores = chip_scores.sum(axis=(1, 2)) # height and width dims\n",
" probability_sum += chip_scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"calibration_summary = {}\n",
"\n",
"for cls, (prob_sum, label_sum) in enumerate(zip(probability_sum, true_counts)):\n",
" calibration_summary[cls] = {\n",
" 'prediction_probability_sum': prob_sum,\n",
" 'label_sum': int(label_sum)\n",
" }\n",
" print('Class {} - {}, prob_sum {}, label_sum {}'.format(cls, viz_util.num_to_name[str(cls)], round(prob_sum), label_sum))\n",
" if label_sum > 0:\n",
" print(' diff is {}%'.format(100 * round((prob_sum - label_sum)/label_sum, 3)))\n",
" calibration_summary[cls]['difference_wrt_label_sum'] = (prob_sum - label_sum)/label_sum"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:wcs] *",
"language": "python",
"name": "conda-env-wcs-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -0,0 +1,294 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Implementation of ModelSessionAbstract for the Orinoquia land cover mapping project. Heavily references
wildlife-conservation-society.orinoquia-land-use/training/inference.py
"wildlife-conservation-society.orinoquia-land-use" and "ai4eutils" need to be on the PYTHONPATH.
"""
import importlib
import logging
import sys
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from web_tool.ModelSessionAbstract import ModelSession
LOGGER = logging.getLogger("server")
class TorchFineTuningOrinoquia(ModelSession):
def __init__(self, gpu_id, **kwargs):
# setting up device to use
LOGGER.debug(f"TorchFineTuningOrinoquia init, gpu_id is {gpu_id}")
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id is not None else "cpu")
# load experiment configuration as a module
config_module_path = kwargs["config_module_path"]
try:
module_name = "config"
spec = importlib.util.spec_from_file_location(module_name, config_module_path)
self.config = importlib.util.module_from_spec(spec)
sys.modules[module_name] = self.config
spec.loader.exec_module(self.config)
except Exception as e:
LOGGER.error(f"Failed to import experiment and model configuration. Exception: {e}")
sys.exit(1)
LOGGER.info(f"config is for experiment {self.config.experiment_name}")
# check that the necessary fields are present in the config
assert self.config.num_classes > 1 # this is the number of classes the initial model supports
self.num_classes = self.config.num_classes # self.num_classes can evolve during a session
assert self.config.chip_size > 1
assert self.config.feature_scale in [1, 2]
chip_size = self.config.chip_size
self.prediction_window_size = self.config.prediction_window_size if self.config.prediction_window_size else 128
self.prediction_window_offset = int((chip_size - self.prediction_window_size) / 2)
print((f"Using chip_size {chip_size} and window_size {self.prediction_window_size}. "
f"So window_offset is {self.prediction_window_offset}"))
# obtain the model that the config has initialized
self.checkpoint_path = kwargs["fn"]
assert os.path.exists(self.checkpoint_path), f"Checkpoint at {self.checkpoint_path} does not exist."
self.model = self.config.model
self._init_model()
# other instance variables
self._last_tile = None
# recording the current feature map (before final layer), and the corrections made
self.current_features = None
self.corr_features = []
self.corr_labels = []
def _init_model(self):
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["state_dict"])
print(f"Using checkpoint at epoch {checkpoint['epoch']}, step {checkpoint['step']}, "
f"val accuracy is {checkpoint.get('val_acc', 'Not Available')}")
self.model = self.model.to(device=self.device)
self.model.eval()
@property
def last_tile(self):
return self._last_tile
def run(self, tile, inference_mode=False):
assert tile.shape[2] == 11 # 10 Landsat bands + DEM
height = tile.shape[0] # tile is of dims (height, width, channels)
width = tile.shape[1]
self._last_tile = tile
prediction_window_offset = self.prediction_window_offset
prediction_window_size = self.prediction_window_size
# apply the preprocessing of bands to the tile
data_array = self.config.preprocess_tile(tile) # dim is (6, H, W)
LOGGER.debug(f"run, tile shape is {tile.shape}")
LOGGER.debug(f"run, data_array shape is {data_array.shape}")
# pad by mirroring at the edges to predict only on the center crop
data_array = np.pad(data_array,
[
(0, 0), # only pad height and width
(prediction_window_offset, prediction_window_offset), # height / rows
(prediction_window_offset, prediction_window_offset) # width / cols
],
mode="symmetric")
# form batches
batch_size = self.config.batch_size if self.config.batch_size is not None else 32
batch = []
batch_indices = [] # cache these to save recalculating when filling in model predictions
chip_size = self.config.chip_size
num_rows = math.ceil(height / prediction_window_size)
num_cols = math.ceil(width / prediction_window_size)
for col_idx in range(num_cols):
col_start = col_idx * prediction_window_size
col_end = col_start + chip_size
for row_idx in range(num_rows):
row_start = row_idx * prediction_window_size
row_end = row_start + chip_size
chip = data_array[:, row_start:row_end, col_start: col_end]
# pad to (chip_size, chip_size)
chip = np.pad(chip,
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
# processing it as the dataset loader _get_chip does
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
chip = chip * sat_mask
batch.append(chip)
valid_row_end = row_start + min(prediction_window_size, height - row_idx * prediction_window_size)
valid_col_end = col_start + min(prediction_window_size, width - col_idx * prediction_window_size)
batch_indices.append(
(row_start, valid_row_end, col_start, valid_col_end)) # explicit to be less confusing
batch = np.array(batch) # (num_chips, channels, height, width)
# score chips in batches
model_output = []
model_features = [] # same spatial dims as model_output, but has 64 or 32 channels
self.model.eval()
with torch.no_grad():
for i in range(0, len(batch), batch_size):
t_batch = batch[i:i + batch_size]
t_batch = torch.from_numpy(t_batch).to(self.device)
scores, features = self.model.forward(t_batch,
return_features=True) # these are scores before the final softmax
softmax_scores = torch.nn.functional.softmax(scores,
dim=1).cpu().numpy() # (batch_size, num_classes, H, W)
softmax_scores = np.transpose(softmax_scores, axes=(0, 2, 3, 1)) # (batch_size, H, W, num_classes)
# only save the center crop
softmax_scores = softmax_scores[
:,
prediction_window_offset:prediction_window_offset + prediction_window_size,
prediction_window_offset:prediction_window_offset + prediction_window_size,
:
]
model_output.append(softmax_scores)
features = features.cpu().numpy()
features = np.transpose(features, axes=(0, 2, 3, 1)) # (batch_size, H, W, num_features)
features = features[
:,
prediction_window_offset:prediction_window_offset + prediction_window_size,
prediction_window_offset:prediction_window_offset + prediction_window_size,
:
]
model_features.append(features)
model_output = np.concatenate(model_output, axis=0)
model_features = np.concatenate(model_features, axis=0)
# fill in the output array; self.num_classes is the number of classes supported by the current model
output = np.zeros((height, width, self.num_classes), dtype=np.float32)
for i, (row_start, row_end, col_start, col_end) in enumerate(batch_indices):
h = row_end - row_start
w = col_end - col_start
output[row_start:row_end, col_start:col_end, :] = model_output[i, :h, :w, :]
print(f"--- Orinoquia ModelSession, output[-1, -1, :] is {output[-1, -1, :]}")
num_features = 64 if self.config.feature_scale == 1 else 32
output_features = np.zeros((height, width, num_features), dtype=np.float32) # float32 used during training too
for i, (row_start, row_end, col_start, col_end) in enumerate(batch_indices):
h = row_end - row_start
w = col_end - col_start
output_features[row_start:row_end, col_start:col_end, :] = model_features[i, :h, :w, :]
print(f"--- Orinoquia ModelSession, output_features[-1, -1, :] is {output_features[-1, -1, :]}")
# save the features
self.current_features = output_features
return output
def add_sample_point(self, row, col, class_idx):
self.corr_labels.append(class_idx)
self.corr_features.append(self.current_features[row, col, :])
print(f"After add_sample_point, corr_labels length is {len(self.corr_labels)}")
return {
"message": f"Training sample for class {class_idx} added",
"success": True
}
def undo(self):
if len(self.corr_features) > 0:
self.corr_features = self.corr_features[:-1]
self.corr_labels = self.corr_labels[:-1]
return {
"message": "Undid training sample",
"success": True
}
else:
return {
"message": "Nothing to undo",
"success": False
}
def reset(self):
self._init_model()
self.num_classes = self.config.num_classes
self.corr_features = []
self.corr_labels = []
return {
"message": f"Model is reset and support {self.num_classes} classes",
"success": True
}
def retrain(self, train_steps=100, learning_rate=1e-3):
print_every = 10
# if any new classes have been added, update self.num_classes and re-initialize final layer
# class start from 0
if max(self.corr_labels) >= self.num_classes:
self.num_classes = max(self.corr_labels) + 1
self.model.change_num_classes(self.num_classes)
self.model.final.to(device=self.device)
LOGGER.debug(f"New classes have been added (total {self.num_classes}) and final layer re-initialized.")
# all corrections since the last reset are used
batch_x = torch.from_numpy(np.array(self.corr_features)).float().to(self.device)
batch_y = torch.from_numpy(np.array(self.corr_labels)).to(self.device)
# make the last layer `final` trainable TODO do we need to do this - default trainable?
for param in self.model.final.parameters(): # see UNet implementation in WCS project repo
param.requires_grad = True
optimizer = optim.Adam(self.model.final.parameters(), lr=learning_rate) # only the last layer
# during re-training, we use equal weight for all classes
criterion = nn.CrossEntropyLoss().to(device=self.device)
self.model.train()
for step in range(train_steps):
with torch.enable_grad():
# forward pass
batch_x_reshaped = batch_x.unsqueeze(2).unsqueeze(3)
scores = self.model.final.forward(batch_x_reshaped).squeeze(3).squeeze(2)
loss = criterion(scores, batch_y)
# backward pass
optimizer.zero_grad()
loss.backward() # compute gradients
optimizer.step() # update parameters
if step % print_every == 0:
preds = scores.argmax(1)
accuracy = (batch_y == preds).float().mean()
print(f'step {step}, loss: {loss.item()}, accuracy: {accuracy.item()}')
return {
"message": f"Fine-tuned model with {len(self.corr_features)} samples for {train_steps} steps",
"success": True
}
def save_state_to(self, directory):
# number of classes could be different from what's in the config for the initial model
return {
"message": "Saving not yet implemented",
"success": False
}
def load_state_from(self, directory):
return {
"message": "Saving and loading not yet implemented",
"success": False
}

45
finetuning/README.md Normal file
Просмотреть файл

@ -0,0 +1,45 @@
# Finetuning the land cover model interactively
Files in this folder are configurations and implementations of required classes for finetuning the model interactively in an instance of the [Land Cover Mapping tool](https://github.com/microsoft/landcover) (this repo will be referred to as the `landcover` repo below).
## Setup
Because our implementation of [ModelSession](https://github.com/microsoft/landcover/blob/master/web_tool/ModelSessionAbstract.py) relies on the experiment configuration file (a `.py` file) that produced the model, this repo needs to be on the `PYTHONPATH` when running the Land Cover Mapping tool's server, in additional to the AI for Earth utilities [repo](https://github.com/microsoft/ai4eutils):
```
export PYTHONPATH="${PYTHONPATH}:/home/boto/wcs/pycharm:/home/boto/lib/ai4eutils"
```
Note that directory names in this repo should not clash with ones in the `landcover` repo.
Basemap needs to be in the `landcover` repo's root directory for the server to serve the data. We can create a symbolic link to the files stored in a blob storage container:
```
ln -s /home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2013_2014_dem/wcs_orinoquia_sr_median_2013_2014_dem.vrt wcs_orinoquia_sr_median_2013_2014_dem.vrt
```
The configuration files for the dataset and model in this folder and `ModelSessionOrinoquia.py` should be copied to the `landcover` repo's `web_tool` directory.
We also need to make the following changes:
- Modify [`worker.py`](https://github.com/microsoft/landcover/blob/master/worker.py) to add a case for the `model_type` `pytorch_landsat`.
- Modify `web_tool/DataLoader.py`, changing
```python
resolution=(x_res, y_res)
```
to
```python
resolution=(30, 30) # hardcode Landsat 8 resolution
```
- In `web_tool/js/components.js`, modify these values in `addInferenceWindowSizeSlider` to:
```python
min: 7680,
max: 23040,
```
- In `web_tool/js/globals.js`, set the following variables:
```python
var INFERENCE_WINDOW_SIZE = 300;
var INFERENCE_WINDOW_SIZE = 7680;
```
## Operation
![Using the Land Cover Mapping tool on Landsat 8 imagery](../visuals/in_the_tool.png)

Просмотреть файл

@ -0,0 +1,38 @@
{
"landsat_orinoquia": {
"metadata": {
"displayName": "Orinoquia Colombia, 2013-2014, Landsat 8",
"locationName": null
},
"dataLayer": {
"type": "CUSTOM",
"path": "data/imagery/wcs_orinoquia_sr_median_2013_2014_dem.vrt",
"padding": 0,
"resolution": 30
},
"basemapLayers": [
{
"layerName": "Landsat 8 2013-2014",
"initialZoom": 8,
"url": "data/basemaps/wcs_2013_2014_basemap_exp07/{z}/{x}/{y}.png",
"initialLocation": [4.400127, -72.304724],
"args": {
"attribution": "Georeferenced Image",
"tms": true,
"maxNativeZoom": 13,
"maxZoom": 16,
"minZoom": 8,
"bounds": [[1.4752987731256113, -67.25191674463355], [7.516203982228033, -74.96450029130249]]
}
}
],
"shapeLayers": [
{
"shapesFn": "data/zones/Orinoquia_outline.geojson",
"zoneNameKey": "NAME_1",
"name": "Region outline"
}
],
"validModels": ["landsat_orinoquia_coarse"]
}
}

Просмотреть файл

@ -0,0 +1,77 @@
{
"landsat_orinoquia_coarse": {
"metadata": {
"displayName": "Landsat Orinoquia Coarse-class"
},
"model": {
"type": "pytorch_landsat",
"fineTuneLayer": -1,
"numParameters": -1,
"inputShape": [
256,
256,
6
],
"config_module_path": "/home/boto/wcs/pycharm/training_wcs/experiments/coarse_baseline/coarse_baseline_config_refactored.py",
"fn": "/home/boto/wcs/mnt/wcs-orinoquia/useful_checkpoints/wcs_coarse_baseline/wcs_coarse_baseline_epoch93_best.pth.tar"
},
"classes": [
{
"name": "Empty of data",
"color": "#000000"
},
{
"name": "Urban and infrastructure",
"color": "#d3d3d3"
},
{
"name": "Agriculture",
"color": "#ffc0cb"
},
{
"name": "Arboreal and forestry crops",
"color": "#008080"
},
{
"name": "Pasture",
"color": "#fa8072"
},
{
"name": "Vegetation",
"color": "#daa520"
},
{
"name": "Forest",
"color": "#8fbc8f"
},
{
"name": "Savanna",
"color": "#ffd700"
},
{
"name": "Sand, rocks and bare land",
"color": "#ffebcd"
},
{
"name": "Unavailable",
"color": "#f5f5f5"
},
{
"name": "Swamp",
"color": "#556b2f"
},
{
"name": "Water",
"color": "#00bfff"
},
{
"name": "Seasonal savanna",
"color": "#f0e68c"
},
{
"name": "Seasonally flooded savanna",
"color": "#d8bfd8"
}
]
}
}

Просмотреть файл

@ -0,0 +1,200 @@
"""
Configurations for the 20200505_mini_baseline experiment
"""
import json
import os
import torch
import numpy as np
from viz_utils import VizUtils
from training_wcs.scripts.models.unet.unet import Unet # models.unet.unet import Unet
experiment_name = 'wcs_baseline_202005'
eval_mode = True
# I/O -------------------------------------------------------------------------------------------------
if not eval_mode:
aml_data_ref = os.environ.get('AZUREML_DATAREFERENCE_wcsorinoquia', '')
assert len(aml_data_ref) > 0, 'Reading aml_data_ref from environment vars resulted in empty string.'
data_dir = os.path.join(aml_data_ref, 'tiles', 'full_sr_median_2013_2014')
assert 'tiles' in os.listdir(data_dir)
assert 'tiles_masks' in os.listdir(data_dir)
# a dir with experiment_name will be created in here and checkpoints are saved here
# set as './outputs' for AML to stream to this Run's folder
out_dir = '/boto_disk_0/wcs/20190518_feature_scale_1/outputs' # './outputs'
os.makedirs(out_dir, exist_ok=True)
# TF events go here. Set it as './logs' if using AML so they can be streamed
log_dir = '/boto_disk_0/wcs/20190518_feature_scale_1/logs' # './logs'
# train/val splits are stored in
# on AML, this needs to be relative to the source_directory level
splits_file = './constants/splits/full_sr_median_2013_2014_splits.json' # '../training_wcs/scripts/constants/splits/full_sr_median_2013_2014_splits.json'
with open(splits_file) as f:
splits = json.load(f)
train_split = splits['train']
val_split = splits['val']
print(f'Train set has {len(train_split)} tiles; val set has {len(val_split)} tiles.')
# Training ----------------------------------------------------------------------------------------------
evaluate_only = False # Only evaluate the model on the val set once
# this is the *total* epoch; if restarting from a checkpoint, be sure to add the additional number of epochs
# to fine-tune on top of the original value of this var
total_epochs = 1000
print_every = 100 # print every how many steps; just the minibatch loss and accuracy
assert print_every >= 1, 'print_every needs to be greater than or equal 1'
starting_checkpoint_path = None
init_learning_rate = 1e-4
batch_size = 24
# probability a chip is kept in the sample while sampling train and val chips at the start of training_wcs
# this should be smaller if we now have more training_wcs examples
# prob_keep_chip = 0.006
# 133 training_wcs tiles * 64 chips per tile = 8512 chips. Should keep every 177 if visualizing 48, which is 0.0056
keep_every = 30 # a balance between not covering all training_wcs tiles vs iterating through val tiles too many times
num_chips_to_viz = 48
# Hardware and framework --------------------------------------------------------------------------------
dtype = torch.float32
# Model -------------------------------------------------------------------------------------------------
num_classes = 34 # empty plus the 33 WCS classes; this is the number of output nodes
num_in_channels = 5 # 2, 3, 6, 7, NDVI
# the smallest number of filters is 64 when feature_scale is 1, and it is 32 when feature_scale is 2
feature_scale = 1
is_deconv = True # True to use transpose convolution filters to learn upsampling; otherwise upsampling is not learnt
is_batchnorm = True
model = Unet(feature_scale=feature_scale,
n_classes=num_classes,
in_channels=num_in_channels,
is_deconv=is_deconv,
is_batchnorm=is_batchnorm)
# Data ---------------------------------------------------------------------------------------------------
common_classes = [
12
]
less_common_classes = [
32, 33
]
weights = []
for i in range(num_classes):
if i in common_classes:
weights.append(1)
elif i in less_common_classes:
weights.append(2)
else:
weights.append(10)
loss_weights = torch.FloatTensor(weights) # None if no weighting for classes
print('Weights on loss per class used:')
print(loss_weights)
# how many subprocesses to use for data loading
# None for now - need to modify datasets.py to use
data_loader_num_workers = None
# not available in IterableDataset data_loader_shuffle = True # True to have the data reshuffled at every epoch
chip_size = 256
# based on min and max values from the sample tile
# wcs_orinoquia_sr_median_2013_2014-0000007424-0000007424_-71.347_4.593.tif in training_wcs set
# bands 4 and 5 are combined to get the NDVI, so the normalization params for 4 and 5 are
# not used during training_wcs data generation, only for visualization (actually not yet used for viz either).
bands_normalization_params = {
# these are the min and max to clip to for the band
'min': {
2: 0,
3: 0,
4: 0,
5: 0,
6: 0,
7: 0
},
'max': {
2: 700,
3: 1500,
4: 1500,
5: 5000,
6: 5000,
7: 3000
},
'gamma': { # all the same in this experiment with value 1 which means no effect
2: 1.0,
3: 1.0,
4: 1.0,
5: 1.0,
6: 1.0,
7: 1.0
}
}
viz_util = VizUtils()
def get_chip(tile_reader, chip_window):
"""
Returns:
A numpy array of dims (5, H, W)
"""
normal_bands = [2, 3, 6, 7] # bands to be used without calculating other indices e.g. NDVI
bands_to_stack = []
for b in normal_bands:
# getting one band at a time because min, max and gamma may be different
band = viz_util.show_landsat8_tile(
tile_reader,
bands=[b], # pass in a list to get the batch dimension in the results
window=chip_window,
band_min=bands_normalization_params['min'][b],
band_max=bands_normalization_params['max'][b],
gamma=bands_normalization_params['gamma'][b],
return_array=True
)
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
ndvi = viz_util.show_landsat8_ndvi(tile_reader, window=chip_window) # 2D, dtype is float32
bands_to_stack.append(ndvi)
stacked = np.stack(bands_to_stack)
if stacked.shape != (5, chip_size, chip_size):
# default pad constant value is 0
stacked = np.pad(stacked,
[(0, 0), (0, chip_size - stacked.shape[1]), (0, chip_size - stacked.shape[2])])
assert stacked.shape == (5, chip_size, chip_size), f'Landsat chip has wrong shape: {stacked.shape}, should be (5, h, w)'
# prepare the chip for display
chip_for_display = viz_util.show_landsat8_tile(tile_reader,
window=chip_window,
band_max=3000, # what looks good for RGB
gamma=0.5,
return_array=True)
return stacked, chip_for_display

Просмотреть файл

@ -0,0 +1,352 @@
"""
First experiment using 6 channels (2, 3, 6, 7, NDVI, elevation) with the 13 + 1 coarse categories
mapped June 29 2020.
"""
import os
import sys
import numpy as np
import rasterio
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# ai4eutils needs to be on the PYTHONPATH
from geospatial.enums import ExperimentConfigMode
from geospatial.visualization.imagery_visualizer import ImageryVisualizer
from geospatial.visualization.raster_label_visualizer import RasterLabelVisualizer
from training_wcs.scripts.models.unet.unet import Unet
from training_wcs.scripts.utils.data_transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
from training_wcs.scripts.utils.datasets import SingleShardChipsDataset
experiment_name = 'wcs_coarse_baseline'
config_mode = ExperimentConfigMode.SCORING
# I/O -------------------------------------------------------------------------------------------------
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.TRAINING]:
data_shard_dir = '/boto_disk_0/wcs_data/shards/full_sr_median_2013_2014_elevation'
# a dir with experiment_name will be created in here and checkpoints are saved here
# set as './outputs' for AML to stream to this Run's folder
out_dir = f'/boto_disk_0/wcs/{experiment_name}/outputs'
os.makedirs(out_dir, exist_ok=True)
# TF events go here. Set it as './logs' if using AML so they can be streamed
log_dir = f'/boto_disk_0/wcs/{experiment_name}/logs' # './logs'
os.makedirs(log_dir, exist_ok=True)
# for scoring script and make_chip_shards
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.SCORING]:
prediction_window_size = 128
label_viz = RasterLabelVisualizer(
label_map='/home/boto/wcs/pycharm/constants/class_lists/wcs_coarse_label_map.json')
data_dir = '/boto_disk_0/wcs_data' # which contains images_srtm
# Training ----------------------------------------------------------------------------------------------
# this is the *total* epoch; if restarting from a checkpoint, be sure to add the additional number of epochs
# to fine-tune on top of the original value of this var
total_epochs = 500
print_every = 150 # print every how many steps; just the minibatch loss and accuracy
assert print_every >= 1, 'print_every needs to be greater than or equal 1'
starting_checkpoint_path = None
init_learning_rate = 5e-5
batch_size = 28
# visualizing results on a sample of chips during training_wcs
num_chips_to_viz = 48
# Hardware and framework --------------------------------------------------------------------------------
dtype = torch.float32
# Model -------------------------------------------------------------------------------------------------
num_classes = 14 # empty plus the 13 *coarse* WCS classes; this is the number of output nodes
num_in_channels = 6 # 2, 3, 6, 7, NDVI, elevation
# the smallest number of filters is 64 when feature_scale is 1, and it is 32 when feature_scale is 2
feature_scale = 1
is_deconv = True # True to use transpose convolution filters to learn upsampling; otherwise upsampling is not learnt
is_batchnorm = True
model = Unet(feature_scale=feature_scale,
n_classes=num_classes,
in_channels=num_in_channels,
is_deconv=is_deconv,
is_batchnorm=is_batchnorm)
# Data ---------------------------------------------------------------------------------------------------
common_classes = [
]
less_common_classes = [
]
weights = []
for i in range(num_classes):
if i in common_classes:
weights.append(0)
elif i in less_common_classes:
weights.append(0)
else:
weights.append(1)
loss_weights = torch.FloatTensor(weights) # None if no weighting for classes
print('Weights on loss per class used:')
print(loss_weights)
# how many subprocesses to use for data loading
# None for now - need to modify datasets.py to use
data_loader_num_workers = None
# not available in IterableDataset data_loader_shuffle = True # True to have the data reshuffled at every epoch
chip_size = 256
# datasets and dataloaders
if config_mode == ExperimentConfigMode.PREPROCESSING:
dset_train = SingleShardChipsDataset(data_shard_dir, shard_prefix='train', channels=None,
transform=transforms.Compose([
ToTensor(),
RandomHorizontalFlip(), # these operate on Tensors, not PIL images
RandomVerticalFlip()
]))
loader_train = DataLoader(dset_train,
batch_size=batch_size,
num_workers=4,
shuffle=True) # currently num_workers is None
dset_val = SingleShardChipsDataset(data_shard_dir, shard_prefix='val', channels=None,
transform=transforms.Compose([
ToTensor(),
RandomHorizontalFlip(), # these operate on Tensors, not PIL images
RandomVerticalFlip()
]))
loader_val = DataLoader(dset_val,
num_workers=4,
batch_size=batch_size)
# Data shards generation configurations --------------------------------------------------------------------
# These configurations are copied from training_wcs/experiments/elevation/elevation_2_config.py
# They are only used with make_chip_shards.py and infer.py
# train.py only use the generated chip shards as numpy arrays
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.SCORING]:
elevation_path = os.path.join(data_dir, 'images_srtm', 'wcs_orinoquia_srtm.tif')
elevation_reader = rasterio.open(elevation_path)
# based on min and max values from the sample tile
# wcs_orinoquia_sr_median_2013_2014-0000007424-0000007424_-71.347_4.593.tif in training_wcs set
# bands 4 and 5 are combined to get the NDVI, so the normalization params for 4 and 5 are
# not used during training_wcs data generation, only for visualization (actually not yet used for viz either).
bands_normalization_params = {
# these are the min and max to clip to for the band
'min': {
2: 0,
3: 0,
4: 0,
5: 0,
6: 0,
7: 0
},
'max': {
2: 700,
3: 1500,
4: 1500,
5: 5000,
6: 5000,
7: 3000
},
'gamma': { # all the same in this experiment with value 1 which means no effect
2: 1.0,
3: 1.0,
4: 1.0,
5: 1.0,
6: 1.0,
7: 1.0
}
}
elevation_standardization_params = {
# from calculations done in GEE
'mean': 399.78,
'std_dev': 714.78
}
def get_elevation_chip(tile_reader, chip_window):
x, y = (tile_reader.bounds.left, tile_reader.bounds.top)
# getting the pixel array indices corresponding to points in georeferenced space
row, col = elevation_reader.index(x, y)
# tile wcs_orinoquia_sr_median_2013_2014-0000000000-0000007424_-72.425_7.671.tif
# top left corner looks up to a negative row index. Clipping to 0 seems to be okay visually
row = max(0, row)
col = max(0, col)
# resolution and project are the same for the elevation data and the Landsat imagery
row = row + chip_window[1]
col = col + chip_window[0] # x is col
try:
w = rasterio.windows.Window.from_slices((row, row + chip_window[3]), (col, col + chip_window[2]))
except Exception as e:
print(str(e))
print('chip window:', str(chip_window))
print('original row and col: ', str(elevation_reader.index(x, y)))
print('row:', row)
print('col:', col)
print(tile_reader.bounds)
print('x:', x)
print('y:', y)
import sys
sys.exit(1)
chip_elevation = elevation_reader.read(1, window=w) # only 1 band
# standardize
chip_elevation = (chip_elevation - elevation_standardization_params['mean']) / elevation_standardization_params[
'std_dev']
return chip_elevation
def _pad_chip(band, chip_window):
"""
Args:
band: numpy array of dims (h, w)
chip_window: (col_off x, row_off y, width, height)
Returns:
band padded to dims (width, height) of the chip_window provided
"""
width = chip_window[2]
height = chip_window[3]
# check for smaller than because we use get_chip to get the entire tile during scoring
if band.shape[0] < height or band.shape[1] < width:
# default pad constant value is 0
try:
band = np.pad(band,
[(0, height - band.shape[0]), (0, width - band.shape[1])])
except Exception as e:
print(f'coarse_baseline_config, _pad_chip exception: {e}')
sys.exit(1)
return band
normal_bands = [2, 3, 6, 7] # bands to be used without calculating other indices e.g. NDVI
def get_chip(tile_reader, chip_window, chip_for_display=True):
"""
Get an area (chip) specified by the chip_window. Is not related to chip_size
Args:
tile_reader: rasterio dataset object of the imagery tile
chip_window: (col_off x, row_off y, width, height)
chip_for_display: True if also return a chip that looks good
Returns:
stacked: A numpy array of dims (6, H, W) - note that height and width are switched from chip_window
chip_for_display: If chip_for_display is True, also a 3-band array of the RGB channels scaled to
look good (R channel not included in stacked)
"""
bands_to_stack = []
for b in normal_bands:
# getting one band at a time because min, max and gamma may be different
band = ImageryVisualizer.show_landsat8_patch(
tile_reader,
bands=[b], # pass in a list to get the batch dimension in the results
window=chip_window,
band_min=bands_normalization_params['min'][b],
band_max=bands_normalization_params['max'][b],
gamma=bands_normalization_params['gamma'][b],
return_array=True
)
# deal with incomplete chips
band = _pad_chip(band, chip_window)
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_reader, window=chip_window) # 2D, dtype is float32
ndvi = _pad_chip(ndvi, chip_window)
bands_to_stack.append(ndvi)
elevation = get_elevation_chip(tile_reader, chip_window) # scene covers entire region, not tiled, so no gaps
elevation = _pad_chip(elevation, chip_window)
bands_to_stack.append(elevation)
try:
stacked = np.stack(bands_to_stack)
except Exception as e:
print(f'Exception in get_chip: {e}')
for b in bands_to_stack:
print(b.shape)
print('')
assert stacked.shape == (6, chip_window[3], chip_window[2]), \
f'Chip has wrong shape: {stacked.shape}, should be (6, h, w)'
if chip_for_display:
# chip for display, getting the RBG bands (default) with a different gamma and band_max that look good
chip_for_display = ImageryVisualizer.show_landsat8_patch(tile_reader,
window=chip_window,
band_max=3000, # what looks good for RGB
gamma=0.5,
return_array=True)
return stacked, chip_for_display
else:
return stacked
def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
"""Same functionality as get_chip(), but applies to a numpy array tile of arbitrary shape.
Currently only used with the landcover tool.
Args:
tile_array: A numpy array of dims (height, width, channels). Expect elevation to be the eleventh channel
Returns:
Numpy array representing of the preprocessed chip of dims (6, height, width) - note that channels is
in-front.
"""
bands_to_stack = []
for b in normal_bands:
# getting one band at a time because min, max and gamma may be different
band = ImageryVisualizer.show_landsat8_patch(
tile_array,
bands=[b], # pass in a list to get the batch dimension in the results
band_min=bands_normalization_params['min'][b],
band_max=bands_normalization_params['max'][b],
gamma=bands_normalization_params['gamma'][b],
return_array=True
)
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_array) # 2D, dtype is float32
bands_to_stack.append(ndvi)
# for the interactive tool, elevation is band 11 (1-indexed) or 10 (0-indexed), and already normalized
# by elevation_standardization_params (could have done the normalization here too)
elevation = tile_array[:, :, 10]
bands_to_stack.append(elevation)
stacked = np.stack(bands_to_stack)
assert stacked.shape == (
6, tile_array.shape[0], tile_array.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
return stacked

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,105 @@
import logging
import torch.nn as nn
from .unet_utils import *
"""
Unet model definition.
Code mostly taken from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/unet.py
"""
class Unet(nn.Module):
def __init__(self, feature_scale=1,
n_classes=3, in_channels=3,
is_deconv=True, is_batchnorm=False):
"""A U-Net implementation.
Args:
feature_scale: the smallest number of filters (depth c) is 64 when feature_scale is 1,
and it is 32 when feature_scale is 2
n_classes: number of output classes
in_channels: number of channels in input
is_deconv:
is_batchnorm:
"""
super(Unet, self).__init__()
self._num_classes = n_classes
assert 64 % feature_scale == 0, f'feature_scale {feature_scale} does not work with this UNet'
filters = [64, 128, 256, 512, 1024] # this is `c` in the diagram, [c, 2c, 4c, 8c, 16c]
filters = [int(x / feature_scale) for x in filters]
logging.info('filters used are: {}'.format(filters))
# downsampling
self.conv1 = UnetConv2(in_channels, filters[0], is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = UnetConv2(filters[0], filters[1], is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = UnetConv2(filters[1], filters[2], is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = UnetConv2(filters[2], filters[3], is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.center = UnetConv2(filters[3], filters[4], is_batchnorm)
# upsampling
self.up_concat4 = UnetUp(filters[4], filters[3], is_deconv)
self.up_concat3 = UnetUp(filters[3], filters[2], is_deconv)
self.up_concat2 = UnetUp(filters[2], filters[1], is_deconv)
self.up_concat1 = UnetUp(filters[1], filters[0], is_deconv)
# final conv (without any concat)
self.final = nn.Conv2d(filters[0], self._num_classes, kernel_size=1)
self._filters = filters # we need this info for re-training
def forward(self, inputs, return_features=False):
"""If return_features is True, returns tuple (final outputs, last feature map),
else returns final outputs only.
"""
conv1 = self.conv1(inputs)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
center = self.center(maxpool4)
up4 = self.up_concat4(conv4, center)
up3 = self.up_concat3(conv3, up4)
up2 = self.up_concat2(conv2, up3)
up1 = self.up_concat1(conv1, up2)
final = self.final(up1)
if return_features:
return final, up1
else:
return final
def change_num_classes(self, new_num_classes: int):
"""Re-initialize the final layer with another number of output classes if different from
existing number of classes
"""
if new_num_classes == self._num_classes:
return
assert new_num_classes > 1, 'Number of classes need to be > 1'
self._num_classes = new_num_classes
self.final = nn.Conv2d(self._filters[0], self._num_classes, kernel_size=1)
nn.init.kaiming_uniform_(self.final.weight)
self.final.bias.data.zero_()

Просмотреть файл

@ -0,0 +1,74 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class UnetConv2(nn.Module):
def __init__(self, in_channels, out_channels, is_batchnorm):
super(UnetConv2, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(
# this amount of padding/stride/kernel_size preserves width/height
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels, is_deconv):
"""
is_deconv: use transposed conv layer to upsample - parameters are learnt; otherwise use
bilinear interpolation to upsample.
"""
super(UnetUp, self).__init__()
self.conv = UnetConv2(in_channels, out_channels, False)
self.is_deconv = is_deconv
if is_deconv:
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
# nn.UpsamplingBilinear2d is deprecated in favor of F.interpolate()
# else:
# self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, inputs1, inputs2):
"""
inputs1 is from the downward path, of higher resolution
inputs2 is from the 'lower' layer. It gets upsampled (spatial size increases) and its depth (channels) halves
to match the depth of inputs1, before being concatenated in the depth dimension.
"""
if self.is_deconv:
outputs2 = self.up(inputs2)
else:
# scale_factor is the multiplier for spatial size
outputs2 = F.interpolate(inputs2, scale_factor=2, mode='bilinear', align_corners=True)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], dim=1))