Add finetuning and training code.
This commit is contained in:
Родитель
19cd9dbb4c
Коммит
f19568dc53
|
@ -130,3 +130,6 @@ dmypy.json
|
|||
|
||||
# Mac file system
|
||||
.DS_Store
|
||||
|
||||
# IDE / PyCharm
|
||||
.idea/
|
||||
|
|
|
@ -0,0 +1,545 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.core.interactiveshell import InteractiveShell\n",
|
||||
"InteractiveShell.ast_node_interactivity = 'all' # default is ‘last_expr’\n",
|
||||
"\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append('/Source/repos/GitHub_MSFT/landcover-orinoquiaa')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import pickle\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"import rasterio\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
|
||||
"from sklearn.calibration import calibration_curve\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from geospatial.visualization.raster_label_visualizer import RasterLabelVisualizer\n",
|
||||
"\n",
|
||||
"plt.rcParams['figure.figsize'] = (10.0, 10.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from data/tile_and_mask.py - which needs to be run in the Solaris env\n",
|
||||
"\n",
|
||||
"def get_lon_lat_from_tile_name(tile_name):\n",
|
||||
" \"\"\"Returns _lon_lat\"\"\"\n",
|
||||
" parts = tile_name.split('_')\n",
|
||||
" lon_lat = f'_{parts[-2]}_{parts[-1].split(\".tif\")[0]}'\n",
|
||||
" return lon_lat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Evaluate a tiles of model predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"viz_util= RasterLabelVisualizer('../constants/class_lists/wcs_coarse_label_map.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Empty of data',\n",
|
||||
" 'Urban and infrastructure',\n",
|
||||
" 'Agriculture',\n",
|
||||
" 'Arboreal and forestry crops',\n",
|
||||
" 'Pasture',\n",
|
||||
" 'Vegetation',\n",
|
||||
" 'Forest',\n",
|
||||
" 'Savanna',\n",
|
||||
" 'Sand, rocks and bare land',\n",
|
||||
" 'Unavailable',\n",
|
||||
" 'Swamp',\n",
|
||||
" 'Water',\n",
|
||||
" 'Seasonal savanna',\n",
|
||||
" 'Seasonally flooded savanna']"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"label_names = sorted(viz_util.num_to_name.items(), key=lambda x: int(x[0]))\n",
|
||||
"label_names = [i[1] for i in label_names]\n",
|
||||
"label_names"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output_paths = '/Data/WCS_land_use/delivered/20200701/results_coarse_baseline_201314'\n",
|
||||
"\n",
|
||||
"mask_paths = '/Data/WCS_land_use/train_full_region_median/tiles_masks_coarse'\n",
|
||||
"\n",
|
||||
"eval_saved_to = '/Data/WCS_land_use/train_full_region_median/result_val_analysis_coarse_baseline'\n",
|
||||
"\n",
|
||||
"num_classes = viz_util.num_classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tile_accuracies = {}\n",
|
||||
"\n",
|
||||
"cm = np.zeros((num_classes, num_classes), dtype=np.int64)\n",
|
||||
"\n",
|
||||
"true_counts = np.zeros((num_classes), dtype=np.int64)\n",
|
||||
"pred_counts = np.zeros((num_classes), dtype=np.int64)\n",
|
||||
"\n",
|
||||
"classes_present_in_gt = set()\n",
|
||||
"\n",
|
||||
"for output_tile_fn in os.listdir(output_paths):\n",
|
||||
" if not output_tile_fn.endswith('.tif'):\n",
|
||||
" continue\n",
|
||||
"# for output_tile_fn in ['res_wcs_orinoquia_sr_median_2013_2014-0000000000-0000022272_-68.962_6.593.tif']:\n",
|
||||
" \n",
|
||||
" output_tile_path = os.path.join(output_paths, output_tile_fn)\n",
|
||||
" out_reader = rasterio.open(output_tile_path)\n",
|
||||
" output_tile = np.array(Image.open(output_tile_path), dtype=np.uint8)\n",
|
||||
" \n",
|
||||
" # mask_-68.423_6.054.png\n",
|
||||
" lon_lat = get_lon_lat_from_tile_name(output_tile_path)\n",
|
||||
" label_mask_path = os.path.join(mask_paths, f'mask{lon_lat}.tif')\n",
|
||||
" label_mask = np.array(Image.open(label_mask_path), dtype=np.uint8)\n",
|
||||
" \n",
|
||||
" output = output_tile.flatten()\n",
|
||||
" labels = label_mask.flatten()\n",
|
||||
" \n",
|
||||
" # mask out where labels is 0, which is outside of boundary of region\n",
|
||||
" # and also where output is 0, which is where no imagery is available on the tile\n",
|
||||
" # now get rid of such entries\n",
|
||||
" labels_masked = labels * (output != 0)\n",
|
||||
" no_label_entries = np.where(labels_masked == 0)\n",
|
||||
" \n",
|
||||
" labels = np.delete(labels, no_label_entries)\n",
|
||||
" output = np.delete(output, no_label_entries)\n",
|
||||
" \n",
|
||||
" classes_present_in_gt.update(labels)\n",
|
||||
" \n",
|
||||
" tile_accuracy = accuracy_score(labels, output, normalize=True)\n",
|
||||
" tile_accuracies[lon_lat] = tile_accuracy\n",
|
||||
"\n",
|
||||
" for y_true, y_pred in tqdm(zip(labels, output)):\n",
|
||||
" cm[y_true][y_pred] += 1\n",
|
||||
" true_counts[y_true] += 1\n",
|
||||
" pred_counts[y_pred] += 1\n",
|
||||
" \n",
|
||||
"overall_accuracy = sum(tile_accuracies.values())/len(tile_accuracies)\n",
|
||||
"print(f'Overall accuracy is {overall_accuracy}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tile_accuracies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Accurate distribution of land types\n",
|
||||
"The shapefile's area attribute did not look correct"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"true_counts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Confusion matrix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# normalize by ground truth label counts\n",
|
||||
"cm_norm = np.zeros((num_classes, num_classes), dtype=np.float)\n",
|
||||
"for y_true in range(num_classes):\n",
|
||||
" for y_pred in range(num_classes):\n",
|
||||
" if true_counts[y_true] == 0:\n",
|
||||
" cm_norm[y_true][y_pred] = 0.0\n",
|
||||
" else:\n",
|
||||
" cm_norm[y_true][y_pred] = cm[y_true][y_pred] / true_counts[y_true]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# docs: https://matplotlib.org/3.1.3/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py\n",
|
||||
"\n",
|
||||
"cm_to_plot = cm_norm\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"fig = plt.figure(figsize=(10, 10), dpi=200) # set dpi to 300 to look good\n",
|
||||
"ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])\n",
|
||||
"im = ax.matshow(cm_to_plot, cmap=plt.cm.YlGnBu)\n",
|
||||
"_ = ax.set_xticks(np.array(range(num_classes)))\n",
|
||||
"_ = ax.set_yticks(np.array(range(num_classes)))\n",
|
||||
"_ = ax.set_xticklabels(label_names)\n",
|
||||
"_ = ax.set_yticklabels(label_names)\n",
|
||||
"_ = ax.set_ylabel('Provided labels')\n",
|
||||
"_ = ax.set_xlabel('Predicted by model')\n",
|
||||
"ax.xaxis.tick_top()\n",
|
||||
"\n",
|
||||
"# Rotate the tick labels\n",
|
||||
"_ = plt.setp(ax.get_xticklabels(), rotation=90)\n",
|
||||
"\n",
|
||||
"_ = ax.set_xticks(np.array(range(num_classes)) - 0.5, minor=True)\n",
|
||||
"_ = ax.set_yticks(np.array(range(num_classes)) - 0.5, minor=True)\n",
|
||||
"ax.grid(which='minor', color='white', linestyle='-', linewidth=3)\n",
|
||||
"\n",
|
||||
"cbar = ax.figure.colorbar(im, ax=ax)\n",
|
||||
"\n",
|
||||
"# no border\n",
|
||||
"for edge, spine in ax.spines.items():\n",
|
||||
" spine.set_visible(False)\n",
|
||||
"\n",
|
||||
"# right-click save - layout isn't right otherwise\n",
|
||||
" \n",
|
||||
"#fig.tight_layout()\n",
|
||||
"#plt.savefig('/Users/siyuyang/Source/temp_data/WCS_land_use/train_200218/result_val/evaluation/cm.png')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cm_norm[32][32]\n",
|
||||
"cm_norm[33][33]\n",
|
||||
"\n",
|
||||
"cm_norm[30, 33] # row, col - ground truth, predicted\n",
|
||||
"cm_norm[33][30] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Side by side label and output counts, in log scale"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Per-class accuracy, precision and recall"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# per-class accuracy\n",
|
||||
"total_obs = cm.sum()\n",
|
||||
"\n",
|
||||
"per_class_accuracy = {}\n",
|
||||
"per_class_recall = {}\n",
|
||||
"per_class_precision = {}\n",
|
||||
"\n",
|
||||
"for cls in range(num_classes):\n",
|
||||
" if cls not in classes_present_in_gt:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" true_pos = cm[cls, cls]\n",
|
||||
" \n",
|
||||
" true_neg = total_obs - cm[cls, :].sum() - cm[:, cls].sum() + true_pos\n",
|
||||
" \n",
|
||||
" false_pos = cm[:, cls].sum() - true_pos\n",
|
||||
" \n",
|
||||
" false_neg = cm[cls, :].sum() - true_pos\n",
|
||||
" \n",
|
||||
" per_class_accuracy[cls] = (true_pos + true_neg) / total_obs\n",
|
||||
" \n",
|
||||
" per_class_precision[cls] = true_pos / (true_pos + false_pos)\n",
|
||||
" \n",
|
||||
" per_class_recall[cls] = true_pos / (true_pos + false_neg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print('Category, Accuracy, Precision, Recall')\n",
|
||||
"for cls, acc in per_class_accuracy.items():\n",
|
||||
" prec = per_class_precision[cls]\n",
|
||||
" recall = per_class_recall[cls]\n",
|
||||
" print(f'{cls} {viz_util.num_to_name[str(cls)]},{acc},{prec},{recall}')\n",
|
||||
" \n",
|
||||
"# paste the result into Pages, and fix the row for \"27 Lakes, lagoons, and natural cienaga\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Since the dataset is so unbalanced (mostly 12 - dense forest) and accuracy counts \"true negatives\" as a win, this is not a good measure of performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Save the evaluation findings - not yet done"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"saved = {\n",
|
||||
" 'overall_accuracy': overall_accuracy,\n",
|
||||
" 'per_class_accuracy': per_class_accuracy,\n",
|
||||
" # 'calibration_summary': calibration_summary\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"with open(eval_saved_to, 'w') as f:\n",
|
||||
" json.dump(saved, f, indent=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Is the model well-calibrated?\n",
|
||||
"\n",
|
||||
"We can also just record a 2D shape - each cell is the confidence of the most confident class?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(output_scores_path, 'rb') as f:\n",
|
||||
" dict_scores = pickle.load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classes_to_plot = [0, 11, 12, 17, 19, 26, 32]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_true = defaultdict(list)\n",
|
||||
"y_prob = defaultdict(list)\n",
|
||||
"\n",
|
||||
"for window, chip_scores in tqdm(dict_scores.items()):\n",
|
||||
" # rasterio window is (col_off x, row_off y, width, height)\n",
|
||||
" \n",
|
||||
" chip_scores = chip_scores.squeeze() # chip_scores have shape (1, 33, 256, 256)\n",
|
||||
" chip_scores = chip_scores.reshape((33, -1))\n",
|
||||
"\n",
|
||||
" chip_labels = label_mask[window[0]:window[0] + 256, window[1]:window[1] + 256]\n",
|
||||
" chip_labels = chip_labels.reshape((1, -1))\n",
|
||||
" # we pad 0 to the end of chips after the tile ends\n",
|
||||
" chip_labels = np.pad(chip_labels, ((0, 0), (0, 256*256 - chip_labels.shape[1]))).squeeze()\n",
|
||||
" \n",
|
||||
" assert chip_scores.shape == (33, 256*256), chip_scores.shape\n",
|
||||
" assert chip_labels.shape == (256*256,), chip_labels.shape\n",
|
||||
" \n",
|
||||
" for cls in classes_to_plot:\n",
|
||||
" cls_y_true = chip_labels == cls\n",
|
||||
" cls_y_prob = chip_scores[cls]\n",
|
||||
" assert len(list(cls_y_true)) == len(list(cls_y_prob)), '{}, {}'.format(\n",
|
||||
" len(list(cls_y_true)), len(list(cls_y_prob))\n",
|
||||
" )\n",
|
||||
" y_true[cls].extend(list(cls_y_true))\n",
|
||||
" y_prob[cls].extend(list(cls_y_prob))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(y_true[12])\n",
|
||||
"len(y_prob[12])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_ = plt.plot([0.0, 1.0], color='grey', linestyle=':')\n",
|
||||
"\n",
|
||||
"for cls in classes_to_plot:\n",
|
||||
" _ = frac_positives, mean_prob_in_bin = calibration_curve(y_true[cls], y_prob[cls], n_bins=10)\n",
|
||||
" _ = plt.plot(mean_prob_in_bin, frac_positives, label=cls, color=viz_util.num_to_color[str(cls)])\n",
|
||||
"_ = plt.legend()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mean_prob_in_bin\n",
|
||||
"frac_positives"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Expected number of pixels for the whole validation area"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"probability_sum = np.zeros(num_classes, dtype=np.float)\n",
|
||||
"\n",
|
||||
"for window, chip_scores in dict_scores.items():\n",
|
||||
" # print(chip_scores.shape) # (1, 33, 256, 256)\n",
|
||||
" chip_scores = chip_scores.squeeze()\n",
|
||||
" chip_scores = chip_scores.sum(axis=(1, 2)) # height and width dims\n",
|
||||
" probability_sum += chip_scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"calibration_summary = {}\n",
|
||||
"\n",
|
||||
"for cls, (prob_sum, label_sum) in enumerate(zip(probability_sum, true_counts)):\n",
|
||||
" calibration_summary[cls] = {\n",
|
||||
" 'prediction_probability_sum': prob_sum,\n",
|
||||
" 'label_sum': int(label_sum)\n",
|
||||
" }\n",
|
||||
" print('Class {} - {}, prob_sum {}, label_sum {}'.format(cls, viz_util.num_to_name[str(cls)], round(prob_sum), label_sum))\n",
|
||||
" if label_sum > 0:\n",
|
||||
" print(' diff is {}%'.format(100 * round((prob_sum - label_sum)/label_sum, 3)))\n",
|
||||
" calibration_summary[cls]['difference_wrt_label_sum'] = (prob_sum - label_sum)/label_sum"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:wcs] *",
|
||||
"language": "python",
|
||||
"name": "conda-env-wcs-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Implementation of ModelSessionAbstract for the Orinoquia land cover mapping project. Heavily references
|
||||
wildlife-conservation-society.orinoquia-land-use/training/inference.py
|
||||
|
||||
"wildlife-conservation-society.orinoquia-land-use" and "ai4eutils" need to be on the PYTHONPATH.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from web_tool.ModelSessionAbstract import ModelSession
|
||||
|
||||
LOGGER = logging.getLogger("server")
|
||||
|
||||
|
||||
class TorchFineTuningOrinoquia(ModelSession):
|
||||
|
||||
def __init__(self, gpu_id, **kwargs):
|
||||
# setting up device to use
|
||||
LOGGER.debug(f"TorchFineTuningOrinoquia init, gpu_id is {gpu_id}")
|
||||
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id is not None else "cpu")
|
||||
|
||||
# load experiment configuration as a module
|
||||
config_module_path = kwargs["config_module_path"]
|
||||
try:
|
||||
module_name = "config"
|
||||
spec = importlib.util.spec_from_file_location(module_name, config_module_path)
|
||||
self.config = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = self.config
|
||||
spec.loader.exec_module(self.config)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to import experiment and model configuration. Exception: {e}")
|
||||
sys.exit(1)
|
||||
LOGGER.info(f"config is for experiment {self.config.experiment_name}")
|
||||
|
||||
# check that the necessary fields are present in the config
|
||||
assert self.config.num_classes > 1 # this is the number of classes the initial model supports
|
||||
self.num_classes = self.config.num_classes # self.num_classes can evolve during a session
|
||||
|
||||
assert self.config.chip_size > 1
|
||||
assert self.config.feature_scale in [1, 2]
|
||||
|
||||
chip_size = self.config.chip_size
|
||||
self.prediction_window_size = self.config.prediction_window_size if self.config.prediction_window_size else 128
|
||||
self.prediction_window_offset = int((chip_size - self.prediction_window_size) / 2)
|
||||
print((f"Using chip_size {chip_size} and window_size {self.prediction_window_size}. "
|
||||
f"So window_offset is {self.prediction_window_offset}"))
|
||||
|
||||
# obtain the model that the config has initialized
|
||||
self.checkpoint_path = kwargs["fn"]
|
||||
assert os.path.exists(self.checkpoint_path), f"Checkpoint at {self.checkpoint_path} does not exist."
|
||||
self.model = self.config.model
|
||||
self._init_model()
|
||||
|
||||
# other instance variables
|
||||
self._last_tile = None
|
||||
|
||||
# recording the current feature map (before final layer), and the corrections made
|
||||
self.current_features = None
|
||||
self.corr_features = []
|
||||
self.corr_labels = []
|
||||
|
||||
def _init_model(self):
|
||||
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint["state_dict"])
|
||||
print(f"Using checkpoint at epoch {checkpoint['epoch']}, step {checkpoint['step']}, "
|
||||
f"val accuracy is {checkpoint.get('val_acc', 'Not Available')}")
|
||||
self.model = self.model.to(device=self.device)
|
||||
self.model.eval()
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
return self._last_tile
|
||||
|
||||
def run(self, tile, inference_mode=False):
|
||||
assert tile.shape[2] == 11 # 10 Landsat bands + DEM
|
||||
height = tile.shape[0] # tile is of dims (height, width, channels)
|
||||
width = tile.shape[1]
|
||||
self._last_tile = tile
|
||||
prediction_window_offset = self.prediction_window_offset
|
||||
prediction_window_size = self.prediction_window_size
|
||||
|
||||
# apply the preprocessing of bands to the tile
|
||||
data_array = self.config.preprocess_tile(tile) # dim is (6, H, W)
|
||||
|
||||
LOGGER.debug(f"run, tile shape is {tile.shape}")
|
||||
LOGGER.debug(f"run, data_array shape is {data_array.shape}")
|
||||
|
||||
# pad by mirroring at the edges to predict only on the center crop
|
||||
data_array = np.pad(data_array,
|
||||
[
|
||||
(0, 0), # only pad height and width
|
||||
(prediction_window_offset, prediction_window_offset), # height / rows
|
||||
(prediction_window_offset, prediction_window_offset) # width / cols
|
||||
],
|
||||
mode="symmetric")
|
||||
|
||||
# form batches
|
||||
batch_size = self.config.batch_size if self.config.batch_size is not None else 32
|
||||
batch = []
|
||||
batch_indices = [] # cache these to save recalculating when filling in model predictions
|
||||
|
||||
chip_size = self.config.chip_size
|
||||
|
||||
num_rows = math.ceil(height / prediction_window_size)
|
||||
num_cols = math.ceil(width / prediction_window_size)
|
||||
|
||||
for col_idx in range(num_cols):
|
||||
col_start = col_idx * prediction_window_size
|
||||
col_end = col_start + chip_size
|
||||
|
||||
for row_idx in range(num_rows):
|
||||
row_start = row_idx * prediction_window_size
|
||||
row_end = row_start + chip_size
|
||||
|
||||
chip = data_array[:, row_start:row_end, col_start: col_end]
|
||||
# pad to (chip_size, chip_size)
|
||||
chip = np.pad(chip,
|
||||
[(0, 0), (0, chip_size - chip.shape[1]), (0, chip_size - chip.shape[2])])
|
||||
|
||||
# processing it as the dataset loader _get_chip does
|
||||
chip = np.nan_to_num(chip, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
sat_mask = chip[0].squeeze() > 0.0 # mask out DEM data where there's no satellite data
|
||||
chip = chip * sat_mask
|
||||
|
||||
batch.append(chip)
|
||||
|
||||
valid_row_end = row_start + min(prediction_window_size, height - row_idx * prediction_window_size)
|
||||
valid_col_end = col_start + min(prediction_window_size, width - col_idx * prediction_window_size)
|
||||
batch_indices.append(
|
||||
(row_start, valid_row_end, col_start, valid_col_end)) # explicit to be less confusing
|
||||
batch = np.array(batch) # (num_chips, channels, height, width)
|
||||
|
||||
# score chips in batches
|
||||
model_output = []
|
||||
model_features = [] # same spatial dims as model_output, but has 64 or 32 channels
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(batch), batch_size):
|
||||
t_batch = batch[i:i + batch_size]
|
||||
t_batch = torch.from_numpy(t_batch).to(self.device)
|
||||
|
||||
scores, features = self.model.forward(t_batch,
|
||||
return_features=True) # these are scores before the final softmax
|
||||
softmax_scores = torch.nn.functional.softmax(scores,
|
||||
dim=1).cpu().numpy() # (batch_size, num_classes, H, W)
|
||||
|
||||
softmax_scores = np.transpose(softmax_scores, axes=(0, 2, 3, 1)) # (batch_size, H, W, num_classes)
|
||||
# only save the center crop
|
||||
softmax_scores = softmax_scores[
|
||||
:,
|
||||
prediction_window_offset:prediction_window_offset + prediction_window_size,
|
||||
prediction_window_offset:prediction_window_offset + prediction_window_size,
|
||||
:
|
||||
]
|
||||
model_output.append(softmax_scores)
|
||||
|
||||
features = features.cpu().numpy()
|
||||
features = np.transpose(features, axes=(0, 2, 3, 1)) # (batch_size, H, W, num_features)
|
||||
features = features[
|
||||
:,
|
||||
prediction_window_offset:prediction_window_offset + prediction_window_size,
|
||||
prediction_window_offset:prediction_window_offset + prediction_window_size,
|
||||
:
|
||||
]
|
||||
model_features.append(features)
|
||||
|
||||
model_output = np.concatenate(model_output, axis=0)
|
||||
model_features = np.concatenate(model_features, axis=0)
|
||||
|
||||
# fill in the output array; self.num_classes is the number of classes supported by the current model
|
||||
output = np.zeros((height, width, self.num_classes), dtype=np.float32)
|
||||
for i, (row_start, row_end, col_start, col_end) in enumerate(batch_indices):
|
||||
h = row_end - row_start
|
||||
w = col_end - col_start
|
||||
output[row_start:row_end, col_start:col_end, :] = model_output[i, :h, :w, :]
|
||||
print(f"--- Orinoquia ModelSession, output[-1, -1, :] is {output[-1, -1, :]}")
|
||||
|
||||
num_features = 64 if self.config.feature_scale == 1 else 32
|
||||
output_features = np.zeros((height, width, num_features), dtype=np.float32) # float32 used during training too
|
||||
for i, (row_start, row_end, col_start, col_end) in enumerate(batch_indices):
|
||||
h = row_end - row_start
|
||||
w = col_end - col_start
|
||||
output_features[row_start:row_end, col_start:col_end, :] = model_features[i, :h, :w, :]
|
||||
print(f"--- Orinoquia ModelSession, output_features[-1, -1, :] is {output_features[-1, -1, :]}")
|
||||
# save the features
|
||||
self.current_features = output_features
|
||||
|
||||
return output
|
||||
|
||||
def add_sample_point(self, row, col, class_idx):
|
||||
self.corr_labels.append(class_idx)
|
||||
self.corr_features.append(self.current_features[row, col, :])
|
||||
|
||||
print(f"After add_sample_point, corr_labels length is {len(self.corr_labels)}")
|
||||
return {
|
||||
"message": f"Training sample for class {class_idx} added",
|
||||
"success": True
|
||||
}
|
||||
|
||||
def undo(self):
|
||||
if len(self.corr_features) > 0:
|
||||
self.corr_features = self.corr_features[:-1]
|
||||
self.corr_labels = self.corr_labels[:-1]
|
||||
return {
|
||||
"message": "Undid training sample",
|
||||
"success": True
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"message": "Nothing to undo",
|
||||
"success": False
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
self._init_model()
|
||||
self.num_classes = self.config.num_classes
|
||||
self.corr_features = []
|
||||
self.corr_labels = []
|
||||
return {
|
||||
"message": f"Model is reset and support {self.num_classes} classes",
|
||||
"success": True
|
||||
}
|
||||
|
||||
def retrain(self, train_steps=100, learning_rate=1e-3):
|
||||
print_every = 10
|
||||
|
||||
# if any new classes have been added, update self.num_classes and re-initialize final layer
|
||||
# class start from 0
|
||||
if max(self.corr_labels) >= self.num_classes:
|
||||
self.num_classes = max(self.corr_labels) + 1
|
||||
self.model.change_num_classes(self.num_classes)
|
||||
self.model.final.to(device=self.device)
|
||||
LOGGER.debug(f"New classes have been added (total {self.num_classes}) and final layer re-initialized.")
|
||||
|
||||
# all corrections since the last reset are used
|
||||
batch_x = torch.from_numpy(np.array(self.corr_features)).float().to(self.device)
|
||||
batch_y = torch.from_numpy(np.array(self.corr_labels)).to(self.device)
|
||||
|
||||
# make the last layer `final` trainable TODO do we need to do this - default trainable?
|
||||
for param in self.model.final.parameters(): # see UNet implementation in WCS project repo
|
||||
param.requires_grad = True
|
||||
|
||||
optimizer = optim.Adam(self.model.final.parameters(), lr=learning_rate) # only the last layer
|
||||
|
||||
# during re-training, we use equal weight for all classes
|
||||
criterion = nn.CrossEntropyLoss().to(device=self.device)
|
||||
|
||||
self.model.train()
|
||||
for step in range(train_steps):
|
||||
with torch.enable_grad():
|
||||
# forward pass
|
||||
batch_x_reshaped = batch_x.unsqueeze(2).unsqueeze(3)
|
||||
scores = self.model.final.forward(batch_x_reshaped).squeeze(3).squeeze(2)
|
||||
loss = criterion(scores, batch_y)
|
||||
|
||||
# backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # compute gradients
|
||||
optimizer.step() # update parameters
|
||||
|
||||
if step % print_every == 0:
|
||||
preds = scores.argmax(1)
|
||||
accuracy = (batch_y == preds).float().mean()
|
||||
|
||||
print(f'step {step}, loss: {loss.item()}, accuracy: {accuracy.item()}')
|
||||
|
||||
return {
|
||||
"message": f"Fine-tuned model with {len(self.corr_features)} samples for {train_steps} steps",
|
||||
"success": True
|
||||
}
|
||||
|
||||
def save_state_to(self, directory):
|
||||
# number of classes could be different from what's in the config for the initial model
|
||||
return {
|
||||
"message": "Saving not yet implemented",
|
||||
"success": False
|
||||
}
|
||||
|
||||
def load_state_from(self, directory):
|
||||
return {
|
||||
"message": "Saving and loading not yet implemented",
|
||||
"success": False
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
# Finetuning the land cover model interactively
|
||||
|
||||
Files in this folder are configurations and implementations of required classes for finetuning the model interactively in an instance of the [Land Cover Mapping tool](https://github.com/microsoft/landcover) (this repo will be referred to as the `landcover` repo below).
|
||||
|
||||
## Setup
|
||||
Because our implementation of [ModelSession](https://github.com/microsoft/landcover/blob/master/web_tool/ModelSessionAbstract.py) relies on the experiment configuration file (a `.py` file) that produced the model, this repo needs to be on the `PYTHONPATH` when running the Land Cover Mapping tool's server, in additional to the AI for Earth utilities [repo](https://github.com/microsoft/ai4eutils):
|
||||
```
|
||||
export PYTHONPATH="${PYTHONPATH}:/home/boto/wcs/pycharm:/home/boto/lib/ai4eutils"
|
||||
```
|
||||
|
||||
Note that directory names in this repo should not clash with ones in the `landcover` repo.
|
||||
|
||||
Basemap needs to be in the `landcover` repo's root directory for the server to serve the data. We can create a symbolic link to the files stored in a blob storage container:
|
||||
```
|
||||
ln -s /home/boto/wcs/mnt/wcs-orinoquia/images_sr_median/2013_2014_dem/wcs_orinoquia_sr_median_2013_2014_dem.vrt wcs_orinoquia_sr_median_2013_2014_dem.vrt
|
||||
```
|
||||
|
||||
The configuration files for the dataset and model in this folder and `ModelSessionOrinoquia.py` should be copied to the `landcover` repo's `web_tool` directory.
|
||||
|
||||
We also need to make the following changes:
|
||||
- Modify [`worker.py`](https://github.com/microsoft/landcover/blob/master/worker.py) to add a case for the `model_type` `pytorch_landsat`.
|
||||
- Modify `web_tool/DataLoader.py`, changing
|
||||
```python
|
||||
resolution=(x_res, y_res)
|
||||
```
|
||||
to
|
||||
```python
|
||||
resolution=(30, 30) # hardcode Landsat 8 resolution
|
||||
```
|
||||
|
||||
- In `web_tool/js/components.js`, modify these values in `addInferenceWindowSizeSlider` to:
|
||||
```python
|
||||
min: 7680,
|
||||
max: 23040,
|
||||
```
|
||||
- In `web_tool/js/globals.js`, set the following variables:
|
||||
```python
|
||||
var INFERENCE_WINDOW_SIZE = 300;
|
||||
var INFERENCE_WINDOW_SIZE = 7680;
|
||||
```
|
||||
|
||||
|
||||
## Operation
|
||||
|
||||
![Using the Land Cover Mapping tool on Landsat 8 imagery](../visuals/in_the_tool.png)
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"landsat_orinoquia": {
|
||||
"metadata": {
|
||||
"displayName": "Orinoquia Colombia, 2013-2014, Landsat 8",
|
||||
"locationName": null
|
||||
},
|
||||
"dataLayer": {
|
||||
"type": "CUSTOM",
|
||||
"path": "data/imagery/wcs_orinoquia_sr_median_2013_2014_dem.vrt",
|
||||
"padding": 0,
|
||||
"resolution": 30
|
||||
},
|
||||
"basemapLayers": [
|
||||
{
|
||||
"layerName": "Landsat 8 2013-2014",
|
||||
"initialZoom": 8,
|
||||
"url": "data/basemaps/wcs_2013_2014_basemap_exp07/{z}/{x}/{y}.png",
|
||||
"initialLocation": [4.400127, -72.304724],
|
||||
"args": {
|
||||
"attribution": "Georeferenced Image",
|
||||
"tms": true,
|
||||
"maxNativeZoom": 13,
|
||||
"maxZoom": 16,
|
||||
"minZoom": 8,
|
||||
"bounds": [[1.4752987731256113, -67.25191674463355], [7.516203982228033, -74.96450029130249]]
|
||||
}
|
||||
}
|
||||
],
|
||||
"shapeLayers": [
|
||||
{
|
||||
"shapesFn": "data/zones/Orinoquia_outline.geojson",
|
||||
"zoneNameKey": "NAME_1",
|
||||
"name": "Region outline"
|
||||
}
|
||||
],
|
||||
"validModels": ["landsat_orinoquia_coarse"]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
{
|
||||
"landsat_orinoquia_coarse": {
|
||||
"metadata": {
|
||||
"displayName": "Landsat Orinoquia Coarse-class"
|
||||
},
|
||||
"model": {
|
||||
"type": "pytorch_landsat",
|
||||
"fineTuneLayer": -1,
|
||||
"numParameters": -1,
|
||||
"inputShape": [
|
||||
256,
|
||||
256,
|
||||
6
|
||||
],
|
||||
"config_module_path": "/home/boto/wcs/pycharm/training_wcs/experiments/coarse_baseline/coarse_baseline_config_refactored.py",
|
||||
"fn": "/home/boto/wcs/mnt/wcs-orinoquia/useful_checkpoints/wcs_coarse_baseline/wcs_coarse_baseline_epoch93_best.pth.tar"
|
||||
},
|
||||
"classes": [
|
||||
{
|
||||
"name": "Empty of data",
|
||||
"color": "#000000"
|
||||
},
|
||||
{
|
||||
"name": "Urban and infrastructure",
|
||||
"color": "#d3d3d3"
|
||||
},
|
||||
{
|
||||
"name": "Agriculture",
|
||||
"color": "#ffc0cb"
|
||||
},
|
||||
{
|
||||
"name": "Arboreal and forestry crops",
|
||||
"color": "#008080"
|
||||
},
|
||||
{
|
||||
"name": "Pasture",
|
||||
"color": "#fa8072"
|
||||
},
|
||||
{
|
||||
"name": "Vegetation",
|
||||
"color": "#daa520"
|
||||
},
|
||||
{
|
||||
"name": "Forest",
|
||||
"color": "#8fbc8f"
|
||||
},
|
||||
{
|
||||
"name": "Savanna",
|
||||
"color": "#ffd700"
|
||||
},
|
||||
{
|
||||
"name": "Sand, rocks and bare land",
|
||||
"color": "#ffebcd"
|
||||
},
|
||||
{
|
||||
"name": "Unavailable",
|
||||
"color": "#f5f5f5"
|
||||
},
|
||||
{
|
||||
"name": "Swamp",
|
||||
"color": "#556b2f"
|
||||
},
|
||||
{
|
||||
"name": "Water",
|
||||
"color": "#00bfff"
|
||||
},
|
||||
{
|
||||
"name": "Seasonal savanna",
|
||||
"color": "#f0e68c"
|
||||
},
|
||||
{
|
||||
"name": "Seasonally flooded savanna",
|
||||
"color": "#d8bfd8"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
"""
|
||||
Configurations for the 20200505_mini_baseline experiment
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from viz_utils import VizUtils
|
||||
from training_wcs.scripts.models.unet.unet import Unet # models.unet.unet import Unet
|
||||
|
||||
|
||||
experiment_name = 'wcs_baseline_202005'
|
||||
|
||||
eval_mode = True
|
||||
|
||||
# I/O -------------------------------------------------------------------------------------------------
|
||||
if not eval_mode:
|
||||
aml_data_ref = os.environ.get('AZUREML_DATAREFERENCE_wcsorinoquia', '')
|
||||
assert len(aml_data_ref) > 0, 'Reading aml_data_ref from environment vars resulted in empty string.'
|
||||
|
||||
data_dir = os.path.join(aml_data_ref, 'tiles', 'full_sr_median_2013_2014')
|
||||
assert 'tiles' in os.listdir(data_dir)
|
||||
assert 'tiles_masks' in os.listdir(data_dir)
|
||||
|
||||
# a dir with experiment_name will be created in here and checkpoints are saved here
|
||||
# set as './outputs' for AML to stream to this Run's folder
|
||||
out_dir = '/boto_disk_0/wcs/20190518_feature_scale_1/outputs' # './outputs'
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# TF events go here. Set it as './logs' if using AML so they can be streamed
|
||||
log_dir = '/boto_disk_0/wcs/20190518_feature_scale_1/logs' # './logs'
|
||||
|
||||
# train/val splits are stored in
|
||||
# on AML, this needs to be relative to the source_directory level
|
||||
splits_file = './constants/splits/full_sr_median_2013_2014_splits.json' # '../training_wcs/scripts/constants/splits/full_sr_median_2013_2014_splits.json'
|
||||
|
||||
with open(splits_file) as f:
|
||||
splits = json.load(f)
|
||||
train_split = splits['train']
|
||||
val_split = splits['val']
|
||||
print(f'Train set has {len(train_split)} tiles; val set has {len(val_split)} tiles.')
|
||||
|
||||
|
||||
# Training ----------------------------------------------------------------------------------------------
|
||||
|
||||
evaluate_only = False # Only evaluate the model on the val set once
|
||||
|
||||
# this is the *total* epoch; if restarting from a checkpoint, be sure to add the additional number of epochs
|
||||
# to fine-tune on top of the original value of this var
|
||||
total_epochs = 1000
|
||||
print_every = 100 # print every how many steps; just the minibatch loss and accuracy
|
||||
assert print_every >= 1, 'print_every needs to be greater than or equal 1'
|
||||
|
||||
starting_checkpoint_path = None
|
||||
|
||||
init_learning_rate = 1e-4
|
||||
|
||||
batch_size = 24
|
||||
|
||||
# probability a chip is kept in the sample while sampling train and val chips at the start of training_wcs
|
||||
# this should be smaller if we now have more training_wcs examples
|
||||
# prob_keep_chip = 0.006
|
||||
# 133 training_wcs tiles * 64 chips per tile = 8512 chips. Should keep every 177 if visualizing 48, which is 0.0056
|
||||
keep_every = 30 # a balance between not covering all training_wcs tiles vs iterating through val tiles too many times
|
||||
|
||||
num_chips_to_viz = 48
|
||||
|
||||
|
||||
# Hardware and framework --------------------------------------------------------------------------------
|
||||
dtype = torch.float32
|
||||
|
||||
|
||||
# Model -------------------------------------------------------------------------------------------------
|
||||
|
||||
num_classes = 34 # empty plus the 33 WCS classes; this is the number of output nodes
|
||||
|
||||
num_in_channels = 5 # 2, 3, 6, 7, NDVI
|
||||
|
||||
# the smallest number of filters is 64 when feature_scale is 1, and it is 32 when feature_scale is 2
|
||||
feature_scale = 1
|
||||
|
||||
is_deconv = True # True to use transpose convolution filters to learn upsampling; otherwise upsampling is not learnt
|
||||
|
||||
is_batchnorm = True
|
||||
|
||||
model = Unet(feature_scale=feature_scale,
|
||||
n_classes=num_classes,
|
||||
in_channels=num_in_channels,
|
||||
is_deconv=is_deconv,
|
||||
is_batchnorm=is_batchnorm)
|
||||
|
||||
|
||||
# Data ---------------------------------------------------------------------------------------------------
|
||||
|
||||
common_classes = [
|
||||
12
|
||||
]
|
||||
less_common_classes = [
|
||||
32, 33
|
||||
]
|
||||
|
||||
weights = []
|
||||
for i in range(num_classes):
|
||||
if i in common_classes:
|
||||
weights.append(1)
|
||||
elif i in less_common_classes:
|
||||
weights.append(2)
|
||||
else:
|
||||
weights.append(10)
|
||||
loss_weights = torch.FloatTensor(weights) # None if no weighting for classes
|
||||
print('Weights on loss per class used:')
|
||||
print(loss_weights)
|
||||
|
||||
# how many subprocesses to use for data loading
|
||||
# None for now - need to modify datasets.py to use
|
||||
data_loader_num_workers = None
|
||||
|
||||
# not available in IterableDataset data_loader_shuffle = True # True to have the data reshuffled at every epoch
|
||||
|
||||
chip_size = 256
|
||||
|
||||
|
||||
# based on min and max values from the sample tile
|
||||
# wcs_orinoquia_sr_median_2013_2014-0000007424-0000007424_-71.347_4.593.tif in training_wcs set
|
||||
# bands 4 and 5 are combined to get the NDVI, so the normalization params for 4 and 5 are
|
||||
# not used during training_wcs data generation, only for visualization (actually not yet used for viz either).
|
||||
bands_normalization_params = {
|
||||
# these are the min and max to clip to for the band
|
||||
'min': {
|
||||
2: 0,
|
||||
3: 0,
|
||||
4: 0,
|
||||
5: 0,
|
||||
6: 0,
|
||||
7: 0
|
||||
},
|
||||
'max': {
|
||||
2: 700,
|
||||
3: 1500,
|
||||
4: 1500,
|
||||
5: 5000,
|
||||
6: 5000,
|
||||
7: 3000
|
||||
},
|
||||
'gamma': { # all the same in this experiment with value 1 which means no effect
|
||||
2: 1.0,
|
||||
3: 1.0,
|
||||
4: 1.0,
|
||||
5: 1.0,
|
||||
6: 1.0,
|
||||
7: 1.0
|
||||
}
|
||||
}
|
||||
|
||||
viz_util = VizUtils()
|
||||
|
||||
def get_chip(tile_reader, chip_window):
|
||||
"""
|
||||
|
||||
Returns:
|
||||
A numpy array of dims (5, H, W)
|
||||
"""
|
||||
normal_bands = [2, 3, 6, 7] # bands to be used without calculating other indices e.g. NDVI
|
||||
bands_to_stack = []
|
||||
for b in normal_bands:
|
||||
# getting one band at a time because min, max and gamma may be different
|
||||
band = viz_util.show_landsat8_tile(
|
||||
tile_reader,
|
||||
bands=[b], # pass in a list to get the batch dimension in the results
|
||||
window=chip_window,
|
||||
band_min=bands_normalization_params['min'][b],
|
||||
band_max=bands_normalization_params['max'][b],
|
||||
gamma=bands_normalization_params['gamma'][b],
|
||||
return_array=True
|
||||
)
|
||||
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
|
||||
|
||||
ndvi = viz_util.show_landsat8_ndvi(tile_reader, window=chip_window) # 2D, dtype is float32
|
||||
bands_to_stack.append(ndvi)
|
||||
|
||||
stacked = np.stack(bands_to_stack)
|
||||
|
||||
if stacked.shape != (5, chip_size, chip_size):
|
||||
# default pad constant value is 0
|
||||
stacked = np.pad(stacked,
|
||||
[(0, 0), (0, chip_size - stacked.shape[1]), (0, chip_size - stacked.shape[2])])
|
||||
|
||||
assert stacked.shape == (5, chip_size, chip_size), f'Landsat chip has wrong shape: {stacked.shape}, should be (5, h, w)'
|
||||
|
||||
# prepare the chip for display
|
||||
chip_for_display = viz_util.show_landsat8_tile(tile_reader,
|
||||
window=chip_window,
|
||||
band_max=3000, # what looks good for RGB
|
||||
gamma=0.5,
|
||||
return_array=True)
|
||||
return stacked, chip_for_display
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
"""
|
||||
First experiment using 6 channels (2, 3, 6, 7, NDVI, elevation) with the 13 + 1 coarse categories
|
||||
mapped June 29 2020.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# ai4eutils needs to be on the PYTHONPATH
|
||||
from geospatial.enums import ExperimentConfigMode
|
||||
from geospatial.visualization.imagery_visualizer import ImageryVisualizer
|
||||
from geospatial.visualization.raster_label_visualizer import RasterLabelVisualizer
|
||||
|
||||
from training_wcs.scripts.models.unet.unet import Unet
|
||||
from training_wcs.scripts.utils.data_transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
|
||||
from training_wcs.scripts.utils.datasets import SingleShardChipsDataset
|
||||
|
||||
experiment_name = 'wcs_coarse_baseline'
|
||||
|
||||
config_mode = ExperimentConfigMode.SCORING
|
||||
|
||||
# I/O -------------------------------------------------------------------------------------------------
|
||||
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.TRAINING]:
|
||||
data_shard_dir = '/boto_disk_0/wcs_data/shards/full_sr_median_2013_2014_elevation'
|
||||
|
||||
# a dir with experiment_name will be created in here and checkpoints are saved here
|
||||
# set as './outputs' for AML to stream to this Run's folder
|
||||
out_dir = f'/boto_disk_0/wcs/{experiment_name}/outputs'
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# TF events go here. Set it as './logs' if using AML so they can be streamed
|
||||
log_dir = f'/boto_disk_0/wcs/{experiment_name}/logs' # './logs'
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# for scoring script and make_chip_shards
|
||||
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.SCORING]:
|
||||
prediction_window_size = 128
|
||||
|
||||
label_viz = RasterLabelVisualizer(
|
||||
label_map='/home/boto/wcs/pycharm/constants/class_lists/wcs_coarse_label_map.json')
|
||||
|
||||
data_dir = '/boto_disk_0/wcs_data' # which contains images_srtm
|
||||
|
||||
# Training ----------------------------------------------------------------------------------------------
|
||||
|
||||
# this is the *total* epoch; if restarting from a checkpoint, be sure to add the additional number of epochs
|
||||
# to fine-tune on top of the original value of this var
|
||||
total_epochs = 500
|
||||
print_every = 150 # print every how many steps; just the minibatch loss and accuracy
|
||||
assert print_every >= 1, 'print_every needs to be greater than or equal 1'
|
||||
|
||||
starting_checkpoint_path = None
|
||||
|
||||
init_learning_rate = 5e-5
|
||||
|
||||
batch_size = 28
|
||||
|
||||
# visualizing results on a sample of chips during training_wcs
|
||||
num_chips_to_viz = 48
|
||||
|
||||
# Hardware and framework --------------------------------------------------------------------------------
|
||||
dtype = torch.float32
|
||||
|
||||
# Model -------------------------------------------------------------------------------------------------
|
||||
|
||||
num_classes = 14 # empty plus the 13 *coarse* WCS classes; this is the number of output nodes
|
||||
|
||||
num_in_channels = 6 # 2, 3, 6, 7, NDVI, elevation
|
||||
|
||||
# the smallest number of filters is 64 when feature_scale is 1, and it is 32 when feature_scale is 2
|
||||
feature_scale = 1
|
||||
|
||||
is_deconv = True # True to use transpose convolution filters to learn upsampling; otherwise upsampling is not learnt
|
||||
|
||||
is_batchnorm = True
|
||||
|
||||
model = Unet(feature_scale=feature_scale,
|
||||
n_classes=num_classes,
|
||||
in_channels=num_in_channels,
|
||||
is_deconv=is_deconv,
|
||||
is_batchnorm=is_batchnorm)
|
||||
|
||||
# Data ---------------------------------------------------------------------------------------------------
|
||||
|
||||
common_classes = [
|
||||
]
|
||||
less_common_classes = [
|
||||
]
|
||||
|
||||
weights = []
|
||||
for i in range(num_classes):
|
||||
if i in common_classes:
|
||||
weights.append(0)
|
||||
elif i in less_common_classes:
|
||||
weights.append(0)
|
||||
else:
|
||||
weights.append(1)
|
||||
|
||||
loss_weights = torch.FloatTensor(weights) # None if no weighting for classes
|
||||
print('Weights on loss per class used:')
|
||||
print(loss_weights)
|
||||
|
||||
# how many subprocesses to use for data loading
|
||||
# None for now - need to modify datasets.py to use
|
||||
data_loader_num_workers = None
|
||||
|
||||
# not available in IterableDataset data_loader_shuffle = True # True to have the data reshuffled at every epoch
|
||||
|
||||
chip_size = 256
|
||||
|
||||
# datasets and dataloaders
|
||||
if config_mode == ExperimentConfigMode.PREPROCESSING:
|
||||
dset_train = SingleShardChipsDataset(data_shard_dir, shard_prefix='train', channels=None,
|
||||
transform=transforms.Compose([
|
||||
ToTensor(),
|
||||
RandomHorizontalFlip(), # these operate on Tensors, not PIL images
|
||||
RandomVerticalFlip()
|
||||
]))
|
||||
loader_train = DataLoader(dset_train,
|
||||
batch_size=batch_size,
|
||||
num_workers=4,
|
||||
shuffle=True) # currently num_workers is None
|
||||
|
||||
dset_val = SingleShardChipsDataset(data_shard_dir, shard_prefix='val', channels=None,
|
||||
transform=transforms.Compose([
|
||||
ToTensor(),
|
||||
RandomHorizontalFlip(), # these operate on Tensors, not PIL images
|
||||
RandomVerticalFlip()
|
||||
]))
|
||||
loader_val = DataLoader(dset_val,
|
||||
num_workers=4,
|
||||
batch_size=batch_size)
|
||||
|
||||
# Data shards generation configurations --------------------------------------------------------------------
|
||||
|
||||
# These configurations are copied from training_wcs/experiments/elevation/elevation_2_config.py
|
||||
# They are only used with make_chip_shards.py and infer.py
|
||||
# train.py only use the generated chip shards as numpy arrays
|
||||
if config_mode in [ExperimentConfigMode.PREPROCESSING, ExperimentConfigMode.SCORING]:
|
||||
elevation_path = os.path.join(data_dir, 'images_srtm', 'wcs_orinoquia_srtm.tif')
|
||||
elevation_reader = rasterio.open(elevation_path)
|
||||
|
||||
# based on min and max values from the sample tile
|
||||
# wcs_orinoquia_sr_median_2013_2014-0000007424-0000007424_-71.347_4.593.tif in training_wcs set
|
||||
# bands 4 and 5 are combined to get the NDVI, so the normalization params for 4 and 5 are
|
||||
# not used during training_wcs data generation, only for visualization (actually not yet used for viz either).
|
||||
bands_normalization_params = {
|
||||
# these are the min and max to clip to for the band
|
||||
'min': {
|
||||
2: 0,
|
||||
3: 0,
|
||||
4: 0,
|
||||
5: 0,
|
||||
6: 0,
|
||||
7: 0
|
||||
},
|
||||
'max': {
|
||||
2: 700,
|
||||
3: 1500,
|
||||
4: 1500,
|
||||
5: 5000,
|
||||
6: 5000,
|
||||
7: 3000
|
||||
},
|
||||
'gamma': { # all the same in this experiment with value 1 which means no effect
|
||||
2: 1.0,
|
||||
3: 1.0,
|
||||
4: 1.0,
|
||||
5: 1.0,
|
||||
6: 1.0,
|
||||
7: 1.0
|
||||
}
|
||||
}
|
||||
|
||||
elevation_standardization_params = {
|
||||
# from calculations done in GEE
|
||||
'mean': 399.78,
|
||||
'std_dev': 714.78
|
||||
}
|
||||
|
||||
|
||||
def get_elevation_chip(tile_reader, chip_window):
|
||||
x, y = (tile_reader.bounds.left, tile_reader.bounds.top)
|
||||
|
||||
# getting the pixel array indices corresponding to points in georeferenced space
|
||||
row, col = elevation_reader.index(x, y)
|
||||
|
||||
# tile wcs_orinoquia_sr_median_2013_2014-0000000000-0000007424_-72.425_7.671.tif
|
||||
# top left corner looks up to a negative row index. Clipping to 0 seems to be okay visually
|
||||
row = max(0, row)
|
||||
col = max(0, col)
|
||||
|
||||
# resolution and project are the same for the elevation data and the Landsat imagery
|
||||
row = row + chip_window[1]
|
||||
col = col + chip_window[0] # x is col
|
||||
|
||||
try:
|
||||
w = rasterio.windows.Window.from_slices((row, row + chip_window[3]), (col, col + chip_window[2]))
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print('chip window:', str(chip_window))
|
||||
print('original row and col: ', str(elevation_reader.index(x, y)))
|
||||
print('row:', row)
|
||||
print('col:', col)
|
||||
print(tile_reader.bounds)
|
||||
print('x:', x)
|
||||
print('y:', y)
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
chip_elevation = elevation_reader.read(1, window=w) # only 1 band
|
||||
|
||||
# standardize
|
||||
chip_elevation = (chip_elevation - elevation_standardization_params['mean']) / elevation_standardization_params[
|
||||
'std_dev']
|
||||
return chip_elevation
|
||||
|
||||
|
||||
def _pad_chip(band, chip_window):
|
||||
"""
|
||||
|
||||
Args:
|
||||
band: numpy array of dims (h, w)
|
||||
chip_window: (col_off x, row_off y, width, height)
|
||||
|
||||
Returns:
|
||||
band padded to dims (width, height) of the chip_window provided
|
||||
"""
|
||||
width = chip_window[2]
|
||||
height = chip_window[3]
|
||||
# check for smaller than because we use get_chip to get the entire tile during scoring
|
||||
if band.shape[0] < height or band.shape[1] < width:
|
||||
# default pad constant value is 0
|
||||
try:
|
||||
band = np.pad(band,
|
||||
[(0, height - band.shape[0]), (0, width - band.shape[1])])
|
||||
except Exception as e:
|
||||
print(f'coarse_baseline_config, _pad_chip exception: {e}')
|
||||
|
||||
sys.exit(1)
|
||||
return band
|
||||
|
||||
|
||||
normal_bands = [2, 3, 6, 7] # bands to be used without calculating other indices e.g. NDVI
|
||||
|
||||
|
||||
def get_chip(tile_reader, chip_window, chip_for_display=True):
|
||||
"""
|
||||
Get an area (chip) specified by the chip_window. Is not related to chip_size
|
||||
Args:
|
||||
tile_reader: rasterio dataset object of the imagery tile
|
||||
chip_window: (col_off x, row_off y, width, height)
|
||||
chip_for_display: True if also return a chip that looks good
|
||||
Returns:
|
||||
stacked: A numpy array of dims (6, H, W) - note that height and width are switched from chip_window
|
||||
chip_for_display: If chip_for_display is True, also a 3-band array of the RGB channels scaled to
|
||||
look good (R channel not included in stacked)
|
||||
"""
|
||||
|
||||
bands_to_stack = []
|
||||
for b in normal_bands:
|
||||
# getting one band at a time because min, max and gamma may be different
|
||||
band = ImageryVisualizer.show_landsat8_patch(
|
||||
tile_reader,
|
||||
bands=[b], # pass in a list to get the batch dimension in the results
|
||||
window=chip_window,
|
||||
band_min=bands_normalization_params['min'][b],
|
||||
band_max=bands_normalization_params['max'][b],
|
||||
gamma=bands_normalization_params['gamma'][b],
|
||||
return_array=True
|
||||
)
|
||||
|
||||
# deal with incomplete chips
|
||||
band = _pad_chip(band, chip_window)
|
||||
|
||||
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
|
||||
|
||||
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_reader, window=chip_window) # 2D, dtype is float32
|
||||
ndvi = _pad_chip(ndvi, chip_window)
|
||||
bands_to_stack.append(ndvi)
|
||||
|
||||
elevation = get_elevation_chip(tile_reader, chip_window) # scene covers entire region, not tiled, so no gaps
|
||||
elevation = _pad_chip(elevation, chip_window)
|
||||
bands_to_stack.append(elevation)
|
||||
|
||||
try:
|
||||
stacked = np.stack(bands_to_stack)
|
||||
except Exception as e:
|
||||
print(f'Exception in get_chip: {e}')
|
||||
for b in bands_to_stack:
|
||||
print(b.shape)
|
||||
print('')
|
||||
|
||||
assert stacked.shape == (6, chip_window[3], chip_window[2]), \
|
||||
f'Chip has wrong shape: {stacked.shape}, should be (6, h, w)'
|
||||
|
||||
if chip_for_display:
|
||||
# chip for display, getting the RBG bands (default) with a different gamma and band_max that look good
|
||||
chip_for_display = ImageryVisualizer.show_landsat8_patch(tile_reader,
|
||||
window=chip_window,
|
||||
band_max=3000, # what looks good for RGB
|
||||
gamma=0.5,
|
||||
return_array=True)
|
||||
|
||||
return stacked, chip_for_display
|
||||
else:
|
||||
return stacked
|
||||
|
||||
|
||||
def preprocess_tile(tile_array: np.ndarray) -> np.ndarray:
|
||||
"""Same functionality as get_chip(), but applies to a numpy array tile of arbitrary shape.
|
||||
Currently only used with the landcover tool.
|
||||
|
||||
Args:
|
||||
tile_array: A numpy array of dims (height, width, channels). Expect elevation to be the eleventh channel
|
||||
|
||||
Returns:
|
||||
Numpy array representing of the preprocessed chip of dims (6, height, width) - note that channels is
|
||||
in-front.
|
||||
"""
|
||||
bands_to_stack = []
|
||||
for b in normal_bands:
|
||||
# getting one band at a time because min, max and gamma may be different
|
||||
band = ImageryVisualizer.show_landsat8_patch(
|
||||
tile_array,
|
||||
bands=[b], # pass in a list to get the batch dimension in the results
|
||||
band_min=bands_normalization_params['min'][b],
|
||||
band_max=bands_normalization_params['max'][b],
|
||||
gamma=bands_normalization_params['gamma'][b],
|
||||
return_array=True
|
||||
)
|
||||
bands_to_stack.append(band) # band is 2D (h, w), already squeezed, dtype is float 32
|
||||
|
||||
ndvi = ImageryVisualizer.get_landsat8_ndvi(tile_array) # 2D, dtype is float32
|
||||
bands_to_stack.append(ndvi)
|
||||
|
||||
# for the interactive tool, elevation is band 11 (1-indexed) or 10 (0-indexed), and already normalized
|
||||
# by elevation_standardization_params (could have done the normalization here too)
|
||||
elevation = tile_array[:, :, 10]
|
||||
bands_to_stack.append(elevation)
|
||||
|
||||
stacked = np.stack(bands_to_stack)
|
||||
|
||||
assert stacked.shape == (
|
||||
6, tile_array.shape[0], tile_array.shape[1]), f'preprocess_tile, wrong shape: {stacked.shape}'
|
||||
return stacked
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,105 @@
|
|||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .unet_utils import *
|
||||
|
||||
"""
|
||||
Unet model definition.
|
||||
|
||||
Code mostly taken from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/unet.py
|
||||
"""
|
||||
|
||||
|
||||
class Unet(nn.Module):
|
||||
|
||||
def __init__(self, feature_scale=1,
|
||||
n_classes=3, in_channels=3,
|
||||
is_deconv=True, is_batchnorm=False):
|
||||
"""A U-Net implementation.
|
||||
|
||||
Args:
|
||||
feature_scale: the smallest number of filters (depth c) is 64 when feature_scale is 1,
|
||||
and it is 32 when feature_scale is 2
|
||||
n_classes: number of output classes
|
||||
in_channels: number of channels in input
|
||||
is_deconv:
|
||||
is_batchnorm:
|
||||
"""
|
||||
super(Unet, self).__init__()
|
||||
|
||||
self._num_classes = n_classes
|
||||
|
||||
assert 64 % feature_scale == 0, f'feature_scale {feature_scale} does not work with this UNet'
|
||||
|
||||
filters = [64, 128, 256, 512, 1024] # this is `c` in the diagram, [c, 2c, 4c, 8c, 16c]
|
||||
filters = [int(x / feature_scale) for x in filters]
|
||||
logging.info('filters used are: {}'.format(filters))
|
||||
|
||||
# downsampling
|
||||
self.conv1 = UnetConv2(in_channels, filters[0], is_batchnorm)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.conv2 = UnetConv2(filters[0], filters[1], is_batchnorm)
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.conv3 = UnetConv2(filters[1], filters[2], is_batchnorm)
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.conv4 = UnetConv2(filters[2], filters[3], is_batchnorm)
|
||||
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.center = UnetConv2(filters[3], filters[4], is_batchnorm)
|
||||
|
||||
# upsampling
|
||||
self.up_concat4 = UnetUp(filters[4], filters[3], is_deconv)
|
||||
self.up_concat3 = UnetUp(filters[3], filters[2], is_deconv)
|
||||
self.up_concat2 = UnetUp(filters[2], filters[1], is_deconv)
|
||||
self.up_concat1 = UnetUp(filters[1], filters[0], is_deconv)
|
||||
|
||||
# final conv (without any concat)
|
||||
self.final = nn.Conv2d(filters[0], self._num_classes, kernel_size=1)
|
||||
self._filters = filters # we need this info for re-training
|
||||
|
||||
def forward(self, inputs, return_features=False):
|
||||
"""If return_features is True, returns tuple (final outputs, last feature map),
|
||||
else returns final outputs only.
|
||||
"""
|
||||
conv1 = self.conv1(inputs)
|
||||
maxpool1 = self.maxpool1(conv1)
|
||||
|
||||
conv2 = self.conv2(maxpool1)
|
||||
maxpool2 = self.maxpool2(conv2)
|
||||
|
||||
conv3 = self.conv3(maxpool2)
|
||||
maxpool3 = self.maxpool3(conv3)
|
||||
|
||||
conv4 = self.conv4(maxpool3)
|
||||
maxpool4 = self.maxpool4(conv4)
|
||||
|
||||
center = self.center(maxpool4)
|
||||
up4 = self.up_concat4(conv4, center)
|
||||
up3 = self.up_concat3(conv3, up4)
|
||||
up2 = self.up_concat2(conv2, up3)
|
||||
up1 = self.up_concat1(conv1, up2)
|
||||
|
||||
final = self.final(up1)
|
||||
|
||||
if return_features:
|
||||
return final, up1
|
||||
else:
|
||||
return final
|
||||
|
||||
def change_num_classes(self, new_num_classes: int):
|
||||
"""Re-initialize the final layer with another number of output classes if different from
|
||||
existing number of classes
|
||||
"""
|
||||
if new_num_classes == self._num_classes:
|
||||
return
|
||||
|
||||
assert new_num_classes > 1, 'Number of classes need to be > 1'
|
||||
self._num_classes = new_num_classes
|
||||
|
||||
self.final = nn.Conv2d(self._filters[0], self._num_classes, kernel_size=1)
|
||||
nn.init.kaiming_uniform_(self.final.weight)
|
||||
self.final.bias.data.zero_()
|
|
@ -0,0 +1,74 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UnetConv2(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, is_batchnorm):
|
||||
super(UnetConv2, self).__init__()
|
||||
|
||||
if is_batchnorm:
|
||||
self.conv1 = nn.Sequential(
|
||||
# this amount of padding/stride/kernel_size preserves width/height
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.conv1(inputs)
|
||||
outputs = self.conv2(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class UnetUp(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, is_deconv):
|
||||
"""
|
||||
|
||||
is_deconv: use transposed conv layer to upsample - parameters are learnt; otherwise use
|
||||
bilinear interpolation to upsample.
|
||||
"""
|
||||
super(UnetUp, self).__init__()
|
||||
|
||||
self.conv = UnetConv2(in_channels, out_channels, False)
|
||||
|
||||
self.is_deconv = is_deconv
|
||||
if is_deconv:
|
||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
# nn.UpsamplingBilinear2d is deprecated in favor of F.interpolate()
|
||||
# else:
|
||||
# self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||
|
||||
def forward(self, inputs1, inputs2):
|
||||
"""
|
||||
inputs1 is from the downward path, of higher resolution
|
||||
inputs2 is from the 'lower' layer. It gets upsampled (spatial size increases) and its depth (channels) halves
|
||||
to match the depth of inputs1, before being concatenated in the depth dimension.
|
||||
"""
|
||||
if self.is_deconv:
|
||||
outputs2 = self.up(inputs2)
|
||||
else:
|
||||
# scale_factor is the multiplier for spatial size
|
||||
outputs2 = F.interpolate(inputs2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
|
||||
offset = outputs2.size()[2] - inputs1.size()[2]
|
||||
padding = 2 * [offset // 2, offset // 2]
|
||||
outputs1 = F.pad(inputs1, padding)
|
||||
|
||||
return self.conv(torch.cat([outputs1, outputs2], dim=1))
|
Загрузка…
Ссылка в новой задаче