ChestXray reproduce scripts
This commit is contained in:
Родитель
a3b709d61f
Коммит
ed82a73abd
|
@ -48,7 +48,7 @@ class BaseAlgo():
|
|||
self.val_acc=[]
|
||||
self.train_acc=[]
|
||||
|
||||
if self.args.dp_noise:
|
||||
if self.args.method_name == 'dp_erm':
|
||||
self.privacy_engine= self.get_dp_noise()
|
||||
|
||||
def get_model(self):
|
||||
|
@ -206,8 +206,8 @@ class BaseAlgo():
|
|||
from opacus.dp_model_inspector import DPModelInspector
|
||||
from opacus.utils import module_modification
|
||||
|
||||
inspector = DPModelInspector()
|
||||
|
||||
inspector = DPModelInspector()
|
||||
print(self.phi)
|
||||
self.phi = module_modification.convert_batchnorm_modules(self.phi)
|
||||
print(self.phi)
|
||||
inspector.validate(self.phi)
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
import os
|
||||
import sys
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def get_base_dir(test_domain, dataset, metric):
|
||||
|
||||
if metric == 'mia':
|
||||
res_dir= 'results/'+str(dataset)+'/privacy/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'privacy_entropy':
|
||||
res_dir= 'results/'+str(dataset)+'/privacy_entropy/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'privacy_loss_attack':
|
||||
res_dir= 'results/'+str(dataset)+'/privacy_loss/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'attribute_attack':
|
||||
res_dir= 'results/'+str(dataset)+'/attribute_attack_' + data_case + '/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'acc:train':
|
||||
res_dir= 'results/' + str(dataset) + '/acc_' + 'train' + '/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'acc:test':
|
||||
res_dir= 'results/' + str(dataset) + '/acc_' + 'test' + '/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'match_score:train':
|
||||
res_dir= 'results/' + str(dataset) + '/match_score_' + 'train' + '/' + str(test_domain) + '/'
|
||||
|
||||
elif metric == 'match_score:test':
|
||||
res_dir= 'results/' + str(dataset) + '/match_score_' + 'test' + '/' + str(test_domain) + '/'
|
||||
|
||||
return res_dir
|
||||
|
||||
|
||||
#rot_mnist, fashion_mnist, rot_mnist_spur
|
||||
dataset=sys.argv[1]
|
||||
|
||||
# kaggle, nih, chex
|
||||
test_domain= sys.argv[2]
|
||||
|
||||
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Hybrid']
|
||||
methods=['erm', 'rand', 'matchdg_erm', 'csd', 'irm', 'hybrid']
|
||||
|
||||
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test']
|
||||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
|
||||
|
||||
acc_train=[]
|
||||
acc_train_err=[]
|
||||
|
||||
acc_test=[]
|
||||
acc_test_err=[]
|
||||
|
||||
mia=[]
|
||||
mia_err=[]
|
||||
|
||||
entropy=[]
|
||||
entropy_err=[]
|
||||
|
||||
loss=[]
|
||||
loss_err=[]
|
||||
|
||||
rank_train=[]
|
||||
rank_train_err=[]
|
||||
|
||||
rank_test=[]
|
||||
rank_test_err=[]
|
||||
|
||||
for metric in metrics:
|
||||
for method in methods:
|
||||
|
||||
res_dir= get_base_dir(test_domain, dataset, metric)
|
||||
|
||||
f= open(res_dir+method+'.txt')
|
||||
data= f.readlines()
|
||||
# print(data[-3].replace('\n', '').split(':')[-1].split(' '))
|
||||
mean= float(data[-3].replace('\n', '').split(':')[-1].split(' ')[-2])
|
||||
sd= float(data[-3].replace('\n', '').split(':')[-1].split(' ')[-1])
|
||||
|
||||
if metric == 'acc:train':
|
||||
acc_train.append(mean)
|
||||
acc_train_err.append(sd)
|
||||
elif metric == 'acc:test':
|
||||
acc_test.append(mean)
|
||||
acc_test_err.append(sd)
|
||||
elif metric == 'mia':
|
||||
mia.append(mean)
|
||||
mia_err.append(sd)
|
||||
elif metric == 'privacy_entropy':
|
||||
entropy.append(mean)
|
||||
entropy_err.append(sd)
|
||||
elif metric == 'privacy_loss_attack':
|
||||
loss.append(mean)
|
||||
loss_err.append(sd)
|
||||
elif metric == 'match_score:train':
|
||||
rank_train.append(mean)
|
||||
rank_train_err.append(sd)
|
||||
elif metric == 'match_score:test':
|
||||
rank_test.append(mean)
|
||||
rank_test_err.append(sd)
|
||||
|
||||
for idx in range(4):
|
||||
|
||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
||||
fontsize=30
|
||||
fontsize_lgd= fontsize/1.2
|
||||
ax.tick_params(labelsize=fontsize)
|
||||
ax.set_xticklabels(x, rotation=25)
|
||||
|
||||
if idx == 0:
|
||||
ax.errorbar(x, acc_train, yerr=acc_train_err, label='Train Accuracy', fmt='o--')
|
||||
ax.errorbar(x, acc_test, yerr=acc_test_err, label='Test Accuracy', fmt='o--')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('OOD Accuracy of ML Model', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 1:
|
||||
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
|
||||
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
|
||||
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
|
||||
ax.set_ylabel('MI Attack Accuracy', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 2:
|
||||
# ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
|
||||
ax.errorbar(x, rank_test, yerr=rank_test_err, label='Test', fmt='o--', color='green')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('Mean Rank of Perfect Match', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 3:
|
||||
ax.errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label='Train Accuracy', fmt='o--')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
|
||||
save_dir= 'results/' + dataset+ '/plots/'+ test_domain +'/'
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
|
||||
domains= ['nih', 'chex', 'kaggle']
|
||||
dataset= 'chestxray'
|
||||
|
||||
test_domain= sys.argv[1]
|
||||
metric= sys.argv[2]
|
||||
# train, acc, mia, privacy_entropy, privacy_loss_attack, match_score, feat_eval
|
||||
if metric in ['acc', 'match_score', 'feat_eval', 'feat_eval_rand', 'attribute_attack']:
|
||||
data_case= sys.argv[3]
|
||||
|
||||
if metric == 'train':
|
||||
base_script= 'python train.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 '
|
||||
res_dir= 'results/' + str(dataset) + '/train_logs/'
|
||||
|
||||
elif metric == 'mia':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric mia --mia_sample_size 200 --mia_logit 1 --batch_size 64 '
|
||||
res_dir= 'results/'+str(dataset)+'/privacy/'
|
||||
|
||||
elif metric == 'privacy_entropy':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric privacy_entropy --mia_sample_size 200 --batch_size 64 '
|
||||
res_dir= 'results/'+str(dataset)+'/privacy_entropy/'
|
||||
|
||||
elif metric == 'privacy_loss_attack':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric privacy_loss_attack --mia_sample_size 200 --batch_size 64 '
|
||||
res_dir= 'results/'+str(dataset)+'/privacy_loss/'
|
||||
|
||||
elif metric == 'acc':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric acc ' + ' --acc_data_case ' + data_case
|
||||
res_dir= 'results/' + str(dataset) + '/acc_' + str(data_case) + '/'
|
||||
|
||||
elif metric == 'match_score':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric match_score --match_func_aug_case 1' + ' --match_func_data_case ' + data_case
|
||||
res_dir= 'results/' + str(dataset) + '/match_score_' + data_case + '/'
|
||||
|
||||
elif metric == 'feat_eval':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric feat_eval --match_func_aug_case 1' + ' --match_func_data_case ' + data_case
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_' + data_case + '/'
|
||||
|
||||
elif metric == 'feat_eval_rand':
|
||||
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric feat_eval --match_func_aug_case 1' + ' --match_func_data_case ' + data_case + ' --match_case 0.0 '
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_rand_' + data_case + '/'
|
||||
|
||||
#Test Domain
|
||||
res_dir= res_dir+ test_domain + '/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
for method in methods:
|
||||
|
||||
if method == 'erm':
|
||||
script= base_script + ' --method_name erm_match --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 0.0 '
|
||||
|
||||
elif method == 'rand':
|
||||
script= base_script + ' --method_name erm_match --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 10.0 '
|
||||
|
||||
elif method == 'csd':
|
||||
script= base_script + ' --method_name csd --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 0.0 --rep_dim 1024'
|
||||
|
||||
elif method == 'irm':
|
||||
script= base_script + ' --method_name irm --epochs 40 --lr 0.001 --match_case 0.0 --penalty_irm 10.0 --penalty_s 5'
|
||||
|
||||
elif method == 'matchdg_ctr':
|
||||
script= base_script + ' --method_name matchdg_ctr --epochs 50 --batch_size 32 --match_case 0.0 --match_flag 1 --pos_metric cos'
|
||||
|
||||
elif method == 'matchdg_erm':
|
||||
script= base_script + ' --method_name matchdg_erm --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
|
||||
|
||||
elif method == 'hybrid':
|
||||
print('Yes')
|
||||
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 1.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
train_domains+= str(d) + '_trans' + ' '
|
||||
|
||||
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', test_domain)
|
||||
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + test_domain
|
||||
script= script + ' > ' + res_dir + str(method) + '.txt'
|
||||
os.system(script)
|
|
@ -18,7 +18,7 @@ from .data_loader import BaseDataLoader
|
|||
class ChestXRayAug(BaseDataLoader):
|
||||
def __init__(self, args, list_domains, root, transform=None, data_case='train', match_func=False):
|
||||
|
||||
super().__init__(args, list_train_domains, root, transform, data_case, match_func)
|
||||
super().__init__(args, list_domains, root, transform, data_case, match_func)
|
||||
self.data, self.data_org, self.labels, self.domains, self.indices, self.objects = self._get_data()
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
@ -69,8 +69,8 @@ class ChestXRayAug(BaseDataLoader):
|
|||
base_class_size=0
|
||||
base_class_idx=-1
|
||||
for d_idx, domain in enumerate( self.list_domains ):
|
||||
class_idx= training_labels[d_idx] == y_c
|
||||
curr_class_size= training_labels[d_idx][class_idx].shape[0]
|
||||
class_idx= list_labels[d_idx] == y_c
|
||||
curr_class_size= list_labels[d_idx][class_idx].shape[0]
|
||||
if base_class_size < curr_class_size:
|
||||
base_class_size= curr_class_size
|
||||
base_class_idx= d_idx
|
||||
|
@ -92,7 +92,7 @@ class ChestXRayAug(BaseDataLoader):
|
|||
data_domains = torch.zeros(data_labels.size())
|
||||
domain_start=0
|
||||
for idx in range(len(self.list_domains)):
|
||||
curr_domain_size= self.training_size[idx]
|
||||
curr_domain_size= self.training_list_size[idx]
|
||||
data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
|
||||
domain_start+= curr_domain_size
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import os
|
||||
import random
|
||||
import copy
|
||||
import os
|
||||
|
||||
#Sklearn
|
||||
from scipy.stats import bernoulli
|
||||
|
@ -99,7 +100,13 @@ dataset= sys.argv[1]
|
|||
model= sys.argv[2]
|
||||
|
||||
#Generate Dataset for Rotated / Fashion MNIST
|
||||
base_dir= 'datasets/mnist/'
|
||||
#TODO: Manage OS Env from args
|
||||
os_env=0
|
||||
if os_env:
|
||||
base_dir= os.getenv('PT_DATA_DIR') + '/mnist/'
|
||||
else:
|
||||
base_dir= 'data/datasets/mnist/'
|
||||
|
||||
if not os.path.exists(base_dir):
|
||||
os.makedirs(base_dir)
|
||||
|
||||
|
@ -207,9 +214,14 @@ for seed in seed_list:
|
|||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
if seed in [0, 1, 2]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model == 'lenet':
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
#Val
|
||||
data_case= 'val'
|
||||
if not os.path.exists(data_dir + data_case + '/'):
|
||||
|
@ -217,8 +229,13 @@ for seed in seed_list:
|
|||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[subset_size:]
|
||||
if seed in [0, 1, 2]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model == 'lenet':
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
#Test
|
||||
data_case= 'test'
|
||||
|
@ -228,10 +245,9 @@ for seed in seed_list:
|
|||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
|
||||
if model== 'lenet':
|
||||
if seed in [0, 1, 2]:
|
||||
if model == 'resnet18':
|
||||
if seed in [9] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
else:
|
||||
if seed in [9]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
elif model == 'lenet':
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
|
|
@ -33,8 +33,10 @@ if not os.path.exists(base_dir):
|
|||
|
||||
num_samples= 10000
|
||||
total_slabs= 7
|
||||
slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
|
||||
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
|
||||
# slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
|
||||
# spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
|
||||
slab_noise_list= [0.0, 0.10]
|
||||
spur_corr_list= [0.20, 0.90]
|
||||
|
||||
for seed in range(10):
|
||||
np.random.seed(seed*10)
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def get_base_dir(train_case, test_case, dataset, metric):
|
||||
|
||||
|
@ -31,6 +32,9 @@ def get_base_dir(train_case, test_case, dataset, metric):
|
|||
|
||||
elif metric == 'feat_eval:train':
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_' + 'train' + '/'
|
||||
|
||||
elif metric == 'feat_eval:test':
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_' + 'test' + '/'
|
||||
|
||||
#Train Domains 30, 45 case
|
||||
if train_case == 'train_abl_2':
|
||||
|
@ -59,7 +63,9 @@ test_case=['test_diff']
|
|||
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
|
||||
methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
|
||||
|
||||
metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train']
|
||||
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train', 'feat_eval:test']
|
||||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
|
||||
|
||||
acc_train=[]
|
||||
acc_train_err=[]
|
||||
|
@ -85,6 +91,9 @@ rank_test_err=[]
|
|||
feat_eval_train=[]
|
||||
feat_eval_train_err=[]
|
||||
|
||||
feat_eval_test=[]
|
||||
feat_eval_test_err=[]
|
||||
|
||||
for metric in metrics:
|
||||
for method in methods:
|
||||
|
||||
|
@ -120,6 +129,9 @@ for metric in metrics:
|
|||
elif metric == 'feat_eval:train':
|
||||
feat_eval_train.append(mean)
|
||||
feat_eval_train_err.append(sd)
|
||||
elif metric == 'feat_eval:test':
|
||||
feat_eval_test.append(mean)
|
||||
feat_eval_test_err.append(sd)
|
||||
|
||||
for idx in range(4):
|
||||
|
||||
|
@ -138,23 +150,32 @@ for idx in range(4):
|
|||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 1:
|
||||
ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
|
||||
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
|
||||
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
|
||||
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
|
||||
ax.set_ylabel('MI Attack Accuracy', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 2:
|
||||
ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
|
||||
# ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
|
||||
ax.errorbar(x, rank_test, yerr=rank_test_err, label='Test', fmt='o--', color='green')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('Mean Rank of Perfect Match', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
|
||||
if idx == 3:
|
||||
ax.errorbar(x, feat_eval_train, yerr=feat_eval_train_err, fmt='o--', color='brown')
|
||||
ax.errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label='Train Accuracy', fmt='o--')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('Cosine Similarity of same object features', fontsize=fontsize)
|
||||
ax.set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
# if idx == 3:
|
||||
# ax.errorbar(x, feat_eval_train, yerr=feat_eval_train_err, label='Train', fmt='o--', color='brown')
|
||||
# ax.errorbar(x, feat_eval_test, yerr=feat_eval_test_err, label='Test', fmt='o--', color='brown')
|
||||
# # ax.set_xlabel('Models', fontsize=fontsize)
|
||||
# ax.set_ylabel('Cosine Similarity of same object features', fontsize=fontsize)
|
||||
# ax.legend(fontsize=fontsize_lgd)
|
||||
|
||||
|
||||
save_dir= 'results/' + dataset+ '/plots_' + train_case + '/'
|
||||
|
|
|
@ -7,14 +7,14 @@ dataset=sys.argv[1]
|
|||
train_case= sys.argv[2]
|
||||
# train, acc, mia, privacy_entropy, privacy_loss_attack, match_score, feat_eval, attribute_attack
|
||||
metric=sys.argv[3]
|
||||
if metric in ['acc', 'match_score', 'feat_eval', 'attribute_attack']:
|
||||
if metric in ['acc', 'match_score', 'feat_eval', 'feat_eval_rand', 'attribute_attack']:
|
||||
data_case= sys.argv[4]
|
||||
|
||||
# test_diff, test_common
|
||||
test_case=['test_diff']
|
||||
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg']
|
||||
# methods= ['matchdg']
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg']
|
||||
# methods=['approx_25', 'approx_50', 'approx_75']
|
||||
|
||||
if metric == 'train':
|
||||
|
@ -61,6 +61,10 @@ elif metric == 'match_score':
|
|||
elif metric == 'feat_eval':
|
||||
base_script= 'python test.py --test_metric feat_eval ' + ' --dataset ' + str(dataset) + ' --match_func_data_case ' + data_case
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_' + data_case + '/'
|
||||
|
||||
elif metric == 'feat_eval_rand':
|
||||
base_script= 'python test.py --test_metric feat_eval ' + ' --dataset ' + str(dataset) + ' --match_func_data_case ' + data_case + ' --match_case 0.0 '
|
||||
res_dir= 'results/' + str(dataset) + '/feat_eval_rand_' + data_case + '/'
|
||||
|
||||
|
||||
#Train Domains 30, 45 case
|
||||
|
|
|
@ -11,7 +11,7 @@ if method == 'erm' or method == 'rand':
|
|||
base_script= 'python train.py --dataset pacs --method_name erm_match --match_case 0.0 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --weight_decay 0.001 '
|
||||
|
||||
elif method == 'matchdg_ctr':
|
||||
base_script= 'python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --batch_size 64 '
|
||||
base_script= 'python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --batch_size 32 '
|
||||
|
||||
elif method == 'matchdg_erm':
|
||||
base_script= 'python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --weight_decay 0.001 '
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
absl-py==0.12.0
|
||||
advertorch==0.2.3
|
||||
astor==0.8.1
|
||||
backcall==0.2.0
|
||||
cached-property==1.5.2
|
||||
certifi==2020.12.5
|
||||
chardet==4.0.0
|
||||
|
@ -13,12 +14,19 @@ h5py==3.2.1
|
|||
idna==2.10
|
||||
imageio==2.9.0
|
||||
importlib-metadata==3.10.0
|
||||
ipykernel==5.5.5
|
||||
ipython==7.24.0
|
||||
ipython-genutils==0.2.0
|
||||
jedi==0.18.0
|
||||
joblib==1.0.1
|
||||
jupyter-client==6.1.12
|
||||
jupyter-core==4.7.1
|
||||
Keras-Applications==1.0.8
|
||||
Keras-Preprocessing==1.1.2
|
||||
kiwisolver==1.3.1
|
||||
Markdown==3.3.4
|
||||
matplotlib==3.4.1
|
||||
matplotlib-inline==0.1.2
|
||||
mia==0.1.2
|
||||
more-itertools==8.8.0
|
||||
networkx==2.5
|
||||
|
@ -26,13 +34,20 @@ numpy==1.20.2
|
|||
opacus==0.13.0
|
||||
opt-einsum==3.3.0
|
||||
pandas==1.2.3
|
||||
parso==0.8.2
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==8.1.2
|
||||
prompt-toolkit==3.0.18
|
||||
protobuf==3.15.6
|
||||
ptyprocess==0.7.0
|
||||
pydicom==2.1.2
|
||||
Pygments==2.9.0
|
||||
pyparsing==2.4.7
|
||||
python-dateutil==2.8.1
|
||||
pytz==2021.1
|
||||
PyWavelets==1.1.1
|
||||
pyzmq==22.1.0
|
||||
requests==2.25.1
|
||||
scikit-image==0.18.1
|
||||
scikit-learn==0.24.1
|
||||
|
@ -48,9 +63,12 @@ tifffile==2021.3.17
|
|||
torch==1.8.1
|
||||
torchvision==0.9.1
|
||||
torchxrayvision==0.0.24
|
||||
tornado==6.1
|
||||
tqdm==4.59.0
|
||||
traitlets==5.0.5
|
||||
typing-extensions==3.7.4.3
|
||||
urllib3==1.26.4
|
||||
wcwidth==0.2.5
|
||||
Werkzeug==1.0.1
|
||||
wrapt==1.12.1
|
||||
zipp==3.4.1
|
||||
|
|
|
@ -40,7 +40,7 @@ attribute=[]
|
|||
attribute_err=[]
|
||||
|
||||
# eval_metrics= ['mi', 'entropy', 'loss', 'attribute']
|
||||
eval_metrics= ['mi', 'entropy', 'loss']
|
||||
eval_metrics= ['entropy', 'loss']
|
||||
for metric in eval_metrics:
|
||||
for method in methods:
|
||||
f= open(base_dir+method+'-'+metric+ '-' + str(test_domain)+'.txt')
|
||||
|
@ -74,16 +74,16 @@ for idx in range(0, 2):
|
|||
if idx == 0:
|
||||
# ax.errorbar(x, train_acc, yerr=train_acc_err, fmt='o--', label='Train-Acc')
|
||||
ax.errorbar(x, acc, yerr=acc_err, fmt='o--', label='Acc')
|
||||
ax.errorbar(x, auc, yerr=auc_err, fmt='o--', label='AUC')
|
||||
# ax.errorbar(x, auc, yerr=auc_err, fmt='o--', label='AUC')
|
||||
ax.errorbar(x, s_auc, yerr=s_auc_err, fmt='o--', label='Linear-RAUC')
|
||||
ax.errorbar(x, sc_auc, yerr=sc_auc_err, fmt='o--', label='Slab-RAUC')
|
||||
# ax.errorbar(x, sc_auc, yerr=sc_auc_err, fmt='o--', label='Slab-RAUC')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax.set_ylabel('ML Model Acc/ AUC', fontsize=fontsize)
|
||||
ax.set_title('OOD Evaluation', fontsize=fontsize)
|
||||
ax.legend(fontsize=fontsize)
|
||||
|
||||
if idx == 1:
|
||||
ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
|
||||
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
|
||||
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
|
||||
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
|
||||
ax.set_ylabel('Attack Model Accuracy', fontsize=fontsize)
|
||||
|
|
12
slab-run.py
12
slab-run.py
|
@ -5,10 +5,11 @@ case= sys.argv[1]
|
|||
slab_noise= float(sys.argv[2])
|
||||
total_seed= 3
|
||||
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
|
||||
# metrics= ['auc', 'mi', 'entropy', 'loss', 'attribute']
|
||||
methods=['matchdg']
|
||||
metrics= ['auc', 'mi', 'entropy', 'loss']
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg', 'mask_linear']
|
||||
# metrics= ['auc', 'mi', 'entropy', 'loss']
|
||||
metrics= ['auc', 'entropy', 'loss']
|
||||
# methods=['matchdg']
|
||||
# metrics= ['entropy', 'loss']
|
||||
|
||||
if case == 'train':
|
||||
|
||||
|
@ -78,7 +79,8 @@ elif case == 'test':
|
|||
elif method == 'matchdg':
|
||||
upd_script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
|
||||
|
||||
for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
|
||||
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
|
||||
for test_domain in [0.2, 0.9]:
|
||||
script= upd_script + ' --test_domains ' + str(test_domain) + ' > ' + res_dir + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
|
||||
os.system(script)
|
||||
|
||||
|
|
2
test.py
2
test.py
|
@ -76,6 +76,8 @@ parser.add_argument('--penalty_s', type=int, default=-1,
|
|||
help='Epoch threshold over which Matching Loss to be optimised')
|
||||
parser.add_argument('--penalty_irm', type=float, default=0.0,
|
||||
help='Penalty weight for IRM invariant classifier loss')
|
||||
parser.add_argument('--penalty_aug', type=float, default=1.0,
|
||||
help='Penalty weight for Augmentation in Hybrid approach loss')
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.1,
|
||||
help='Penalty weight for Matching Loss')
|
||||
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
|
||||
|
|
4
train.py
4
train.py
|
@ -121,7 +121,7 @@ parser.add_argument('--test_metric', type=str, default='match_score',
|
|||
help='Evaluation Metrics: acc; match_score, t_sne, mia')
|
||||
parser.add_argument('--top_k', type=int, default=10,
|
||||
help='Top K matches to consider for the match score evaluation metric')
|
||||
parser.add_argument('--match_func_aug_case', type=int, default=1,
|
||||
parser.add_argument('--match_func_aug_case', type=int, default=0,
|
||||
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
|
||||
parser.add_argument('--match_func_data_case', type=str, default='val',
|
||||
help='Dataset Train/Val/Test for the match score evaluation metric')
|
||||
|
@ -177,7 +177,7 @@ for run in range(args.n_runs):
|
|||
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
|
||||
|
||||
#Import the module as per the current training method
|
||||
if args.method_name == 'erm_match' or args.method_name == 'mask_linear':
|
||||
if args.method_name == 'erm_match' or args.method_name == 'mask_linear' or args.method_name == 'dp_erm':
|
||||
from algorithms.erm_match import ErmMatch
|
||||
train_method= ErmMatch(
|
||||
args, train_dataset, val_dataset,
|
||||
|
|
|
@ -67,7 +67,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
domain_count[domain]= 0
|
||||
|
||||
# Create dictionary: class label -> list of ordered indices
|
||||
if args.method_name == 'hybrid':
|
||||
if args.method_name == 'hybrid' and args.match_func_aug_case == 0:
|
||||
for batch_idx, (x_e, _, y_e ,d_e, idx_e, obj_e) in enumerate(train_dataset):
|
||||
x_e= x_e
|
||||
y_e= torch.argmax(y_e, dim=1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче