This commit is contained in:
divyat09 2021-06-03 11:30:34 +00:00
Родитель a3b709d61f
Коммит ed82a73abd
15 изменённых файлов: 333 добавлений и 40 удалений

Просмотреть файл

@ -48,7 +48,7 @@ class BaseAlgo():
self.val_acc=[]
self.train_acc=[]
if self.args.dp_noise:
if self.args.method_name == 'dp_erm':
self.privacy_engine= self.get_dp_noise()
def get_model(self):
@ -206,8 +206,8 @@ class BaseAlgo():
from opacus.dp_model_inspector import DPModelInspector
from opacus.utils import module_modification
inspector = DPModelInspector()
inspector = DPModelInspector()
print(self.phi)
self.phi = module_modification.convert_batchnorm_modules(self.phi)
print(self.phi)
inspector.validate(self.phi)

145
cxray_plot.py Normal file
Просмотреть файл

@ -0,0 +1,145 @@
import os
import sys
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
def get_base_dir(test_domain, dataset, metric):
if metric == 'mia':
res_dir= 'results/'+str(dataset)+'/privacy/' + str(test_domain) + '/'
elif metric == 'privacy_entropy':
res_dir= 'results/'+str(dataset)+'/privacy_entropy/' + str(test_domain) + '/'
elif metric == 'privacy_loss_attack':
res_dir= 'results/'+str(dataset)+'/privacy_loss/' + str(test_domain) + '/'
elif metric == 'attribute_attack':
res_dir= 'results/'+str(dataset)+'/attribute_attack_' + data_case + '/' + str(test_domain) + '/'
elif metric == 'acc:train':
res_dir= 'results/' + str(dataset) + '/acc_' + 'train' + '/' + str(test_domain) + '/'
elif metric == 'acc:test':
res_dir= 'results/' + str(dataset) + '/acc_' + 'test' + '/' + str(test_domain) + '/'
elif metric == 'match_score:train':
res_dir= 'results/' + str(dataset) + '/match_score_' + 'train' + '/' + str(test_domain) + '/'
elif metric == 'match_score:test':
res_dir= 'results/' + str(dataset) + '/match_score_' + 'test' + '/' + str(test_domain) + '/'
return res_dir
#rot_mnist, fashion_mnist, rot_mnist_spur
dataset=sys.argv[1]
# kaggle, nih, chex
test_domain= sys.argv[2]
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Hybrid']
methods=['erm', 'rand', 'matchdg_erm', 'csd', 'irm', 'hybrid']
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test']
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
acc_train=[]
acc_train_err=[]
acc_test=[]
acc_test_err=[]
mia=[]
mia_err=[]
entropy=[]
entropy_err=[]
loss=[]
loss_err=[]
rank_train=[]
rank_train_err=[]
rank_test=[]
rank_test_err=[]
for metric in metrics:
for method in methods:
res_dir= get_base_dir(test_domain, dataset, metric)
f= open(res_dir+method+'.txt')
data= f.readlines()
# print(data[-3].replace('\n', '').split(':')[-1].split(' '))
mean= float(data[-3].replace('\n', '').split(':')[-1].split(' ')[-2])
sd= float(data[-3].replace('\n', '').split(':')[-1].split(' ')[-1])
if metric == 'acc:train':
acc_train.append(mean)
acc_train_err.append(sd)
elif metric == 'acc:test':
acc_test.append(mean)
acc_test_err.append(sd)
elif metric == 'mia':
mia.append(mean)
mia_err.append(sd)
elif metric == 'privacy_entropy':
entropy.append(mean)
entropy_err.append(sd)
elif metric == 'privacy_loss_attack':
loss.append(mean)
loss_err.append(sd)
elif metric == 'match_score:train':
rank_train.append(mean)
rank_train_err.append(sd)
elif metric == 'match_score:test':
rank_test.append(mean)
rank_test_err.append(sd)
for idx in range(4):
matplotlib.rcParams.update({'errorbar.capsize': 2})
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
fontsize=30
fontsize_lgd= fontsize/1.2
ax.tick_params(labelsize=fontsize)
ax.set_xticklabels(x, rotation=25)
if idx == 0:
ax.errorbar(x, acc_train, yerr=acc_train_err, label='Train Accuracy', fmt='o--')
ax.errorbar(x, acc_test, yerr=acc_test_err, label='Test Accuracy', fmt='o--')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('OOD Accuracy of ML Model', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
if idx == 1:
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
ax.set_ylabel('MI Attack Accuracy', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
if idx == 2:
# ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
ax.errorbar(x, rank_test, yerr=rank_test_err, label='Test', fmt='o--', color='green')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('Mean Rank of Perfect Match', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
if idx == 3:
ax.errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label='Train Accuracy', fmt='o--')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
save_dir= 'results/' + dataset+ '/plots/'+ test_domain +'/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.tight_layout()
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)

83
cxray_run.py Normal file
Просмотреть файл

@ -0,0 +1,83 @@
import os
import sys
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
domains= ['nih', 'chex', 'kaggle']
dataset= 'chestxray'
test_domain= sys.argv[1]
metric= sys.argv[2]
# train, acc, mia, privacy_entropy, privacy_loss_attack, match_score, feat_eval
if metric in ['acc', 'match_score', 'feat_eval', 'feat_eval_rand', 'attribute_attack']:
data_case= sys.argv[3]
if metric == 'train':
base_script= 'python train.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 '
res_dir= 'results/' + str(dataset) + '/train_logs/'
elif metric == 'mia':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric mia --mia_sample_size 200 --mia_logit 1 --batch_size 64 '
res_dir= 'results/'+str(dataset)+'/privacy/'
elif metric == 'privacy_entropy':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric privacy_entropy --mia_sample_size 200 --batch_size 64 '
res_dir= 'results/'+str(dataset)+'/privacy_entropy/'
elif metric == 'privacy_loss_attack':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric privacy_loss_attack --mia_sample_size 200 --batch_size 64 '
res_dir= 'results/'+str(dataset)+'/privacy_loss/'
elif metric == 'acc':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric acc ' + ' --acc_data_case ' + data_case
res_dir= 'results/' + str(dataset) + '/acc_' + str(data_case) + '/'
elif metric == 'match_score':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric match_score --match_func_aug_case 1' + ' --match_func_data_case ' + data_case
res_dir= 'results/' + str(dataset) + '/match_score_' + data_case + '/'
elif metric == 'feat_eval':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric feat_eval --match_func_aug_case 1' + ' --match_func_data_case ' + data_case
res_dir= 'results/' + str(dataset) + '/feat_eval_' + data_case + '/'
elif metric == 'feat_eval_rand':
base_script= 'python test.py --dataset chestxray --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --test_metric feat_eval --match_func_aug_case 1' + ' --match_func_data_case ' + data_case + ' --match_case 0.0 '
res_dir= 'results/' + str(dataset) + '/feat_eval_rand_' + data_case + '/'
#Test Domain
res_dir= res_dir+ test_domain + '/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
for method in methods:
if method == 'erm':
script= base_script + ' --method_name erm_match --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 0.0 '
elif method == 'rand':
script= base_script + ' --method_name erm_match --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 10.0 '
elif method == 'csd':
script= base_script + ' --method_name csd --epochs 40 --lr 0.001 --match_case 0.0 --penalty_ws 0.0 --rep_dim 1024'
elif method == 'irm':
script= base_script + ' --method_name irm --epochs 40 --lr 0.001 --match_case 0.0 --penalty_irm 10.0 --penalty_s 5'
elif method == 'matchdg_ctr':
script= base_script + ' --method_name matchdg_ctr --epochs 50 --batch_size 32 --match_case 0.0 --match_flag 1 --pos_metric cos'
elif method == 'matchdg_erm':
script= base_script + ' --method_name matchdg_erm --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
elif method == 'hybrid':
print('Yes')
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 1.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
train_domains=''
for d in domains:
if d != test_domain:
train_domains+= str(d) + '_trans' + ' '
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', test_domain)
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + test_domain
script= script + ' > ' + res_dir + str(method) + '.txt'
os.system(script)

Просмотреть файл

@ -18,7 +18,7 @@ from .data_loader import BaseDataLoader
class ChestXRayAug(BaseDataLoader):
def __init__(self, args, list_domains, root, transform=None, data_case='train', match_func=False):
super().__init__(args, list_train_domains, root, transform, data_case, match_func)
super().__init__(args, list_domains, root, transform, data_case, match_func)
self.data, self.data_org, self.labels, self.domains, self.indices, self.objects = self._get_data()
def __getitem__(self, index):
@ -69,8 +69,8 @@ class ChestXRayAug(BaseDataLoader):
base_class_size=0
base_class_idx=-1
for d_idx, domain in enumerate( self.list_domains ):
class_idx= training_labels[d_idx] == y_c
curr_class_size= training_labels[d_idx][class_idx].shape[0]
class_idx= list_labels[d_idx] == y_c
curr_class_size= list_labels[d_idx][class_idx].shape[0]
if base_class_size < curr_class_size:
base_class_size= curr_class_size
base_class_idx= d_idx
@ -92,7 +92,7 @@ class ChestXRayAug(BaseDataLoader):
data_domains = torch.zeros(data_labels.size())
domain_start=0
for idx in range(len(self.list_domains)):
curr_domain_size= self.training_size[idx]
curr_domain_size= self.training_list_size[idx]
data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
domain_start+= curr_domain_size

Просмотреть файл

@ -4,6 +4,7 @@ import sys
import os
import random
import copy
import os
#Sklearn
from scipy.stats import bernoulli
@ -99,7 +100,13 @@ dataset= sys.argv[1]
model= sys.argv[2]
#Generate Dataset for Rotated / Fashion MNIST
base_dir= 'datasets/mnist/'
#TODO: Manage OS Env from args
os_env=0
if os_env:
base_dir= os.getenv('PT_DATA_DIR') + '/mnist/'
else:
base_dir= 'data/datasets/mnist/'
if not os.path.exists(base_dir):
os.makedirs(base_dir)
@ -207,9 +214,14 @@ for seed in seed_list:
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if seed in [0, 1, 2]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model == 'lenet':
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
#Val
data_case= 'val'
if not os.path.exists(data_dir + data_case + '/'):
@ -217,8 +229,13 @@ for seed in seed_list:
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[subset_size:]
if seed in [0, 1, 2]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model == 'lenet':
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
#Test
data_case= 'test'
@ -228,10 +245,9 @@ for seed in seed_list:
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if model== 'lenet':
if seed in [0, 1, 2]:
if model == 'resnet18':
if seed in [9] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
else:
if seed in [9]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model == 'lenet':
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)

