Changes to MNIST Plots
This commit is contained in:
Родитель
142da80d36
Коммит
30a4aca36e
|
@ -2859,9 +2859,9 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "drwils_env",
|
"display_name": "matchdg-env",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "drwils_env"
|
"name": "matchdg-env"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -61,7 +61,7 @@ dataset=sys.argv[1]
|
||||||
test_case=['test_diff']
|
test_case=['test_diff']
|
||||||
|
|
||||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||||
fig, ax = plt.subplots(1, 4, figsize=(33, 8))
|
fig, ax = plt.subplots(1, 3, figsize=(33, 8))
|
||||||
fontsize=35
|
fontsize=35
|
||||||
fontsize_lgd= fontsize/1.2
|
fontsize_lgd= fontsize/1.2
|
||||||
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
|
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
|
||||||
|
@ -70,7 +70,7 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
|
||||||
|
|
||||||
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
|
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
|
||||||
|
|
||||||
for idx in range(4):
|
for idx in range(3):
|
||||||
|
|
||||||
marker_list = ['o', '^', '*']
|
marker_list = ['o', '^', '*']
|
||||||
legend_count = 0
|
legend_count = 0
|
||||||
|
|
|
@ -87,11 +87,11 @@ for test_domain in [0.2, 0.9]:
|
||||||
ax[count].set_xticklabels(x, rotation=25)
|
ax[count].set_xticklabels(x, rotation=25)
|
||||||
|
|
||||||
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
|
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
|
||||||
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
|
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Stable Features (Linear-RAUC)')
|
||||||
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
|
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='MI Attack Acc', fmt='o--')
|
||||||
|
|
||||||
# gen_gap= np.array(train_acc) - np.array(acc)
|
# gen_gap= np.array(train_acc) - np.array(acc)
|
||||||
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
|
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Generalization Gap')
|
||||||
|
|
||||||
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
|
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
|
||||||
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
|
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
|
||||||
|
@ -99,7 +99,7 @@ for test_domain in [0.2, 0.9]:
|
||||||
count+=1
|
count+=1
|
||||||
|
|
||||||
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
||||||
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=3)
|
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=4)
|
||||||
|
|
||||||
save_dir= 'results/slab/plots/'
|
save_dir= 'results/slab/plots/'
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче