Changes to MNIST Plots
This commit is contained in:
Родитель
142da80d36
Коммит
30a4aca36e
|
@ -2859,9 +2859,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drwils_env",
|
||||
"display_name": "matchdg-env",
|
||||
"language": "python",
|
||||
"name": "drwils_env"
|
||||
"name": "matchdg-env"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -61,7 +61,7 @@ dataset=sys.argv[1]
|
|||
test_case=['test_diff']
|
||||
|
||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||
fig, ax = plt.subplots(1, 4, figsize=(33, 8))
|
||||
fig, ax = plt.subplots(1, 3, figsize=(33, 8))
|
||||
fontsize=35
|
||||
fontsize_lgd= fontsize/1.2
|
||||
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
|
||||
|
@ -70,7 +70,7 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
|
|||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
|
||||
|
||||
for idx in range(4):
|
||||
for idx in range(3):
|
||||
|
||||
marker_list = ['o', '^', '*']
|
||||
legend_count = 0
|
||||
|
|
|
@ -87,11 +87,11 @@ for test_domain in [0.2, 0.9]:
|
|||
ax[count].set_xticklabels(x, rotation=25)
|
||||
|
||||
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
|
||||
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
|
||||
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
|
||||
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Stable Features (Linear-RAUC)')
|
||||
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='MI Attack Acc', fmt='o--')
|
||||
|
||||
# gen_gap= np.array(train_acc) - np.array(acc)
|
||||
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
|
||||
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Generalization Gap')
|
||||
|
||||
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
|
||||
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
|
||||
|
@ -99,7 +99,7 @@ for test_domain in [0.2, 0.9]:
|
|||
count+=1
|
||||
|
||||
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
||||
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=3)
|
||||
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=4)
|
||||
|
||||
save_dir= 'results/slab/plots/'
|
||||
if not os.path.exists(save_dir):
|
||||
|
|
Загрузка…
Ссылка в новой задаче