Attribute Attack on C-MNIST modifications

This commit is contained in:
divyat09 2021-10-01 12:00:55 +00:00
Родитель 29136f47f6
Коммит 142da80d36
7 изменённых файлов: 198 добавлений и 55 удалений

Просмотреть файл

@ -18,7 +18,7 @@ import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h):
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute):
# Get total number of labeled examples
mnist_labels = labels[indices]
@ -64,9 +64,11 @@ def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, doma
if rand_var[i]:
# Change colors per label for test domains relative to the train domains
if data_case == 'test':
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
# Choose this for test domain with permuted colors
# curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
if cmnist_permute:
# Choose this for test domain with permuted colors
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
else:
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
else:
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
else:
@ -107,6 +109,7 @@ parser.add_argument('--data_size', type=int, default=60000)
parser.add_argument('--subset_size', type=int, default=2000)
parser.add_argument('--img_w', type=int, default=224)
parser.add_argument('--img_h', type=int, default=224)
parser.add_argument('--cmnist_permute', type=int, default=0)
args = parser.parse_args()
@ -117,6 +120,7 @@ img_h= args.img_h
data_size= args.data_size
subset_size= args.subset_size
val_size= int(args.subset_size/5)
cmnist_permute= args.cmnist_permute
#Generate Dataset for Rotated / Fashion MNIST
#TODO: Manage OS Env from args
@ -184,7 +188,18 @@ for seed in seed_list:
res=np.random.choice(data_size, subset_size+val_size)
print('Seed: ', seed)
for domain in domains:
# The case of permuted test domain for colored rotated MNIST, only update test data
if dataset == 'rot_mnist_spur' and cmnist_permute:
#Test
data_case= 'test'
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if seed in [9] and domain in [0, 15, 30, 45, 60, 75, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
continue
#Train
data_case= 'train'
if not os.path.exists(data_dir + data_case + '/'):
@ -195,10 +210,10 @@ for seed in seed_list:
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
elif model in ['lenet']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
#Val
data_case= 'val'
@ -210,10 +225,10 @@ for seed in seed_list:
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
elif model in ['lenet']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
#Test
data_case= 'test'
@ -222,10 +237,38 @@ for seed in seed_list:
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if model == 'resnet18':
if seed in [0, 1, 2, 9] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
elif model in ['lenet', 'lenet_mdg']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
# Extra data sampling for carrying out the attribute attack on spurious rotated mnist
if dataset == 'rot_mnist_spur':
#Train
data_case= 'train'
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if seed in [0, 1, 2] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
#Val
data_case= 'val'
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[subset_size:]
if seed in [0, 1, 2] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)
#Test
data_case= 'test'
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if seed in [9] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)

Просмотреть файл

@ -180,6 +180,13 @@
"`python3 reproduce_scripts/pacs_run.py --method hybrid --model resnet18`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results would be stored in the `results/pacs/logs/` directory"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -252,7 +259,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The results would be stored in the `results/cxray/train_logs` directory"
"The results would be stored in the `results/chestxray/train_logs` directory"
]
},
{
@ -280,7 +287,8 @@
"\n",
"Run the following command for data generation:\n",
"\n",
" \n",
"`python3 data/data_gen_domainbed.py`\n",
"\n",
"Run the following command for training models:\n",
"\n",
"`python3 reproduce_rmnist_domainbed.py`\n",
@ -607,7 +615,9 @@
"\n",
"### Plotting Results\n",
"\n",
"`python3 reproduce_scripts/cxray_plot.py`"
"`python3 reproduce_scripts/cxray_plot.py`\n",
"\n",
"The plots would be stored in the directory: `results/chestxray/plots/`"
]
},
{
@ -616,7 +626,33 @@
"source": [
"# Attribute Attack\n",
"\n",
"python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000"
"### Preparing Data\n",
"\n",
"`python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000`\n",
"\n",
"### Training Models\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric train --data_aug 0`\n",
"\n",
"### Evaluating OOD Accuracy\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case train --data_aug 0`\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case test --data_aug 0`\n",
"\n",
"### Evaluating AI Attack Accuracy\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric attribute_attack --data_case 0 --data_aug 0`\n",
"\n",
"### Evaluating OOD Accuracy on the permuted test domain\n",
"\n",
"Generate data with permuted test domain\n",
"\n",
"`python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000 --cmnist_permute 1`\n",
"\n",
"Run the following command to obtain OOD accuracy on permuted test domain\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist_spur --train_case train_all --metric acc --data_case test --data_aug 0`"
]
},
{

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -45,12 +45,11 @@ fontsize_lgd= fontsize/1.2
for idx in range(3):
marker_list = ['o', '^', '*']
legend_list = ['RSNA', 'NIH', 'Chex']
legend_count = 0
for test_domain in ['kaggle', 'nih', 'chex']:
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test']
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack']
acc_train=[]
acc_train_err=[]
@ -110,11 +109,11 @@ for idx in range(3):
ax[idx].set_xticklabels(x, rotation=25)
if idx == 0:
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
if idx == 1:
ax[idx].errorbar(x, loss, yerr=loss_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].errorbar(x, loss, yerr=loss_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].set_ylabel('MI Attack Accuracy', fontsize=fontsize)
# if idx == 2:
@ -125,7 +124,7 @@ for idx in range(3):
# ax[idx].legend(fontsize=fontsize_lgd)
if idx == 2:
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label=test_domain, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label= legend_list[legend_count], marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].set_ylabel('Train-Test Accuracy Gap ', fontsize=fontsize)

Просмотреть файл

@ -39,9 +39,6 @@ methods= args.methods
# test_diff, test_common
test_case=['test_diff']
# List of methods to train/evaluate
# methods=[]
if metric == 'train':
if dataset in ['rot_mnist', 'rot_mnist_spur']:
base_script= 'python train.py --dataset ' + str(dataset)
@ -155,7 +152,7 @@ for method in methods:
script= base_script + ' --method_name erm_match --penalty_ws 10.0 --match_case 1.0 --epochs 25 ' + ' > ' + case + '.txt'
else:
script= base_script + ' --method_name erm_match --penalty_ws 0.1 --match_case 1.0 --epochs 25 ' + ' > ' + case + '.txt'
os.system(script)
os.system(script)
elif method == 'matchdg':
if metric == 'train':

Просмотреть файл

@ -90,8 +90,8 @@ for test_domain in [0.2, 0.9]:
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
gen_gap= np.array(train_acc) - np.array(acc)
ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
# gen_gap= np.array(train_acc) - np.array(acc)
# ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)

Просмотреть файл

@ -287,7 +287,6 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
elif args.dataset_name in ['rot_mnist', 'fashion_mnist', 'rot_mnist_spur']:
if data_case == 'test' and args.mnist_case not in ['lenet']:
#TODO: Infer this based on the total number of seed values for the mnist case
# Actually by default the seeds 0, 1, 2 are for training and seed 9 is for test; mention that properly in comments
mnist_subset= 9
else: