This commit is contained in:
divyat09 2021-10-04 11:55:27 +00:00
Родитель 142da80d36
Коммит 30a4aca36e
4 изменённых файлов: 46 добавлений и 71 удалений

Просмотреть файл

@ -2859,9 +2859,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "drwils_env", "display_name": "matchdg-env",
"language": "python", "language": "python",
"name": "drwils_env" "name": "matchdg-env"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -61,7 +61,7 @@ dataset=sys.argv[1]
test_case=['test_diff'] test_case=['test_diff']
matplotlib.rcParams.update({'errorbar.capsize': 2}) matplotlib.rcParams.update({'errorbar.capsize': 2})
fig, ax = plt.subplots(1, 4, figsize=(33, 8)) fig, ax = plt.subplots(1, 3, figsize=(33, 8))
fontsize=35 fontsize=35
fontsize_lgd= fontsize/1.2 fontsize_lgd= fontsize/1.2
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf'] x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
@ -70,7 +70,7 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test'] metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
for idx in range(4): for idx in range(3):
marker_list = ['o', '^', '*'] marker_list = ['o', '^', '*']
legend_count = 0 legend_count = 0

Просмотреть файл

@ -87,11 +87,11 @@ for test_domain in [0.2, 0.9]:
ax[count].set_xticklabels(x, rotation=25) ax[count].set_xticklabels(x, rotation=25)
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc') ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC') ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Stable Features (Linear-RAUC)')
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--') ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='MI Attack Acc', fmt='o--')
# gen_gap= np.array(train_acc) - np.array(acc) # gen_gap= np.array(train_acc) - np.array(acc)
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap') # ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Generalization Gap')
ax[count].set_ylabel('Metric Score', fontsize=fontsize) ax[count].set_ylabel('Metric Score', fontsize=fontsize)
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize) ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
@ -99,7 +99,7 @@ for test_domain in [0.2, 0.9]:
count+=1 count+=1
lines, labels = fig.axes[-1].get_legend_handles_labels() lines, labels = fig.axes[-1].get_legend_handles_labels()
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=3) lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=4)
save_dir= 'results/slab/plots/' save_dir= 'results/slab/plots/'
if not os.path.exists(save_dir): if not os.path.exists(save_dir):