reproducibility details addition; updates to MatchDG based on slab dataset issues, noise in invariant mechanism
This commit is contained in:
Родитель
0b8bb0597b
Коммит
29136f47f6
|
@ -77,12 +77,6 @@ def get_noise_multiplier(
|
|||
class BaseAlgo():
|
||||
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
|
||||
|
||||
# from evaluation.base_eval import BaseEval
|
||||
# self.test_method= BaseEval(
|
||||
# args, train_dataset, val_dataset,
|
||||
# test_dataset, base_res_dir,
|
||||
# run, cuda
|
||||
# )
|
||||
|
||||
self.args= args
|
||||
self.train_dataset= train_dataset['data_loader']
|
||||
|
@ -111,7 +105,7 @@ class BaseAlgo():
|
|||
self.val_acc=[]
|
||||
self.train_acc=[]
|
||||
|
||||
# if self.args.method_name == 'dp_erm':
|
||||
# Differentially Private Noise
|
||||
if self.args.dp_noise:
|
||||
self.privacy_engine= self.get_dp_noise()
|
||||
|
||||
|
@ -297,28 +291,16 @@ class BaseAlgo():
|
|||
MAX_GRAD_NORM = 5.0
|
||||
DELTA = 1.0/(self.total_domains*self.domain_size)
|
||||
BATCH_SIZE = self.args.batch_size * self.total_domains
|
||||
VIRTUAL_BATCH_SIZE = 10*BATCH_SIZE
|
||||
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
|
||||
N_ACCUMULATION_STEPS = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
|
||||
SAMPLE_RATE = BATCH_SIZE /(self.total_domains*self.domain_size)
|
||||
DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
|
||||
|
||||
|
||||
|
||||
NOISE_MULTIPLIER = get_noise_multiplier(self.args.dp_epsilon, DELTA, SAMPLE_RATE, self.args.epochs, DEFAULT_ALPHAS)
|
||||
|
||||
print("Target Epsilon: ", self.args.dp_epsilon)
|
||||
print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
|
||||
|
||||
# sys.exit(-1)
|
||||
|
||||
from opacus import PrivacyEngine
|
||||
# privacy_engine = PrivacyEngine(
|
||||
# self.phi,
|
||||
# sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
|
||||
# alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
# noise_multiplier=NOISE_MULTIPLIER,
|
||||
# max_grad_norm=MAX_GRAD_NORM,
|
||||
# )
|
||||
privacy_engine = PrivacyEngine(
|
||||
self.phi,
|
||||
batch_size= BATCH_SIZE,
|
||||
|
@ -326,6 +308,11 @@ class BaseAlgo():
|
|||
noise_multiplier=NOISE_MULTIPLIER,
|
||||
max_grad_norm=MAX_GRAD_NORM,
|
||||
)
|
||||
|
||||
privacy_engine.attach(self.opt)
|
||||
|
||||
if self.args.dp_attach_opt:
|
||||
print('Standard DP Training with finite epsilon')
|
||||
privacy_engine.attach(self.opt)
|
||||
else:
|
||||
print('DP Training with infinite epsilon')
|
||||
|
||||
return privacy_engine
|
|
@ -110,15 +110,16 @@ class ErmMatch(BaseAlgo):
|
|||
|
||||
|
||||
loss_e.backward(retain_graph=False)
|
||||
|
||||
if batch_idx % 10 == 9:
|
||||
|
||||
if self.args.dp_noise and self.args.dp_attach_opt:
|
||||
if batch_idx % 10 == 9:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
else:
|
||||
self.opt.virtual_step()
|
||||
else:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
else:
|
||||
self.opt.virtual_step()
|
||||
|
||||
# self.opt.step()
|
||||
# self.opt.zero_grad()
|
||||
|
||||
#Gradient Norm Computation
|
||||
# batch_grad_norm=0.0
|
||||
|
@ -156,10 +157,6 @@ class ErmMatch(BaseAlgo):
|
|||
self.max_epoch= epoch
|
||||
self.save_model()
|
||||
|
||||
# Sanity check on the test accuracy
|
||||
# self.test_method.get_model()
|
||||
# self.test_method.get_metric_eval()
|
||||
# print( ' Sanity Check Test Accuracy: ', self.test_method.metric_score )
|
||||
print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
|
||||
|
||||
if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:
|
||||
|
|
|
@ -178,6 +178,9 @@ class MatchDG(BaseAlgo):
|
|||
# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
||||
|
||||
# If no instances of label y_c in the current batch then continue
|
||||
|
||||
print(pos_feat_match.shape[0], neg_feat_match.shape[0], y_c)
|
||||
|
||||
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
|
||||
continue
|
||||
|
||||
|
@ -229,6 +232,9 @@ class MatchDG(BaseAlgo):
|
|||
|
||||
loss_e += ( ( epoch- self.args.penalty_s )/(self.args.epochs -self.args.penalty_s) )*diff_hinge_loss
|
||||
|
||||
if not loss_e.requires_grad:
|
||||
continue
|
||||
|
||||
loss_e.backward(retain_graph=False)
|
||||
self.opt.step()
|
||||
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
#Common imports
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import copy
|
||||
import os
|
||||
|
||||
#Sklearn
|
||||
from scipy.stats import bernoulli
|
||||
|
||||
#Pillow
|
||||
from PIL import Image, ImageColor, ImageOps
|
||||
|
||||
#Pytorch
|
||||
import torch
|
||||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h):
|
||||
|
||||
# Get total number of labeled examples
|
||||
mnist_labels = labels[indices]
|
||||
mnist_imgs = imgs[indices]
|
||||
mnist_size = mnist_labels.shape[0]
|
||||
|
||||
to_pil= transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((img_w, img_h))
|
||||
])
|
||||
|
||||
to_augment= transforms.Compose([
|
||||
transforms.RandomResizedCrop(img_w, scale=(0.7,1.0)),
|
||||
transforms.RandomHorizontalFlip()
|
||||
])
|
||||
|
||||
to_tensor= transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
if dataset == 'rot_mnist_spur':
|
||||
color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']
|
||||
# Adding color with 70 percent probability
|
||||
rand_var= bernoulli.rvs(0.7, size=mnist_size)
|
||||
|
||||
# Run transforms
|
||||
if dataset == 'rot_mnist_spur':
|
||||
mnist_img_rot= torch.zeros((mnist_size, 3, img_w, img_h))
|
||||
mnist_img_rot_org= torch.zeros((mnist_size, 3, img_w, img_h))
|
||||
else:
|
||||
mnist_img_rot= torch.zeros((mnist_size, img_w, img_h))
|
||||
mnist_img_rot_org= torch.zeros((mnist_size, img_w, img_h))
|
||||
|
||||
mnist_idx=[]
|
||||
|
||||
for i in range(len(mnist_imgs)):
|
||||
|
||||
curr_image= to_pil(mnist_imgs[i])
|
||||
|
||||
#Color the image
|
||||
if dataset == 'rot_mnist_spur':
|
||||
if rand_var[i]:
|
||||
# Change colors per label for test domains relative to the train domains
|
||||
if data_case == 'test':
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
|
||||
# Choose this for test domain with permuted colors
|
||||
# curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
|
||||
else:
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
|
||||
else:
|
||||
curr_image = ImageOps.colorize(curr_image, black ="black", white ="white")
|
||||
|
||||
#Rotation
|
||||
if domain == '0':
|
||||
img_rotated= curr_image
|
||||
else:
|
||||
img_rotated= transforms.functional.rotate( curr_image, int(domain) )
|
||||
|
||||
mnist_img_rot_org[i]= to_tensor(img_rotated)
|
||||
#Augmentation
|
||||
mnist_img_rot[i]= to_tensor(to_augment(img_rotated))
|
||||
|
||||
if data_case == 'train' or data_case == 'val':
|
||||
torch.save(mnist_img_rot, save_dir + '_data.pt')
|
||||
|
||||
torch.save(mnist_img_rot_org, save_dir + '_org_data.pt')
|
||||
torch.save(mnist_labels, save_dir + '_label.pt')
|
||||
|
||||
if dataset == 'rot_mnist_spur':
|
||||
np.save(save_dir + '_spur.npy', rand_var)
|
||||
|
||||
print('Data Case: ', data_case, ' Source Domain: ', domain, ' Shape: ', mnist_img_rot.shape, mnist_img_rot_org.shape, mnist_labels.shape)
|
||||
|
||||
return
|
||||
|
||||
# Main Function
|
||||
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset', type=str, default='rot_mnist',
|
||||
help='Datasets: rot_mnist; fashion_mnist; rot_mnist_spur')
|
||||
parser.add_argument('--model', type=str, default='resnet18',
|
||||
help='Base Models: resnet18; lenet')
|
||||
parser.add_argument('--data_size', type=int, default=60000)
|
||||
parser.add_argument('--subset_size', type=int, default=2000)
|
||||
parser.add_argument('--img_w', type=int, default=224)
|
||||
parser.add_argument('--img_h', type=int, default=224)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset= args.dataset
|
||||
model= args.model
|
||||
img_w= args.img_w
|
||||
img_h= args.img_h
|
||||
data_size= args.data_size
|
||||
subset_size= args.subset_size
|
||||
val_size= int(args.subset_size/5)
|
||||
|
||||
#Generate Dataset for Rotated / Fashion MNIST
|
||||
#TODO: Manage OS Env from args
|
||||
os_env=0
|
||||
if os_env:
|
||||
base_dir= os.getenv('PT_DATA_DIR') + '/mnist/'
|
||||
else:
|
||||
base_dir= 'data/datasets/mnist/'
|
||||
|
||||
if not os.path.exists(base_dir):
|
||||
os.makedirs(base_dir)
|
||||
|
||||
|
||||
data_dir= base_dir + dataset + '_' + model + '/'
|
||||
if not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
|
||||
|
||||
if dataset =='rot_mnist' or dataset == 'rot_mnist_spur':
|
||||
data_obj_train= datasets.MNIST(base_dir,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.ToTensor()
|
||||
)
|
||||
|
||||
data_obj_test= datasets.MNIST(base_dir,
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.ToTensor()
|
||||
)
|
||||
mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
|
||||
mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))
|
||||
|
||||
elif dataset == 'fashion_mnist':
|
||||
data_obj_train= datasets.FashionMNIST(base_dir,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.ToTensor()
|
||||
)
|
||||
|
||||
data_obj_test= datasets.FashionMNIST(base_dir,
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.ToTensor()
|
||||
)
|
||||
mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
|
||||
mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))
|
||||
|
||||
|
||||
# For testing over different base objects; seed 9
|
||||
# Seed 9 only for test data, See 0:3 for train data
|
||||
seed_list= [0, 1, 2, 9]
|
||||
domains= [0, 15, 30, 45, 60, 75, 90]
|
||||
|
||||
for seed in seed_list:
|
||||
|
||||
# Random Seed
|
||||
random.seed(seed*10)
|
||||
np.random.seed(seed*10)
|
||||
torch.manual_seed(seed*10)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed*10)
|
||||
|
||||
# Indices
|
||||
res=np.random.choice(data_size, subset_size+val_size)
|
||||
print('Seed: ', seed)
|
||||
for domain in domains:
|
||||
|
||||
#Train
|
||||
data_case= 'train'
|
||||
if not os.path.exists(data_dir + data_case + '/'):
|
||||
os.makedirs(data_dir + data_case + '/')
|
||||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model in ['lenet']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
#Val
|
||||
data_case= 'val'
|
||||
if not os.path.exists(data_dir + data_case + '/'):
|
||||
os.makedirs(data_dir + data_case + '/')
|
||||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[subset_size:]
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model in ['lenet']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
|
||||
#Test
|
||||
data_case= 'test'
|
||||
if not os.path.exists(data_dir + data_case + '/'):
|
||||
os.makedirs(data_dir + data_case + '/')
|
||||
|
||||
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
|
||||
indices= res[:subset_size]
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [0, 1, 2, 9] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model in ['lenet', 'lenet_mdg']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
|
@ -38,8 +38,8 @@ if not os.path.exists(base_dir):
|
|||
|
||||
num_samples= 10000
|
||||
total_slabs= 7
|
||||
slab_noise_list= [0.0, 0.10, 0.20]
|
||||
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90, 1.0]
|
||||
slab_noise_list= [0.0, 0.10]
|
||||
spur_corr_list= [0.0, 0.10, 0.20, 0.90]
|
||||
|
||||
for seed in range(10):
|
||||
np.random.seed(seed*10)
|
||||
|
|
|
@ -187,7 +187,7 @@
|
|||
"Fracture doesn't exist. Adding nans instead.\n",
|
||||
"Lung Opacity doesn't exist. Adding nans instead.\n",
|
||||
"Enlarged Cardiomediastinum doesn't exist. Adding nans instead.\n",
|
||||
"{'Support Devices', 'Pleural Other'} will be dropped\n",
|
||||
"{'Pleural Other', 'Support Devices'} will be dropped\n",
|
||||
"Infiltration doesn't exist. Adding nans instead.\n",
|
||||
"Emphysema doesn't exist. Adding nans instead.\n",
|
||||
"Fibrosis doesn't exist. Adding nans instead.\n",
|
||||
|
@ -565,9 +565,29 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(112120,)\n",
|
||||
"Error:\n",
|
||||
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
||||
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
||||
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
||||
"(191010,)\n",
|
||||
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
||||
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
||||
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
||||
"(26684,)\n",
|
||||
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
||||
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
||||
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"base_dir=root_dir + '/data/datasets/chestxray/'\n",
|
||||
" \n",
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -37,7 +37,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -51,7 +51,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -77,7 +77,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -86,7 +86,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -107,7 +107,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
|
@ -0,0 +1,658 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reproducing results\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following code reproduces results for Slab dataset, Rotated MNIST and Fashion-MNIST dataset, and PACS dataset corresponding to Tables 1, 2, 3, 4, 5, 6 in the main paper.\n",
|
||||
"\n",
|
||||
"## Note regarding hardware requirements\n",
|
||||
"\n",
|
||||
"The code requires a GPU device, also the batch size for MatchDG Phase 1 training might need to be adjusted according to the memory limits of the GPU device. In case of CUDA of out of memory issues, try with a smaller batch size."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installing Libraries\n",
|
||||
"\n",
|
||||
"List of all the required packages are mentioned in the file 'requirements.txt'\n",
|
||||
"\n",
|
||||
"You may install them as follows: `pip install -r requirements.txt`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 1: Slab Dataset\n",
|
||||
"\n",
|
||||
"## Prepare Slab Dataset\n",
|
||||
"\n",
|
||||
"Run the following command:\n",
|
||||
"\n",
|
||||
"`python3 data_gen_syn.py`\n",
|
||||
"\n",
|
||||
"## Table 1\n",
|
||||
"\n",
|
||||
"Run the following command:\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/reproduce_slab.py train`\n",
|
||||
"\n",
|
||||
"The results would be stored in the `results/slab/logs/` directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 2, 3: RotMNIST & Fashion MNIST\n",
|
||||
"\n",
|
||||
"## Prepare Data for Rot MNIST & Fashion MNIST\n",
|
||||
"\n",
|
||||
"Run the following command"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000` \n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset fashion_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 10000`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Table 2\n",
|
||||
"\n",
|
||||
"Rotated MNIST dataset with training domains set to [15, 30, 45, 60, 75] and the test domains set to [0, 90]. \n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train`\n",
|
||||
"\n",
|
||||
"The results would be present in the `results/rot_mnist/train_logs/` directory\n",
|
||||
"\n",
|
||||
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"To obtain results for the different set of training domains in the paper, change the input to the parameter `--train_case` to `train_abl_3` for training with domains [30, 45, 60], and `train_abl_2` for training with domains [30, 45] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Table 3\n",
|
||||
"\n",
|
||||
"Run the following commands:\n",
|
||||
"\n",
|
||||
"`python test.py --dataset rot_mnist --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --test_metric match_score`\n",
|
||||
"\n",
|
||||
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score`\n",
|
||||
"\n",
|
||||
"For MDG Perf, run the folllowing command to first train the model:\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_mdg_ctr_run.py --dataset rot_mnist --perf_init 1`\n",
|
||||
"\n",
|
||||
"Then run the following commands to evalute match function metrics:\n",
|
||||
"\n",
|
||||
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --pos_metric cos --test_metric match_score`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 4, 5: PACS\n",
|
||||
"\n",
|
||||
"## Prepare Data for PACS\n",
|
||||
"\n",
|
||||
"Download the PACS dataset (https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk) and place it in the directory '/data/datasets/pacs/' "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Table 4\n",
|
||||
"\n",
|
||||
"* RandMatch: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/pacs_run.py --method rand --model resnet18`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* MatchDG:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For contrastive phase, we train with the resnet50 model despite the model architecture in Phase 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/pacs_run.py --method matchdg_ctr --model resnet50`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/pacs_run.py --method matchdg_erm --model resnet18`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* MDGHybrid:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"No need to train the contrastive phase again if already done while training MatchDG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/pacs_run.py --method hybrid --model resnet18`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Table 5\n",
|
||||
"\n",
|
||||
"Repeat the above commands and replace the argument to flag --model with resnet50 with resnet18 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 6: Chest X-Ray"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare Data for Chest X-Ray\n",
|
||||
"\n",
|
||||
" -Follow the steps in the Preprocess.ipynb notebook to donwload and process the Chest X-Ray datasets\n",
|
||||
" -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* NIH: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric train`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* Chexpert: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric train`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* RSNA: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric train`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The results would be stored in the `results/cxray/train_logs` directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 11\n",
|
||||
"\n",
|
||||
"Run the following command for data generation:\n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset rot_mnist --model lenet --img_h 32 --img_w 32 --subset_size 1000`\n",
|
||||
"\n",
|
||||
"Run the following commands for training models:\n",
|
||||
"\n",
|
||||
"`python3 reproduce_rmnist_lenet.py`\n",
|
||||
"\n",
|
||||
"The results will be stored in the directory: `results/rmnist_lenet/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 12\n",
|
||||
"\n",
|
||||
"Run the following command for data generation:\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"Run the following command for training models:\n",
|
||||
"\n",
|
||||
"`python3 reproduce_rmnist_domainbed.py`\n",
|
||||
"\n",
|
||||
"The results will be stored in the directory: `results/rmnist_domain_bed/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 13\n",
|
||||
"\n",
|
||||
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train --methods approx_25 approx_50 approx_75`\n",
|
||||
"\n",
|
||||
"The results will be stored in the directory: `results/rot_mnist/train_logs/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 14\n",
|
||||
"\n",
|
||||
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric match_score --data_case train --methods rand perf matchdg`\n",
|
||||
"\n",
|
||||
"The results would be stored in the directory: `results/rot_mnist/match_score_train/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 15\n",
|
||||
"\n",
|
||||
"Generate data again for the Fashion MNIST 2k sample case by running the following command:\n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset fashion_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000`\n",
|
||||
"\n",
|
||||
"Then follow the same commands as mentioned in the Table 2 section"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 16\n",
|
||||
"\n",
|
||||
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"MatchDG Iterative corresponds to the default MatchDG algorithm, with the same results as in Table 3\n",
|
||||
"\n",
|
||||
"For MatchDG Non Iterative, run the folllowing command to first the model\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_mdg_ctr_run.py --dataset rot_mnist --iterative 0`\n",
|
||||
"\n",
|
||||
"Then run the following command to evaluate match function metrics:\n",
|
||||
"\n",
|
||||
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 0 --pos_metric cos --test_metric match_score`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table 18\n",
|
||||
"\n",
|
||||
"Repeat the commands mentioned for PACS ResNet-18 (Table 4) and replace the argument to flag --model with alexnet with resnet18 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Privacy & OOD: ICLR\n",
|
||||
"\n",
|
||||
"# Slab Dataset\n",
|
||||
"\n",
|
||||
"### Preparing Data\n",
|
||||
"\n",
|
||||
"`python3 data_gen_syn.py`\n",
|
||||
"\n",
|
||||
"### Training Models\n",
|
||||
"\n",
|
||||
"Run the following command to train models with no noise in the prediction mechanism based on slab features\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.0`\n",
|
||||
"\n",
|
||||
"Run the following command to train models with noise in the prediction mechanism based on slab features\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.10`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy, Randomised-AUC, & Privacy Loss Attack\n",
|
||||
"\n",
|
||||
"Run the following command for the case of no noise in the prediction mechanism based on slab features\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.0 --case test`\n",
|
||||
"\n",
|
||||
"Run the following command for the case of noise in the prediction mechanism based on slab features\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.10 --case test`\n",
|
||||
"\n",
|
||||
"### Plotting Results\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-plot.py 0.0`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/slab-plot.py 0.1`\n",
|
||||
"\n",
|
||||
"The plots would be stored in the directory: `results/slab/plots/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Rotated & Fashion MNIST \n",
|
||||
"\n",
|
||||
"For convenience we provide the commands for the Rotated MNIST dataset. To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"### Preparing Data\n",
|
||||
"\n",
|
||||
"`python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000` \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Training Models\n",
|
||||
"\n",
|
||||
"Training Domains: [15, 30, 45, 60, 75]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train --data_aug 0`\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45, 60]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric train --data_aug 0`\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric train --data_aug 0`\n",
|
||||
"\n",
|
||||
"The results would be present in the results/rot_mnist/train_logs/ directory\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy\n",
|
||||
"\n",
|
||||
"Training Domains: [15, 30, 45, 60, 75]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric acc --data_case train --data_aug 0 `\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric acc --data_case test --data_aug 0 `\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45, 60]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric acc --data_case train --data_aug 0`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric acc --data_case test --data_aug 0`\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric acc --data_case train --data_aug 0`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric acc --data_case test --data_aug 0`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating MI Attack Accuracy\n",
|
||||
"\n",
|
||||
"Training Domains: [15, 30, 45, 60, 75]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric privacy_loss_attack --data_aug 0 `\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45, 60]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric privacy_loss_attack --data_aug 0`\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric privacy_loss_attack --data_aug 0`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating Mean Rank\n",
|
||||
"\n",
|
||||
"Training Domains: [15, 30, 45, 60, 75]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric match_score --data_case test --data_aug 0 `\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45, 60]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric match_score --data_case test --data_aug 0`\n",
|
||||
"\n",
|
||||
"Training Domains: [30, 45]\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric match_score --data_case test --data_aug 0`\n",
|
||||
"\n",
|
||||
"### Plotting Results\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_plot.py rot_mnist`\n",
|
||||
"\n",
|
||||
"The plots would be stored in the directory: `results/rot_mnist/plots/`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Differentially Private Noise\n",
|
||||
"\n",
|
||||
"For convenience we provide the commands for the Rotated MNIST dataset. To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
|
||||
"\n",
|
||||
"The command below produces results for the case of epsilon 1.0; repeat the same command by changing the input to the paramter `--dp_epsilon` to the other values from the list: [1, 2, 5, 10]. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Training Models\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf`\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf --metric acc --data_case test `\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating MI Attack Accuracy\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf --metric privacy_loss_attack`\n",
|
||||
"\n",
|
||||
"### Infinite Epsilon Case\n",
|
||||
"\n",
|
||||
"Append this extra parameter ` --dp_attach_opt 0 ` to all the commands above. This does not attach the differential privacy engine with the optimizer. Also, change the epsilon value to the parameter ` --dp_epsilon ` to any random value as it does not matter since the privacy engine is not attached to the optimizer\n",
|
||||
"\n",
|
||||
"### Plotting Results\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ChestXRay Dataset\n",
|
||||
"\n",
|
||||
"### Prepare Data for Chest X-Ray\n",
|
||||
"\n",
|
||||
" -Follow the steps in the Preprocess.ipynb notebook to donwload and process the Chest X-Ray datasets\n",
|
||||
" -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations\n",
|
||||
"\n",
|
||||
"### Training Models\n",
|
||||
"\n",
|
||||
"Test Domain NIH\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric train`\n",
|
||||
"\n",
|
||||
"Test Domain Chexpert\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric train`\n",
|
||||
"\n",
|
||||
"Test Domain RSNA\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric train`\n",
|
||||
"\n",
|
||||
"### Evaluating OOD Accuracy\n",
|
||||
"\n",
|
||||
"Test Domain NIH\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric acc --data_case train`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric acc --data_case test`\n",
|
||||
"\n",
|
||||
"Test Domain Chexpert\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric acc --data_case train`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric acc --data_case test`\n",
|
||||
"\n",
|
||||
"Test Domain RSNA\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric acc --data_case train`\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric acc --data_case test`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Evaluating MI Attack Accuracy\n",
|
||||
"\n",
|
||||
"Test Domain NIH\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric privacy_loss_attack`\n",
|
||||
"\n",
|
||||
"Test Domain Chexpert\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric privacy_loss_attack`\n",
|
||||
"\n",
|
||||
"Test Domain RSNA\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric privacy_loss_attack`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Plotting Results\n",
|
||||
"\n",
|
||||
"`python3 reproduce_scripts/cxray_plot.py`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Attribute Attack\n",
|
||||
"\n",
|
||||
"python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -351,7 +351,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.7"
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -16,7 +16,11 @@ metric= args.metric
|
|||
test_domain= args.test_domain
|
||||
data_case= args.data_case
|
||||
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
|
||||
if metric == 'train':
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
|
||||
else:
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_erm', 'hybrid']
|
||||
|
||||
domains= ['nih', 'chex', 'kaggle']
|
||||
dataset= 'chestxray'
|
||||
|
||||
|
|
|
@ -1,17 +1,23 @@
|
|||
import os
|
||||
#rot_mnist, fashion_mnist
|
||||
dataset=sys.argv[1]
|
||||
import argparse
|
||||
|
||||
#TODO
|
||||
#1) Add another argparse arugment for deciding between perfect and non-iterative
|
||||
#2) Script for evaluating the match function metrics
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset', type=str, default='rot_mnist',
|
||||
help='Datasets: rot_mnist; fashion_mnist')
|
||||
parser.add_argument('--iterative', type=int, default=1,
|
||||
help='Iterative updates to positive matches')
|
||||
parser.add_argument('--perf_init', type=int, default=0,
|
||||
help='Positive matches Initialization to perfect matches')
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset= args.dataset
|
||||
iterative= args.iterative
|
||||
perf_init= args.perf_init
|
||||
|
||||
base_script= 'python train.py --epochs 50 --batch_size 64 --dataset ' + str(dataset)
|
||||
|
||||
#Perf MDG
|
||||
script= base_script + ' --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --pos_metric cos '
|
||||
script= base_script + ' --method_name matchdg_ctr --pos_metric cos --match_case ' + str(perf_init) + ' --match_flag ' + str(iterative)
|
||||
os.system(script)
|
||||
|
||||
#Non Iterative MDG
|
||||
script= base_script + ' --method_name matchdg_ctr --match_case 0.0 --match_flag 0 --pos_metric cos '
|
||||
os.system(script)
|
||||
|
||||
#2) Script for evaluating the match function metrics
|
||||
|
|
|
@ -61,7 +61,7 @@ dataset=sys.argv[1]
|
|||
test_case=['test_diff']
|
||||
|
||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||
fig, ax = plt.subplots(1, 3, figsize=(33, 8))
|
||||
fig, ax = plt.subplots(1, 4, figsize=(33, 8))
|
||||
fontsize=35
|
||||
fontsize_lgd= fontsize/1.2
|
||||
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
|
||||
|
@ -70,7 +70,7 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
|
|||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
|
||||
|
||||
for idx in range(3):
|
||||
for idx in range(4):
|
||||
|
||||
marker_list = ['o', '^', '*']
|
||||
legend_count = 0
|
||||
|
@ -154,7 +154,8 @@ for idx in range(3):
|
|||
|
||||
if idx == 0:
|
||||
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label=legend_label, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
|
||||
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
|
||||
|
||||
|
||||
if idx == 1:
|
||||
ax[idx].errorbar(x, loss, yerr=loss_err, label=legend_label, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
|
@ -165,10 +166,8 @@ for idx in range(3):
|
|||
ax[idx].set_ylabel('Mean Rank', fontsize=fontsize)
|
||||
|
||||
if idx == 3:
|
||||
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label=legend_label, fmt='o--')
|
||||
# ax.set_xlabel('Models', fontsize=fontsize)
|
||||
ax[idx].set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
|
||||
# ax[idx].legend(fontsize=fontsize_lgd)
|
||||
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--')
|
||||
ax[idx].set_ylabel('Generalization Gap', fontsize=fontsize)
|
||||
|
||||
legend_count+= 1
|
||||
|
||||
|
|
|
@ -14,10 +14,17 @@ parser.add_argument('--data_case', type=str, default='test',
|
|||
help='train: Evaluate the acc/match_score metrics on the train dataset; test: Evaluate the acc/match_score metrics on the test dataset')
|
||||
parser.add_argument('--data_aug', type=int, default=1,
|
||||
help='0: No data augmentation for fashion mnist; 1: Data augmentation for fashion mnist')
|
||||
|
||||
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg'],
|
||||
help='List of methods: erm, irm, csd, rand, approx_25, approx_50, approx_75, perf, matchdg')
|
||||
|
||||
|
||||
parser.add_argument('--dp_noise', type=int, default=0,
|
||||
help='0: No DP noise; 1: Add DP noise')
|
||||
parser.add_argument('--dp_epsilon', type=float, default=1.0,
|
||||
parser.add_argument('--dp_epsilon', type=float, default=100.0,
|
||||
help='Epsilon value for Differential Privacy')
|
||||
parser.add_argument('--dp_attach_opt', type=int, default=1,
|
||||
help='0: Infinite Epsilon; 1: Finite Epsilion')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -27,14 +34,13 @@ train_case= args.train_case
|
|||
metric= args.metric
|
||||
data_case= args.data_case
|
||||
data_aug= args.data_aug
|
||||
methods= args.methods
|
||||
|
||||
# test_diff, test_common
|
||||
test_case=['test_diff']
|
||||
|
||||
# List of methods to train/evaluate
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'approx_25', 'approx_50', 'approx_75', 'perf', 'matchdg']
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg']
|
||||
methods=['erm', 'perf']
|
||||
# methods=[]
|
||||
|
||||
if metric == 'train':
|
||||
if dataset in ['rot_mnist', 'rot_mnist_spur']:
|
||||
|
@ -46,7 +52,6 @@ if metric == 'train':
|
|||
elif metric == 'mia':
|
||||
if dataset in ['rot_mnist', 'rot_mnist_spur']:
|
||||
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 ' + ' --dataset ' + str(dataset)
|
||||
elif dataset in ['fashion_mnist']:
|
||||
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 ' + ' --dataset ' + str(dataset)
|
||||
|
||||
res_dir= 'results/'+str(dataset)+'/privacy_clf/'
|
||||
|
@ -107,7 +112,7 @@ if test_case == 'test_common':
|
|||
|
||||
#Differential Privacy
|
||||
if args.dp_noise:
|
||||
base_script += ' --dp_noise ' + str(args.dp_noise) + ' --dp_epsilon ' + str(args.dp_epsilon) + ' '
|
||||
base_script += ' --dp_noise ' + str(args.dp_noise) + ' --dp_epsilon ' + str(args.dp_epsilon) + ' --dp_attach_opt ' + str(args.dp_attach_opt) + ' '
|
||||
res_dir= res_dir[:-1] + '_epsilon_' + str(args.dp_epsilon) + '/'
|
||||
|
||||
if not os.path.exists(res_dir):
|
||||
|
|
|
@ -1,33 +1,40 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# method: rand, matchdg, perf
|
||||
method= sys.argv[1]
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--methods', nargs='+', type=str, default=['rand', 'matchdg_ctr', 'matchdg_erm'],
|
||||
help='List of methods')
|
||||
|
||||
args = parser.parse_args()
|
||||
methods= args.methods
|
||||
domains= [0, 15, 30, 45, 60, 75]
|
||||
|
||||
if method == 'rand':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name erm_match --perfect_match 0 --match_case 0.0 --penalty_ws 1.0 --epochs 25 --batch_size 128 --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
|
||||
|
||||
elif method == 'matchdg_ctr':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_ctr --perfect_match 0 --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name domain_bed_mnist --img_h 28 --img_w 28 --match_func_aug_case 1 '
|
||||
for method in methods:
|
||||
|
||||
elif method == 'matchdg_erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_erm --perfect_match 0 --match_case -1 --penalty_ws 1.0 --epochs 25 --batch_size 128 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name domain_bed_mnist --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
|
||||
if method == 'rand':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name erm_match --perfect_match 0 --match_case 0.0 --penalty_ws 1.0 --epochs 25 --batch_size 128 --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
|
||||
|
||||
for test_domain in domains:
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
train_domains+= str(d) + ' '
|
||||
print(train_domains)
|
||||
elif method == 'matchdg_ctr':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_ctr --perfect_match 0 --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name domain_bed_mnist --img_h 28 --img_w 28 --match_func_aug_case 1 '
|
||||
|
||||
res_dir= 'results/rmnist_domain_bed/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
|
||||
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
|
||||
print('Method: ', method, ' Test Domain: ', test_domain)
|
||||
os.system(script)
|
||||
elif method == 'matchdg_erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_erm --perfect_match 0 --match_case -1 --penalty_ws 1.0 --epochs 25 --batch_size 128 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name domain_bed_mnist --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
|
||||
|
||||
for test_domain in domains:
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
train_domains+= str(d) + ' '
|
||||
print(train_domains)
|
||||
|
||||
res_dir= 'results/rmnist_domain_bed/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
|
||||
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
|
||||
print('Method: ', method, ' Test Domain: ', test_domain)
|
||||
os.system(script)
|
|
@ -1,40 +1,46 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# method: rand, matchdg, perf
|
||||
method= sys.argv[1]
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'rand', 'matchdg_ctr', 'matchdg_erm', 'perf'],
|
||||
help='List of methods')
|
||||
|
||||
args = parser.parse_args()
|
||||
methods= args.methods
|
||||
domains= [0, 15, 30, 45, 60, 75]
|
||||
|
||||
for method in methods:
|
||||
|
||||
if method == 'perf':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
|
||||
|
||||
if method == 'perf':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
|
||||
elif method == 'erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
|
||||
|
||||
elif method == 'erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
|
||||
|
||||
elif method == 'rand':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
|
||||
elif method == 'rand':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
|
||||
|
||||
elif method == 'matchdg_ctr':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name lenet --img_h 32 --img_w 32 --match_func_aug_case 1 '
|
||||
|
||||
elif method == 'matchdg_erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --epochs 100 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name lenet --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
|
||||
elif method == 'matchdg_ctr':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name lenet --img_h 32 --img_w 32 --match_func_aug_case 1 '
|
||||
|
||||
for test_domain in domains:
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
train_domains+= str(d) + ' '
|
||||
print(train_domains)
|
||||
elif method == 'matchdg_erm':
|
||||
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --epochs 100 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name lenet --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
|
||||
|
||||
res_dir= 'results/rmnist_lenet/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
|
||||
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
|
||||
print('Method: ', method, ' Test Domain: ', test_domain)
|
||||
os.system(script)
|
||||
for test_domain in domains:
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
train_domains+= str(d) + ' '
|
||||
print(train_domains)
|
||||
|
||||
res_dir= 'results/rmnist_lenet/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
|
||||
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
|
||||
print('Method: ', method, ' Test Domain: ', test_domain)
|
||||
os.system(script)
|
|
@ -2,13 +2,9 @@ import os
|
|||
import sys
|
||||
|
||||
'''
|
||||
argv1: Allowed Values (train, test)
|
||||
argv1: Allowed Values (train, evaluate)
|
||||
'''
|
||||
|
||||
## TODO
|
||||
# Choose a better name for train, test as in the test stage we do not really evalute the test accuracy
|
||||
# Maybe make the train set accuracy also as part of the logs in the train.py phase (Check once if it changes any results from the main paper)
|
||||
|
||||
case= sys.argv[1]
|
||||
methods=['erm', 'mmd', 'coral', 'dann', 'c-mmd', 'c-coral', 'c-dann', 'rand', 'perf']
|
||||
total_seed= 10
|
||||
|
@ -16,7 +12,7 @@ total_seed= 10
|
|||
if case == 'train':
|
||||
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
|
||||
|
||||
elif case == 'test':
|
||||
elif case == 'evaluate':
|
||||
base_script= 'python test.py --test_metric per_domain_acc --acc_data_case train --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
slab_noise= float(sys.argv[1])
|
||||
base_dir= 'slab_res/slab_noise_' + str(slab_noise) + '/'
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg', 'perf', 'mask_linear']
|
||||
base_dir= 'results/slab/slab_noise_' + str(slab_noise) + '/'
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
|
||||
|
||||
# x=['ERM', 'IRM', 'CSD', 'Rand', 'MatchDG', 'Perf', 'Mask']
|
||||
x=['ERM', 'IRM', 'CSD', 'Rand', 'Perf', 'Mask']
|
||||
x=['ERM', 'IRM', 'CSD', 'Rand', 'Perf', 'Oracle']
|
||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||
fig, ax = plt.subplots(1, 2, figsize=(24, 8))
|
||||
fontsize=40
|
||||
|
@ -16,10 +17,12 @@ fontsize_lgd= fontsize/1.2
|
|||
marker_list = ['o', '^', '*']
|
||||
count= 0
|
||||
|
||||
for test_domain in [0.3, 0.9]:
|
||||
for test_domain in [0.2, 0.9]:
|
||||
|
||||
acc=[]
|
||||
acc_err=[]
|
||||
train_acc =[]
|
||||
train_acc_err =[]
|
||||
auc=[]
|
||||
auc_err=[]
|
||||
s_auc=[]
|
||||
|
@ -40,6 +43,11 @@ for test_domain in [0.3, 0.9]:
|
|||
sc_auc.append( float(data[-1].replace('\n', '').split(' ')[-2] ))
|
||||
sc_auc_err.append( float( data[-1].replace('\n', '').split(' ')[-1] ) )
|
||||
|
||||
f= open(base_dir + method + '-train-auc-' + str(test_domain) + '.txt')
|
||||
data= f.readlines()
|
||||
train_acc.append( float( data[-4].replace('\n', '').split(' ')[-2] ))
|
||||
train_acc_err.append( float( data[-4].replace('\n', '').split(' ')[-1] ))
|
||||
|
||||
|
||||
#Privacy Metrics
|
||||
mia=[]
|
||||
|
@ -81,6 +89,10 @@ for test_domain in [0.3, 0.9]:
|
|||
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
|
||||
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
|
||||
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
|
||||
|
||||
gen_gap= np.array(train_acc) - np.array(acc)
|
||||
ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
|
||||
|
||||
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
|
||||
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
|
||||
|
||||
|
@ -88,7 +100,11 @@ for test_domain in [0.3, 0.9]:
|
|||
|
||||
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
||||
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('results/privacy_slab_' + str(slab_noise) + '.pdf', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=600)
|
||||
|
||||
save_dir= 'results/slab/plots/'
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig( save_dir + 'privacy_slab_' + str(slab_noise) + '.pdf', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=600)
|
||||
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
'''
|
||||
argv1: Case (train, test)
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--case', type=str, default='train',
|
||||
help='Case: train; test')
|
||||
parser.add_argument('--slab_noise', type=float, default=0.0,
|
||||
help='Probability of corrupting slab features')
|
||||
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'irm', 'csd', 'rand', 'matchdg', 'perf', 'mask_linear'],
|
||||
help='List of methods')
|
||||
|
||||
argv2: Noise in the slab feature (Flip probability)
|
||||
'''
|
||||
args = parser.parse_args()
|
||||
|
||||
case= sys.argv[1]
|
||||
slab_noise= float(sys.argv[2])
|
||||
case= args.case
|
||||
slab_noise= args.slab_noise
|
||||
methods= args.methods
|
||||
|
||||
metrics= ['train-auc', 'auc', 'loss']
|
||||
total_seed= 3
|
||||
|
||||
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
|
||||
# methods=['matchdg']
|
||||
# metrics= ['auc', 'mi', 'entropy', 'loss']
|
||||
metrics= ['auc', 'loss']
|
||||
|
||||
if case == 'train':
|
||||
|
||||
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 0.90 --slab_data_dim 2 '
|
||||
|
@ -40,6 +44,7 @@ if case == 'train':
|
|||
#CTR Phase
|
||||
script = base_script + ' --method_name matchdg_ctr --batch_size 512 --match_case 0.0 --match_flag 1 --match_interrupt 5 --pos_metric cos '
|
||||
os.system(script)
|
||||
|
||||
#ERM Phase
|
||||
script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
|
||||
|
||||
|
@ -51,7 +56,9 @@ elif case == 'test':
|
|||
|
||||
if metric == 'auc':
|
||||
base_script= 'python test_slab.py --train_domains 0.0 0.10 '
|
||||
|
||||
|
||||
if metric == 'train-auc':
|
||||
base_script= 'python test_slab.py --train_domains 0.0 0.10 --acc_data_case train '
|
||||
elif metric == 'mi':
|
||||
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 400 --dataset slab --model_name slab --out_classes 2 --train_domains 0.0 0.10 '
|
||||
elif metric == 'entropy':
|
||||
|
@ -62,7 +69,7 @@ elif case == 'test':
|
|||
base_script= 'python test.py --test_metric attribute_attack --mia_logit 1 --attribute_domain 0 --dataset slab --model_name slab --out_classes 2 --train_domains 0.0 0.10 '
|
||||
|
||||
base_script= base_script + ' --slab_noise ' + str(slab_noise) + ' --n_runs ' + str(total_seed)
|
||||
res_dir= 'slab_res/slab_noise_' + str(slab_noise) + '/'
|
||||
res_dir= 'results/slab/slab_noise_' + str(slab_noise) + '/'
|
||||
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
@ -85,36 +92,6 @@ elif case == 'test':
|
|||
elif method == 'matchdg':
|
||||
upd_script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
|
||||
|
||||
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
|
||||
for test_domain in [0.3, 0.9]:
|
||||
for test_domain in [0.2, 0.9]:
|
||||
script= upd_script + ' --test_domains ' + str(test_domain) + ' > ' + res_dir + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
|
||||
os.system(script)
|
||||
|
||||
|
||||
# elif case == 'train_plot':
|
||||
|
||||
# for metric in metrics:
|
||||
|
||||
# if metric == 'auc':
|
||||
# base_script= 'python logit_plot_slab.py --train_domains 0.0 0.10 '
|
||||
|
||||
# for method in methods:
|
||||
|
||||
# if method == 'erm':
|
||||
# upd_script= base_script + ' --method_name perf_match --penalty_ws 0.0 '
|
||||
# elif method == 'irm':
|
||||
# upd_script= base_script + ' --method_name irm_slab --penalty_irm 10.0 --penalty_s 2 '
|
||||
# elif method == 'csd':
|
||||
# upd_script= base_script + ' --method_name csd_slab --penalty_ws 0.0 --rep_dim 100 '
|
||||
# elif method == 'rand':
|
||||
# upd_script= base_script + ' --method_name rand_match --penalty_ws 1.0 '
|
||||
# elif method == 'perf':
|
||||
# upd_script= base_script + ' --method_name perf_match --penalty_ws 1.0 '
|
||||
# elif method == 'mask_linear':
|
||||
# upd_script= base_script + ' --method_name mask_linear --penalty_ws 0.0 '
|
||||
|
||||
|
||||
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
|
||||
# script= upd_script + ' --test_domains ' + str(test_domain) + ' > slab_temp/' + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
|
||||
# os.system(script)
|
||||
|
||||
|
|
3
test.py
3
test.py
|
@ -119,6 +119,9 @@ parser.add_argument('--dp_noise', type=int, default=0,
|
|||
help='0: No DP noise; 1: Add DP noise')
|
||||
parser.add_argument('--dp_epsilon', type=float, default=1.0,
|
||||
help='Epsilon value for Differential Privacy')
|
||||
parser.add_argument('--dp_attach_opt', type=int, default=1,
|
||||
help='0: Infinite Epsilon; 1: Finite Epsilion')
|
||||
|
||||
|
||||
|
||||
#MMD, DANN
|
||||
|
|
|
@ -276,7 +276,10 @@ for run in range(args.n_runs):
|
|||
model= test_method.phi
|
||||
|
||||
test_method.get_metric_eval()
|
||||
std_acc= test_method.metric_score['test accuracy']
|
||||
if args.acc_data_case == 'train':
|
||||
std_acc= test_method.metric_score['train accuracy']
|
||||
elif args.acc_data_case == 'test':
|
||||
std_acc= test_method.metric_score['test accuracy']
|
||||
print('Test Accuracy: ', std_acc)
|
||||
|
||||
spur_prob= float(test_domains[0])
|
||||
|
|
3
train.py
3
train.py
|
@ -119,6 +119,9 @@ parser.add_argument('--dp_noise', type=int, default=0,
|
|||
help='0: No DP noise; 1: Add DP noise')
|
||||
parser.add_argument('--dp_epsilon', type=float, default=1.0,
|
||||
help='Epsilon value for Differential Privacy')
|
||||
# Special case when you want to check results with the dp setting for the infinite epsilon case
|
||||
parser.add_argument('--dp_attach_opt', type=int, default=1,
|
||||
help='0: Infinite Epsilon; 1: Finite Epsilion')
|
||||
|
||||
|
||||
#MMD, DANN
|
||||
|
|
|
@ -286,18 +286,13 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
data_obj= ChestXRay(args, domains, '/chestxray_spur/', data_case=data_case, match_func=match_func)
|
||||
|
||||
elif args.dataset_name in ['rot_mnist', 'fashion_mnist', 'rot_mnist_spur']:
|
||||
if data_case == 'test' and args.mnist_case not in ['lenet', 'lenet_mdg']:
|
||||
if data_case == 'test' and args.mnist_case not in ['lenet']:
|
||||
#TODO: Infer this based on the total number of seed values for the mnist case
|
||||
# Actually by default the seeds 0, 1, 2 are for training and seed 9 is for test; mention that properly in comments
|
||||
mnist_subset= 9
|
||||
else:
|
||||
mnist_subset= run
|
||||
|
||||
#TODO: Only Temporary, in order to see if it changes results on MNIST
|
||||
# if eval_case:
|
||||
# if args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
|
||||
# mnist_subset=run
|
||||
|
||||
|
||||
print('MNIST Subset: ', mnist_subset)
|
||||
data_obj= MnistRotated(args, domains, mnist_subset, '/mnist/', data_case=data_case, match_func=match_func)
|
||||
|
||||
|
|
|
@ -229,7 +229,9 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
if perfect_match == 1:
|
||||
## Find all instances among the curr_domain with same object as obj_base[idx]
|
||||
## .nonzero() converts True matche to match indexes; [0, 0] takes into the first match of same base object in the curr domain
|
||||
perfect_match_rank.append( (obj_curr[sort_idx] == obj_base[idx]).nonzero()[0,0].item() )
|
||||
|
||||
if obj_base[idx] in obj_curr[sort_idx]:
|
||||
perfect_match_rank.append( (obj_curr[sort_idx] == obj_base[idx]).nonzero()[0,0].item() )
|
||||
# print('Time Taken in CTR Loop: ', time.time()-start_time)
|
||||
|
||||
elif inferred_match == 0 and perfect_match == 1:
|
||||
|
@ -242,6 +244,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
|
||||
# Select random matches with perm_prob probability
|
||||
if rand_vars[idx]:
|
||||
|
||||
rand_indices = np.arange(ordered_curr_indices.size()[0])
|
||||
np.random.shuffle(rand_indices)
|
||||
curr_indices= ordered_curr_indices[rand_indices][:total_matches_per_point]
|
||||
|
|
Загрузка…
Ссылка в новой задаче