зеркало из https://github.com/microsoft/torchgeo.git
Add scripts to generate plots from paper (#186)
* Add scripts to generate plots from paper * Style fix
This commit is contained in:
Родитель
b578cb73a2
Коммит
5b77ded218
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df1 = pd.read_csv("original-benchmark-results.csv")
|
||||
df2 = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
mean1 = df1.groupby("sampler").mean()
|
||||
mean2 = df2.groupby("sampler").mean()
|
||||
|
||||
cached1 = (
|
||||
df1[(df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
cached2 = (
|
||||
df2[(df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
not_cached1 = (
|
||||
df1[(~df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
not_cached2 = (
|
||||
df2[(~df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean()
|
||||
)
|
||||
|
||||
print("cached, original\n", cached1)
|
||||
print("cached, warped\n", cached2)
|
||||
print("not cached, original\n", not_cached1)
|
||||
print("not cached, warped\n", not_cached2)
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
labels = ["GridGeoSampler", "RandomBatchGeoSampler", "RandomGeoSampler"]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
x = np.arange(3)
|
||||
width = 0.2
|
||||
|
||||
rects1 = ax.bar(
|
||||
x - width * 3 / 2,
|
||||
not_cached1["rate"],
|
||||
width,
|
||||
label="Raw Data, Not Cached",
|
||||
color=cmap[0],
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
x - width * 1 / 2,
|
||||
not_cached2["rate"],
|
||||
width,
|
||||
label="Preprocessed, Not Cached",
|
||||
color=cmap[1],
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
x + width * 1 / 2, cached1["rate"], width, label="Raw Data, Cached", color=cmap[2]
|
||||
)
|
||||
rects3 = ax.bar(
|
||||
x + width * 3 / 2,
|
||||
cached2["rate"],
|
||||
width,
|
||||
label="Preprocessed, Cached",
|
||||
color=cmap[3],
|
||||
)
|
||||
|
||||
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, fontsize=12)
|
||||
ax.tick_params(axis="x", labelrotation=10)
|
||||
ax.legend(fontsize="large")
|
||||
|
||||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
random_cached = df[(df["sampler"] == "RandomGeoSampler") & (df["cached"])]
|
||||
random_batch_cached = df[(df["sampler"] == "RandomBatchGeoSampler") & (df["cached"])]
|
||||
grid_cached = df[(df["sampler"] == "GridGeoSampler") & (df["cached"])]
|
||||
other = [
|
||||
("RandomGeoSampler", random_cached),
|
||||
("RandomBatchGeoSampler", random_batch_cached),
|
||||
("GridGeoSampler", grid_cached),
|
||||
]
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
for i, (label, df) in enumerate(other):
|
||||
df = df.groupby("batch_size")
|
||||
ax.plot(df.mean().index, df.mean()["rate"], color=cmap[i], label=label)
|
||||
ax.fill_between(
|
||||
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
|
||||
)
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
|
||||
ax.set_xlabel("batch size", fontsize=12)
|
||||
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
|
||||
ax.legend(loc="center right", fontsize="large")
|
||||
|
||||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
df1 = pd.read_csv("original-benchmark-results.csv")
|
||||
df2 = pd.read_csv("warped-benchmark-results.csv")
|
||||
|
||||
random_cached1 = df1[(df1["sampler"] == "RandomGeoSampler") & (df1["cached"])]
|
||||
random_cached2 = df2[(df2["sampler"] == "RandomGeoSampler") & (df2["cached"])]
|
||||
random_cachedp = random_cached1
|
||||
random_cachedp["rate"] /= random_cached2["rate"]
|
||||
|
||||
random_batch_cached1 = df1[
|
||||
(df1["sampler"] == "RandomBatchGeoSampler") & (df1["cached"])
|
||||
]
|
||||
random_batch_cached2 = df2[
|
||||
(df2["sampler"] == "RandomBatchGeoSampler") & (df2["cached"])
|
||||
]
|
||||
random_batch_cachedp = random_batch_cached1
|
||||
random_batch_cachedp["rate"] /= random_batch_cached2["rate"]
|
||||
|
||||
grid_cached1 = df1[(df1["sampler"] == "GridGeoSampler") & (df1["cached"])]
|
||||
grid_cached2 = df2[(df2["sampler"] == "GridGeoSampler") & (df2["cached"])]
|
||||
grid_cachedp = grid_cached1
|
||||
grid_cachedp["rate"] /= grid_cached2["rate"]
|
||||
|
||||
other = [
|
||||
("RandomGeoSampler (cached)", random_cachedp),
|
||||
("RandomBatchGeoSampler (cached)", random_batch_cachedp),
|
||||
("GridGeoSampler (cached)", grid_cachedp),
|
||||
]
|
||||
|
||||
cmap = sns.color_palette()
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
for i, (label, df) in enumerate(other):
|
||||
df = df.groupby("batch_size")
|
||||
ax.plot([16, 32, 64, 128, 256], df.mean()["rate"], color=cmap[i], label=label)
|
||||
ax.fill_between(
|
||||
df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2
|
||||
)
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256])
|
||||
ax.set_xlabel("batch size")
|
||||
ax.set_ylabel("% sampling rate (patches/sec)")
|
||||
ax.legend()
|
||||
plt.show()
|
Загрузка…
Ссылка в новой задаче