Просмотреть файл

@ -33,8 +33,10 @@ if not os.path.exists(base_dir):
num_samples= 10000
total_slabs= 7
slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
# slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
# spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
slab_noise_list= [0.0, 0.10]
spur_corr_list= [0.20, 0.90]
for seed in range(10):
np.random.seed(seed*10)

Просмотреть файл

@ -2,6 +2,7 @@ import os
import sys
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
def get_base_dir(train_case, test_case, dataset, metric):
@ -31,6 +32,9 @@ def get_base_dir(train_case, test_case, dataset, metric):
elif metric == 'feat_eval:train':
res_dir= 'results/' + str(dataset) + '/feat_eval_' + 'train' + '/'
elif metric == 'feat_eval:test':
res_dir= 'results/' + str(dataset) + '/feat_eval_' + 'test' + '/'
#Train Domains 30, 45 case
if train_case == 'train_abl_2':
@ -59,7 +63,9 @@ test_case=['test_diff']
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train']
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train', 'feat_eval:test']
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
acc_train=[]
acc_train_err=[]
@ -85,6 +91,9 @@ rank_test_err=[]
feat_eval_train=[]
feat_eval_train_err=[]
feat_eval_test=[]
feat_eval_test_err=[]
for metric in metrics:
for method in methods:
@ -120,6 +129,9 @@ for metric in metrics:
elif metric == 'feat_eval:train':
feat_eval_train.append(mean)
feat_eval_train_err.append(sd)
elif metric == 'feat_eval:test':
feat_eval_test.append(mean)
feat_eval_test_err.append(sd)
for idx in range(4):
@ -138,23 +150,32 @@ for idx in range(4):
ax.legend(fontsize=fontsize_lgd)
if idx == 1:
ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
ax.set_ylabel('MI Attack Accuracy', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
if idx == 2:
ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
# ax.errorbar(x, rank_train, yerr=rank_train_err, label='Train', fmt='o--', color='brown')
ax.errorbar(x, rank_test, yerr=rank_test_err, label='Test', fmt='o--', color='green')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('Mean Rank of Perfect Match', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
if idx == 3:
ax.errorbar(x, feat_eval_train, yerr=feat_eval_train_err, fmt='o--', color='brown')
ax.errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label='Train Accuracy', fmt='o--')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('Cosine Similarity of same object features', fontsize=fontsize)
ax.set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
ax.legend(fontsize=fontsize_lgd)
# if idx == 3:
# ax.errorbar(x, feat_eval_train, yerr=feat_eval_train_err, label='Train', fmt='o--', color='brown')
# ax.errorbar(x, feat_eval_test, yerr=feat_eval_test_err, label='Test', fmt='o--', color='brown')
# # ax.set_xlabel('Models', fontsize=fontsize)
# ax.set_ylabel('Cosine Similarity of same object features', fontsize=fontsize)
# ax.legend(fontsize=fontsize_lgd)
save_dir= 'results/' + dataset+ '/plots_' + train_case + '/'

Просмотреть файл

@ -7,14 +7,14 @@ dataset=sys.argv[1]
train_case= sys.argv[2]
# train, acc, mia, privacy_entropy, privacy_loss_attack, match_score, feat_eval, attribute_attack
metric=sys.argv[3]
if metric in ['acc', 'match_score', 'feat_eval', 'attribute_attack']:
if metric in ['acc', 'match_score', 'feat_eval', 'feat_eval_rand', 'attribute_attack']:
data_case= sys.argv[4]
# test_diff, test_common
test_case=['test_diff']
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg']
# methods= ['matchdg']
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg']
# methods=['approx_25', 'approx_50', 'approx_75']
if metric == 'train':
@ -61,6 +61,10 @@ elif metric == 'match_score':
elif metric == 'feat_eval':
base_script= 'python test.py --test_metric feat_eval ' + ' --dataset ' + str(dataset) + ' --match_func_data_case ' + data_case
res_dir= 'results/' + str(dataset) + '/feat_eval_' + data_case + '/'
elif metric == 'feat_eval_rand':
base_script= 'python test.py --test_metric feat_eval ' + ' --dataset ' + str(dataset) + ' --match_func_data_case ' + data_case + ' --match_case 0.0 '
res_dir= 'results/' + str(dataset) + '/feat_eval_rand_' + data_case + '/'
#Train Domains 30, 45 case

Просмотреть файл

@ -11,7 +11,7 @@ if method == 'erm' or method == 'rand':
base_script= 'python train.py --dataset pacs --method_name erm_match --match_case 0.0 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --weight_decay 0.001 '
elif method == 'matchdg_ctr':
base_script= 'python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --batch_size 64 '
base_script= 'python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --batch_size 32 '
elif method == 'matchdg_erm':
base_script= 'python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --weight_decay 0.001 '

Просмотреть файл

@ -1,6 +1,7 @@
absl-py==0.12.0
advertorch==0.2.3
astor==0.8.1
backcall==0.2.0
cached-property==1.5.2
certifi==2020.12.5
chardet==4.0.0
@ -13,12 +14,19 @@ h5py==3.2.1
idna==2.10
imageio==2.9.0
importlib-metadata==3.10.0
ipykernel==5.5.5
ipython==7.24.0
ipython-genutils==0.2.0
jedi==0.18.0
joblib==1.0.1
jupyter-client==6.1.12
jupyter-core==4.7.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.4.1
matplotlib-inline==0.1.2
mia==0.1.2
more-itertools==8.8.0
networkx==2.5
@ -26,13 +34,20 @@ numpy==1.20.2
opacus==0.13.0
opt-einsum==3.3.0
pandas==1.2.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.1.2
prompt-toolkit==3.0.18
protobuf==3.15.6
ptyprocess==0.7.0
pydicom==2.1.2
Pygments==2.9.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
PyWavelets==1.1.1
pyzmq==22.1.0
requests==2.25.1
scikit-image==0.18.1
scikit-learn==0.24.1
@ -48,9 +63,12 @@ tifffile==2021.3.17
torch==1.8.1
torchvision==0.9.1
torchxrayvision==0.0.24
tornado==6.1
tqdm==4.59.0
traitlets==5.0.5
typing-extensions==3.7.4.3
urllib3==1.26.4
wcwidth==0.2.5
Werkzeug==1.0.1
wrapt==1.12.1
zipp==3.4.1

Просмотреть файл

@ -40,7 +40,7 @@ attribute=[]
attribute_err=[]
# eval_metrics= ['mi', 'entropy', 'loss', 'attribute']
eval_metrics= ['mi', 'entropy', 'loss']
eval_metrics= ['entropy', 'loss']
for metric in eval_metrics:
for method in methods:
f= open(base_dir+method+'-'+metric+ '-' + str(test_domain)+'.txt')
@ -74,16 +74,16 @@ for idx in range(0, 2):
if idx == 0:
# ax.errorbar(x, train_acc, yerr=train_acc_err, fmt='o--', label='Train-Acc')
ax.errorbar(x, acc, yerr=acc_err, fmt='o--', label='Acc')
ax.errorbar(x, auc, yerr=auc_err, fmt='o--', label='AUC')
# ax.errorbar(x, auc, yerr=auc_err, fmt='o--', label='AUC')
ax.errorbar(x, s_auc, yerr=s_auc_err, fmt='o--', label='Linear-RAUC')
ax.errorbar(x, sc_auc, yerr=sc_auc_err, fmt='o--', label='Slab-RAUC')
# ax.errorbar(x, sc_auc, yerr=sc_auc_err, fmt='o--', label='Slab-RAUC')
# ax.set_xlabel('Models', fontsize=fontsize)
ax.set_ylabel('ML Model Acc/ AUC', fontsize=fontsize)
ax.set_title('OOD Evaluation', fontsize=fontsize)
ax.legend(fontsize=fontsize)
if idx == 1:
ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
# ax.errorbar(x, mia, yerr=mia_err, label='Classifier Attack', color='blue', fmt='o--')
ax.errorbar(x, entropy, yerr=entropy_err, label='Entropy Attack', color='red', fmt='o--')
ax.errorbar(x, loss, yerr=loss_err, label='Loss Attack', color='orange', fmt='o--')
ax.set_ylabel('Attack Model Accuracy', fontsize=fontsize)

Просмотреть файл

@ -5,10 +5,11 @@ case= sys.argv[1]
slab_noise= float(sys.argv[2])
total_seed= 3
# methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
# metrics= ['auc', 'mi', 'entropy', 'loss', 'attribute']
methods=['matchdg']
metrics= ['auc', 'mi', 'entropy', 'loss']
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg', 'mask_linear']
# metrics= ['auc', 'mi', 'entropy', 'loss']
metrics= ['auc', 'entropy', 'loss']
# methods=['matchdg']
# metrics= ['entropy', 'loss']
if case == 'train':
@ -78,7 +79,8 @@ elif case == 'test':
elif method == 'matchdg':
upd_script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
for test_domain in [0.2, 0.9]:
script= upd_script + ' --test_domains ' + str(test_domain) + ' > ' + res_dir + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
os.system(script)

Просмотреть файл

@ -76,6 +76,8 @@ parser.add_argument('--penalty_s', type=int, default=-1,
help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0,
help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1,
help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,

Просмотреть файл

@ -121,7 +121,7 @@ parser.add_argument('--test_metric', type=str, default='match_score',
help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--top_k', type=int, default=10,
help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=1,
parser.add_argument('--match_func_aug_case', type=int, default=0,
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val',
help='Dataset Train/Val/Test for the match score evaluation metric')
@ -177,7 +177,7 @@ for run in range(args.n_runs):
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
#Import the module as per the current training method
if args.method_name == 'erm_match' or args.method_name == 'mask_linear':
if args.method_name == 'erm_match' or args.method_name == 'mask_linear' or args.method_name == 'dp_erm':
from algorithms.erm_match import ErmMatch
train_method= ErmMatch(
args, train_dataset, val_dataset,

Просмотреть файл

@ -67,7 +67,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
domain_count[domain]= 0
# Create dictionary: class label -> list of ordered indices
if args.method_name == 'hybrid':
if args.method_name == 'hybrid' and args.match_func_aug_case == 0:
for batch_idx, (x_e, _, y_e ,d_e, idx_e, obj_e) in enumerate(train_dataset):
x_e= x_e
y_e= torch.argmax(y_e, dim=1)