Attribute Attack on C-MNIST modifications
This commit is contained in:
Родитель
29136f47f6
Коммит
142da80d36
|
@ -18,7 +18,7 @@ import torch
|
|||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h):
|
||||
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute):
|
||||
|
||||
# Get total number of labeled examples
|
||||
mnist_labels = labels[indices]
|
||||
|
@ -64,9 +64,11 @@ def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, doma
|
|||
if rand_var[i]:
|
||||
# Change colors per label for test domains relative to the train domains
|
||||
if data_case == 'test':
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
|
||||
# Choose this for test domain with permuted colors
|
||||
# curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
|
||||
if cmnist_permute:
|
||||
# Choose this for test domain with permuted colors
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
|
||||
else:
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
|
||||
else:
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
|
||||
else:
|
||||
|
@ -107,6 +109,7 @@ parser.add_argument('--data_size', type=int, default=60000)
|
|||
parser.add_argument('--subset_size', type=int, default=2000)
|
||||
parser.add_argument('--img_w', type=int, default=224)
|
||||
parser.add_argument('--img_h', type=int, default=224)
|
||||
parser.add_argument('--cmnist_permute', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -117,6 +120,7 @@ img_h= args.img_h
|
|||
data_size= args.data_size
|
||||
subset_size= args.subset_size
|
||||
val_size= int(args.subset_size/5)
|
||||
cmnist_permute= args.cmnist_permute
|
||||
|
||||
#Generate Dataset for Rotated / Fashion MNIST
|
||||
#TODO: Manage OS Env from args
|
||||
|
@ -184,7 +188,18 @@ for seed in seed_list:
|
|||
res=np.random.choice(data_size, subset_size+val_size)
|
||||
print('Seed: ', seed)
|
||||
for domain in domains:
|
||||
|
||||
|
||||
# The case of permuted test domain for colored rotated MNIST, only update test data
|
||||
if dataset == 'rot_mnist_spur' and cmnist_permute:
|
||||
#Test
|
||||
data_case= 'test'
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
if seed in [9] and domain in [0, 15, 30, 45, 60, 75, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
continue
|
||||
|
||||
#Train
|
||||
data_case= 'train'
|
||||
if not os.path.exists(data_dir + data_case + '/'):
|
||||
|
@ -195,10 +210,10 @@ for seed in seed_list:
|
|||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
elif model in ['lenet']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
#Val
|
||||
data_case= 'val'
|
||||
|
@ -210,10 +225,10 @@ for seed in seed_list:
|
|||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
elif model in ['lenet']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
#Test
|
||||
data_case= 'test'
|
||||
|
@ -222,10 +237,38 @@ for seed in seed_list:
|
|||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2, 9] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
elif model in ['lenet', 'lenet_mdg']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
|
||||
# Extra data sampling for carrying out the attribute attack on spurious rotated mnist
|
||||
if dataset == 'rot_mnist_spur':
|
||||
|
||||
#Train
|
||||
data_case= 'train'
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
if seed in [0, 1, 2] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
#Val
|
||||
data_case= 'val'
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[subset_size:]
|
||||
if seed in [0, 1, 2] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
#Test
|
||||
data_case= 'test'
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
if seed in [9] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
|
||||
|
||||
|
||||
|
|
@ -180,6 +180,13 @@
|
|||
"`python3 reproduce_scripts/pacs_run.py --method hybrid --model resnet18`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The results would be stored in the `results/pacs/logs/` directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -252,7 +259,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The results would be stored in the `results/cxray/train_logs` directory"
|
||||
"The results would be stored in the `results/chestxray/train_logs` directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -280,7 +287,8 @@
|
|||
"\n",
|
||||
"Run the following command for data generation:\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"`python3 data/data_gen_domainbed.py`\n",
|
||||
"\n",
|
||||
"Run the following command for training models:\n",
|
||||
"\n",
|
||||
"`python3 reproduce_rmnist_domainbed.py`\n",
|
||||
|
@ -607,7 +615,9 @@
|
|||
"\n",
|
||||
"### Plotting Results\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_plot.py`"
|
||||
"`python3 reproduce_scripts/cxray_plot.py`\n",
|
||||
"\n",
|
||||
"The plots would be stored in the directory: `results/chestxray/plots/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -616,7 +626,33 @@
|
|||
"source": [
|
||||
"# Attribute Attack\n",
|
||||
"\n",
|
||||
"python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000"
|
||||
"### Preparing Data\n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000`\n",
|
||||
"\n",
|
||||
"### Training Models\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric train --data_aug 0`\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case train --data_aug 0`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case test --data_aug 0`\n",
|
||||
"\n",
|
||||
"### Evaluating AI Attack Accuracy\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric attribute_attack --data_case 0 --data_aug 0`\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy on the permuted test domain\n",
|
||||
"\n",
|
||||
"Generate data with permuted test domain\n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000 --cmnist_permute 1`\n",
|
||||
"\n",
|
||||
"Run the following command to obtain OOD accuracy on permuted test domain\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case test --data_aug 0`"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -45,12 +45,11 @@ fontsize_lgd= fontsize/1.2
|
|||
for idx in range(3):
|
||||
|
||||
marker_list = ['o', '^', '*']
|
||||
legend_list = ['RSNA', 'NIH', 'Chex']
|
||||
legend_count = 0
|
||||
for test_domain in ['kaggle', 'nih', 'chex']:
|
||||
|
||||
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test']
|
||||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack']
|
||||
|
||||
acc_train=[]
|
||||
acc_train_err=[]
|
||||
|
@ -110,11 +109,11 @@ for idx in range(3):
|
|||
ax[idx].set_xticklabels(x, rotation=25)
|
||||
|
||||
if idx == 0:
|
||||
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
|
||||
|
||||
if idx == 1:
|
||||
ax[idx].errorbar(x, loss, yerr=loss_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].errorbar(x, loss, yerr=loss_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].set_ylabel('MI Attack Accuracy', fontsize=fontsize)
|
||||
|
||||
# if idx == 2:
|
||||
|
@ -125,7 +124,7 @@ for idx in range(3):
|
|||
# ax[idx].legend(fontsize=fontsize_lgd)
|
||||
|
||||
if idx == 2:
|
||||
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].set_ylabel('Train-Test Accuracy Gap ', fontsize=fontsize)
|
||||
|
||||
|
||||
|
|
|
@ -39,9 +39,6 @@ methods= args.methods
|
|||
# test_diff, test_common
|
||||
test_case=['test_diff']
|
||||
|
||||
# List of methods to train/evaluate
|
||||
# methods=[]
|
||||
|
||||
if metric == 'train':
|
||||
if dataset in ['rot_mnist', 'rot_mnist_spur']:
|
||||
base_script= 'python train.py --dataset ' + str(dataset)
|
||||
|
@ -155,7 +152,7 @@ for method in methods:
|
|||
script= base_script + ' --method_name erm_match --penalty_ws 10.0 --match_case 1.0 --epochs 25 ' + ' > ' + case + '.txt'
|
||||
else:
|
||||
script= base_script + ' --method_name erm_match --penalty_ws 0.1 --match_case 1.0 --epochs 25 ' + ' > ' + case + '.txt'
|
||||
os.system(script)
|
||||
os.system(script)
|
||||
|
||||
elif method == 'matchdg':
|
||||
if metric == 'train':
|
||||
|
|
|
@ -90,8 +90,8 @@ for test_domain in [0.2, 0.9]:
|
|||
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
|
||||
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
|
||||
|
||||
gen_gap= np.array(train_acc) - np.array(acc)
|
||||
ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
|
||||
# gen_gap= np.array(train_acc) - np.array(acc)
|
||||
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
|
||||
|
||||
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
|
||||
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
|
||||
|
|
|
@ -287,7 +287,6 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
|
||||
elif args.dataset_name in ['rot_mnist', 'fashion_mnist', 'rot_mnist_spur']:
|
||||
if data_case == 'test' and args.mnist_case not in ['lenet']:
|
||||
#TODO: Infer this based on the total number of seed values for the mnist case
|
||||
# Actually by default the seeds 0, 1, 2 are for training and seed 9 is for test; mention that properly in comments
|
||||
mnist_subset= 9
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче