From 4d5e07ce372361dda4cd1ae5b39a730358843b29 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 6 Sep 2023 16:05:02 -0500 Subject: [PATCH] Add script to plot example predictions (#1534) --- .../ssl4eo/plot_example_predictions.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100755 experiments/ssl4eo/plot_example_predictions.py diff --git a/experiments/ssl4eo/plot_example_predictions.py b/experiments/ssl4eo/plot_example_predictions.py new file mode 100755 index 000000000..54e79e0cd --- /dev/null +++ b/experiments/ssl4eo/plot_example_predictions.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +# import matplotlib.pyplot as plt +import segmentation_models_pytorch as smp +import torch +from PIL import Image + +from torchgeo.datamodules import L7IrishDataModule +from torchgeo.datasets import unbind_samples + +device = torch.device("cpu") + +# Load weights +path = "data/l7irish/checkpoint-epoch=26-val_loss=0.68.ckpt" +state_dict = torch.load(path, map_location=device)["state_dict"] +state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()} + +# Initialize model +model = smp.Unet(encoder_name="resnet18", in_channels=9, classes=5) +model.to(device) +model.load_state_dict(state_dict) + +# Initialize data loaders +datamodule = L7IrishDataModule( + root="data/l7irish", crs="epsg:3857", download=True, batch_size=1, patch_size=224 +) +datamodule.setup("test") + +i = 0 +for batch in datamodule.test_dataloader(): + image = batch["image"] + mask = batch["mask"] + image.to(device) + + # Skip nodata pixels + if 0 in mask: + continue + + # Skip boring images + if len(mask.unique()) < 4: + continue + + # Make a prediction + prediction = model(image) + prediction = prediction.argmax(dim=1) + prediction.detach().to("cpu") + + batch["prediction"] = prediction + + for sample in unbind_samples(batch): + # Plot + # datamodule.test_dataset.plot(sample) + # plt.show() + path = f"data/l7irish_predictions/{i}" + print(f"Saving {path}...") + os.makedirs(path, exist_ok=True) + for key in ["image", "mask", "prediction"]: + data = sample[key] + if key == "image": + data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype("uint8") + Image.fromarray(data, "RGB").save(f"{path}/{key}.png") + else: + data = data * 255 / 4 + data = data.numpy().astype("uint8").squeeze() + Image.fromarray(data, "L").save(f"{path}/{key}.png") + i += 1