Initial commit
|
@ -0,0 +1,3 @@
|
|||
dockerfile
|
||||
gitignore
|
||||
README.md
|
|
@ -0,0 +1,110 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# Mac file system
|
||||
.DS_Store
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Building Damage Assessment Model
|
||||
|
||||
**Jump to: [Data sources](#data-sources) | [Setup](#setup) | [Data processing](#data-processing) | [Data splits & augmentation](#data-splits-&-augmentation)| [Overview of the model](#overview-of-the-model) | [Running experiments](#running-experiments) | [Results](#results) |**
|
||||
|
||||
|
||||
Natural disasters affect 350 million people each year. Allocating resources such as shelter, medical aid, and food would relieve people of pain most effectively if the impact of the disaster could be assessed in a short time frame after the disaster. The Netherlands Red Cross (NLRC) founded the [510](https://www.510.global/) initiative in 2016 to turn data into timely information and put it in the hands of aid workers. This study was a Microsoft AI for Humanitarian Action project in collaboration with the NLRC 510 global initiative. In this study, we leverage high-resolution satellite imagery to conduct building footprint segmentation and train a classifier to assign each building's damage severity level via an end-to-end deep learning pipeline. Knowing the damage to individual buildings will enable calculating accurately the number of shelters or most impacted areas by natural disasters required in large-scale disaster incidents such as a hurricane.
|
||||
|
||||
## Data Sources
|
||||
We used [xBD dataset](https://xview2.org/), a publicly available dataset, to train and evaluate our proposed network performance. Detailed information about this dataset is provided in ["xBD: A Dataset for Assessing Building Damage from Satellite Imagery"](https://arxiv.org/abs/1911.09296) by Ritwik Gupta et al.
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Docker
|
||||
|
||||
This code uses Docker to allow portability, the only dependency is docker itself, you can get docker from [here] (https://docs.docker.com/get-docker/).
|
||||
|
||||
After you have installed, you will just need to run the build command from the root of this project (in the example bellow we add a -t tag, you can use whatever tag you want), same level as the dockerfile is:
|
||||
|
||||
```
|
||||
docker build . -t nlrc-building-damage-assessment:latest
|
||||
```
|
||||
|
||||
After the image is build, run the inference code by using below command, passing the parameters with the correspoding values. Use nvidia-docker for using GPUS.
|
||||
|
||||
```
|
||||
docker run --name "nlrc-model" --rm -v /datadrive/nlrc:/mnt nlrc-building-damage-assessment:latest "--output_dir" "/mnt" "--data_img_dir" "/mnt/dataset" "--data_inference_dict" "/mnt/constants/splits/all_disaster_splits_sliced_img_augmented_20.json" "--data_mean_stddev" "/mnt/constants/splits/all_disaster_mean_stddev_tiles_0_1.json" "--label_map_json" "/mnt/constants/class_lists/xBD_label_map.json" "--model" "/mnt/models/model_best.pth.tar"
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
### Creating the conda environment
|
||||
|
||||
At the root directory of this repo, use environment.yml to create a conda virtual environment called `nlrc`:
|
||||
|
||||
```
|
||||
conda env create --file environment.yml
|
||||
```
|
||||
|
||||
If you need additional packages, add them in environment.yml and update the environment:
|
||||
```
|
||||
conda env update --name nlrc --file environment.yml --prune
|
||||
```
|
||||
|
||||
### Installing `ai4eutils`
|
||||
|
||||
We make use of the `geospatial` module in the [ai4eutils](https://github.com/microsoft/ai4eutils) repo for some of the data processing steps, so you may need to clone it and add its path to the `PYTHONPATH`:
|
||||
|
||||
```
|
||||
export PYTHONPATH="${PYTHONPATH}:/path/to/ai4eutils"
|
||||
```
|
||||
|
||||
## Data processing
|
||||
|
||||
### Generate masks from polygons
|
||||
|
||||
We generate pixel masks based on the xBD dataset labels provided as polygons in geoJSON files since the tier3 disasters did not come with masks and the masks for the other disasters had a border value that was likely 0, which would not help to separate the buildings. To do that, we modified the xView baseline repo's [script](https://github.com/DIUx-xView/xView2_baseline/blob/master/utils/mask_polygons.py) for [create_label_masks.py](./data/inspect_masks.ipynb) to generate the masks for entire dataset. Commands that we ran:
|
||||
```
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/hold -b 1
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/test -b 1
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/train -b 1
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw_tier3 -b 1
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/hold -b 2
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/test -b 2
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw/train -b 2
|
||||
|
||||
python data/create_label_masks.py ./nlrc-damage-assessment/public_datasets/xBD/raw_tier3 -b 2
|
||||
|
||||
```
|
||||
Masks for border widths of 1 and 2 were created in case we would like to experiment with both cases. We used a border of 2 but you can see their effects in the notebook [inspect_masks.ipynb](./data/inspect_masks.ipynb) and choose either case that you would prefer.
|
||||
|
||||
|
||||
### Generate smaller image patches
|
||||
To create smaller patches from xBD original tiles you can use this code: [make_smaller_tiles.py](./data/make_smaller_tiles.py). In the experiments, we cropped 1024x1024 images into 256x256 patches. In this document, we refer to each original 1024x1024 xBD image as a "tile" and any image of smaller size cropped from the original tile is referred to as a "patch".
|
||||
|
||||
```
|
||||
python data/make_smaller_tiles.py
|
||||
```
|
||||
|
||||
### Generate npy files
|
||||
Given the size of the dataset, loading the images one by one from the blob storage is very time-consuming. To resolve this low-speed issue, we loaded the entire dataset once and saved them as one npy file. Then for the experimnets, we loaded the data upfront, which helped with speed-up significantly.
|
||||
|
||||
To do that you can use this code: [make_data_shards.py](./data/make_data_shards.py).
|
||||
We saved the entire xBD dataset into one shard but if more shards are required due to the large size of the data, the same code can be used for that purpose as well.
|
||||
|
||||
```
|
||||
python ./data/make_data_shards.py
|
||||
```
|
||||
|
||||
### Data normalization
|
||||
Normalization for each patch is conducted based on its corresponding tile's mean and standard deviation. To compute mean & standard deviation for each tile, [data/compute_mean_stddev.ipynb](./data/compute_mean_stddev.ipynb) in the `data` dir of this repo can be used a reference.
|
||||
sample of a file that contains tile-based mean & standard deviation for xBD dataset: [constants/splits/all_disaster_mean_stddev_tiles_0_1.json](./constants/splits/all_disaster_mean_stddev_tiles_0_1.json)
|
||||
|
||||
### Constants
|
||||
|
||||
The xBD output classes and their numerical class used in the label polygon/masks are documented on the xView2_baseline [repo](https://github.com/DIUx-xView/xView2_baseline/tree/821d9f8b9201ee7952aef13b073c9fd38ce11d4b#output):
|
||||
|
||||
```
|
||||
0 for no building
|
||||
1 for building found and classified no-damaged
|
||||
2 for building found and classified minor-damage
|
||||
3 for building found and classified major-damage
|
||||
4 for building found and classified destroyed
|
||||
```
|
||||
|
||||
We have an extra class, `5`, to denote building features that were "unclassified" in xBD. They need to be discounted during training and evaluation.
|
||||
|
||||
|
||||
[constants/class_lists](./constants/class_lists/) contains files with information about mapping between actual damage category, class labels and color codes used for predicted classes vizulaization.
|
||||
![Damage class legend](./constants/class_lists/xBD_damage_class_legend.png)
|
||||
|
||||
## Data splits & augmentation
|
||||
|
||||
To make the split, [data/class_distribution_and_splits.ipynb](./data/class_distribution_and_splits.ipynb) in the `data` dir of this repo can be used a reference.
|
||||
|
||||
We do not use the train/test/hold splits that xBD used during the xView2 competition. We retain the folder structure of the datasets (which has train/test/hold as folder names). This JSON is a dictionary where the key is the name of the disaster, and the value is another dict with keys `train`, `val` and `test`, each pointing to a list of file paths, starting from the xBD root directory in the data storage container.
|
||||
|
||||
Files with information about dataset splits and their corresponding paths along with the mean and standard deviation for each image tile used for normalization are placed in [constants/splits](./constants/splits/) in this repo. All the splits used in our training are randomized based on the xBD 1024x1024 tiles before being cropped into smaller patches.
|
||||
|
||||
Sample of a file that contains xBD 1024x1024 image paths for train/val/test sets split at ratio 80:10:10 can be found via [constants/splits/all_disaster_splits.json](./constants/splits/all_disaster_splits.json)
|
||||
|
||||
Sample of a file that contains xBD 256x256 image paths for train/val/test sets split at ratio 80:10:10 where each 1024x1024 xBD tile has been cropped into 20 patches (16 non-overlapping and 4 overlapping patches) can be found via [constants/splits/all_disaster_splits_sliced_img_augmented_20.json](./constants/splits/all_disaster_splits_sliced_img_augmented_20.json)
|
||||
|
||||
Sample of a file that contains xBD 1024x1024 image paths for train/val/test sets split at ratio 90:10:0 can be found via [constants/splits/final_mdl_all_disaster_splits.json](./constants/splits/final_mdl_all_disaster_splits.json)
|
||||
|
||||
Sample of a file that contains xBD 256x256 image paths for for train/val/test sets split at ratio 90:10:0 where each 1024x1024 xBD tile has been cropped into 20 patches (16 non-overlapping and 4 overlapping): [constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json](./constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json)
|
||||
|
||||
Note: In our sample split jsons, '/labels/' is part of the string for each image paths. Please note that this component of the string is being replaced with appropriate strings to reflect the correct folder structure in xBD dataset when loading pre- and post-disaster images and their labels masks during training and inference.
|
||||
|
||||
As explined above, for our experiments, we used the augmented splits where each 1024x1024 xBD tile has been cropped into 20 patches (16 non-overlapping and 4 overlapping patches). Moreover, during the training, we conduct random vertical and horizontal flipping on-the-fly as implementd in [train/utils/dataset_shard_load.py](./train/utils/dataset_shard_load.py).
|
||||
|
||||
|
||||
## Overview of the model
|
||||
|
||||
Our proposed approach shares some characteristics with ["An Attention-Based System for Damage Assessment Using Satellite Imagery"](https://arxiv.org/pdf/2004.06643v1.pdf) by Hanxiang Hao et al. However, we do not incorporate any attention mechanism in the network and we use a fewer number of convolutional layers for the segmentation arm, which is a UNet approach. Details of our architecture are shown below:
|
||||
|
||||
![Network Architecture Schema](./images/model.PNG)
|
||||
<!--
|
||||
<p align="center">
|
||||
<img src="" width="800"/>
|
||||
</p> -->
|
||||
|
||||
Our proposed model, implemented in this repo, is an end-to-end model, which provides building masks on a pair of pre- & post-disaster satellite imagery along with the level of damage for each detected building, in case of natural disasters.
|
||||
|
||||
## Running experiments
|
||||
|
||||
### Training
|
||||
The model needs to be trained sequentially for building detection and damage classification tasks. The loss function has three components for penalizing mistakes on three different predicted outputs of the network that include: (I) building detection on pre-disaster imagery, (II) building detection on post-disaster imagery, and (III) the damage classification output. We use a Siamese approach where for segmentation tasks (I) & (II), the UNet network parameters are shared (shown in the upper and lower arm of the network in the Network Architecture Schema). UNet embeddings of pre- and post-disaster imagery, generated from the encoder part of the model, are differenced and run through several convolutional layers to give the final pixel-wise classification of the damage level for predicted buildings via the middle arm shown in Network Architecture Schema.
|
||||
|
||||
To train the network for building detection task, we first avoid penalizing the network for mistakes in the damage classification task by setting the weight for that task equal to 0 (e.g., in config dictionary, set `'weights_loss': [0.5, 0.5, 0] & 'mode': 'bld'`). In this model, if a checkpoint path is provided, the model resumes training based on that, otherwise, it starts training from scratch.
|
||||
Once model training progresses and reasonable results are achieved for building detection, the training can be stopped and the damage classification task can get started via freezing the parameters of the building segmentation task (e.g., in config dictionary, set `'weights_loss': [0, 0, 1] & 'mode': 'dmg'`). The checkpoint for the best epoch tuned for building segmentation tasks should be provided in the config dictionary to allow proper damage classification results. Please review the config dictionary in [train/train.py](./train/train.py) before running the code.
|
||||
|
||||
```
|
||||
python train/train.py
|
||||
```
|
||||
|
||||
The experiment's progress can be monitored via tensorboard.
|
||||
|
||||
```
|
||||
tensorboard --host 0.0.0.0 --logdir ./outputs/experiment_name/logs/ --port 8009
|
||||
```
|
||||
|
||||
### Inference
|
||||
The trained model can be used for inference via the following command. The best trained model file is presented under [models](./models/model_best.pth.tar) directory. Please review the paramters needed in [inference/inference.py](./inference/inference.py) before running the code.
|
||||
|
||||
```
|
||||
python inference/inference.py --output_dir outputs/ --data_img_dir xBD_sliced_augmented_20_alldisasters/ --data_inference_dict constants/splits/all_disaster_splits_sliced_img_augmented_20.json --data_mean_stddev constants/splits/all_disaster_mean_stddev_tiles_0_1.json --label_map_json constants/class_lists/xBD_label_map.json --model models/model_best.pth.tar
|
||||
```
|
||||
Samples of files that contain input images paths are shown in [constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json](./constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json) and [constants/splits/all_disaster_splits_sliced_img_augmented_20.json](./constants/splits/all_disaster_splits_sliced_img_augmented_20.json).
|
||||
This input file can contain one or multiple image paths for inference.
|
||||
|
||||
|
||||
### Evaluation
|
||||
During the development phase, our main evaluation metric was pixel based evaluation for both tasks, i.e., building segmentation and damage classification. However, building-level evaluation utilities are also provided in [eval](./eval) folder for your reference. We have incorporated the building-level evaluation metric into the inference code as well. The evaluation results are saved as CSV files. However, please note that the building predicted polygons might be connected in some patches and calculated true positive numbers are extremenly underestimated. Thus, we believe that pixel-level evaluation is a better metric that reflect more fairly on the performace of the model in damaged area detection.
|
||||
|
||||
## Results
|
||||
We show the results on validation and test sets of our splits along with some segmenation maps with damage level.
|
||||
![Pixel-level evaluation results on several different splits](./images/results_3.PNG)
|
||||
![Validation sample I visualzization results](./images/results_2.PNG)
|
||||
![Validation sample II visualzization results](./images/results_1.PNG)
|
После Ширина: | Высота: | Размер: 12 KiB |
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"num_to_name": {
|
||||
"0": "Background",
|
||||
"1": "No damage",
|
||||
"2": "Minor damage",
|
||||
"3": "Major damage",
|
||||
"4": "Destroyed",
|
||||
"5": "Unclassified"
|
||||
},
|
||||
"label_name_to_num": {
|
||||
"no-damage": 1,
|
||||
"minor-damage": 2,
|
||||
"major-damage": 3,
|
||||
"destroyed": 4,
|
||||
"un-classified": 5
|
||||
},
|
||||
"num_to_color": {
|
||||
"0": "black",
|
||||
"1": "limegreen",
|
||||
"2": "orange",
|
||||
"3": "mediumslateblue",
|
||||
"4": "mediumvioletred",
|
||||
"5": "lightgray"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# Data processing
|
||||
|
||||
## Generate masks from polygons
|
||||
|
||||
We generate pixel masks based on the xBD dataset labels provided as polygons in geoJSON files, since the tier3 disasters did not come with masks and the masks for the other disasters had a border value that was likely 0, which would not help to separate the buildings.
|
||||
|
||||
We modified the xView baseline repo's [script](https://github.com/DIUx-xView/xView2_baseline/blob/master/utils/mask_polygons.py) for `create_label_masks.py` to generate the masks for all wind disasters. Running the script only took < 10 minutes. Commands that we ran:
|
||||
```
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/hold -b 1
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/test -b 1
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/train -b 1
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw_tier3 -b 1
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/hold -b 2
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/test -b 2
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw/train -b 2
|
||||
|
||||
python data/create_label_masks.py /home/lynx/mnt/nlrc-damage-assessment/public_datasets/xBD/raw_tier3 -b 2
|
||||
|
||||
```
|
||||
|
||||
Masks for border width of 1 and 2 were created in case we would like to experiment. You can see their effects in the notebook [inspect_masks.ipynb](./inspect_masks.ipynb).
|
|
@ -0,0 +1,388 @@
|
|||
{
|
||||
"pre_areas": {
|
||||
"guatemala-volcano": {
|
||||
"no-subtype": 555083.5836875269
|
||||
},
|
||||
"hurricane-florence": {
|
||||
"no-subtype": 11596718.672202773
|
||||
},
|
||||
"hurricane-harvey": {
|
||||
"no-subtype": 64235072.75402863
|
||||
},
|
||||
"hurricane-matthew": {
|
||||
"no-subtype": 9551591.73552839
|
||||
},
|
||||
"hurricane-michael": {
|
||||
"no-subtype": 43913489.5134971
|
||||
},
|
||||
"mexico-earthquake": {
|
||||
"no-subtype": 52658622.98346215
|
||||
},
|
||||
"midwest-flooding": {
|
||||
"no-subtype": 13892135.398812711
|
||||
},
|
||||
"palu-tsunami": {
|
||||
"no-subtype": 33203878.023979597
|
||||
},
|
||||
"santa-rosa-wildfire": {
|
||||
"no-subtype": 25924688.820611726
|
||||
},
|
||||
"socal-fire": {
|
||||
"no-subtype": 23049244.366123896
|
||||
},
|
||||
"joplin-tornado": {
|
||||
"no-subtype": 16040713.727818588
|
||||
},
|
||||
"lower-puna-volcano": {
|
||||
"no-subtype": 1674418.6473476125
|
||||
},
|
||||
"moore-tornado": {
|
||||
"no-subtype": 24100625.420579486
|
||||
},
|
||||
"nepal-flooding": {
|
||||
"no-subtype": 18686861.90131992
|
||||
},
|
||||
"pinery-bushfire": {
|
||||
"no-subtype": 3616593.3513660347
|
||||
},
|
||||
"portugal-wildfire": {
|
||||
"no-subtype": 14850977.596319549
|
||||
},
|
||||
"sunda-tsunami": {
|
||||
"no-subtype": 8411636.819834696
|
||||
},
|
||||
"tuscaloosa-tornado": {
|
||||
"no-subtype": 16399750.244178453
|
||||
},
|
||||
"woolsey-fire": {
|
||||
"no-subtype": 5507821.958937849
|
||||
}
|
||||
},
|
||||
"post_areas": {
|
||||
"guatemala-volcano": {
|
||||
"no-damage": 424233.6886453207,
|
||||
"minor-damage": 29976.136637258915,
|
||||
"destroyed": 13941.99039685283,
|
||||
"major-damage": 24967.16713784182,
|
||||
"un-classified": 61147.989949917406
|
||||
},
|
||||
"hurricane-florence": {
|
||||
"no-damage": 8602392.32303186,
|
||||
"minor-damage": 576109.6347002991,
|
||||
"un-classified": 626750.2712243387,
|
||||
"major-damage": 1699878.4674861634,
|
||||
"destroyed": 68945.8519562309
|
||||
},
|
||||
"hurricane-harvey": {
|
||||
"major-damage": 19476393.610768512,
|
||||
"minor-damage": 7644750.508543043,
|
||||
"no-damage": 35160357.23794543,
|
||||
"un-classified": 1108755.1759903177,
|
||||
"destroyed": 689618.9005291539
|
||||
},
|
||||
"hurricane-matthew": {
|
||||
"minor-damage": 5636593.268183401,
|
||||
"no-damage": 1638989.3645763558,
|
||||
"major-damage": 1236911.176751667,
|
||||
"destroyed": 751569.3877616163,
|
||||
"un-classified": 269248.90720901743
|
||||
},
|
||||
"hurricane-michael": {
|
||||
"no-damage": 25807328.226440027,
|
||||
"minor-damage": 10552598.155848961,
|
||||
"un-classified": 177777.65172964666,
|
||||
"major-damage": 6311634.324490453,
|
||||
"destroyed": 987272.8873742732
|
||||
},
|
||||
"mexico-earthquake": {
|
||||
"no-damage": 52106590.83687903,
|
||||
"minor-damage": 293567.94942568743,
|
||||
"un-classified": 71036.76478397334,
|
||||
"major-damage": 80656.76917687584,
|
||||
"destroyed": 5357.650044842722
|
||||
},
|
||||
"midwest-flooding": {
|
||||
"no-damage": 13088260.501154136,
|
||||
"un-classified": 188512.98368551358,
|
||||
"destroyed": 128263.53270132157,
|
||||
"major-damage": 215626.0213974825,
|
||||
"minor-damage": 245182.62643771025
|
||||
},
|
||||
"palu-tsunami": {
|
||||
"no-damage": 28429408.79366035,
|
||||
"destroyed": 3327823.8268332803,
|
||||
"major-damage": 1111737.3222539243,
|
||||
"un-classified": 266830.25486142916,
|
||||
"minor-damage": 831.4207079555346
|
||||
},
|
||||
"santa-rosa-wildfire": {
|
||||
"no-damage": 18173724.881392688,
|
||||
"destroyed": 7355765.6933949385,
|
||||
"un-classified": 21348.4592846757,
|
||||
"major-damage": 169763.6860384565,
|
||||
"minor-damage": 157051.30029256907
|
||||
},
|
||||
"socal-fire": {
|
||||
"no-damage": 20839832.579554196,
|
||||
"un-classified": 238065.398824969,
|
||||
"destroyed": 1707512.377862279,
|
||||
"minor-damage": 95134.90757587964,
|
||||
"major-damage": 117743.7346977082
|
||||
},
|
||||
"joplin-tornado": {
|
||||
"destroyed": 2858702.8832999812,
|
||||
"minor-damage": 2451123.516412556,
|
||||
"major-damage": 1490180.1164239356,
|
||||
"un-classified": 305713.5343858696,
|
||||
"no-damage": 8902980.181803767
|
||||
},
|
||||
"lower-puna-volcano": {
|
||||
"un-classified": 92873.13689085834,
|
||||
"no-damage": 1278018.0225500248,
|
||||
"minor-damage": 28325.978727244514,
|
||||
"destroyed": 258886.20300476396,
|
||||
"major-damage": 12627.194320518138
|
||||
},
|
||||
"moore-tornado": {
|
||||
"no-damage": 20858990.4921232,
|
||||
"un-classified": 162590.68324451018,
|
||||
"destroyed": 1321422.0560918555,
|
||||
"minor-damage": 1123501.9486039935,
|
||||
"major-damage": 583883.1525602402
|
||||
},
|
||||
"nepal-flooding": {
|
||||
"no-damage": 13396893.557653563,
|
||||
"un-classified": 494409.1431600371,
|
||||
"minor-damage": 2526383.940554102,
|
||||
"major-damage": 2167094.243345501,
|
||||
"destroyed": 82101.82189437859
|
||||
},
|
||||
"pinery-bushfire": {
|
||||
"un-classified": 68172.28891898633,
|
||||
"no-damage": 3330864.1249928013,
|
||||
"major-damage": 56681.20608206688,
|
||||
"destroyed": 110390.71678650305,
|
||||
"minor-damage": 43455.85213152673
|
||||
},
|
||||
"portugal-wildfire": {
|
||||
"no-damage": 13827027.478102874,
|
||||
"un-classified": 253977.52260864992,
|
||||
"destroyed": 467906.2358453642,
|
||||
"minor-damage": 81083.11611925096,
|
||||
"major-damage": 189985.77147544394
|
||||
},
|
||||
"sunda-tsunami": {
|
||||
"no-damage": 6957184.433223553,
|
||||
"un-classified": 1339309.6059006073,
|
||||
"destroyed": 54783.33837448994,
|
||||
"major-damage": 46925.12787137661
|
||||
},
|
||||
"tuscaloosa-tornado": {
|
||||
"no-damage": 12509748.859454736,
|
||||
"un-classified": 377831.00611432316,
|
||||
"minor-damage": 1915863.4168710771,
|
||||
"destroyed": 877805.9944712821,
|
||||
"major-damage": 690388.5966991772
|
||||
},
|
||||
"woolsey-fire": {
|
||||
"destroyed": 1230804.3924022317,
|
||||
"minor-damage": 81576.74220822704,
|
||||
"no-damage": 4069087.176166785,
|
||||
"major-damage": 81166.70944708213,
|
||||
"un-classified": 33949.60325156282
|
||||
}
|
||||
},
|
||||
"pre_counts": {
|
||||
"guatemala-volcano": {
|
||||
"no-subtype": 991
|
||||
},
|
||||
"hurricane-florence": {
|
||||
"no-subtype": 11548
|
||||
},
|
||||
"hurricane-harvey": {
|
||||
"no-subtype": 37955
|
||||
},
|
||||
"hurricane-matthew": {
|
||||
"no-subtype": 23964
|
||||
},
|
||||
"hurricane-michael": {
|
||||
"no-subtype": 35501
|
||||
},
|
||||
"mexico-earthquake": {
|
||||
"no-subtype": 51473
|
||||
},
|
||||
"midwest-flooding": {
|
||||
"no-subtype": 13896
|
||||
},
|
||||
"palu-tsunami": {
|
||||
"no-subtype": 55789
|
||||
},
|
||||
"santa-rosa-wildfire": {
|
||||
"no-subtype": 21955
|
||||
},
|
||||
"socal-fire": {
|
||||
"no-subtype": 18969
|
||||
},
|
||||
"joplin-tornado": {
|
||||
"no-subtype": 15352
|
||||
},
|
||||
"lower-puna-volcano": {
|
||||
"no-subtype": 3410
|
||||
},
|
||||
"moore-tornado": {
|
||||
"no-subtype": 22958
|
||||
},
|
||||
"nepal-flooding": {
|
||||
"no-subtype": 43265
|
||||
},
|
||||
"pinery-bushfire": {
|
||||
"no-subtype": 5961
|
||||
},
|
||||
"portugal-wildfire": {
|
||||
"no-subtype": 23413
|
||||
},
|
||||
"sunda-tsunami": {
|
||||
"no-subtype": 16947
|
||||
},
|
||||
"tuscaloosa-tornado": {
|
||||
"no-subtype": 15006
|
||||
},
|
||||
"woolsey-fire": {
|
||||
"no-subtype": 7015
|
||||
}
|
||||
},
|
||||
"post_counts": {
|
||||
"guatemala-volcano": {
|
||||
"no-damage": 731,
|
||||
"minor-damage": 26,
|
||||
"destroyed": 33,
|
||||
"major-damage": 23,
|
||||
"un-classified": 178
|
||||
},
|
||||
"hurricane-florence": {
|
||||
"no-damage": 8466,
|
||||
"minor-damage": 232,
|
||||
"un-classified": 820,
|
||||
"major-damage": 1949,
|
||||
"destroyed": 81
|
||||
},
|
||||
"hurricane-harvey": {
|
||||
"major-damage": 13378,
|
||||
"minor-damage": 4510,
|
||||
"no-damage": 18638,
|
||||
"un-classified": 581,
|
||||
"destroyed": 848
|
||||
},
|
||||
"hurricane-matthew": {
|
||||
"minor-damage": 12331,
|
||||
"no-damage": 4058,
|
||||
"major-damage": 2717,
|
||||
"destroyed": 3524,
|
||||
"un-classified": 1334
|
||||
},
|
||||
"hurricane-michael": {
|
||||
"no-damage": 22692,
|
||||
"minor-damage": 8292,
|
||||
"un-classified": 373,
|
||||
"major-damage": 2919,
|
||||
"destroyed": 1225
|
||||
},
|
||||
"mexico-earthquake": {
|
||||
"no-damage": 51084,
|
||||
"minor-damage": 221,
|
||||
"un-classified": 111,
|
||||
"major-damage": 54,
|
||||
"destroyed": 3
|
||||
},
|
||||
"midwest-flooding": {
|
||||
"no-damage": 12819,
|
||||
"un-classified": 473,
|
||||
"destroyed": 165,
|
||||
"major-damage": 193,
|
||||
"minor-damage": 246
|
||||
},
|
||||
"palu-tsunami": {
|
||||
"no-damage": 46796,
|
||||
"destroyed": 7203,
|
||||
"major-damage": 1178,
|
||||
"un-classified": 611,
|
||||
"minor-damage": 1
|
||||
},
|
||||
"santa-rosa-wildfire": {
|
||||
"no-damage": 15843,
|
||||
"destroyed": 5810,
|
||||
"un-classified": 86,
|
||||
"major-damage": 95,
|
||||
"minor-damage": 121
|
||||
},
|
||||
"socal-fire": {
|
||||
"no-damage": 15697,
|
||||
"un-classified": 693,
|
||||
"destroyed": 2333,
|
||||
"minor-damage": 136,
|
||||
"major-damage": 110
|
||||
},
|
||||
"joplin-tornado": {
|
||||
"destroyed": 3274,
|
||||
"minor-damage": 2192,
|
||||
"major-damage": 1005,
|
||||
"un-classified": 656,
|
||||
"no-damage": 8225
|
||||
},
|
||||
"lower-puna-volcano": {
|
||||
"un-classified": 554,
|
||||
"no-damage": 2277,
|
||||
"minor-damage": 49,
|
||||
"destroyed": 504,
|
||||
"major-damage": 26
|
||||
},
|
||||
"moore-tornado": {
|
||||
"no-damage": 19453,
|
||||
"un-classified": 586,
|
||||
"destroyed": 1584,
|
||||
"minor-damage": 886,
|
||||
"major-damage": 449
|
||||
},
|
||||
"nepal-flooding": {
|
||||
"no-damage": 31225,
|
||||
"un-classified": 1683,
|
||||
"minor-damage": 5134,
|
||||
"major-damage": 4721,
|
||||
"destroyed": 502
|
||||
},
|
||||
"pinery-bushfire": {
|
||||
"un-classified": 524,
|
||||
"no-damage": 5027,
|
||||
"major-damage": 99,
|
||||
"destroyed": 229,
|
||||
"minor-damage": 82
|
||||
},
|
||||
"portugal-wildfire": {
|
||||
"no-damage": 20787,
|
||||
"un-classified": 1064,
|
||||
"destroyed": 1090,
|
||||
"minor-damage": 176,
|
||||
"major-damage": 296
|
||||
},
|
||||
"sunda-tsunami": {
|
||||
"no-damage": 14078,
|
||||
"un-classified": 2590,
|
||||
"destroyed": 179,
|
||||
"major-damage": 100
|
||||
},
|
||||
"tuscaloosa-tornado": {
|
||||
"no-damage": 10499,
|
||||
"un-classified": 908,
|
||||
"minor-damage": 2036,
|
||||
"destroyed": 1097,
|
||||
"major-damage": 466
|
||||
},
|
||||
"woolsey-fire": {
|
||||
"destroyed": 1876,
|
||||
"minor-damage": 189,
|
||||
"no-damage": 4638,
|
||||
"major-damage": 126,
|
||||
"un-classified": 186
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
create_label_mask.py
|
||||
|
||||
For each label json file, flat in one directory, outputs a 2D raster of labels.
|
||||
Have to run this for different root directories containing `labels` and `images` folders.
|
||||
|
||||
Manually fill out the disaster name (prefix to file names) in DISASTERS_OF_INTEREST at the top of the script.
|
||||
Masks will be generated for these disasters only.
|
||||
|
||||
Sample invocation:
|
||||
```
|
||||
python data/create_label_masks.py /home/lynx/data -b 2
|
||||
```
|
||||
|
||||
This script borrows code and functions from
|
||||
https://github.com/DIUx-xView/xView2_baseline/blob/master/utils/mask_polygons.py
|
||||
Below is their copyright statement:
|
||||
"""
|
||||
#####################################################################################################################################################################
|
||||
# xView2 #
|
||||
# Copyright 2019 Carnegie Mellon University. #
|
||||
# NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO #
|
||||
# WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS FOR PURPOSE OR MERCHANTABILITY, #
|
||||
# EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT, #
|
||||
# TRADEMARK, OR COPYRIGHT INFRINGEMENT. #
|
||||
# Released under a MIT (SEI)-style license, please see LICENSE.md or contact permission@sei.cmu.edu for full terms. #
|
||||
# [DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited distribution. Please see Copyright notice for non-US Government use #
|
||||
# and distribution. #
|
||||
# This Software includes and/or makes use of the following Third-Party Software subject to its own license: #
|
||||
# 1. SpaceNet (https://github.com/motokimura/spacenet_building_detection/blob/master/LICENSE) Copyright 2017 Motoki Kimura. #
|
||||
# DM19-0988 #
|
||||
#####################################################################################################################################################################
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
# documentation for cv2 fillPoly https://docs.opencv.org/master/d6/d6e/group__imgproc__draw.html#ga8c69b68fab5f25e2223b6496aa60dad5
|
||||
from cv2 import fillPoly, imwrite
|
||||
import numpy as np
|
||||
from shapely import wkt
|
||||
from shapely.geometry import mapping, Polygon
|
||||
from skimage.io import imread
|
||||
from tqdm import tqdm
|
||||
|
||||
# keep as a tuple, not a list
|
||||
# add a _ at the end of the disaster-name so they are prefix-free
|
||||
|
||||
DISASTERS_OF_INTEREST = ('guatemala-volcano_', 'hurricane-florence_', 'hurricane-harvey_', 'mexico-earthquake_', 'midwest-flooding_', 'palu-tsunami_', 'santa-rosa-wildfire_', 'socal-fire_', 'lower-puna-volcano_', 'nepal-flooding_', 'pinery-bushfire_', 'portugal-wildfire_', 'sunda-tsunami_', 'woolsey-fire_')
|
||||
|
||||
# running from repo root
|
||||
with open('constants/class_lists/xBD_label_map.json') as label_map_file:
|
||||
LABEL_NAME_TO_NUM = json.load(label_map_file)['label_name_to_num']
|
||||
|
||||
|
||||
def get_dimensions(file_path):
|
||||
""" Returns (width, height, channels) of the image at file_path
|
||||
"""
|
||||
pil_img = imread(file_path)
|
||||
img = np.array(pil_img)
|
||||
w, h, c = img.shape
|
||||
return (w, h, c)
|
||||
|
||||
|
||||
def read_json(json_path):
|
||||
with open(json_path) as f:
|
||||
j = json.load(f)
|
||||
return j
|
||||
|
||||
|
||||
def get_feature_info(feature):
|
||||
"""Reading coordinate and category information from the label json file
|
||||
Args:
|
||||
feature: a python dictionary of json labels
|
||||
Returns a dict mapping the uid of the polygon to a tuple
|
||||
(numpy array of coords, numerical category of the building)
|
||||
"""
|
||||
props = {}
|
||||
|
||||
for feat in feature['features']['xy']:
|
||||
# read the coordinates
|
||||
feat_shape = wkt.loads(feat['wkt'])
|
||||
coords = list(mapping(feat_shape)['coordinates'][0]) # a new, independent geometry with coordinates copied
|
||||
|
||||
# determine the damage type
|
||||
if 'subtype' in feat['properties']:
|
||||
damage_class = feat['properties']['subtype']
|
||||
else:
|
||||
damage_class = 'no-damage' # usually for pre images - assign them to the no-damage class
|
||||
|
||||
damage_class_num = LABEL_NAME_TO_NUM[damage_class] # get the numerical label
|
||||
|
||||
# maps to (numpy array of coords, numerical category of the building)
|
||||
props[feat['properties']['uid']] = (np.array(coords, np.int32), damage_class_num)
|
||||
return props
|
||||
|
||||
|
||||
def mask_polygons_together_with_border(size, polys, border):
|
||||
"""
|
||||
|
||||
Args:
|
||||
size: A tuple of (width, height, channels)
|
||||
polys: A dict of feature uid: (numpy array of coords, numerical category of the building), from
|
||||
get_feature_info()
|
||||
border: Pixel width to shrink each shape by to create some space between adjacent shapes
|
||||
|
||||
Returns:
|
||||
a dict of masked polygons with the shapes filled in from cv2.fillPoly
|
||||
"""
|
||||
|
||||
# For each WKT polygon, read the WKT format and fill the polygon as an image
|
||||
mask_img = np.zeros(size, np.uint8) # 0 is the background class
|
||||
|
||||
for uid, tup in polys.items():
|
||||
# poly is a np.ndarray
|
||||
poly, damage_class_num = tup
|
||||
|
||||
# blank = np.zeros(size, np.uint8)
|
||||
|
||||
# Creating a shapely polygon object out of the numpy array
|
||||
polygon = Polygon(poly)
|
||||
|
||||
# Getting the center points from the polygon and the polygon points
|
||||
(poly_center_x, poly_center_y) = polygon.centroid.coords[0]
|
||||
polygon_points = polygon.exterior.coords
|
||||
|
||||
# Setting a new polygon with each X,Y manipulated based off the center point
|
||||
shrunk_polygon = []
|
||||
for (x, y) in polygon_points:
|
||||
if x < poly_center_x:
|
||||
x += border
|
||||
elif x > poly_center_x:
|
||||
x -= border
|
||||
|
||||
if y < poly_center_y:
|
||||
y += border
|
||||
elif y > poly_center_y:
|
||||
y -= border
|
||||
|
||||
shrunk_polygon.append([x, y])
|
||||
|
||||
# Transforming the polygon back to a np.ndarray
|
||||
ns_poly = np.array(shrunk_polygon, np.int32)
|
||||
|
||||
# Filling the shrunken polygon to add a border between close polygons
|
||||
# Assuming there is no overlap!
|
||||
fillPoly(mask_img, [ns_poly], (damage_class_num, damage_class_num, damage_class_num))
|
||||
|
||||
mask_img = mask_img[:, :, 0].squeeze()
|
||||
print(f'shape of final mask_img: {mask_img.shape}')
|
||||
return mask_img
|
||||
|
||||
|
||||
def mask_tiles(images_dir, label_paths, targets_dir, border_width, overwrite_target):
|
||||
|
||||
for label_path in tqdm(label_paths):
|
||||
|
||||
tile_id = os.path.basename(label_path).split('.json')[0] # just the file name without extension
|
||||
image_path = os.path.join(images_dir, f'{tile_id}.png')
|
||||
target_path = os.path.join(targets_dir, f'{tile_id}_b{border_width}.png')
|
||||
|
||||
if os.path.exists(target_path) and not overwrite_target:
|
||||
continue
|
||||
|
||||
# read the label json
|
||||
label_json = read_json(label_path)
|
||||
|
||||
# read the image and get its size
|
||||
tile_size = get_dimensions(image_path)
|
||||
|
||||
# read in the polygons from the json file
|
||||
polys = get_feature_info(label_json)
|
||||
|
||||
mask_img = mask_polygons_together_with_border(tile_size, polys, border_width)
|
||||
|
||||
imwrite(target_path, mask_img)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Create masks for each label json file for disasters specified at the top of the script.')
|
||||
parser.add_argument(
|
||||
'root_dir',
|
||||
help=('Path to the directory that contains both the `images` and `labels` folders. '
|
||||
'The `targets_border{border_width}` folder will be created if it does not already exist.')
|
||||
)
|
||||
parser.add_argument(
|
||||
'-b', '--border_width',
|
||||
type=int,
|
||||
default=1
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o', '--overwrite_target',
|
||||
help='flag if we want to generate all targets anew',
|
||||
action='store_true'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
images_dir = os.path.join(args.root_dir, 'images')
|
||||
labels_dir = os.path.join(args.root_dir, 'labels')
|
||||
|
||||
assert os.path.exists(args.root_dir), 'root_dir does not exist'
|
||||
assert os.path.isdir(args.root_dir), 'root_dir needs to be path to a directory'
|
||||
assert os.path.exists(images_dir), 'root_dir does not contain the folder `images`'
|
||||
assert os.path.exists(labels_dir), 'root_dir does not contain the folder `labels`'
|
||||
assert args.border_width >= 0, 'border_width < 0'
|
||||
assert args.border_width < 5, 'specified border_width is > 4 pixels - are you sure?'
|
||||
|
||||
assert isinstance(DISASTERS_OF_INTEREST, tuple)
|
||||
for i in DISASTERS_OF_INTEREST:
|
||||
assert i.endswith('_')
|
||||
|
||||
print(f'Disasters to create the masks for: {DISASTERS_OF_INTEREST}')
|
||||
|
||||
targets_dir = os.path.join(args.root_dir, f'targets_border{args.border_width}')
|
||||
print(f'A targets directory is at {targets_dir}')
|
||||
os.makedirs(targets_dir, exist_ok=True)
|
||||
|
||||
# list out label files for the disaster of interest
|
||||
li_label_fn = os.listdir(labels_dir)
|
||||
li_label_fn = sorted([i for i in li_label_fn if i.endswith('.json')])
|
||||
li_label_paths = [os.path.join(labels_dir, i) for i in li_label_fn if i.startswith(DISASTERS_OF_INTEREST)]
|
||||
|
||||
print(f'{len(li_label_fn)} label jsons found in labels_dir, '
|
||||
f'{len(li_label_paths)} are for the disasters of interest.')
|
||||
|
||||
mask_tiles(images_dir, li_label_paths, targets_dir, args.border_width, args.overwrite_target)
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from glob import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms.functional as TF
|
||||
from skimage import transform
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
class DisasterDataset(Dataset):
|
||||
def __init__(self, data_dir, data_dir_ls, data_mean_stddev, transform:bool, normalize:bool):
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.dataset_sub_dir = data_dir_ls
|
||||
self.data_mean_stddev = data_mean_stddev
|
||||
self.transform = transform
|
||||
self.normalize = normalize
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset_sub_dir)
|
||||
|
||||
def __getitem__(self, i):
|
||||
|
||||
imgs_dir = self.data_dir + self.dataset_sub_dir[i].replace('labels', 'images')
|
||||
masks_dir = self.data_dir + self.dataset_sub_dir[i].replace('labels', 'targets_border2')
|
||||
|
||||
idx = imgs_dir
|
||||
|
||||
img_suffix = '_' + imgs_dir.split('_')[-1]
|
||||
mask_suffix = '_' + masks_dir.split('_')[-1]
|
||||
|
||||
|
||||
pre_img_tile_name = imgs_dir[0:-1*(len(img_suffix))] + '_pre_disaster'
|
||||
pre_img_file_name = imgs_dir[0:-1*(len(img_suffix))] + '_pre_disaster' + img_suffix
|
||||
pre_img_file = glob(pre_img_file_name + '.*')
|
||||
|
||||
mask_file_name = masks_dir[0:-1*(len(mask_suffix))] + '_pre_disaster_b2' + mask_suffix
|
||||
mask_file = glob(mask_file_name + '.*')
|
||||
|
||||
post_img_tile_name = pre_img_tile_name.replace('pre', 'post')
|
||||
post_img_file_name = pre_img_file_name.replace('pre', 'post')
|
||||
post_img_file = glob(post_img_file_name + '.*')
|
||||
|
||||
damage_class_file_name = mask_file_name.replace('pre', 'post')
|
||||
damage_class_file = glob(damage_class_file_name + '.*')
|
||||
|
||||
assert len(mask_file) == 1, \
|
||||
f'Either no mask or multiple masks found for the ID {idx}: {mask_file_name}'
|
||||
assert len(pre_img_file) == 1, \
|
||||
f'Either no image or multiple images found for the ID {idx}: {pre_img_file_name}'
|
||||
assert len(post_img_file) == 1, \
|
||||
f'Either no post disaster image or multiple images found for the ID {idx}: {post_img_file_name}'
|
||||
assert len(damage_class_file) == 1, \
|
||||
f'Either no damage class image or multiple images found for the ID {idx}: {damage_class_file_name}'
|
||||
|
||||
mask = cv2.imread(mask_file[0], cv2.IMREAD_GRAYSCALE)
|
||||
pre_img = cv2.imread(pre_img_file[0])
|
||||
post_img = cv2.imread(post_img_file[0])
|
||||
damage_class = cv2.imread(damage_class_file[0], cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
assert pre_img.shape[0] == mask.shape[0], \
|
||||
f'Image and building mask {idx} should be the same size, but are {pre_img.shape} and {mask.shape}'
|
||||
assert mask.size == damage_class.size, \
|
||||
f'Image and damage classes mask {idx} should be the same size, but are {mask.size} and {damage_class.size}'
|
||||
assert pre_img.size == post_img.size, \
|
||||
f'Pre_ & _post disaster Images {idx} should be the same size, but are {pre_img.size} and {post_img.size}'
|
||||
|
||||
data = {'pre_image': pre_img, 'post_image': post_img, 'building_mask': mask, 'damage_mask': damage_class, 'pre_img_tile_name': pre_img_tile_name}
|
||||
|
||||
return data
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
make_data_shards.py
|
||||
|
||||
This is an additional pre-processing step after tile_and_mask.py to cut chips out of the tiles
|
||||
and store them in large numpy arrays, so they can all be loaded in memory during training.
|
||||
|
||||
The train and val splits will be stored separately to distinguish them.
|
||||
|
||||
This is an improvement on the original approach of chipping during training using LandsatDataset, but it is an
|
||||
extra step, so each new experiment requiring a different input size/set of channels would need to re-run
|
||||
this step. Data augmentation is still added on-the-fly.
|
||||
|
||||
Example invocation:
|
||||
```
|
||||
export AZUREML_DATAREFERENCE_wcsorinoquia=/boto_disk_0/wcs_data/tiles/full_sr_median_2013_2014
|
||||
|
||||
python data/make_chip_shards.py --config_module_path training_wcs/experiments/elevation/elevation_2_config.py --out_dir /boto_disk_0/wcs_data/shards/full_sr_median_2013_2014_elevation
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import numpy as np
|
||||
from dataset_shard_save import DisasterDataset
|
||||
from train_utils import load_json_files, dump_json_files
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
config = {'num_shards': 1,
|
||||
'out_dir': './xBD_sliced_augmented_20_alldisasters_final_mdl_npy/',
|
||||
'data_dir': './xBD_sliced_augmented_20_alldisasters/',
|
||||
'data_splits_json': './nlrc.building-damage-assessment/constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json',
|
||||
'data_mean_stddev': './nlrc.building-damage-assessment/constants/splits/all_disaster_mean_stddev_tiles_0_1.json'}
|
||||
|
||||
def load_dataset():
|
||||
splits = load_json_files(config['data_splits_json'])
|
||||
data_mean_stddev = load_json_files(config['data_mean_stddev'])
|
||||
|
||||
train_ls = []
|
||||
val_ls = []
|
||||
test_ls = []
|
||||
for item, val in splits.items():
|
||||
train_ls += val['train']
|
||||
val_ls += val['val']
|
||||
test_ls += val['test']
|
||||
|
||||
xBD_train = DisasterDataset(config['data_dir'], train_ls, data_mean_stddev, transform=False, normalize=True)
|
||||
xBD_val = DisasterDataset(config['data_dir'], val_ls, data_mean_stddev, transform=False, normalize=True)
|
||||
xBD_test = DisasterDataset(config['data_dir'], test_ls, data_mean_stddev, transform=False, normalize=True)
|
||||
|
||||
print('xBD_disaster_dataset train length: {}'.format(len(xBD_train)))
|
||||
print('xBD_disaster_dataset val length: {}'.format(len(xBD_val)))
|
||||
print('xBD_disaster_dataset test length: {}'.format(len(xBD_test)))
|
||||
|
||||
return xBD_train, xBD_val, xBD_test
|
||||
|
||||
def create_shard(dataset, num_shards):
|
||||
"""Iterate through the dataset to produce shards of chips as numpy arrays, for imagery input and labels.
|
||||
|
||||
Args:
|
||||
dataset: an instance of LandsatDataset, which when iterated, each item contains fields
|
||||
'chip' and 'chip_label'
|
||||
data = {'pre_image': pre_img, 'post_image': post_img, 'building_mask': mask, 'damage_mask': damage_class}
|
||||
|
||||
num_shards: number of numpy arrays to store all chips in
|
||||
|
||||
Returns:
|
||||
returns a 2-tuple, where
|
||||
- the first item is a list of numpy arrays of dimension (num_chips, channel, height, width) with
|
||||
dtype float for the input imagery chips
|
||||
- the second item is a list of numpy arrays of dimension (num_chips, height, width) with
|
||||
dtype int for the label chips.
|
||||
"""
|
||||
pre_image_chips, post_image_chips, bld_mask_chips, dmg_mask_chips, pre_img_tile_name_chips = [], [], [], [], []
|
||||
for item in tqdm(dataset):
|
||||
# not using chip_id and chip_for_display fields
|
||||
pre_image_chips.append(item['pre_image'])
|
||||
post_image_chips.append(item['post_image'])
|
||||
bld_mask_chips.append(item['building_mask'])
|
||||
dmg_mask_chips.append(item['damage_mask'])
|
||||
pre_img_tile_name_chips.append(item['pre_img_tile_name'])
|
||||
|
||||
num_chips = len(pre_image_chips)
|
||||
print(f'Created {num_chips} chips.')
|
||||
|
||||
items_per_shards = math.ceil(num_chips / num_shards)
|
||||
shard_idx = []
|
||||
for i in range(num_shards):
|
||||
shard_idx.append(
|
||||
(i * items_per_shards, (1 + i) * items_per_shards)
|
||||
)
|
||||
|
||||
print('Stacking imagery and label chips into shards')
|
||||
pre_image_chip_shards, post_image_chip_shards, bld_mask_chip_shards, dmg_mask_chip_shards, pre_img_tile_name_chip_shards = [], [], [], [], []
|
||||
for begin_idx, end_idx in shard_idx:
|
||||
if begin_idx < num_chips:
|
||||
pre_image_chip_shard = pre_image_chips[begin_idx:end_idx]
|
||||
pre_image_chip_shard = np.stack(pre_image_chip_shard, axis=0)
|
||||
print(f'dim of input chip shard is {pre_image_chip_shard.shape}, dtype is {pre_image_chip_shard.dtype}')
|
||||
pre_image_chip_shards.append(pre_image_chip_shard)
|
||||
|
||||
post_image_chip_shard = post_image_chips[begin_idx:end_idx]
|
||||
post_image_chip_shard = np.stack(post_image_chip_shard, axis=0)
|
||||
print(f'dim of input chip shard is {post_image_chip_shard.shape}, dtype is {post_image_chip_shard.dtype}')
|
||||
post_image_chip_shards.append(post_image_chip_shard)
|
||||
|
||||
bld_mask_chip_shard = bld_mask_chips[begin_idx:end_idx]
|
||||
bld_mask_chip_shard = np.stack(bld_mask_chip_shard, axis=0)
|
||||
print(f'dim of label chip shard is {bld_mask_chip_shard.shape}, dtype is {bld_mask_chip_shard.dtype}')
|
||||
bld_mask_chip_shards.append(bld_mask_chip_shard)
|
||||
|
||||
dmg_mask_chip_shard = dmg_mask_chips[begin_idx:end_idx]
|
||||
dmg_mask_chip_shard = np.stack(dmg_mask_chip_shard, axis=0)
|
||||
print(f'dim of label chip shard is {dmg_mask_chip_shard.shape}, dtype is {dmg_mask_chip_shard.dtype}')
|
||||
dmg_mask_chip_shards.append(dmg_mask_chip_shard)
|
||||
|
||||
pre_img_tile_name_chip_shard = pre_img_tile_name_chips[begin_idx:end_idx]
|
||||
pre_img_tile_name_chip_shard = np.stack(pre_img_tile_name_chip_shard, axis=0)
|
||||
print(f'dim of pre_img_tile_name_chip_shard chip shard is {pre_img_tile_name_chip_shard.shape}, dtype is {pre_img_tile_name_chip_shard.dtype}')
|
||||
pre_img_tile_name_chip_shards.append(pre_img_tile_name_chip_shard)
|
||||
|
||||
return (pre_image_chip_shards, post_image_chip_shards, bld_mask_chip_shards, dmg_mask_chip_shards, pre_img_tile_name_chip_shards)
|
||||
|
||||
|
||||
def save_shards(out_dir, set_name, pre_image_chip_shards, post_image_chip_shards, bld_mask_chip_shards, dmg_mask_chip_shards, pre_img_tile_name_chip_shards):
|
||||
for i_shard, (pre_image_chip_shard, post_image_chip_shard, bld_mask_chip_shard, dmg_mask_chip_shard, pre_img_tile_name_chip_shard) in enumerate(zip(pre_image_chip_shards, post_image_chip_shards, bld_mask_chip_shards, dmg_mask_chip_shards, pre_img_tile_name_chip_shards)):
|
||||
shard_path = os.path.join(out_dir, f'{set_name}_pre_image_chips_{i_shard}.npy')
|
||||
np.save(shard_path, pre_image_chip_shard)
|
||||
print(f'Saved {shard_path}')
|
||||
|
||||
shard_path = os.path.join(out_dir, f'{set_name}_post_image_chips_{i_shard}.npy')
|
||||
np.save(shard_path, post_image_chip_shard)
|
||||
print(f'Saved {shard_path}')
|
||||
|
||||
shard_path = os.path.join(out_dir, f'{set_name}_bld_mask_chips_{i_shard}.npy')
|
||||
np.save(shard_path, bld_mask_chip_shard)
|
||||
print(f'Saved {shard_path}')
|
||||
|
||||
shard_path = os.path.join(out_dir, f'{set_name}_dmg_mask_chips_{i_shard}.npy')
|
||||
np.save(shard_path, dmg_mask_chip_shard)
|
||||
print(f'Saved {shard_path}')
|
||||
|
||||
shard_path = os.path.join(out_dir, f'{set_name}_pre_img_tile_chips_{i_shard}.npy')
|
||||
np.save(shard_path, pre_img_tile_name_chip_shard)
|
||||
print(f'Saved {shard_path}')
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
out_dir = config['out_dir']
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
train_set, val_set, test_set = load_dataset()
|
||||
|
||||
print('Iterating through the training set to generate chips...')
|
||||
train_pre_image_chip_shards, train_post_image_chip_shards, train_bld_mask_chip_shards, train_dmg_mask_chip_shards, train_pre_img_tile_name_chip_shards = create_shard(train_set, config['num_shards'])
|
||||
save_shards(out_dir, 'train', train_pre_image_chip_shards, train_post_image_chip_shards, train_bld_mask_chip_shards, train_dmg_mask_chip_shards, train_pre_img_tile_name_chip_shards)
|
||||
|
||||
del train_pre_image_chip_shards
|
||||
del train_post_image_chip_shards
|
||||
del train_bld_mask_chip_shards
|
||||
del train_dmg_mask_chip_shards
|
||||
|
||||
print('Iterating through the val set to generate chips...')
|
||||
val_pre_image_chip_shards, val_post_image_chip_shards, val_bld_mask_chip_shards, val_dmg_mask_chip_shards, val_pre_img_tile_name_chip_shards = create_shard(val_set, config['num_shards'])
|
||||
save_shards(out_dir, 'val', val_pre_image_chip_shards, val_post_image_chip_shards, val_bld_mask_chip_shards, val_dmg_mask_chip_shards, val_pre_img_tile_name_chip_shards)
|
||||
|
||||
del val_pre_image_chip_shards
|
||||
del val_post_image_chip_shards
|
||||
del val_bld_mask_chip_shards
|
||||
del val_dmg_mask_chip_shards
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
config = {
|
||||
'batch_size': 1,
|
||||
'data_dir': './xBD/',
|
||||
'sliced_data_dir': './final_mdl_all_disaster_splits/',
|
||||
'disaster_splits_json': './nlrc.building-damage-assessment/constants/splits/final_mdl_all_disaster_splits.json',
|
||||
'disaster_splits_json_sliced': './nlrc.building-damage-assessment/constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json'
|
||||
}
|
||||
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
level=logging.INFO,
|
||||
format='{asctime} {levelname} {message}',
|
||||
style='{',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
xBD_train, xBD_val, xBD_test = load_dataset()
|
||||
|
||||
train_loader = DataLoader(xBD_train, batch_size=config['batch_size'], shuffle=False, num_workers=8)
|
||||
test_loader = DataLoader(xBD_test, batch_size=config['batch_size'], shuffle=False, num_workers=8)
|
||||
val_loader = DataLoader(xBD_val, batch_size=config['batch_size'], shuffle=False, num_workers=8)
|
||||
|
||||
with ThreadPool(3) as pool:
|
||||
loaders = [
|
||||
('TRAIN', train_loader),
|
||||
('TEST', test_loader),
|
||||
('VAL', val_loader)
|
||||
]
|
||||
pool.starmap(iterate_and_slice, loaders)
|
||||
|
||||
logging.info(f'Done')
|
||||
|
||||
return
|
||||
|
||||
|
||||
def iterate_and_slice(split_name: str, data_loader: DataLoader):
|
||||
for batch_idx, data in enumerate(data_loader):
|
||||
logging.info(f'{split_name}: batch_idx {batch_idx} sliced into 20 images.')
|
||||
logging.info(f'Done for split {split_name}')
|
||||
|
||||
|
||||
def load_dataset():
|
||||
splits_all_disasters = load_json_files(config['disaster_splits_json'])
|
||||
sliced_data_json = copy.deepcopy(splits_all_disasters)
|
||||
|
||||
train_ls = []
|
||||
val_ls = []
|
||||
test_ls = []
|
||||
for disaster_name, splits in splits_all_disasters.items():
|
||||
logging.info(f'disaster_name: {disaster_name}.')
|
||||
l = len(splits['train'])
|
||||
logging.info(f'training set number of tiles: {l}.')
|
||||
|
||||
train_ls += splits['train']
|
||||
val_ls += splits['val']
|
||||
test_ls += splits['test']
|
||||
|
||||
for disaster_name, splits in sliced_data_json.items():
|
||||
new_vals_tr = []
|
||||
new_vals_ts = []
|
||||
new_vals_val = []
|
||||
|
||||
for slice_sub in range(0,20):
|
||||
new_vals_tr += [sub + '_' + str(slice_sub) for sub in sliced_data_json[disaster_name]['train']]
|
||||
new_vals_ts += [sub + '_' + str(slice_sub) for sub in sliced_data_json[disaster_name]['test']]
|
||||
new_vals_val += [sub + '_' + str(slice_sub) for sub in sliced_data_json[disaster_name]['val']]
|
||||
|
||||
logging.info(f'disaster_name: {disaster_name}.')
|
||||
logging.info(f'training set number of chips: {len(new_vals_tr)}.')
|
||||
|
||||
sliced_data_json[disaster_name]['train'] = new_vals_tr
|
||||
sliced_data_json[disaster_name]['test'] = new_vals_ts
|
||||
sliced_data_json[disaster_name]['val'] = new_vals_val
|
||||
|
||||
dump_json_files(config['disaster_splits_json_sliced'], sliced_data_json)
|
||||
|
||||
logging.info(f'train dataset length before cropping: {len(train_ls)}.')
|
||||
xBD_train = SliceDataset(config['data_dir'], train_ls, config['sliced_data_dir'])
|
||||
logging.info(f'train dataset length after cropping: {len(xBD_train)}.')
|
||||
|
||||
logging.info(f'test dataset length before cropping: {len(test_ls)}.')
|
||||
xBD_test = SliceDataset(config['data_dir'], test_ls, config['sliced_data_dir'])
|
||||
logging.info(f'test dataset length after cropping: {len(xBD_test)}.')
|
||||
|
||||
logging.info(f'val dataset length before cropping: {len(val_ls)}.')
|
||||
xBD_val = SliceDataset(config['data_dir'], val_ls, config['sliced_data_dir'])
|
||||
logging.info(f'val dataset length after cropping: {len(xBD_val)}.')
|
||||
|
||||
print('xBD_disaster_dataset train length: {}'.format(len(xBD_train)))
|
||||
print('xBD_disaster_dataset val length: {}'.format(len(xBD_val)))
|
||||
print('xBD_disaster_dataset test length: {}'.format(len(xBD_test)))
|
||||
return xBD_train, xBD_val, xBD_test
|
||||
|
||||
|
||||
def load_json_files(json_filename):
|
||||
with open(json_filename) as f:
|
||||
file_content = json.load(f)
|
||||
return file_content
|
||||
|
||||
|
||||
def dump_json_files(json_filename, my_dict):
|
||||
with open(json_filename, 'w') as f:
|
||||
json.dump(my_dict, f, indent=4)
|
||||
return
|
||||
|
||||
|
||||
class SliceDataset(Dataset):
|
||||
def __init__(self, data_dir, data_dir_ls, sliced_data_dir, mask_suffix=''):
|
||||
"""
|
||||
Args:
|
||||
data_dir: root xBD directory
|
||||
data_dir_ls: list of all tiles in this split, across different disasters
|
||||
sliced_data_dir: output directory for chips
|
||||
"""
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.data_dir_ls = data_dir_ls
|
||||
self.sliced_data_dir = sliced_data_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_dir_ls)
|
||||
|
||||
@classmethod
|
||||
def slice_tile(self, mask, pre_img, post_img, damage_class):
|
||||
|
||||
|
||||
img_h = pre_img.size[0]
|
||||
img_w = pre_img.size[1]
|
||||
|
||||
h_idx = [0, 256, 512, 768]
|
||||
w_idx = [0, 256, 512, 768]
|
||||
|
||||
img_h = 256
|
||||
img_w = 256
|
||||
|
||||
sliced_sample_dic = {}
|
||||
counter = 0
|
||||
for i in h_idx:
|
||||
for j in w_idx:
|
||||
mask_sub = TF.crop(mask, i, j, img_h, img_w)
|
||||
pre_img_sub = TF.crop(pre_img, i, j, img_h, img_w)
|
||||
post_img_sub = TF.crop(post_img, i, j, img_h, img_w)
|
||||
damage_class_sub = TF.crop(damage_class, i, j, img_h, img_w)
|
||||
sliced_sample_dic[str(counter)] = {'mask': mask_sub, 'pre_img': pre_img_sub, 'post_img': post_img_sub, 'damage_class': damage_class_sub}
|
||||
counter += 1
|
||||
|
||||
# pick 4 random slices from each tile
|
||||
for item in range(0,4):
|
||||
i = random.randint(5, h_idx[-1]-5)
|
||||
j = random.randint(5, w_idx[-1]-5)
|
||||
mask_sub = TF.crop(mask, i, j, img_h, img_w)
|
||||
pre_img_sub = TF.crop(pre_img, i, j, img_h, img_w)
|
||||
post_img_sub = TF.crop(post_img, i, j, img_h, img_w)
|
||||
damage_class_sub = TF.crop(damage_class, i, j, img_h, img_w)
|
||||
sliced_sample_dic[str(counter)] = {'mask': mask_sub, 'pre_img': pre_img_sub, 'post_img': post_img_sub, 'damage_class': damage_class_sub}
|
||||
counter += 1
|
||||
|
||||
return sliced_sample_dic
|
||||
|
||||
def __getitem__(self, i):
|
||||
|
||||
imgs_dir = os.path.join(self.data_dir, self.data_dir_ls[i].replace('labels', 'images'))
|
||||
masks_dir = os.path.join(self.data_dir, self.data_dir_ls[i].replace('labels', 'targets_border2'))
|
||||
|
||||
idx = imgs_dir
|
||||
|
||||
pre_img_file_name = imgs_dir + '_pre_disaster'
|
||||
pre_img_file = glob(pre_img_file_name + '.*')
|
||||
|
||||
mask_file_name = masks_dir + '_pre_disaster_b2'
|
||||
mask_file = glob(mask_file_name + '.*')
|
||||
|
||||
post_img_file_name = pre_img_file_name.replace('pre', 'post')
|
||||
post_img_file = glob(post_img_file_name + '.*')
|
||||
|
||||
damage_class_file_name = mask_file_name.replace('pre', 'post')
|
||||
damage_class_file = glob(damage_class_file_name + '.*')
|
||||
|
||||
assert len(mask_file) == 1, \
|
||||
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
|
||||
assert len(pre_img_file) == 1, \
|
||||
f'Either no image or multiple images found for the ID {idx}: {pre_img_file}'
|
||||
assert len(post_img_file) == 1, \
|
||||
f'Either no post disaster image or multiple images found for the ID {idx}: {post_img_file}'
|
||||
assert len(damage_class_file) == 1, \
|
||||
f'Either no damage class image or multiple images found for the ID {idx}: {damage_class_file}'
|
||||
|
||||
mask = Image.open(mask_file[0])
|
||||
pre_img = Image.open(pre_img_file[0])
|
||||
post_img = Image.open(post_img_file[0])
|
||||
damage_class = Image.open(damage_class_file[0])
|
||||
|
||||
assert pre_img.size == mask.size, \
|
||||
f'Image and building mask {idx} should be the same size, but are {pre_img.size} and {mask.size}'
|
||||
assert pre_img.size == damage_class.size, \
|
||||
f'Image and damage classes mask {idx} should be the same size, but are {pre_img.size} and {damage_class.size}'
|
||||
assert pre_img.size == post_img.size, \
|
||||
f'Pre & post disaster Images {idx} should be the same size, but are {pre_img.size} and {post_img.size}'
|
||||
|
||||
sliced_sample_dic = self.slice_tile(mask, pre_img, post_img, damage_class)
|
||||
|
||||
for item, val in sliced_sample_dic.items():
|
||||
os.makedirs(
|
||||
os.path.split(mask_file[0].replace(self.data_dir, self.sliced_data_dir))[0],
|
||||
exist_ok=True
|
||||
)
|
||||
val['mask'].save(mask_file[0].replace(self.data_dir, self.sliced_data_dir).replace('.png', '_' + item + '.png'))
|
||||
file_name = mask_file[0].replace(self.data_dir, self.sliced_data_dir).replace('.png', '_' + item + '.png')
|
||||
|
||||
os.makedirs(
|
||||
os.path.split(pre_img_file[0].replace(self.data_dir, self.sliced_data_dir))[0],
|
||||
exist_ok=True
|
||||
)
|
||||
val['pre_img'].save(pre_img_file[0].replace(self.data_dir, self.sliced_data_dir).replace('.png', '_' + item + '.png'))
|
||||
|
||||
os.makedirs(
|
||||
os.path.split(post_img_file[0].replace(self.data_dir, self.sliced_data_dir))[0],
|
||||
exist_ok=True
|
||||
)
|
||||
val['post_img'].save(post_img_file[0].replace(self.data_dir, self.sliced_data_dir).replace('.png', '_' + item + '.png'))
|
||||
|
||||
os.makedirs(
|
||||
os.path.split(damage_class_file[0].replace(self.data_dir, self.sliced_data_dir))[0],
|
||||
exist_ok=True
|
||||
)
|
||||
val['damage_class'].save(damage_class_file[0].replace(self.data_dir, self.sliced_data_dir).replace('.png', '_' + item + '.png'))
|
||||
|
||||
return {
|
||||
'pre_image': TF.to_tensor(pre_img),
|
||||
'post_image': TF.to_tensor(post_img),
|
||||
'building_mask': TF.to_tensor(mask),
|
||||
'damage_mask': TF.to_tensor(damage_class)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,316 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Class to visualize raster mask labels and hardmax or softmax model predictions, for semantic segmentation tasks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Union, Tuple
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageColor
|
||||
|
||||
|
||||
class RasterLabelVisualizer(object):
|
||||
"""Visualizes raster mask labels and predictions."""
|
||||
|
||||
def __init__(self, label_map: Union[str, dict]):
|
||||
"""Constructs a raster label visualizer.
|
||||
|
||||
Args:
|
||||
label_map: a path to a JSON file containing a dict, or a dict. The dict needs to have two fields:
|
||||
|
||||
num_to_name {
|
||||
numerical category (str or int) : display name (str)
|
||||
}
|
||||
|
||||
num_to_color {
|
||||
numerical category (str or int) : color representation (an object that matplotlib.colors recognizes
|
||||
as a color; additionally a (R, G, B) tuple or list with uint8 values will also be parsed)
|
||||
}
|
||||
"""
|
||||
if isinstance(label_map, str):
|
||||
assert os.path.exists(label_map)
|
||||
with open(label_map) as f:
|
||||
label_map = json.load(f)
|
||||
|
||||
assert 'num_to_name' in label_map
|
||||
assert isinstance(label_map['num_to_name'], dict)
|
||||
assert 'num_to_color' in label_map
|
||||
assert isinstance(label_map['num_to_color'], dict)
|
||||
|
||||
self.num_to_name = RasterLabelVisualizer._dict_key_to_int(label_map['num_to_name'])
|
||||
self.num_to_color = RasterLabelVisualizer._dict_key_to_int(label_map['num_to_color'])
|
||||
|
||||
assert len(self.num_to_color) == len(self.num_to_name)
|
||||
self.num_classes = len(self.num_to_name)
|
||||
|
||||
# check for duplicate names or colors
|
||||
assert len(set(self.num_to_color.values())) == self.num_classes, 'There are duplicate colors in the colormap'
|
||||
assert len(set(self.num_to_name.values())) == self.num_classes, \
|
||||
'There are duplicate class names in the colormap'
|
||||
|
||||
self.num_to_color = RasterLabelVisualizer.standardize_colors(self.num_to_color)
|
||||
|
||||
# create the custom colormap according to colors defined in label_map
|
||||
required_colors = []
|
||||
# key is originally a string
|
||||
for num, color_name in sorted(self.num_to_color.items(), key=lambda x: x[0]): # num already cast to int
|
||||
rgb = mcolors.to_rgb(mcolors.CSS4_COLORS[color_name])
|
||||
# mcolors.to_rgb is to [0, 1] values; ImageColor.getrgb gets [1, 255] values
|
||||
required_colors.append(rgb)
|
||||
|
||||
self.colormap = mcolors.ListedColormap(required_colors)
|
||||
# vmin and vmax appear to be inclusive,
|
||||
# so if there are a total of 34 classes, class 0 to class 33 each maps to a color
|
||||
self.normalizer = mcolors.Normalize(vmin=0, vmax=self.num_classes - 1)
|
||||
|
||||
self.color_matrix = self._make_color_matrix()
|
||||
|
||||
@staticmethod
|
||||
def _dict_key_to_int(d: dict) -> dict:
|
||||
return {int(k): v for k, v in d.items()}
|
||||
|
||||
def _make_color_matrix(self) -> np.ndarray:
|
||||
"""Creates a color matrix of dims (num_classes, 3), where a row corresponds to the RGB values of each class.
|
||||
"""
|
||||
matrix = []
|
||||
for num, color in sorted(self.num_to_color.items(), key=lambda x: x[0]):
|
||||
rgb = RasterLabelVisualizer.matplotlib_color_to_uint8_rgb(color)
|
||||
matrix.append(rgb)
|
||||
matrix = np.array(matrix)
|
||||
|
||||
assert matrix.shape == (self.num_classes, 3)
|
||||
|
||||
return matrix
|
||||
|
||||
@staticmethod
|
||||
def standardize_colors(num_to_color: dict) -> dict:
|
||||
"""Return a new dict num_to_color with colors verified. uint8 RGB tuples are converted to a hex string
|
||||
as matplotlib.colors do not accepted uint8 intensity values"""
|
||||
new = {}
|
||||
for num, color in num_to_color.items():
|
||||
if mcolors.is_color_like(color):
|
||||
new[num] = color
|
||||
else:
|
||||
# try to see if it's a (r, g, b) tuple or list of uint8 values
|
||||
assert len(color) == 3 or len(
|
||||
color) == 4, f'Color {color} is specified as a tuple or list but is not of length 3 or 4'
|
||||
for c in color:
|
||||
assert isinstance(c, int) and 0 < c < 256, f'RGB value {c} is out of range'
|
||||
|
||||
new[num] = RasterLabelVisualizer.uint8_rgb_to_hex(color[0], color[1], color[3]) # drop any alpha values
|
||||
assert len(new) == len(num_to_color)
|
||||
return new
|
||||
|
||||
@staticmethod
|
||||
def uint8_rgb_to_hex(r: int, g: int, b: int) -> str:
|
||||
"""Convert RGB values in uint8 to a hex color string
|
||||
|
||||
Reference
|
||||
https://codereview.stackexchange.com/questions/229282/performance-for-simple-code-that-converts-a-rgb-tuple-to-hex-string
|
||||
"""
|
||||
return f'#{r:02x}{g:02x}{b:02x}'
|
||||
|
||||
@staticmethod
|
||||
def matplotlib_color_to_uint8_rgb(color: Union[str, tuple, list]) -> Tuple[int, int, int]:
|
||||
"""Converts any matplotlib recognized color representation to (R, G, B) uint intensity values
|
||||
|
||||
Need to use matplotlib, which recognizes different color formats, to convert to hex,
|
||||
then use PIL to convert to uint8 RGB. matplotlib does not support the uint8 RGB format
|
||||
"""
|
||||
color_hex = mcolors.to_hex(color)
|
||||
color_rgb = ImageColor.getcolor(color_hex, 'RGB') # '#DDA0DD' to (221, 160, 221); alpha silently dropped
|
||||
return color_rgb
|
||||
|
||||
def get_tiff_colormap(self) -> dict:
|
||||
"""Returns the object to pass to rasterio dataset object's write_colormap() function,
|
||||
which is a dict mapping int values to a tuple of (R, G, B)
|
||||
|
||||
See https://rasterio.readthedocs.io/en/latest/topics/color.html for writing the TIFF colormap
|
||||
"""
|
||||
colormap = {}
|
||||
for num, color in self.num_to_color.items():
|
||||
# uint8 RGB required by TIFF
|
||||
colormap[num] = RasterLabelVisualizer.matplotlib_color_to_uint8_rgb(color)
|
||||
return colormap
|
||||
|
||||
def get_tool_colormap(self) -> str:
|
||||
"""Returns a string that is a JSON of a list of items specifying the name and color
|
||||
of classes. Example:
|
||||
"[
|
||||
{"name": "Water", "color": "#0000FF"},
|
||||
{"name": "Tree Canopy", "color": "#008000"},
|
||||
{"name": "Field", "color": "#80FF80"},
|
||||
{"name": "Built", "color": "#806060"}
|
||||
]"
|
||||
"""
|
||||
classes = []
|
||||
for num, name in sorted(self.num_to_name.items(), key=lambda x: int(x[0])):
|
||||
color = self.num_to_color[num]
|
||||
color_hex = mcolors.to_hex(color)
|
||||
classes.append({
|
||||
'name': name,
|
||||
'color': color_hex
|
||||
})
|
||||
classes = json.dumps(classes, indent=4)
|
||||
return classes
|
||||
|
||||
@staticmethod
|
||||
def plot_colortable(name_to_color: dict, title: str, sort_colors: bool = False, emptycols: int = 0) -> plt.Figure:
|
||||
"""
|
||||
function taken from https://matplotlib.org/3.1.0/gallery/color/named_colors.html
|
||||
"""
|
||||
|
||||
cell_width = 212
|
||||
cell_height = 22
|
||||
swatch_width = 70
|
||||
margin = 12
|
||||
topmargin = 40
|
||||
|
||||
# Sort name_to_color by hue, saturation, value and name.
|
||||
if sort_colors is True:
|
||||
by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))),
|
||||
name)
|
||||
for name, color in name_to_color.items())
|
||||
names = [name for hsv, name in by_hsv]
|
||||
else:
|
||||
names = list(name_to_color)
|
||||
|
||||
n = len(names)
|
||||
ncols = 4 - emptycols
|
||||
nrows = n // ncols + int(n % ncols > 0)
|
||||
|
||||
width = cell_width * 4 + 2 * margin
|
||||
height = cell_height * nrows + margin + topmargin
|
||||
dpi = 80 # other numbers don't seem to work well
|
||||
|
||||
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
|
||||
fig.subplots_adjust(margin / width, margin / height,
|
||||
(width - margin) / width, (height - topmargin) / height)
|
||||
ax.set_xlim(0, cell_width * 4)
|
||||
ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.)
|
||||
ax.yaxis.set_visible(False)
|
||||
ax.xaxis.set_visible(False)
|
||||
ax.set_axis_off()
|
||||
ax.set_title(title, fontsize=24, loc='left', pad=10)
|
||||
|
||||
for i, name in enumerate(names):
|
||||
row = i % nrows
|
||||
col = i // nrows
|
||||
y = row * cell_height
|
||||
|
||||
swatch_start_x = cell_width * col
|
||||
swatch_end_x = cell_width * col + swatch_width
|
||||
text_pos_x = cell_width * col + swatch_width + 7
|
||||
|
||||
ax.text(text_pos_x, y, name, fontsize=14,
|
||||
horizontalalignment='left',
|
||||
verticalalignment='center')
|
||||
|
||||
ax.hlines(y, swatch_start_x, swatch_end_x,
|
||||
color=name_to_color[name], linewidth=18)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_color_legend(self, legend_title: str = 'Categories') -> plt.Figure:
|
||||
"""Builds a legend of color block, numerical categories and names of the categories.
|
||||
|
||||
Returns:
|
||||
a matplotlib.pyplot Figure
|
||||
"""
|
||||
label_map = {}
|
||||
for num, color in self.num_to_color.items():
|
||||
label_map['{} {}'.format(num, self.num_to_name[num])] = color
|
||||
|
||||
fig = RasterLabelVisualizer.plot_colortable(label_map, legend_title, sort_colors=False, emptycols=3)
|
||||
return fig
|
||||
|
||||
def show_label_raster(self, label_raster: Union[Image.Image, np.ndarray],
|
||||
size: Tuple[int, int] = (10, 10)) -> Tuple[Image.Image, BytesIO]:
|
||||
"""Visualizes a label mask or hardmax predictions of a model, according to the category color map
|
||||
provided when the class was initialized.
|
||||
|
||||
The label_raster provided needs to contain values in [0, num_classes].
|
||||
|
||||
Args:
|
||||
label_raster: 2D numpy array or PIL Image where each number indicates the pixel's class
|
||||
size: matplotlib size in inches (h, w)
|
||||
|
||||
Returns:
|
||||
(im, buf) - PIL image of the matplotlib figure, and a BytesIO buf containing the matplotlib Figure
|
||||
saved as a PNG
|
||||
"""
|
||||
if not isinstance(label_raster, np.ndarray):
|
||||
label_raster = np.asarray(label_raster)
|
||||
|
||||
label_raster = label_raster.squeeze()
|
||||
assert len(label_raster.shape) == 2, 'label_raster provided has more than 2 dimensions after squeezing'
|
||||
|
||||
label_raster.astype(np.uint8)
|
||||
|
||||
# min of 0, which is usually empty / no label
|
||||
assert np.min(label_raster) >= 0, f'Invalid value for class label: {np.min(label_raster)}'
|
||||
|
||||
# non-empty, actual class labels start at 1
|
||||
assert np.max(label_raster) <= self.num_classes, f'Invalid value for class label: {np.max(label_raster)}'
|
||||
|
||||
_ = plt.figure(figsize=size)
|
||||
_ = plt.imshow(label_raster, cmap=self.colormap, norm=self.normalizer, interpolation='none')
|
||||
|
||||
buf = BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
plt.close()
|
||||
buf.seek(0)
|
||||
im = Image.open(buf)
|
||||
return im, buf
|
||||
|
||||
@staticmethod
|
||||
def visualize_matrix(matrix: np.ndarray) -> Image.Image:
|
||||
"""Shows a 2D matrix of RGB or greyscale values as a PIL Image.
|
||||
|
||||
Args:
|
||||
matrix: a (H, W, 3) or (H, W) numpy array, representing a colored or greyscale image
|
||||
|
||||
Returns:
|
||||
a PIL Image object
|
||||
"""
|
||||
assert len(matrix.shape) in [2, 3]
|
||||
|
||||
image = Image.fromarray(matrix)
|
||||
return image
|
||||
|
||||
def visualize_softmax_predictions(self, softmax_preds: np.ndarray) -> np.ndarray:
|
||||
"""Visualizes softmax probabilities in RGB according to the class label's assigned colors
|
||||
|
||||
Args:
|
||||
softmax_preds: numpy array of dimensions (batch_size, num_classes, H, W) or (num_classes, H, W)
|
||||
|
||||
Returns:
|
||||
numpy array of size ((batch_size), H, W, 3). You may need to roll the last axis to in-front before
|
||||
writing to TIFF
|
||||
|
||||
Raises:
|
||||
ValueError when the dimension of softmax_preds is not compliant
|
||||
"""
|
||||
|
||||
assert len(softmax_preds.shape) == 4 or len(softmax_preds.shape) == 3
|
||||
|
||||
# row the num_classes dimension to the end
|
||||
if len(softmax_preds.shape) == 4:
|
||||
assert softmax_preds.shape[1] == self.num_classes
|
||||
softmax_preds_transposed = np.transpose(softmax_preds, axes=(0, 2, 3, 1))
|
||||
elif len(softmax_preds.shape) == 3:
|
||||
assert softmax_preds.shape[0] == self.num_classes
|
||||
softmax_preds_transposed = np.transpose(softmax_preds, axes=(1, 2, 0))
|
||||
else:
|
||||
raise ValueError('softmax_preds does not have the required length in the dimension of the classes')
|
||||
|
||||
# ((batch_size), H, W, num_classes) @ (num_classes * 3) = ((batch_size), H, W, 3)
|
||||
colored_view = softmax_preds_transposed @ self.color_matrix
|
||||
return colored_view
|
|
@ -0,0 +1,8 @@
|
|||
services:
|
||||
inference:
|
||||
build: .
|
||||
image: "nlrc-building-damage-assessment:latest"
|
||||
volumes:
|
||||
- type: bind
|
||||
source: ${dataset_path}
|
||||
target: "/inference_files/dataset/"
|
|
@ -0,0 +1,62 @@
|
|||
# syntax=docker/dockerfile:1
|
||||
FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04
|
||||
|
||||
# Miniconda archive to install
|
||||
ARG miniconda_version="4.9.2"
|
||||
ARG miniconda_checksum="122c8c9beb51e124ab32a0fa6426c656"
|
||||
ARG conda_version="4.9.2"
|
||||
ARG PYTHON_VERSION=default
|
||||
|
||||
ENV CONDA_DIR=/opt/conda \
|
||||
SHELL=/bin/bash \
|
||||
NB_USER=$NB_USER \
|
||||
NB_UID=$NB_UID \
|
||||
NB_GID=$NB_GID \
|
||||
LANG=en_US.UTF-8 \
|
||||
LANGUAGE=en_US.UTF-8
|
||||
ENV PATH=$CONDA_DIR/bin:$PATH \
|
||||
HOME=/home/$NB_USER
|
||||
|
||||
ENV MINICONDA_VERSION="${miniconda_version}" \
|
||||
CONDA_VERSION="${conda_version}"
|
||||
|
||||
# General OS dependencies
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
RUN apt-get update \
|
||||
&& apt-get install -yq --no-install-recommends \
|
||||
wget \
|
||||
apt-utils \
|
||||
unzip \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
locales \
|
||||
fonts-liberation \
|
||||
unattended-upgrades \
|
||||
run-one \
|
||||
nano \
|
||||
libgl1-mesa-glx \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Miniconda installation
|
||||
WORKDIR /tmp
|
||||
RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-py38_${MINICONDA_VERSION}-Linux-x86_64.sh && \
|
||||
echo "${miniconda_checksum} *Miniconda3-py38_${MINICONDA_VERSION}-Linux-x86_64.sh" | md5sum -c - && \
|
||||
/bin/bash Miniconda3-py38_${MINICONDA_VERSION}-Linux-x86_64.sh -f -b -p $CONDA_DIR && \
|
||||
rm Miniconda3-py38_${MINICONDA_VERSION}-Linux-x86_64.sh && \
|
||||
# Conda configuration see https://conda.io/projects/conda/en/latest/configuration.html
|
||||
echo "conda ${CONDA_VERSION}" >> $CONDA_DIR/conda-meta/pinned && \
|
||||
conda install --quiet --yes "conda=${CONDA_VERSION}" && \
|
||||
conda install --quiet --yes pip && \
|
||||
conda update --all --quiet --yes && \
|
||||
conda clean --all -f -y && \
|
||||
rm -rf /home/$NB_USER/.cache/yarn
|
||||
|
||||
# Inference code
|
||||
WORKDIR /inference_files
|
||||
COPY . /inference_files/
|
||||
|
||||
RUN /opt/conda/bin/conda init bash
|
||||
RUN conda env create --file environment.yml
|
||||
|
||||
ENTRYPOINT ["/opt/conda/envs/nlrc/bin/python", "./inference/inference.py"]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
name: nlrc
|
||||
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
- nvidia
|
||||
- pytorch
|
||||
|
||||
dependencies:
|
||||
- python==3.7.9
|
||||
- pip==21.0.1
|
||||
- rasterio==1.1.0
|
||||
- gdal==3.0.2
|
||||
- jupyterlab==3.0.7
|
||||
- numpy==1.20.1
|
||||
- nb_conda_kernels==2.3.1
|
||||
- tqdm==4.56.2
|
||||
- humanfriendly==9.1
|
||||
- geopandas==0.8.1
|
||||
- shapely==1.7.1
|
||||
- scikit-learn==0.23.2
|
||||
- tensorboard>=1.15
|
||||
- tensorflow>=1.15.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.1
|
||||
- pip:
|
||||
- azure-storage-blob==12.7.1
|
||||
- opencv-python==4.5.1.48
|
||||
- pandas==1.1.5
|
||||
- matplotlib==3.3.4
|
||||
- scikit-image==0.17.2
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
|
||||
Input a set of raster model output files (same format as the masks) and original polygon JSON files, produces
|
||||
polygonized versions of the model output and compute metrics based on an IoU threshold.
|
||||
|
||||
TODO: we need the outer loop to iterate over tiles, and to draw confusion matrices etc.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
import cv2
|
||||
import rasterio.features
|
||||
import numpy as np
|
||||
import shapely.geometry
|
||||
from shapely.geometry import mapping, Polygon
|
||||
from PIL import Image
|
||||
|
||||
from data.create_label_masks import get_feature_info, read_json
|
||||
|
||||
all_classes = set([1, 2, 3, 4, 5])
|
||||
allowed_classes = set([1, 2, 3, 4]) # 5 is Unclassified, not used during evaluation
|
||||
|
||||
|
||||
def _evaluate_tile(pred_polygons_and_class: list,
|
||||
label_polygons_and_class: list,
|
||||
allowed_classes,
|
||||
iou_threshold: float=0.5):
|
||||
"""
|
||||
Method
|
||||
- For each predicted polygon, we find the maximum value of IoU it has with any ground truth
|
||||
polygon within the tile. This ground truth polygon is its "match".
|
||||
- Using the threshold IoU specified (typically and by default 0.5), if a prediction has
|
||||
overlap above the threshold AND the correct class, it is considered a true positive.
|
||||
All other predictions, no matter what their IOU is with any gt, are false positives.
|
||||
- Note that it is possible for one ground truth polygon to be the match for
|
||||
multiple predictions, especially if the IoU threshold is low, but each prediction
|
||||
only has one matching ground truth polygon.
|
||||
- For ground truth polygon not matched by any predictions, it is a false negative.
|
||||
- Given the TP, FP, and FN counts for each class, we can calculate the precision and recall
|
||||
for each tile *for each class*.
|
||||
|
||||
|
||||
- To plot a confusion table, we output two lists, one for the predictions and one for the
|
||||
ground truth polygons (because the set of polygons to confuse over are not the same...)
|
||||
1. For the list of predictions, each item is associated with the ground truth class of
|
||||
the polygon that it matched, or a "false positive" attribute.
|
||||
2. For the list of ground truth polygons, each is associated with the predicted class of
|
||||
the polygon it matched, or a "false negative" attribute.
|
||||
|
||||
Args:
|
||||
pred_polygons_and_class: list of tuples of shapely Polygon representing the geometry of the prediction,
|
||||
and the predicted class
|
||||
label_polygons_and_class: list of tuples of shapely Polygon representing the ground truth geometry,
|
||||
and the class
|
||||
allowed_classes: which classes should be evaluated
|
||||
iou_threshold: Intersection over union threshold above which a predicted polygon is considered
|
||||
true positive
|
||||
|
||||
Returns:
|
||||
results: a dict of dicts, keyed by the class number, and points to a dict with counts of
|
||||
true positives "tp", false positives "fp", and false negatives "fn"
|
||||
list_preds: a list with one entry for each prediction. Each entry is of the form
|
||||
{'pred': 3, 'label': 3}. This information is for a confusion matrix based on the
|
||||
predicted polygons.
|
||||
list_labels: same as list_preds, while each entry corresponds to a ground truth polygon.
|
||||
The value for 'pred' is None if this polygon is a false negative.
|
||||
"""
|
||||
|
||||
# the matched label polygon's IoU with the pred polygon, and the label polygon's index
|
||||
pred_max_iou_w_label = [(0.0, None)] * len(pred_polygons_and_class)
|
||||
|
||||
for i_pred, (pred_poly, pred_class) in enumerate(pred_polygons_and_class):
|
||||
|
||||
# cannot skip pred_class if it's not in the allowed list, as the list above relies on their indices
|
||||
|
||||
for i_label, (label_poly, label_class) in enumerate(label_polygons_and_class):
|
||||
|
||||
if not pred_poly.is_valid:
|
||||
pred_poly = pred_poly.buffer(0)
|
||||
if not label_poly.is_valid:
|
||||
label_poly = label_poly.buffer(0)
|
||||
|
||||
intersection = pred_poly.intersection(label_poly)
|
||||
union = pred_poly.union(label_poly) # they should not have zero area
|
||||
iou = intersection.area / union.area
|
||||
|
||||
if iou > pred_max_iou_w_label[i_pred][0]:
|
||||
pred_max_iou_w_label[i_pred] = (iou, i_label)
|
||||
|
||||
results = defaultdict(lambda: defaultdict(int)) # class: {tp, fp, fn} counts
|
||||
results[-1] = len(pred_polygons_and_class)
|
||||
i_label_polygons_matched = set()
|
||||
list_preds = []
|
||||
list_labels = []
|
||||
|
||||
for i_pred, (pred_poly, pred_class) in enumerate(pred_polygons_and_class):
|
||||
|
||||
if pred_class not in allowed_classes:
|
||||
continue
|
||||
|
||||
max_iou, matched_i_label = pred_max_iou_w_label[i_pred]
|
||||
|
||||
item = {
|
||||
'pred': pred_class,
|
||||
'label': label_polygons_and_class[matched_i_label][1] if matched_i_label is not None else None
|
||||
}
|
||||
|
||||
if matched_i_label is not None:
|
||||
list_labels.append(item)
|
||||
|
||||
list_preds.append(item)
|
||||
|
||||
|
||||
if max_iou > iou_threshold and label_polygons_and_class[matched_i_label][1] == pred_class:
|
||||
# true positive
|
||||
i_label_polygons_matched.add(matched_i_label)
|
||||
results[pred_class]['tp'] += 1
|
||||
for cls in allowed_classes:
|
||||
if cls != pred_class:
|
||||
results[cls]['tn'] += 1
|
||||
else:
|
||||
# false positive - all other predictions
|
||||
results[pred_class]['fp'] += 1 # note that it is a FP for the prediction's class
|
||||
# print(matched_i_label)
|
||||
##results[matched_i_label]['fn'] += 1 # note that it is a FP for the prediction's class
|
||||
|
||||
# calculate the number of false negatives - how many label polygons are not matched by any predictions
|
||||
for i_label, (label_poly, label_class) in enumerate(label_polygons_and_class):
|
||||
|
||||
if label_class not in allowed_classes:
|
||||
continue
|
||||
|
||||
if i_label not in i_label_polygons_matched:
|
||||
results[label_class]['fn'] += 1
|
||||
list_labels.append({
|
||||
'pred': None,
|
||||
'label': label_class
|
||||
})
|
||||
|
||||
return results, list_preds, list_labels
|
||||
|
||||
|
||||
def get_label_and_pred_polygons_for_tile_json_input(path_label_json, path_pred_mask):
|
||||
"""
|
||||
For each tile, cast the polygons specified in the label JSON file to shapely Polygons, and
|
||||
polygonize the prediction mask.
|
||||
|
||||
Args:
|
||||
path_label_json: path to the label JSON file provided by xBD
|
||||
path_pred_mask: path to the PNG or TIFF mask predicted by the model, where each pixel is one
|
||||
of the allowed classes.
|
||||
|
||||
Returns:
|
||||
pred_polygons_and_class: list of tuples of shapely Polygon representing the geometry of the prediction,
|
||||
and the predicted class
|
||||
label_polygons_and_class: list of tuples of shapely Polygon representing the ground truth geometry,
|
||||
and the class
|
||||
"""
|
||||
assert path_label_json.endswith('.json')
|
||||
|
||||
# get the label polygons
|
||||
|
||||
label_json = read_json(path_label_json)
|
||||
polys = get_feature_info(label_json)
|
||||
|
||||
label_polygons_and_class = [] # tuples of (shapely polygon, damage_class_num)
|
||||
|
||||
for uid, tup in polys.items():
|
||||
poly, damage_class_num = tup # poly is a np.ndarray
|
||||
polygon = Polygon(poly)
|
||||
|
||||
if damage_class_num in allowed_classes:
|
||||
label_polygons_and_class.append((polygon, damage_class_num))
|
||||
|
||||
# polygonize the prediction mask
|
||||
|
||||
# 1. Detect the connected components by all non-background classes to determine the predicted
|
||||
# building blobs first (if we do this per class, a building with some pixels predicted to be
|
||||
# in another class will result in more buildings than connected components)
|
||||
mask_pred = np.asarray(Image.open(path_pred_mask))
|
||||
assert len(mask_pred.shape) == 2, 'mask should be 2D only.'
|
||||
|
||||
background_and_others_mask = np.where(mask_pred > 0, 1, 0).astype(np.int16) # all non-background classes become 1
|
||||
|
||||
# rasterio.features.shapes:
|
||||
# default is 4-connected for connectivity - see https://www.mathworks.com/help/images/pixel-connectivity.html
|
||||
# specify the `mask` parameter, otherwise the background will be returned as a shape
|
||||
connected_components = rasterio.features.shapes(background_and_others_mask, mask=mask_pred > 0)
|
||||
|
||||
polygons = []
|
||||
for component_geojson, pixel_val in connected_components:
|
||||
# reference: https://shapely.readthedocs.io/en/stable/manual.html#python-geo-interface
|
||||
shape = shapely.geometry.shape(component_geojson)
|
||||
assert isinstance(shape, Polygon)
|
||||
if shape.area >20:
|
||||
polygons.append(shape)
|
||||
|
||||
# 2. The majority class for each building blob is assigned to be that building's predicted class.
|
||||
polygons_by_class = []
|
||||
|
||||
for c in all_classes:
|
||||
|
||||
# default is 4-connected for connectivity
|
||||
shapes = rasterio.features.shapes(mask_pred, mask=mask_pred == c)
|
||||
|
||||
for shape_geojson, pixel_val in shapes:
|
||||
shape = shapely.geometry.shape(shape_geojson)
|
||||
assert isinstance(shape, Polygon)
|
||||
polygons_by_class.append((shape, int(pixel_val)))
|
||||
|
||||
# we take the class of the shape with the maximum overlap with the building polygon to be the class of the building - majority vote
|
||||
polygons_max_overlap = [0.0] * len(polygons) # indexed by polygon_i
|
||||
polygons_max_overlap_class = [None] * len(polygons)
|
||||
|
||||
assert isinstance(polygons, list) # need the order constant
|
||||
|
||||
for polygon_i, polygon in enumerate(polygons):
|
||||
for shape, shape_class in polygons_by_class:
|
||||
if not shape.is_valid:
|
||||
shape = shape.buffer(0)
|
||||
if not polygon.is_valid:
|
||||
polygon = polygon.buffer(0)
|
||||
intersection_area = polygon.intersection(shape).area
|
||||
if intersection_area > polygons_max_overlap[polygon_i]:
|
||||
polygons_max_overlap[polygon_i] = intersection_area
|
||||
polygons_max_overlap_class[polygon_i] = shape_class
|
||||
|
||||
pred_polygons_and_class = [] # include all classes
|
||||
|
||||
for polygon_i, (max_overlap_area, clss) in enumerate(zip(polygons_max_overlap, polygons_max_overlap_class)):
|
||||
pred_polygons_and_class.append(
|
||||
(polygons[polygon_i], clss)
|
||||
)
|
||||
return pred_polygons_and_class, label_polygons_and_class
|
||||
|
||||
def get_label_and_pred_polygons_for_tile_mask_input(label_mask, path_pred_mask):
|
||||
"""
|
||||
For each tile, polygonize the prediction and label mask.
|
||||
|
||||
Args:
|
||||
label_mask: array that contains label mask
|
||||
path_pred_mask: path to the PNG or TIFF mask predicted by the model, where each pixel is one
|
||||
of the allowed classes.
|
||||
|
||||
Returns:
|
||||
pred_polygons_and_class: list of tuples of shapely Polygon representing the geometry of the prediction,
|
||||
and the predicted class
|
||||
label_polygons_and_class: list of tuples of shapely Polygon representing the ground truth geometry,
|
||||
and the class
|
||||
"""
|
||||
# polygonize the label mask
|
||||
# mask_label = np.asarray(Image.open(path_label_mask))
|
||||
label_polygons_and_class = [] # tuples of (shapely polygon, damage_class_num)
|
||||
# print('label_mask')
|
||||
# print(label_mask.shape)
|
||||
for c in all_classes:
|
||||
|
||||
# default is 4-connected for connectivity
|
||||
shapes = rasterio.features.shapes(label_mask, mask=label_mask == c)
|
||||
|
||||
for shape_geojson, pixel_val in shapes:
|
||||
shape = shapely.geometry.shape(shape_geojson)
|
||||
assert isinstance(shape, Polygon)
|
||||
label_polygons_and_class.append((shape, int(pixel_val)))
|
||||
|
||||
|
||||
# polygonize the prediction mask
|
||||
|
||||
# 1. Detect the connected components by all non-background classes to determine the predicted
|
||||
# building blobs first (if we do this per class, a building with some pixels predicted to be
|
||||
# in another class will result in more buildings than connected components)
|
||||
mask_pred = np.asarray(Image.open(path_pred_mask))
|
||||
# mask_pred = cv2.medianBlur(mask_pred, 17)
|
||||
|
||||
# print('mask_pred')
|
||||
# print(mask_pred.shape)
|
||||
assert len(mask_pred.shape) == 2, 'mask should be 2D only.'
|
||||
|
||||
background_and_others_mask = np.where(mask_pred > 0, 1, 0).astype(np.int16) # all non-background classes become 1
|
||||
|
||||
# rasterio.features.shapes:
|
||||
# default is 4-connected for connectivity - see https://www.mathworks.com/help/images/pixel-connectivity.html
|
||||
# specify the `mask` parameter, otherwise the background will be returned as a shape
|
||||
connected_components = rasterio.features.shapes(background_and_others_mask, mask=mask_pred > 0)
|
||||
|
||||
polygons = []
|
||||
for component_geojson, pixel_val in connected_components:
|
||||
# reference: https://shapely.readthedocs.io/en/stable/manual.html#python-geo-interface
|
||||
shape = shapely.geometry.shape(component_geojson)
|
||||
assert isinstance(shape, Polygon)
|
||||
if shape.area >20:
|
||||
polygons.append(shape)
|
||||
|
||||
# 2. The majority class for each building blob is assigned to be that building's predicted class.
|
||||
polygons_by_class = []
|
||||
|
||||
for c in all_classes:
|
||||
|
||||
# default is 4-connected for connectivity
|
||||
shapes = rasterio.features.shapes(mask_pred, mask=mask_pred == c)
|
||||
|
||||
for shape_geojson, pixel_val in shapes:
|
||||
shape = shapely.geometry.shape(shape_geojson)
|
||||
assert isinstance(shape, Polygon)
|
||||
polygons_by_class.append((shape, int(pixel_val)))
|
||||
|
||||
# we take the class of the shape with the maximum overlap with the building polygon to be the class of the building - majority vote
|
||||
polygons_max_overlap = [0.0] * len(polygons) # indexed by polygon_i
|
||||
polygons_max_overlap_class = [None] * len(polygons)
|
||||
|
||||
assert isinstance(polygons, list) # need the order constant
|
||||
|
||||
for polygon_i, polygon in enumerate(polygons):
|
||||
for shape, shape_class in polygons_by_class:
|
||||
if not shape.is_valid:
|
||||
shape = shape.buffer(0)
|
||||
if not polygon.is_valid:
|
||||
polygon = polygon.buffer(0)
|
||||
intersection_area = polygon.intersection(shape).area
|
||||
if intersection_area > polygons_max_overlap[polygon_i]:
|
||||
polygons_max_overlap[polygon_i] = intersection_area
|
||||
polygons_max_overlap_class[polygon_i] = shape_class
|
||||
|
||||
pred_polygons_and_class = [] # include all classes
|
||||
|
||||
for polygon_i, (max_overlap_area, clss) in enumerate(zip(polygons_max_overlap, polygons_max_overlap_class)):
|
||||
pred_polygons_and_class.append(
|
||||
(polygons[polygon_i], clss)
|
||||
)
|
||||
return pred_polygons_and_class, label_polygons_and_class
|
После Ширина: | Высота: | Размер: 7.6 KiB |
После Ширина: | Высота: | Размер: 7.1 KiB |
После Ширина: | Высота: | Размер: 3.4 KiB |
После Ширина: | Высота: | Размер: 396 KiB |
После Ширина: | Высота: | Размер: 2.0 MiB |
После Ширина: | Высота: | Размер: 1.9 MiB |
После Ширина: | Высота: | Размер: 189 KiB |
|
@ -0,0 +1,344 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
import sys
|
||||
PACKAGE_PARENT = '..'
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
|
||||
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
|
||||
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
import logging
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from datetime import datetime
|
||||
from torchvision import transforms
|
||||
from data.raster_label_visualizer import RasterLabelVisualizer
|
||||
from models.end_to_end_Siam_UNet import SiamUnet
|
||||
from utils.datasets import DisasterDataset
|
||||
from eval.eval_building_level import _evaluate_tile, get_label_and_pred_polygons_for_tile_mask_input, allowed_classes
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
config = {'labels_dmg': [0, 1, 2, 3, 4],
|
||||
'labels_bld': [0, 1]}
|
||||
|
||||
parser = argparse.ArgumentParser(description='Building Damage Assessment Inference')
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='Path to an empty directory where outputs will be saved. This directory will be created if it does not exist.')
|
||||
parser.add_argument('--data_img_dir', type=str, required=True, help='Path to a directory that contains input images.')
|
||||
parser.add_argument('--data_inference_dict', type=str, required=True, help='Path to a json file that contains a dict of path information for each individual image to be used for inference.')
|
||||
parser.add_argument('--data_mean_stddev', type=str, required=True, help='Path to a json file that contains mean and stddev for each tile used for normalization of each image patch.')
|
||||
parser.add_argument('--label_map_json', type=str, required=True, help='Path to a json file that contains information between actual labels and encoded labels for classification task.')
|
||||
parser.add_argument('--model', type=str, required=True, help='Path to a trained model to be used for inference.')
|
||||
parser.add_argument('--gpu', type=str, default="cuda:0", help='GPU to run on.')
|
||||
parser.add_argument('--experiment_name', type=str, default='infer', help='Choose a name for each inference test folder.')
|
||||
parser.add_argument('--num_chips_to_viz', type=int, default=1, help='Number of patches to visualize in tensorboard for monitoring.')
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(stream=sys.stdout,
|
||||
level=logging.INFO,
|
||||
format='{asctime} {levelname} {message}',
|
||||
style='{',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logging.info(f'Using PyTorch version {torch.__version__}.')
|
||||
device = torch.device(args.gpu if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f'Using device: {device}.')
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
eval_results_val_dmg = pd.DataFrame(columns=['class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
eval_results_val_dmg_building_level = pd.DataFrame(columns=['class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
eval_results_val_bld = pd.DataFrame(columns=['class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
|
||||
# set up directories
|
||||
logger_dir = os.path.join(args.output_dir, args.experiment_name, 'logs')
|
||||
os.makedirs(logger_dir, exist_ok=True)
|
||||
|
||||
evals_dir = os.path.join(args.output_dir, args.experiment_name, 'evals')
|
||||
os.makedirs(evals_dir, exist_ok=True)
|
||||
|
||||
output_dir = os.path.join(args.output_dir, args.experiment_name, 'output')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# initialize logger instances
|
||||
logger_test= SummaryWriter(log_dir=logger_dir)
|
||||
|
||||
# load test data
|
||||
global test_dataset, test_loader, labels_set_dmg, labels_set_bld, viz
|
||||
label_map = load_json_files(args.label_map_json)
|
||||
viz = RasterLabelVisualizer(label_map=label_map)
|
||||
test_dataset, samples_idx_list = load_dataset()
|
||||
|
||||
labels_set_dmg = config['labels_dmg']
|
||||
labels_set_bld = config['labels_bld']
|
||||
|
||||
#load model and its state from the given checkpoint
|
||||
model = SiamUnet()
|
||||
checkpoint_path = args.model
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
logging.info('Loading checkpoint from {}'.format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
model = model.to(device=device)
|
||||
logging.info(f"Using checkpoint at epoch {checkpoint['epoch']}, val f1 is {checkpoint.get('val_f1_avg', 'Not Available')}")
|
||||
else:
|
||||
logging.info('No valid checkpoint is provided.')
|
||||
return
|
||||
|
||||
# inference
|
||||
logging.info(f'Start model inference ...')
|
||||
inference_start_time = datetime.now()
|
||||
confusion_mtrx_df_val_dmg, confusion_mtrx_df_val_bld, confusion_mtrx_df_val_dmg_building_level = validate(test_dataset, model, logger_test, samples_idx_list, evals_dir)
|
||||
inference_duration = datetime.now() - inference_start_time
|
||||
logging.info((f'inference duration is {inference_duration.total_seconds()} seconds'))
|
||||
|
||||
logging.info(f'compute actual metrics for model evaluation based on validation set ...')
|
||||
|
||||
# damage level eval validation (pixelwise)
|
||||
eval_results_val_dmg = compute_eval_metrics(labels_set_dmg, confusion_mtrx_df_val_dmg, eval_results_val_dmg)
|
||||
f1_harmonic_mean = 0
|
||||
metrics = 'f1'
|
||||
for index, row in eval_results_val_dmg.iterrows():
|
||||
if (int(row['class']) in labels_set_dmg[1:]) & (metrics == 'f1'):
|
||||
f1_harmonic_mean += 1.0/(row[metrics]+1e-10)
|
||||
f1_harmonic_mean = 4.0/f1_harmonic_mean
|
||||
eval_results_val_dmg = eval_results_val_dmg.append({'class':'harmonic-mean-of-all', 'precision':'-', 'recall':'-', 'f1':f1_harmonic_mean, 'accuracy':'-'}, ignore_index=True)
|
||||
|
||||
# damage level eval validation (building-level)
|
||||
eval_results_val_dmg_building_level = compute_eval_metrics(labels_set_dmg, confusion_mtrx_df_val_dmg_building_level, eval_results_val_dmg_building_level)
|
||||
f1_harmonic_mean = 0
|
||||
metrics = 'f1'
|
||||
for index, row in eval_results_val_dmg_building_level.iterrows():
|
||||
if (int(row['class']) in labels_set_dmg[1:]) & (metrics == 'f1'):
|
||||
f1_harmonic_mean += 1.0/(row[metrics]+1e-10)
|
||||
f1_harmonic_mean = 4.0/f1_harmonic_mean
|
||||
eval_results_val_dmg_building_level = eval_results_val_dmg_building_level.append({'class':'harmonic-mean-of-all', 'precision':'-', 'recall':'-', 'f1':f1_harmonic_mean, 'accuracy':'-'}, ignore_index=True)
|
||||
|
||||
|
||||
# bld detection eval validation (pixelwise)
|
||||
eval_results_val_bld = compute_eval_metrics(labels_set_bld, confusion_mtrx_df_val_bld, eval_results_val_bld)
|
||||
|
||||
# save confusion metrices
|
||||
confusion_mtrx_df_val_bld.to_csv(os.path.join(evals_dir, 'confusion_mtrx_bld.csv'), index=False)
|
||||
confusion_mtrx_df_val_dmg.to_csv(os.path.join(evals_dir, 'confusion_mtrx_dmg.csv'), index=False)
|
||||
confusion_mtrx_df_val_dmg_building_level.to_csv(os.path.join(evals_dir, 'confusion_mtrx_dmg_building_level.csv'), index=False)
|
||||
|
||||
# save evalution metrics
|
||||
eval_results_val_bld.to_csv(os.path.join(evals_dir, 'eval_results_bld.csv'), index=False)
|
||||
eval_results_val_dmg.to_csv(os.path.join(evals_dir, 'eval_results_dmg.csv'), index=False)
|
||||
eval_results_val_dmg_building_level.to_csv(os.path.join(evals_dir, 'eval_results_dmg_building_level.csv'), index=False)
|
||||
|
||||
logging.info('Done')
|
||||
|
||||
return
|
||||
|
||||
def validate(loader, model, logger_test, samples_idx_list, evals_dir):
|
||||
|
||||
"""
|
||||
Evaluate the model on dataset of the loader
|
||||
"""
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
model.eval() # put model to evaluation mode
|
||||
confusion_mtrx_df_val_dmg = pd.DataFrame(columns=['img_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total'])
|
||||
confusion_mtrx_df_val_bld = pd.DataFrame(columns=['img_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total'])
|
||||
confusion_mtrx_df_val_dmg_building_level = pd.DataFrame(columns=['img_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total'])
|
||||
|
||||
with torch.no_grad():
|
||||
for img_idx, data in enumerate(tqdm(loader)): # assume batch size is 1
|
||||
c = data['pre_image'].size()[0]
|
||||
h = data['pre_image'].size()[1]
|
||||
w = data['pre_image'].size()[2]
|
||||
|
||||
x_pre = data['pre_image'].reshape(1, c, h, w).to(device=device)
|
||||
x_post = data['post_image'].reshape(1, c, h, w).to(device=device)
|
||||
y_seg = data['building_mask'].to(device=device)
|
||||
y_cls = data['damage_mask'].to(device=device)
|
||||
|
||||
scores = model(x_pre, x_post)
|
||||
|
||||
# compute accuracy for segmenation model on pre_ images
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
|
||||
# modify damage prediction based on UNet arm predictions
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
preds_cls = torch.argmax(softmax(scores[2]), dim=1)
|
||||
|
||||
path_pred_mask = data['preds_img_dir'] +'.png'
|
||||
logging.info('save png image for damage level predictions: ' + path_pred_mask)
|
||||
im = Image.fromarray(preds_cls.cpu().numpy()[0,:,:].astype(np.uint8))
|
||||
if not os.path.exists(os.path.split(data['preds_img_dir'])[0]):
|
||||
os.makedirs(os.path.split(data['preds_img_dir'])[0])
|
||||
im.save(path_pred_mask)
|
||||
logging.info(f'saved image size: {preds_cls.size()}')
|
||||
|
||||
# compute building-level confusion metrics
|
||||
pred_polygons_and_class, label_polygons_and_class = get_label_and_pred_polygons_for_tile_mask_input(y_cls.cpu().numpy().astype(np.uint8), path_pred_mask)
|
||||
results, list_preds, list_labels = _evaluate_tile(pred_polygons_and_class, label_polygons_and_class, allowed_classes, 0.1)
|
||||
total_objects = results[-1]
|
||||
for label_class in results:
|
||||
if label_class != -1:
|
||||
true_pos_cls = results[label_class]['tp'] if 'tp' in results[label_class].keys() else 0
|
||||
true_neg_cls = results[label_class]['tn'] if 'tn' in results[label_class].keys() else 0
|
||||
false_pos_cls = results[label_class]['fp'] if 'fp' in results[label_class].keys() else 0
|
||||
false_neg_cls = results[label_class]['fn'] if 'fn' in results[label_class].keys() else 0
|
||||
confusion_mtrx_df_val_dmg_building_level = confusion_mtrx_df_val_dmg_building_level.append({'img_idx':img_idx, 'class':label_class, 'true_pos':true_pos_cls, 'true_neg':true_neg_cls, 'false_pos':false_pos_cls, 'false_neg':false_neg_cls, 'total':total_objects}, ignore_index=True)
|
||||
|
||||
# compute comprehensive pixel-level comfusion metrics
|
||||
confusion_mtrx_df_val_dmg = compute_confusion_mtrx(confusion_mtrx_df_val_dmg, img_idx, labels_set_dmg, preds_cls, y_cls, y_seg)
|
||||
confusion_mtrx_df_val_bld = compute_confusion_mtrx(confusion_mtrx_df_val_bld, img_idx, labels_set_bld, preds_seg_pre, y_seg, [])
|
||||
|
||||
# add viz results to logger
|
||||
if img_idx in samples_idx_list:
|
||||
prepare_for_vis(img_idx, logger_test, model, device, softmax)
|
||||
|
||||
return confusion_mtrx_df_val_dmg, confusion_mtrx_df_val_bld, confusion_mtrx_df_val_dmg_building_level
|
||||
|
||||
def load_dataset():
|
||||
splits = load_json_files(args.data_inference_dict)
|
||||
data_mean_stddev = load_json_files(args.data_mean_stddev)
|
||||
test_ls = []
|
||||
for item, val in splits.items():
|
||||
test_ls += val['test']
|
||||
test_dataset = DisasterDataset(args.data_img_dir, test_ls, data_mean_stddev, transform=False, normalize=True)
|
||||
logging.info('xBD_disaster_dataset test length: {}'.format(len(test_dataset)))
|
||||
assert len(test_dataset) > 0
|
||||
samples_idx_list = get_sample_images(test_dataset)
|
||||
logging.info('items selected for viz: {}'.format(samples_idx_list))
|
||||
return test_dataset, samples_idx_list
|
||||
|
||||
def compute_eval_metrics(labels_set, confusion_mtrx_df, eval_results):
|
||||
for cls in labels_set:
|
||||
class_idx = (confusion_mtrx_df['class']==cls)
|
||||
precision = confusion_mtrx_df.loc[class_idx,'true_pos'].sum()/(confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'false_pos'].sum() + sys.float_info.epsilon)
|
||||
recall = confusion_mtrx_df.loc[class_idx,'true_pos'].sum()/(confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'false_neg'].sum() + sys.float_info.epsilon)
|
||||
f1 = 2 * (precision * recall)/(precision + recall + sys.float_info.epsilon)
|
||||
accuracy = (confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'true_neg'].sum())/(confusion_mtrx_df.loc[class_idx,'total'].sum() + sys.float_info.epsilon)
|
||||
eval_results = eval_results.append({'class':cls, 'precision':precision, 'recall':recall, 'f1':f1, 'accuracy':accuracy}, ignore_index=True)
|
||||
return eval_results
|
||||
|
||||
def compute_confusion_mtrx(confusion_mtrx_df, img_idx, labels_set, y_preds, y_true, y_true_bld_mask):
|
||||
for cls in labels_set[1:]:
|
||||
confusion_mtrx_df = compute_confusion_mtrx_class(confusion_mtrx_df, img_idx, labels_set, y_preds, y_true, y_true_bld_mask, cls)
|
||||
return confusion_mtrx_df
|
||||
|
||||
def compute_confusion_mtrx_class(confusion_mtrx_df, img_idx, labels_set, y_preds, y_true, y_true_bld_mask, cls):
|
||||
|
||||
y_true_binary = y_true.detach().clone()
|
||||
y_preds_binary = y_preds.detach().clone()
|
||||
|
||||
if len(labels_set) > 2:
|
||||
# convert to 0/1
|
||||
y_true_binary[y_true_binary != cls] = -1
|
||||
y_preds_binary[y_preds_binary != cls] = -1
|
||||
|
||||
y_true_binary[y_true_binary == cls] = 1
|
||||
y_preds_binary[y_preds_binary == cls] = 1
|
||||
|
||||
y_true_binary[y_true_binary == -1] = 0
|
||||
y_preds_binary[y_preds_binary == -1] = 0
|
||||
|
||||
# compute confusion metric
|
||||
true_pos_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 1) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
false_neg_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 1) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
true_neg_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 0) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
false_pos_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 0) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
|
||||
# compute total pixels
|
||||
total_pixels = y_true_bld_mask.float().sum().item()
|
||||
|
||||
else:
|
||||
|
||||
# compute confusion metric
|
||||
true_pos_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 1)).float().sum().item()
|
||||
false_neg_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 1)).float().sum().item()
|
||||
true_neg_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 0)).float().sum().item()
|
||||
false_pos_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 0)).float().sum().item()
|
||||
|
||||
# compute total pixels
|
||||
total_pixels = 1
|
||||
for item in y_true_binary.size():
|
||||
total_pixels *= item
|
||||
|
||||
confusion_mtrx_df = confusion_mtrx_df.append({'class':cls, 'img_idx':img_idx, 'true_pos':true_pos_cls, 'true_neg':true_neg_cls, 'false_pos':false_pos_cls, 'false_neg':false_neg_cls, 'total':total_pixels}, ignore_index=True)
|
||||
return confusion_mtrx_df
|
||||
|
||||
def prepare_for_vis(item, logger, model, device, softmax):
|
||||
|
||||
iteration = 0
|
||||
data = test_dataset[item]
|
||||
c = data['pre_image'].size()[0]
|
||||
h = data['pre_image'].size()[1]
|
||||
w = data['pre_image'].size()[2]
|
||||
|
||||
pre = data['pre_image'].reshape(1, c, h, w)
|
||||
post = data['post_image'].reshape(1, c, h, w)
|
||||
|
||||
scores = model(pre.to(device=device), post.to(device=device))
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
preds_seg_post = torch.argmax(softmax(scores[1]), dim=1)
|
||||
|
||||
# modify damage prediction based on UNet arm predictions
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
|
||||
# visualize predictions and add to tensorboard
|
||||
tag = 'pr_bld_mask_pre_test_id_' + str(item)
|
||||
logger.add_image(tag, preds_seg_pre, iteration, dataformats='CHW')
|
||||
|
||||
tag = 'pr_bld_mask_post_test_id_' + str(item)
|
||||
logger.add_image(tag, preds_seg_post, iteration, dataformats='CHW')
|
||||
|
||||
tag = 'pr_dmg_mask_test_id_' + str(item)
|
||||
im, buf = viz.show_label_raster(torch.argmax(softmax(scores[2]), dim=1).cpu().numpy(), size=(5, 5))
|
||||
preds_cls = transforms.ToTensor()(transforms.ToPILImage()(np.array(im)).convert("RGB"))
|
||||
logger.add_image(tag, preds_cls, iteration, dataformats='CHW')
|
||||
|
||||
# visualize GT and add to tensorboard
|
||||
pre = data['pre_image']
|
||||
tag = 'gt_img_pre_test_id_' + str(item)
|
||||
logger.add_image(tag, data['pre_image_orig'], iteration, dataformats='CHW')
|
||||
|
||||
post = data['post_image']
|
||||
tag = 'gt_img_post_test_id_' + str(item)
|
||||
logger.add_image(tag, data['post_image_orig'], iteration, dataformats='CHW')
|
||||
|
||||
gt_seg = data['building_mask'].reshape(1, h, w)
|
||||
tag = 'gt_bld_mask_test_id_' + str(item)
|
||||
logger.add_image(tag, gt_seg, iteration, dataformats='CHW')
|
||||
|
||||
im, buf = viz.show_label_raster(np.array(data['damage_mask']), size=(5, 5))
|
||||
gt_cls = transforms.ToTensor()(transforms.ToPILImage()(np.array(im)).convert("RGB"))
|
||||
tag = 'gt_dmg_mask_test_id_' + str(item)
|
||||
logger.add_image(tag, gt_cls, iteration, dataformats='CHW')
|
||||
return
|
||||
|
||||
def get_sample_images(dataset):
|
||||
|
||||
assert len(dataset) > args.num_chips_to_viz
|
||||
|
||||
samples_idx_list = []
|
||||
from random import randint
|
||||
for sample_idx in range(0, args.num_chips_to_viz):
|
||||
value = randint(0, len(dataset))
|
||||
samples_idx_list.append(value)
|
||||
|
||||
return samples_idx_list
|
||||
|
||||
def load_json_files(json_filename):
|
||||
with open(json_filename) as f:
|
||||
file_content = json.load(f)
|
||||
return file_content
|
||||
|
||||
def dump_json_files(json_filename, my_dict):
|
||||
with open(json_filename, 'w') as f:
|
||||
json.dump(my_dict, f, indent=4)
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from PIL import Image
|
||||
from glob import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms.functional as TF
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import os
|
||||
class DisasterDataset(Dataset):
|
||||
def __init__(self, data_dir, data_dir_ls, data_mean_stddev, transform:bool, normalize:bool):
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.dataset_sub_dir = data_dir_ls
|
||||
self.data_mean_stddev = data_mean_stddev
|
||||
self.transform = transform
|
||||
self.normalize = normalize
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset_sub_dir)
|
||||
|
||||
@classmethod
|
||||
def apply_transform(self, mask, pre_img, post_img, damage_class):
|
||||
'''
|
||||
apply tranformation functions on PIL images
|
||||
'''
|
||||
if random.random() > 0.5:
|
||||
# Resize
|
||||
img_h = pre_img.size[0]
|
||||
img_w = pre_img.size[1]
|
||||
|
||||
resize = transforms.Resize(size=(int(round(1.016*img_h)), int(round(1.016*img_w))))
|
||||
mask = resize(mask)
|
||||
pre_img = resize(pre_img)
|
||||
post_img = resize(post_img)
|
||||
damage_class = resize(damage_class)
|
||||
|
||||
# Random crop
|
||||
i, j, h, w = transforms.RandomCrop.get_params(pre_img, output_size=(img_h, img_w))
|
||||
mask = TF.crop(mask, i, j, h, w)
|
||||
pre_img = TF.crop(pre_img, i, j, h, w)
|
||||
post_img = TF.crop(post_img, i, j, h, w)
|
||||
damage_class = TF.crop(damage_class, i, j, h, w)
|
||||
|
||||
# Random horizontal flipping
|
||||
if random.random() > 0.5:
|
||||
mask = TF.hflip(mask)
|
||||
pre_img = TF.hflip(pre_img)
|
||||
post_img = TF.hflip(post_img)
|
||||
damage_class = TF.hflip(damage_class)
|
||||
|
||||
# Random vertical flipping
|
||||
if random.random() > 0.5:
|
||||
mask = TF.vflip(mask)
|
||||
pre_img = TF.vflip(pre_img)
|
||||
post_img = TF.vflip(post_img)
|
||||
damage_class = TF.vflip(damage_class)
|
||||
|
||||
return mask, pre_img, post_img, damage_class
|
||||
|
||||
def __getitem__(self, i):
|
||||
|
||||
imgs_dir = os.path.join(self.data_dir ,self.dataset_sub_dir[i].replace('labels', 'images'))
|
||||
imgs_dir_tile = self.dataset_sub_dir[i].replace('labels', 'images')
|
||||
masks_dir = os.path.join(self.data_dir, self.dataset_sub_dir[i].replace('labels', 'targets_border2'))
|
||||
preds_dir = os.path.join(self.data_dir ,self.dataset_sub_dir[i].replace('labels', 'predictions'))
|
||||
|
||||
idx = imgs_dir
|
||||
|
||||
img_suffix = '_' + imgs_dir.split('_')[-1]
|
||||
img_suffix_tile = '_' + imgs_dir_tile.split('_')[-1]
|
||||
mask_suffix = '_' + masks_dir.split('_')[-1]
|
||||
|
||||
pre_img_tile_name = imgs_dir_tile[0:-1*(len(img_suffix_tile))] + '_pre_disaster'
|
||||
pre_img_file_name = imgs_dir[0:-1*(len(img_suffix))] + '_pre_disaster' + img_suffix
|
||||
pre_img_file = glob(pre_img_file_name + '.*')
|
||||
|
||||
mask_file_name = masks_dir[0:-1*(len(mask_suffix))] + '_pre_disaster_b2' + mask_suffix
|
||||
mask_file = glob(mask_file_name + '.*')
|
||||
|
||||
post_img_tile_name = pre_img_tile_name.replace('pre', 'post')
|
||||
post_img_file_name = pre_img_file_name.replace('pre', 'post')
|
||||
post_img_file = glob(post_img_file_name + '.*')
|
||||
|
||||
damage_class_file_name = mask_file_name.replace('pre', 'post')
|
||||
damage_class_file = glob(damage_class_file_name + '.*')
|
||||
|
||||
assert len(mask_file) == 1, \
|
||||
f'Either no mask or multiple masks found for the ID {idx}: {mask_file_name}'
|
||||
assert len(pre_img_file) == 1, \
|
||||
f'Either no image or multiple images found for the ID {idx}: {pre_img_file_name}'
|
||||
assert len(post_img_file) == 1, \
|
||||
f'Either no post disaster image or multiple images found for the ID {idx}: {post_img_file_name}'
|
||||
assert len(damage_class_file) == 1, \
|
||||
f'Either no damage class image or multiple images found for the ID {idx}: {damage_class_file_name}'
|
||||
|
||||
mask = Image.open(mask_file[0])
|
||||
pre_img = Image.open(pre_img_file[0])
|
||||
post_img = Image.open(post_img_file[0])
|
||||
damage_class = Image.open(damage_class_file[0])
|
||||
|
||||
assert pre_img.size == mask.size, \
|
||||
f'Image and building mask {idx} should be the same size, but are {pre_img.size} and {mask.size}'
|
||||
assert pre_img.size == damage_class.size, \
|
||||
f'Image and damage classes mask {idx} should be the same size, but are {pre_img.size} and {damage_class.size}'
|
||||
assert pre_img.size == post_img.size, \
|
||||
f'Pre_ & _post disaster Images {idx} should be the same size, but are {pre_img.size} and {post_img.size}'
|
||||
|
||||
if self.transform is True:
|
||||
mask, pre_img, post_img, damage_class = self.apply_transform(mask, pre_img, post_img, damage_class)
|
||||
|
||||
# copy original image for viz
|
||||
pre_img_orig = pre_img
|
||||
post_img_orig = post_img
|
||||
|
||||
if self.normalize is True:
|
||||
# normalize the images based on a tilewise mean & std dev --> pre_
|
||||
mean_pre = self.data_mean_stddev[pre_img_tile_name][0]
|
||||
stddev_pre = self.data_mean_stddev[pre_img_tile_name][1]
|
||||
norm_pre = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean_pre, std=stddev_pre)
|
||||
])
|
||||
pre_img = norm_pre(np.array(pre_img).astype(dtype='float64')/255.0)
|
||||
|
||||
# normalize the images based on a tilewise mean & std dev --> post_
|
||||
mean_post = self.data_mean_stddev[post_img_tile_name][0]
|
||||
stddev_post = self.data_mean_stddev[post_img_tile_name][1]
|
||||
norm_post = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean_post, std=stddev_post)
|
||||
])
|
||||
post_img = norm_post(np.array(post_img).astype(dtype='float64')/255.0)
|
||||
|
||||
# convert eveything to arrays
|
||||
pre_img = np.array(pre_img)
|
||||
post_img = np.array(post_img)
|
||||
mask = np.array(mask)
|
||||
damage_class = np.array(damage_class)
|
||||
|
||||
# replace non-classified pixels with background
|
||||
damage_class = np.where(damage_class==5, 0, damage_class)
|
||||
|
||||
return {'pre_image': torch.from_numpy(pre_img).type(torch.FloatTensor), 'post_image': torch.from_numpy(post_img).type(torch.FloatTensor), 'building_mask': torch.from_numpy(mask).type(torch.LongTensor), 'damage_mask': torch.from_numpy(damage_class).type(torch.LongTensor), 'pre_image_orig': transforms.ToTensor()(pre_img_orig), 'post_image_orig': transforms.ToTensor()(post_img_orig), 'img_file_idx':imgs_dir[0:-1*(len(img_suffix))].split('/')[-1] + img_suffix, 'preds_img_dir':preds_dir}
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SiamUnet(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=3, out_channels_s=2, out_channels_c=5, init_features=16):
|
||||
super(SiamUnet, self).__init__()
|
||||
|
||||
features = init_features
|
||||
|
||||
# UNet layers
|
||||
self.encoder1 = SiamUnet._block(in_channels, features, name="enc1")
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.encoder2 = SiamUnet._block(features, features * 2, name="enc2")
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.encoder3 = SiamUnet._block(features * 2, features * 4, name="enc3")
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.encoder4 = SiamUnet._block(features * 4, features * 8, name="enc4")
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.bottleneck = SiamUnet._block(features * 8, features * 16, name="bottleneck")
|
||||
|
||||
self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
|
||||
self.decoder4 = SiamUnet._block((features * 8) * 2, features * 8, name="dec4")
|
||||
self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
|
||||
self.decoder3 = SiamUnet._block((features * 4) * 2, features * 4, name="dec3")
|
||||
self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
|
||||
self.decoder2 = SiamUnet._block((features * 2) * 2, features * 2, name="dec2")
|
||||
self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
|
||||
self.decoder1 = SiamUnet._block(features * 2, features, name="dec1")
|
||||
|
||||
self.conv_s = nn.Conv2d(in_channels=features, out_channels=out_channels_s, kernel_size=1)
|
||||
|
||||
# Siamese classifier layers
|
||||
self.upconv4_c = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
|
||||
self.conv4_c = SiamUnet._block(features * 16, features * 16, name="conv4")
|
||||
|
||||
self.upconv3_c = nn.ConvTranspose2d(features * 16, features * 4, kernel_size=2, stride=2)
|
||||
self.conv3_c = SiamUnet._block(features * 8, features * 8, name="conv3")
|
||||
|
||||
self.upconv2_c = nn.ConvTranspose2d(features * 8, features * 2, kernel_size=2, stride=2)
|
||||
self.conv2_c = SiamUnet._block(features * 4, features * 4, name="conv2")
|
||||
|
||||
self.upconv1_c = nn.ConvTranspose2d(features * 4, features, kernel_size=2, stride=2)
|
||||
self.conv1_c = SiamUnet._block(features * 2, features * 2, name="conv1")
|
||||
|
||||
self.conv_c = nn.Conv2d(in_channels=features * 2, out_channels=out_channels_c, kernel_size=1)
|
||||
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
# UNet on x1
|
||||
enc1_1 = self.encoder1(x1)
|
||||
enc2_1 = self.encoder2(self.pool1(enc1_1))
|
||||
enc3_1 = self.encoder3(self.pool2(enc2_1))
|
||||
enc4_1 = self.encoder4(self.pool3(enc3_1))
|
||||
|
||||
bottleneck_1 = self.bottleneck(self.pool4(enc4_1))
|
||||
|
||||
dec4_1 = self.upconv4(bottleneck_1)
|
||||
dec4_1 = torch.cat((dec4_1, enc4_1), dim=1)
|
||||
dec4_1 = self.decoder4(dec4_1)
|
||||
dec3_1 = self.upconv3(dec4_1)
|
||||
dec3_1 = torch.cat((dec3_1, enc3_1), dim=1)
|
||||
dec3_1 = self.decoder3(dec3_1)
|
||||
dec2_1 = self.upconv2(dec3_1)
|
||||
dec2_1 = torch.cat((dec2_1, enc2_1), dim=1)
|
||||
dec2_1 = self.decoder2(dec2_1)
|
||||
dec1_1 = self.upconv1(dec2_1)
|
||||
dec1_1 = torch.cat((dec1_1, enc1_1), dim=1)
|
||||
dec1_1 = self.decoder1(dec1_1)
|
||||
|
||||
# UNet on x2
|
||||
enc1_2 = self.encoder1(x2)
|
||||
enc2_2 = self.encoder2(self.pool1(enc1_2))
|
||||
enc3_2 = self.encoder3(self.pool2(enc2_2))
|
||||
enc4_2 = self.encoder4(self.pool3(enc3_2))
|
||||
|
||||
bottleneck_2 = self.bottleneck(self.pool4(enc4_2))
|
||||
|
||||
dec4_2 = self.upconv4(bottleneck_2)
|
||||
dec4_2 = torch.cat((dec4_2, enc4_2), dim=1)
|
||||
dec4_2 = self.decoder4(dec4_2)
|
||||
dec3_2 = self.upconv3(dec4_2)
|
||||
dec3_2 = torch.cat((dec3_2, enc3_2), dim=1)
|
||||
dec3_2 = self.decoder3(dec3_2)
|
||||
dec2_2 = self.upconv2(dec3_2)
|
||||
dec2_2 = torch.cat((dec2_2, enc2_2), dim=1)
|
||||
dec2_2 = self.decoder2(dec2_2)
|
||||
dec1_2 = self.upconv1(dec2_2)
|
||||
dec1_2 = torch.cat((dec1_2, enc1_2), dim=1)
|
||||
dec1_2 = self.decoder1(dec1_2)
|
||||
|
||||
# Siamese
|
||||
dec1_c = bottleneck_2 - bottleneck_1
|
||||
|
||||
dec1_c = self.upconv4_c(dec1_c) # features * 16 -> features * 8
|
||||
diff_2 = enc4_2 - enc4_1 # features * 16 -> features * 8
|
||||
dec2_c = torch.cat((diff_2, dec1_c), dim=1) #512
|
||||
dec2_c = self.conv4_c(dec2_c)
|
||||
|
||||
dec2_c = self.upconv3_c(dec2_c) # 512->256
|
||||
diff_3 = enc3_2 - enc3_1
|
||||
dec3_c = torch.cat((diff_3, dec2_c), dim=1) # ->512
|
||||
dec3_c = self.conv3_c(dec3_c)
|
||||
|
||||
dec3_c = self.upconv2_c(dec3_c) #512->256
|
||||
diff_4 = enc2_2 - enc2_1
|
||||
dec4_c = torch.cat((diff_4, dec3_c), dim=1) #
|
||||
dec4_c = self.conv2_c(dec4_c)
|
||||
|
||||
dec4_c = self.upconv1_c(dec4_c)
|
||||
diff_5 = enc1_2 - enc1_1
|
||||
dec5_c = torch.cat((diff_5, dec4_c), dim=1)
|
||||
dec5_c = self.conv1_c(dec5_c)
|
||||
|
||||
return self.conv_s(dec1_1), self.conv_s(dec1_2), self.conv_c(dec5_c)
|
||||
|
||||
@staticmethod
|
||||
def _block(in_channels, features, name):
|
||||
return nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
name + "conv1",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=features,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
),
|
||||
(name + "norm1", nn.BatchNorm2d(num_features=features)),
|
||||
(name + "relu1", nn.ReLU(inplace=True)),
|
||||
(
|
||||
name + "conv2",
|
||||
nn.Conv2d(
|
||||
in_channels=features,
|
||||
out_channels=features,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
),
|
||||
(name + "norm2", nn.BatchNorm2d(num_features=features)),
|
||||
(name + "relu2", nn.ReLU(inplace=True)),
|
||||
]
|
||||
)
|
||||
)
|
|
@ -0,0 +1,704 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import shutil
|
||||
import logging
|
||||
import torchvision
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch.nn as nn
|
||||
from datetime import datetime
|
||||
from torchvision import transforms
|
||||
from time import localtime, strftime
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils.raster_label_visualizer import RasterLabelVisualizer
|
||||
from models.end_to_end_Siam_UNet import SiamUnet
|
||||
from utils.train_utils import AverageMeter
|
||||
from utils.train_utils import load_json_files, dump_json_files
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from utils.dataset_shard_load import DisasterDataset
|
||||
|
||||
config = {'labels_dmg': [0, 1, 2, 3, 4],
|
||||
'labels_bld': [0, 1],
|
||||
'weights_seg': [1, 15],
|
||||
'weights_damage': [1, 35, 70, 150, 120],
|
||||
'weights_loss': [0, 0, 1],
|
||||
'mode': 'dmg',
|
||||
'init_learning_rate': 0.0005,#dmg: 0.005, #UNet: 0.01,
|
||||
'device': 'cuda:2',
|
||||
'epochs': 1500,
|
||||
'batch_size': 32,
|
||||
'num_chips_to_viz': 1,
|
||||
'experiment_name': 'train_UNet', #train_dmg
|
||||
'out_dir': './nlrc_outputs/',
|
||||
'data_dir_shards': './xBD_sliced_augmented_20_alldisasters_final_mdl_npy/',
|
||||
'shard_no': 0,
|
||||
'disaster_splits_json': './nlrc.building-damage-assessment/constants/splits/final_mdl_all_disaster_splits_sliced_img_augmented_20.json',
|
||||
'disaster_mean_stddev': './nlrc.building-damage-assessment/constants/splits/all_disaster_mean_stddev_tiles_0_1.json',
|
||||
'label_map_json': './nlrc.building-damage-assessment/constants/class_lists/xBD_label_map.json',
|
||||
'starting_checkpoint_path': './nlrc_outputs/UNet_all_data_dmg/checkpoints/checkpoint_epoch120_2021-06-30-10-28-49.pth.tar'}
|
||||
|
||||
logging.basicConfig(stream=sys.stdout,
|
||||
level=logging.INFO,
|
||||
format='{asctime} {levelname} {message}',
|
||||
style='{',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logging.info(f'Using PyTorch version {torch.__version__}.')
|
||||
device = torch.device(config['device'] if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f'Using device: {device}.')
|
||||
|
||||
def main():
|
||||
|
||||
global viz, labels_set_dmg, labels_set_bld
|
||||
global xBD_train, xBD_val
|
||||
global train_loader, val_loader, test_loader
|
||||
global weights_loss, mode
|
||||
|
||||
xBD_train, xBD_val = load_dataset()
|
||||
|
||||
train_loader = DataLoader(xBD_train, batch_size=config['batch_size'], shuffle=True, num_workers=8, pin_memory=False)
|
||||
val_loader = DataLoader(xBD_val, batch_size=config['batch_size'], shuffle=False, num_workers=8, pin_memory=False)
|
||||
|
||||
label_map = load_json_files(config['label_map_json'])
|
||||
viz = RasterLabelVisualizer(label_map=label_map)
|
||||
|
||||
labels_set_dmg = config['labels_dmg']
|
||||
labels_set_bld = config['labels_bld']
|
||||
mode = config['mode']
|
||||
|
||||
eval_results_tr_dmg = pd.DataFrame(columns=['epoch', 'class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
eval_results_tr_bld = pd.DataFrame(columns=['epoch', 'class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
eval_results_val_dmg = pd.DataFrame(columns=['epoch', 'class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
eval_results_val_bld = pd.DataFrame(columns=['epoch', 'class', 'precision', 'recall', 'f1', 'accuracy'])
|
||||
|
||||
# set up directories
|
||||
checkpoint_dir = os.path.join(config['out_dir'], config['experiment_name'], 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
logger_dir = os.path.join(config['out_dir'], config['experiment_name'], 'logs')
|
||||
os.makedirs(logger_dir, exist_ok=True)
|
||||
|
||||
evals_dir = os.path.join(config['out_dir'], config['experiment_name'], 'evals')
|
||||
os.makedirs(evals_dir, exist_ok=True)
|
||||
|
||||
config_dir = os.path.join(config['out_dir'], config['experiment_name'], 'configs')
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
dump_json_files(os.path.join(config_dir,'config.txt') , config)
|
||||
|
||||
|
||||
# define model
|
||||
model = SiamUnet().to(device=device)
|
||||
model_summary(model)
|
||||
|
||||
|
||||
# resume from a checkpoint if provided
|
||||
starting_checkpoint_path = config['starting_checkpoint_path']
|
||||
if starting_checkpoint_path and os.path.isfile(starting_checkpoint_path):
|
||||
logging.info('Loading checkpoint from {}'.format(starting_checkpoint_path))
|
||||
checkpoint = torch.load(starting_checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
#don't load the optimizer settings so that a newly specified lr can take effect
|
||||
if mode == 'dmg':
|
||||
print_network(model)
|
||||
model = freeze_model_param(model)
|
||||
print_network(model)
|
||||
|
||||
# monitor model
|
||||
logger_model = SummaryWriter(log_dir=logger_dir)
|
||||
for tag, value in model.named_parameters():
|
||||
tag = tag.replace('.', '/')
|
||||
logger_model.add_histogram(tag, value.data.cpu().numpy(), global_step=0)
|
||||
|
||||
reinitialize_Siamese(model)
|
||||
|
||||
for tag, value in model.named_parameters():
|
||||
tag = tag.replace('.', '/')
|
||||
logger_model.add_histogram(tag, value.data.cpu().numpy(), global_step=1)
|
||||
|
||||
logger_model.flush()
|
||||
logger_model.close()
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config['init_learning_rate'])
|
||||
else:
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=config['init_learning_rate'])
|
||||
|
||||
starting_epoch = checkpoint['epoch'] + 1 # we did not increment epoch before saving it, so can just start here
|
||||
best_acc = checkpoint.get('best_f1', 0.0)
|
||||
logging.info(f'Loaded checkpoint, starting epoch is {starting_epoch}, '
|
||||
f'best f1 is {best_acc}')
|
||||
|
||||
else:
|
||||
logging.info('No valid checkpoint is provided. Start to train from scratch...')
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=config['init_learning_rate'])
|
||||
starting_epoch = 1
|
||||
best_acc = 0.0
|
||||
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2000, verbose=True)
|
||||
|
||||
# define loss functions and weights on classes
|
||||
weights_seg_tf = torch.FloatTensor(config['weights_seg'])
|
||||
weights_damage_tf = torch.FloatTensor(config['weights_damage'])
|
||||
weights_loss = config['weights_loss']
|
||||
|
||||
criterion_seg_1 = nn.CrossEntropyLoss(weight=weights_seg_tf).to(device=device)
|
||||
criterion_seg_2 = nn.CrossEntropyLoss(weight=weights_seg_tf).to(device=device)
|
||||
criterion_damage = nn.CrossEntropyLoss(weight=weights_damage_tf).to(device=device)
|
||||
|
||||
# initialize logger instances
|
||||
logger_train = SummaryWriter(log_dir=logger_dir)
|
||||
logger_val = SummaryWriter(log_dir=logger_dir)
|
||||
logger_test= SummaryWriter(log_dir=logger_dir)
|
||||
|
||||
|
||||
logging.info('Log image samples')
|
||||
logging.info('Get sample chips from train set...')
|
||||
sample_train_ids = get_sample_images(which_set='train')
|
||||
logging.info('Get sample chips from val set...')
|
||||
sample_val_ids = get_sample_images(which_set='val')
|
||||
|
||||
epoch = starting_epoch
|
||||
step_tr = 1
|
||||
epochs = config['epochs']
|
||||
|
||||
while (epoch <= epochs):
|
||||
|
||||
###### train
|
||||
logger_train.add_scalar( 'learning_rate', optimizer.param_groups[0]["lr"], epoch)
|
||||
logging.info(f'Model training for epoch {epoch}/{epochs}')
|
||||
train_start_time = datetime.now()
|
||||
model, optimizer, step_tr, confusion_mtrx_df_tr_dmg, confusion_mtrx_df_tr_bld = train(train_loader, model, criterion_seg_1, criterion_seg_2, criterion_damage, optimizer, epochs, epoch, step_tr, logger_train, logger_val, sample_train_ids, sample_val_ids, device)
|
||||
train_duration = datetime.now() - train_start_time
|
||||
logger_train.add_scalar('time_training', train_duration.total_seconds(), epoch)
|
||||
|
||||
logging.info(f'Compute actual metrics for model evaluation based on training set ...')
|
||||
|
||||
# damage level eval train
|
||||
eval_results_tr_dmg = compute_eval_metrics(epoch, labels_set_dmg, confusion_mtrx_df_tr_dmg, eval_results_tr_dmg)
|
||||
eval_results_tr_dmg_epoch = eval_results_tr_dmg.loc[eval_results_tr_dmg['epoch'] == epoch,:]
|
||||
f1_harmonic_mean = 0
|
||||
for metrics in ['f1']:
|
||||
for index, row in eval_results_tr_dmg_epoch.iterrows():
|
||||
if int(row['class']) in labels_set_dmg[1:]:
|
||||
logger_train.add_scalar( 'tr_dmg_class_' + str(int(row['class'])) + '_' + str(metrics), row[metrics], epoch)
|
||||
if metrics == 'f1':
|
||||
f1_harmonic_mean += 1.0/(row[metrics]+1e-10)
|
||||
f1_harmonic_mean = 4.0/f1_harmonic_mean
|
||||
logger_train.add_scalar( 'tr_dmg_harmonic_mean_f1', f1_harmonic_mean, epoch)
|
||||
|
||||
|
||||
# bld level eval train
|
||||
eval_results_tr_bld = compute_eval_metrics(epoch, labels_set_bld, confusion_mtrx_df_tr_bld, eval_results_tr_bld)
|
||||
eval_results_tr_bld_epoch = eval_results_tr_bld.loc[eval_results_tr_bld['epoch'] == epoch,:]
|
||||
for metrics in ['f1']:
|
||||
for index, row in eval_results_tr_bld_epoch.iterrows():
|
||||
if int(row['class']) in labels_set_dmg[1:]:
|
||||
logger_train.add_scalar( 'tr_bld_class_' + str(int(row['class'])) + '_' + str(metrics), row[metrics], epoch)
|
||||
|
||||
|
||||
####### validation
|
||||
logging.info(f'Model validation for epoch {epoch}/{epochs}')
|
||||
eval_start_time = datetime.now()
|
||||
confusion_mtrx_df_val_dmg, confusion_mtrx_df_val_bld, losses_val = validation(val_loader, model, criterion_seg_1, criterion_seg_2, criterion_damage, epochs, epoch, logger_val)
|
||||
eval_duration = datetime.now() - eval_start_time
|
||||
# decay Learning Rate
|
||||
scheduler.step(losses_val)
|
||||
logger_val.add_scalar('time_validation', eval_duration.total_seconds(), epoch)
|
||||
logging.info(f'Compute actual metrics for model evaluation based on validation set ...')
|
||||
|
||||
# damage level eval validation
|
||||
eval_results_val_dmg = compute_eval_metrics(epoch, labels_set_dmg, confusion_mtrx_df_val_dmg, eval_results_val_dmg)
|
||||
eval_results_val_dmg_epoch = eval_results_val_dmg.loc[eval_results_val_dmg['epoch'] == epoch,:]
|
||||
f1_harmonic_mean = 0
|
||||
for metrics in ['f1']:
|
||||
for index, row in eval_results_val_dmg_epoch.iterrows():
|
||||
if int(row['class']) in labels_set_dmg[1:]:
|
||||
logger_val.add_scalar( 'val_dmg_class_' + str(int(row['class'])) + '_' + str(metrics), row[metrics], epoch)
|
||||
if metrics == 'f1':
|
||||
f1_harmonic_mean += 1.0/(row[metrics]+1e-10)
|
||||
f1_harmonic_mean = 4.0/f1_harmonic_mean
|
||||
logger_val.add_scalar( 'val_dmg_harmonic_mean_f1', f1_harmonic_mean, epoch)
|
||||
|
||||
|
||||
# bld level eval validation
|
||||
eval_results_val_bld = compute_eval_metrics(epoch, labels_set_bld, confusion_mtrx_df_val_bld, eval_results_val_bld)
|
||||
eval_results_val_bld_epoch = eval_results_val_bld.loc[eval_results_val_bld['epoch'] == epoch,:]
|
||||
for metrics in ['f1']:
|
||||
for index, row in eval_results_val_bld_epoch.iterrows():
|
||||
if int(row['class']) in labels_set_bld[1:]:
|
||||
logger_val.add_scalar( 'val_bld_class_' + str(int(row['class'])) + '_' + str(metrics), row[metrics], epoch)
|
||||
|
||||
|
||||
####### compute average accuracy across all classes to select the best model
|
||||
val_acc_avg = f1_harmonic_mean
|
||||
is_best = val_acc_avg > best_acc
|
||||
best_acc = max(val_acc_avg, best_acc)
|
||||
|
||||
logging.info(f'Saved checkpoint for epoch {epoch}. Is it the highest f1 checkpoint so far: {is_best}\n')
|
||||
|
||||
save_checkpoint({
|
||||
'epoch': epoch,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'val_f1_avg': val_acc_avg,
|
||||
'best_f1': best_acc}, is_best, checkpoint_dir)
|
||||
|
||||
# log execution time for this epoch
|
||||
logging.info((f'epoch training duration is {train_duration.total_seconds()} seconds;'
|
||||
f'epoch evaluation duration is {eval_duration.total_seconds()} seconds'))
|
||||
|
||||
epoch += 1
|
||||
|
||||
logger_train.flush()
|
||||
logger_train.close()
|
||||
logger_val.flush()
|
||||
logger_val.close()
|
||||
|
||||
|
||||
logging.info('Done')
|
||||
|
||||
return
|
||||
|
||||
def train(loader, model, criterion_seg_1, criterion_seg_2, criterion_damage, optimizer, epochs, epoch, step_tr, logger_train, logger_val, sample_train_ids, sample_val_ids, device):
|
||||
"""
|
||||
Train the model on dataset of the loader
|
||||
"""
|
||||
confusion_mtrx_df_tr_dmg = pd.DataFrame(columns=['epoch', 'batch_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
confusion_mtrx_df_tr_bld = pd.DataFrame(columns=['epoch', 'batch_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
|
||||
losses_tr = AverageMeter()
|
||||
loss_seg_pre = AverageMeter()
|
||||
loss_seg_post = AverageMeter()
|
||||
loss_dmg = AverageMeter()
|
||||
|
||||
|
||||
for batch_idx, data in enumerate(tqdm(loader)):
|
||||
|
||||
x_pre = data['pre_image'].to(device=device) # move to device, e.g. GPU
|
||||
x_post = data['post_image'].to(device=device)
|
||||
y_seg = data['building_mask'].to(device=device)
|
||||
y_cls = data['damage_mask'].to(device=device)
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
scores = model(x_pre, x_post)
|
||||
|
||||
# modify damage prediction based on UNet arm
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
|
||||
loss = weights_loss[0]*criterion_seg_1(scores[0], y_seg) + weights_loss[1]*criterion_seg_2(scores[1], y_seg) + weights_loss[2]*criterion_damage(scores[2], y_cls)
|
||||
loss_seg_pre_tr = criterion_seg_1(scores[0], y_seg)
|
||||
loss_seg_post_tr = criterion_seg_2(scores[1], y_seg)
|
||||
loss_dmg_tr = criterion_damage(scores[2], y_cls)
|
||||
|
||||
losses_tr.update(loss.item(), x_pre.size(0))
|
||||
loss_seg_pre.update(loss_seg_pre_tr.item(), x_pre.size(0))
|
||||
loss_seg_post.update(loss_seg_post_tr.item(), x_pre.size(0))
|
||||
loss_dmg.update(loss_dmg_tr.item(), x_pre.size(0))
|
||||
|
||||
loss.backward() # compute gradients
|
||||
optimizer.step()
|
||||
|
||||
# compute predictions & confusion metrics
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
preds_seg_post = torch.argmax(softmax(scores[1]), dim=1)
|
||||
preds_cls = torch.argmax(softmax(scores[2]), dim=1)
|
||||
|
||||
confusion_mtrx_df_tr_dmg = compute_confusion_mtrx(confusion_mtrx_df_tr_dmg, epoch, batch_idx, labels_set_dmg, preds_cls, y_cls, y_seg)
|
||||
confusion_mtrx_df_tr_bld = compute_confusion_mtrx(confusion_mtrx_df_tr_bld, epoch, batch_idx, labels_set_bld, preds_seg_pre, y_seg, [])
|
||||
|
||||
|
||||
# logging image viz
|
||||
prepare_for_vis(sample_train_ids, logger_train, model, 'train', epoch, device, softmax)
|
||||
prepare_for_vis(sample_val_ids, logger_val, model, 'val', epoch, device, softmax)
|
||||
step_tr += 1
|
||||
|
||||
logger_train.add_scalars('loss_tr', {'_total':losses_tr.avg, '_seg_pre': loss_seg_pre.avg, '_seg_post': loss_seg_post.avg, '_dmg': loss_dmg.avg}, epoch)
|
||||
|
||||
return model, optimizer, step_tr, confusion_mtrx_df_tr_dmg, confusion_mtrx_df_tr_bld
|
||||
|
||||
def validation(loader, model, criterion_seg_1, criterion_seg_2, criterion_damage, epochs, epoch, logger_val):
|
||||
|
||||
"""
|
||||
Evaluate the model on dataset of the loader
|
||||
"""
|
||||
confusion_mtrx_df_val_dmg = pd.DataFrame(columns=['epoch', 'batch_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
confusion_mtrx_df_val_bld = pd.DataFrame(columns=['epoch', 'batch_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
losses_val = AverageMeter()
|
||||
loss_seg_pre = AverageMeter()
|
||||
loss_seg_post = AverageMeter()
|
||||
loss_dmg = AverageMeter()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, data in enumerate(tqdm(loader)):
|
||||
x_pre = data['pre_image'].to(device=device) # move to device, e.g. GPU
|
||||
x_post = data['post_image'].to(device=device)
|
||||
y_seg = data['building_mask'].to(device=device)
|
||||
y_cls = data['damage_mask'].to(device=device)
|
||||
|
||||
model.eval() # put model to evaluation mode
|
||||
scores = model(x_pre, x_post)
|
||||
|
||||
# modify damage prediction based on UNet arm
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
|
||||
loss = weights_loss[0]*criterion_seg_1(scores[0], y_seg) + weights_loss[1]*criterion_seg_2(scores[1], y_seg) + weights_loss[2]*criterion_damage(scores[2], y_cls)
|
||||
loss_seg_pre_val = criterion_seg_1(scores[0], y_seg)
|
||||
loss_seg_post_val = criterion_seg_2(scores[1], y_seg)
|
||||
loss_dmg_val = criterion_damage(scores[2], y_cls)
|
||||
|
||||
losses_val.update(loss.item(), x_pre.size(0))
|
||||
loss_seg_pre.update(loss_seg_pre_val.item(), x_pre.size(0))
|
||||
loss_seg_post.update(loss_seg_post_val.item(), x_pre.size(0))
|
||||
loss_dmg.update(loss_dmg_val.item(), x_pre.size(0))
|
||||
|
||||
# compute predictions & confusion metrics
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
preds_seg_post = torch.argmax(softmax(scores[1]), dim=1)
|
||||
preds_cls = torch.argmax(softmax(scores[2]), dim=1)
|
||||
|
||||
confusion_mtrx_df_val_dmg = compute_confusion_mtrx(confusion_mtrx_df_val_dmg, epoch, batch_idx, labels_set_dmg, preds_cls, y_cls, y_seg)
|
||||
confusion_mtrx_df_val_bld = compute_confusion_mtrx(confusion_mtrx_df_val_bld, epoch, batch_idx, labels_set_bld, preds_seg_pre, y_seg, [])
|
||||
|
||||
logger_val.add_scalars('loss_val', {'_total': losses_val.avg, '_seg_pre': loss_seg_pre.avg, '_seg_post': loss_seg_post.avg, '_dmg': loss_dmg.avg}, epoch)
|
||||
|
||||
return confusion_mtrx_df_val_dmg, confusion_mtrx_df_val_bld, losses_val.avg
|
||||
|
||||
def compute_eval_metrics(epoch, labels_set, confusion_mtrx_df, eval_results):
|
||||
for cls in labels_set:
|
||||
class_idx = (confusion_mtrx_df['class']==cls)
|
||||
precision = confusion_mtrx_df.loc[class_idx,'true_pos'].sum()/(confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'false_pos'].sum())
|
||||
recall = confusion_mtrx_df.loc[class_idx,'true_pos'].sum()/(confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'false_neg'].sum())
|
||||
f1 = 2 * (precision * recall)/(precision + recall)
|
||||
accuracy = (confusion_mtrx_df.loc[class_idx,'true_pos'].sum() + confusion_mtrx_df.loc[class_idx,'true_neg'].sum())/(confusion_mtrx_df.loc[class_idx,'total_pixels'].sum())
|
||||
eval_results = eval_results.append({'epoch':epoch, 'class':cls, 'precision':precision, 'recall':recall, 'f1':f1, 'accuracy':accuracy}, ignore_index=True)
|
||||
return eval_results
|
||||
|
||||
def compute_confusion_mtrx(confusion_mtrx_df, epoch, batch_idx, labels_set, y_preds, y_true, y_true_bld_mask):
|
||||
for cls in labels_set[1:]:
|
||||
confusion_mtrx_df = compute_confusion_mtrx_class(confusion_mtrx_df, epoch, batch_idx, labels_set, y_preds, y_true, y_true_bld_mask, cls)
|
||||
return confusion_mtrx_df
|
||||
|
||||
def compute_confusion_mtrx_class(confusion_mtrx_df, epoch, batch_idx, labels_set, y_preds, y_true, y_true_bld_mask, cls):
|
||||
|
||||
y_true_binary = y_true.detach().clone()
|
||||
y_preds_binary = y_preds.detach().clone()
|
||||
|
||||
if len(labels_set) > 2:
|
||||
# convert to 0/1
|
||||
|
||||
y_true_binary[y_true_binary != cls] = -1
|
||||
y_preds_binary[y_preds_binary != cls] = -1
|
||||
|
||||
y_true_binary[y_true_binary == cls] = 1
|
||||
y_preds_binary[y_preds_binary == cls] = 1
|
||||
|
||||
|
||||
y_true_binary[y_true_binary == -1] = 0
|
||||
y_preds_binary[y_preds_binary == -1] = 0
|
||||
|
||||
|
||||
# compute confusion metric
|
||||
true_pos_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 1) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
false_neg_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 1) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
true_neg_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 0) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
false_pos_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 0) & (y_true_bld_mask == 1)).float().sum().item()
|
||||
|
||||
# compute total pixels
|
||||
total_pixels = y_true_bld_mask.float().sum().item()
|
||||
|
||||
else:
|
||||
|
||||
# compute confusion metric
|
||||
true_pos_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 1)).float().sum().item()
|
||||
false_neg_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 1)).float().sum().item()
|
||||
true_neg_cls = ((y_true_binary == y_preds_binary) & (y_true_binary == 0)).float().sum().item()
|
||||
false_pos_cls = ((y_true_binary != y_preds_binary) & (y_true_binary == 0)).float().sum().item()
|
||||
|
||||
# compute total pixels
|
||||
total_pixels = 1
|
||||
for item in y_true_binary.size():
|
||||
total_pixels *= item
|
||||
|
||||
confusion_mtrx_df = confusion_mtrx_df.append({'epoch':epoch, 'class':cls, 'batch_idx':batch_idx, 'true_pos':true_pos_cls, 'true_neg':true_neg_cls, 'false_pos':false_pos_cls, 'false_neg':false_neg_cls, 'total_pixels':total_pixels}, ignore_index=True)
|
||||
|
||||
return confusion_mtrx_df
|
||||
|
||||
def save_checkpoint(state, is_best, checkpoint_dir='../checkpoints'):
|
||||
"""
|
||||
checkpoint_dir is used to save the best checkpoint if this checkpoint is best one so far
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
f"checkpoint_epoch{state['epoch']}_"
|
||||
f"{strftime('%Y-%m-%d-%H-%M-%S', localtime())}.pth.tar")
|
||||
torch.save(state, checkpoint_path)
|
||||
if is_best:
|
||||
shutil.copyfile(checkpoint_path, os.path.join(checkpoint_dir, 'model_best.pth.tar'))
|
||||
|
||||
def get_sample_images(which_set='train'):
|
||||
|
||||
"""
|
||||
Get a deterministic set of images in the specified set (train or val) by using the dataset and
|
||||
not the dataloader. Only works if the dataset is not IterableDataset.
|
||||
|
||||
Args:
|
||||
which_set: one of 'train' or 'val'
|
||||
|
||||
Returns:
|
||||
samples: a dict with keys 'chip' and 'chip_label', pointing to torch Tensors of
|
||||
dims (num_chips_to_visualize, channels, height, width) and (num_chips_to_visualize, height, width)
|
||||
respectively
|
||||
"""
|
||||
assert which_set == 'train' or which_set == 'val'
|
||||
|
||||
dataset = xBD_train if which_set == 'train' else xBD_val
|
||||
|
||||
num_to_skip = 1 # first few chips might be mostly blank
|
||||
assert len(dataset) > num_to_skip + config['num_chips_to_viz']
|
||||
|
||||
keep_every = math.floor((len(dataset) - num_to_skip) / config['num_chips_to_viz'])
|
||||
samples_idx_list = []
|
||||
|
||||
for sample_idx in range(num_to_skip, len(dataset), keep_every):
|
||||
samples_idx_list.append(sample_idx)
|
||||
|
||||
return samples_idx_list
|
||||
|
||||
def prepare_for_vis(sample_train_ids, logger, model, which_set, iteration, device, softmax):
|
||||
|
||||
for item in sample_train_ids:
|
||||
data = xBD_train[item] if which_set == 'train' else xBD_val[item]
|
||||
|
||||
c = data['pre_image'].size()[0]
|
||||
h = data['pre_image'].size()[1]
|
||||
w = data['pre_image'].size()[2]
|
||||
|
||||
pre = data['pre_image'].reshape(1, c, h, w)
|
||||
post = data['post_image'].reshape(1, c, h, w)
|
||||
|
||||
|
||||
scores = model(pre.to(device=device), post.to(device=device))
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
preds_seg_post = torch.argmax(softmax(scores[1]), dim=1)
|
||||
|
||||
# modify damage prediction based on UNet arm
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
|
||||
# add to tensorboard
|
||||
tag = 'pr_bld_mask_pre_train_id_' + str(item) if which_set == 'train' else 'pr_bld_mask_pre_val_id_' + str(item)
|
||||
logger.add_image(tag, preds_seg_pre, iteration, dataformats='CHW')
|
||||
|
||||
tag = 'pr_bld_mask_post_train_id_' + str(item) if which_set == 'train' else 'pr_bld_mask_post_val_id_' + str(item)
|
||||
logger.add_image(tag, preds_seg_post, iteration, dataformats='CHW')
|
||||
|
||||
tag = 'pr_dmg_mask_train_id_' + str(item) if which_set == 'train' else 'pr_dmg_mask_val_id_' + str(item)
|
||||
im, buf = viz.show_label_raster(torch.argmax(softmax(scores[2]), dim=1).cpu().numpy(), size=(5, 5))
|
||||
preds_cls = transforms.ToTensor()(transforms.ToPILImage()(np.array(im)).convert("RGB"))
|
||||
logger.add_image(tag, preds_cls, iteration, dataformats='CHW')
|
||||
|
||||
if iteration == 1:
|
||||
pre = data['pre_image']
|
||||
tag = 'gt_img_pre_train_id_' + str(item) if which_set == 'train' else 'gt_img_pre_val_id_' + str(item)
|
||||
logger.add_image(tag, data['pre_image_orig'], iteration, dataformats='CHW')
|
||||
|
||||
post = data['post_image']
|
||||
tag = 'gt_img_post_train_id_' + str(item) if which_set == 'train' else 'gt_img_post_val_id_' + str(item)
|
||||
logger.add_image(tag, data['post_image_orig'], iteration, dataformats='CHW')
|
||||
|
||||
gt_seg = data['building_mask'].reshape(1, h, w)
|
||||
tag = 'gt_bld_mask_train_id_' + str(item) if which_set == 'train' else 'gt_bld_mask_val_id_' + str(item)
|
||||
logger.add_image(tag, gt_seg, iteration, dataformats='CHW')
|
||||
|
||||
im, buf = viz.show_label_raster(np.array(data['damage_mask']), size=(5, 5))
|
||||
gt_cls = transforms.ToTensor()(transforms.ToPILImage()(np.array(im)).convert("RGB"))
|
||||
tag = 'gt_dmg_mask_train_id_' + str(item) if which_set == 'train' else 'gt_dmg_mask_val_id_' + str(item)
|
||||
logger.add_image(tag, gt_cls, iteration, dataformats='CHW')
|
||||
return
|
||||
|
||||
def load_dataset():
|
||||
splits = load_json_files(config['disaster_splits_json'])
|
||||
data_mean_stddev = load_json_files(config['disaster_mean_stddev'])
|
||||
|
||||
train_ls = []
|
||||
val_ls = []
|
||||
for item, val in splits.items():
|
||||
train_ls += val['train']
|
||||
val_ls += val['val']
|
||||
xBD_train = DisasterDataset(config['data_dir_shards'], config['shard_no'], 'train', data_mean_stddev, transform=True, normalize=True)
|
||||
xBD_val = DisasterDataset(config['data_dir_shards'], config['shard_no'], 'val', data_mean_stddev, transform=False, normalize=True)
|
||||
|
||||
print('xBD_disaster_dataset train length: {}'.format(len(xBD_train)))
|
||||
print('xBD_disaster_dataset val length: {}'.format(len(xBD_val)))
|
||||
|
||||
return xBD_train, xBD_val
|
||||
|
||||
def model_summary(model):
|
||||
print("model_summary")
|
||||
print()
|
||||
print("Layer_name"+"\t"*7+"Number of Parameters")
|
||||
print("="*100)
|
||||
model_parameters = [layer for layer in model.parameters() if layer.requires_grad]
|
||||
layer_name = [child for child in model.children()]
|
||||
j = 0
|
||||
total_params = 0
|
||||
print("\t"*10)
|
||||
for i in layer_name:
|
||||
print()
|
||||
param = 0
|
||||
try:
|
||||
bias = (i.bias is not None)
|
||||
except:
|
||||
bias = False
|
||||
if not bias:
|
||||
param =model_parameters[j].numel()+model_parameters[j+1].numel()
|
||||
j = j+2
|
||||
else:
|
||||
param =model_parameters[j].numel()
|
||||
j = j+1
|
||||
print(str(i)+"\t"*3+str(param))
|
||||
total_params+=param
|
||||
print("="*100)
|
||||
print(f"Total Params:{total_params}")
|
||||
|
||||
def freeze_model_param(model):
|
||||
for i in [0, 3]:
|
||||
model.encoder1[i].weight.requires_grad = False
|
||||
model.encoder2[i].weight.requires_grad = False
|
||||
model.encoder3[i].weight.requires_grad = False
|
||||
model.encoder4[i].weight.requires_grad = False
|
||||
|
||||
model.bottleneck[i].weight.requires_grad = False
|
||||
|
||||
model.decoder4[i].weight.requires_grad = False
|
||||
model.decoder3[i].weight.requires_grad = False
|
||||
model.decoder2[i].weight.requires_grad = False
|
||||
model.decoder1[i].weight.requires_grad = False
|
||||
|
||||
for i in [1, 4]:
|
||||
model.encoder1[i].weight.requires_grad = False
|
||||
model.encoder1[i].bias.requires_grad = False
|
||||
|
||||
model.encoder2[i].weight.requires_grad = False
|
||||
model.encoder2[i].bias.requires_grad = False
|
||||
|
||||
model.encoder3[i].weight.requires_grad = False
|
||||
model.encoder3[i].bias.requires_grad = False
|
||||
|
||||
model.encoder4[i].weight.requires_grad = False
|
||||
model.encoder4[i].bias.requires_grad = False
|
||||
|
||||
model.bottleneck[i].weight.requires_grad = False
|
||||
model.bottleneck[i].bias.requires_grad = False
|
||||
|
||||
model.decoder4[i].weight.requires_grad = False
|
||||
model.decoder4[i].bias.requires_grad = False
|
||||
|
||||
model.decoder3[i].weight.requires_grad = False
|
||||
model.decoder3[i].bias.requires_grad = False
|
||||
|
||||
model.decoder2[i].weight.requires_grad = False
|
||||
model.decoder2[i].bias.requires_grad = False
|
||||
|
||||
model.decoder1[i].weight.requires_grad = False
|
||||
model.decoder1[i].bias.requires_grad = False
|
||||
|
||||
|
||||
model.upconv4.weight.requires_grad = False
|
||||
model.upconv4.bias.requires_grad = False
|
||||
|
||||
model.upconv3.weight.requires_grad = False
|
||||
model.upconv3.bias.requires_grad = False
|
||||
|
||||
model.upconv2.weight.requires_grad = False
|
||||
model.upconv2.bias.requires_grad = False
|
||||
|
||||
model.upconv1.weight.requires_grad = False
|
||||
model.upconv1.bias.requires_grad = False
|
||||
|
||||
model.conv_s.weight.requires_grad = False
|
||||
model.conv_s.bias.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
def print_network(model):
|
||||
print('model summary')
|
||||
for name, p in model.named_parameters():
|
||||
print(name)
|
||||
print(p.requires_grad)
|
||||
|
||||
def reinitialize_Siamese(model):
|
||||
torch.nn.init.xavier_uniform_(model.upconv4_c.weight)
|
||||
torch.nn.init.xavier_uniform_(model.upconv3_c.weight)
|
||||
torch.nn.init.xavier_uniform_(model.upconv2_c.weight)
|
||||
torch.nn.init.xavier_uniform_(model.upconv1_c.weight)
|
||||
torch.nn.init.xavier_uniform_(model.conv_c.weight)
|
||||
|
||||
model.upconv4_c.bias.data.fill_(0.01)
|
||||
model.upconv3_c.bias.data.fill_(0.01)
|
||||
model.upconv2_c.bias.data.fill_(0.01)
|
||||
model.upconv1_c.bias.data.fill_(0.01)
|
||||
model.conv_c.bias.data.fill_(0.01)
|
||||
|
||||
model.conv4_c.apply(init_weights)
|
||||
model.conv3_c.apply(init_weights)
|
||||
model.conv2_c.apply(init_weights)
|
||||
model.conv1_c.apply(init_weights)
|
||||
|
||||
return model
|
||||
|
||||
def init_weights(m):
|
||||
if type(m) == nn.Linear:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
def test(loader, model, epoch):
|
||||
|
||||
"""
|
||||
Evaluate the model on test dataset of the loader
|
||||
"""
|
||||
softmax = torch.nn.Softmax(dim=1)
|
||||
model.eval() # put model to evaluation mode
|
||||
confusion_mtrx_df_test_dmg = pd.DataFrame(columns=['img_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
confusion_mtrx_df_test_bld = pd.DataFrame(columns=['img_idx', 'class', 'true_pos', 'true_neg', 'false_pos', 'false_neg', 'total_pixels'])
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, data in enumerate(tqdm(loader)):
|
||||
c = data['pre_image'].size()[0]
|
||||
h = data['pre_image'].size()[1]
|
||||
w = data['pre_image'].size()[2]
|
||||
|
||||
x_pre = data['pre_image'].reshape(1, c, h, w).to(device=device)
|
||||
x_post = data['post_image'].reshape(1, c, h, w).to(device=device)
|
||||
|
||||
y_seg = data['building_mask'].to(device=device)
|
||||
y_cls = data['damage_mask'].to(device=device)
|
||||
|
||||
scores = model(x_pre, x_post)
|
||||
|
||||
preds_seg_pre = torch.argmax(softmax(scores[0]), dim=1)
|
||||
preds_seg_post = torch.argmax(softmax(scores[1]), dim=1)
|
||||
|
||||
for c in range(0,scores[2].shape[1]):
|
||||
scores[2][:,c,:,:] = torch.mul(scores[2][:,c,:,:], preds_seg_pre)
|
||||
preds_cls = torch.argmax(softmax(scores[2]), dim=1)
|
||||
|
||||
# compute comprehensive comfusion metrics
|
||||
confusion_mtrx_df_test_dmg = compute_confusion_mtrx(confusion_mtrx_df_test_dmg, epoch, batch_idx, labels_set_dmg, preds_cls, y_cls, y_seg)
|
||||
confusion_mtrx_df_test_bld = compute_confusion_mtrx(confusion_mtrx_df_test_bld, epoch, batch_idx, labels_set_bld, preds_seg_pre, y_seg, [])
|
||||
|
||||
return confusion_mtrx_df_test_dmg, confusion_mtrx_df_test_bld
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import logging
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms.functional as TF
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import os
|
||||
|
||||
class DisasterDataset(Dataset):
|
||||
def __init__(self, data_dir, i_shard, set_name, data_mean_stddev, transform:bool, normalize:bool):
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.transform = transform
|
||||
self.normalize = normalize
|
||||
self.data_mean_stddev = data_mean_stddev
|
||||
|
||||
shard_path = os.path.join(data_dir, f'{set_name}_pre_image_chips_{i_shard}.npy')
|
||||
self.pre_image_chip_shard = np.load(shard_path)
|
||||
logging.info(f'pre_image_chips loaded{self.pre_image_chip_shard.shape}')
|
||||
|
||||
shard_path = os.path.join(data_dir, f'{set_name}_post_image_chips_{i_shard}.npy')
|
||||
self.post_image_chip_shard = np.load(shard_path)
|
||||
logging.info(f'post_image_chips loaded{self.post_image_chip_shard.shape}')
|
||||
|
||||
shard_path = os.path.join(data_dir, f'{set_name}_bld_mask_chips_{i_shard}.npy')
|
||||
self.bld_mask_chip_shard = np.load(shard_path)
|
||||
logging.info(f'bld_mask_chips loaded{self.bld_mask_chip_shard.shape}')
|
||||
|
||||
shard_path = os.path.join(data_dir, f'{set_name}_dmg_mask_chips_{i_shard}.npy')
|
||||
self.dmg_mask_chip_shard = np.load(shard_path)
|
||||
logging.info(f'dmg_mask_chips loaded{self.dmg_mask_chip_shard.shape}')
|
||||
|
||||
shard_path = os.path.join(data_dir, f'{set_name}_pre_img_tile_chips_{i_shard}.npy')
|
||||
self.pre_img_tile_chip_shard = np.load(shard_path)
|
||||
logging.info(f'pre_img_tile_chips loaded{self.pre_img_tile_chip_shard.shape}')
|
||||
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pre_image_chip_shard)
|
||||
|
||||
@classmethod
|
||||
def apply_transform(self, mask, pre_img, post_img, damage_class):
|
||||
|
||||
'''
|
||||
apply tranformation functions on cv2 arrays
|
||||
'''
|
||||
|
||||
# Random horizontal flipping
|
||||
if random.random() > 0.5:
|
||||
mask = cv2.flip(mask, flipCode=1)
|
||||
pre_img = cv2.flip(pre_img, flipCode=1)
|
||||
post_img = cv2.flip(post_img, flipCode=1)
|
||||
damage_class = cv2.flip(damage_class, flipCode=1)
|
||||
|
||||
# Random vertical flipping
|
||||
if random.random() > 0.5:
|
||||
mask = cv2.flip(mask, flipCode=0)
|
||||
pre_img = cv2.flip(pre_img, flipCode=0)
|
||||
post_img = cv2.flip(post_img, flipCode=0)
|
||||
damage_class = cv2.flip(damage_class, flipCode=0)
|
||||
|
||||
return mask, pre_img, post_img, damage_class
|
||||
|
||||
def __getitem__(self, i):
|
||||
|
||||
pre_img = self.pre_image_chip_shard[i]
|
||||
post_img = self.post_image_chip_shard[i]
|
||||
mask = self.bld_mask_chip_shard[i]
|
||||
damage_class= self.dmg_mask_chip_shard[i]
|
||||
|
||||
# copy original image for viz
|
||||
pre_img_orig = pre_img
|
||||
post_img_orig = post_img
|
||||
|
||||
if self.transform is True:
|
||||
mask, pre_img, post_img, damage_class = self.apply_transform(mask, pre_img, post_img, damage_class)
|
||||
|
||||
if self.normalize is True:
|
||||
pre_img_tile_name = self.pre_img_tile_chip_shard[i]
|
||||
post_img_tile_name = pre_img_tile_name.replace('pre', 'post')
|
||||
|
||||
# normalize the images based on a tilewise mean & std dev --> pre_
|
||||
mean_pre = self.data_mean_stddev[pre_img_tile_name][0]
|
||||
stddev_pre = self.data_mean_stddev[pre_img_tile_name][1]
|
||||
norm_pre = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean_pre, std=stddev_pre)
|
||||
])
|
||||
pre_img = norm_pre(np.array(pre_img).astype(dtype='float64')/255.0)
|
||||
|
||||
# normalize the images based on a tilewise mean & std dev --> post_
|
||||
mean_post = self.data_mean_stddev[post_img_tile_name][0]
|
||||
stddev_post = self.data_mean_stddev[post_img_tile_name][1]
|
||||
norm_post = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean_post, std=stddev_post)
|
||||
])
|
||||
post_img = norm_post(np.array(post_img).astype(dtype='float64')/255.0)
|
||||
else:
|
||||
pre_img = np.array(transforms.ToTensor()(pre_img)).astype(dtype='float64')/255.0
|
||||
post_img = np.array(transforms.ToTensor()(post_img)).astype(dtype='float64')/255.0
|
||||
|
||||
# convert eveything to arrays
|
||||
pre_img = np.array(pre_img)
|
||||
post_img = np.array(post_img)
|
||||
mask = np.array(mask)
|
||||
damage_class = np.array(damage_class)
|
||||
|
||||
# replace non-classified pixels with background
|
||||
damage_class = np.where(damage_class==5, 0, damage_class)
|
||||
|
||||
return {'pre_image': torch.from_numpy(pre_img).type(torch.FloatTensor), 'post_image': torch.from_numpy(post_img).type(torch.FloatTensor), 'building_mask': torch.from_numpy(mask).type(torch.LongTensor), 'damage_mask': torch.from_numpy(damage_class).type(torch.LongTensor), 'pre_image_orig': transforms.ToTensor()(pre_img_orig), 'post_image_orig': transforms.ToTensor()(post_img_orig)}
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
|
||||
def load_json_files(json_filename):
|
||||
with open(json_filename) as f:
|
||||
file_content = json.load(f)
|
||||
return file_content
|
||||
|
||||
def dump_json_files(json_filename, my_dict):
|
||||
with open(json_filename, 'w') as f:
|
||||
json.dump(my_dict, f, indent=4)
|
||||
return
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value
|
||||
https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
"""
|
||||
|
||||
Args:
|
||||
val: mini-batch loss or accuracy value
|
||||
n: mini-batch size
|
||||
"""
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|