reproducibility details addition; updates to MatchDG based on slab dataset issues, noise in invariant mechanism

This commit is contained in:
divyat09 2021-09-29 17:06:27 +00:00
Родитель 0b8bb0597b
Коммит 29136f47f6
26 изменённых файлов: 1527 добавлений и 351 удалений

Просмотреть файл

@ -77,12 +77,6 @@ def get_noise_multiplier(
class BaseAlgo():
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
# from evaluation.base_eval import BaseEval
# self.test_method= BaseEval(
# args, train_dataset, val_dataset,
# test_dataset, base_res_dir,
# run, cuda
# )
self.args= args
self.train_dataset= train_dataset['data_loader']
@ -111,7 +105,7 @@ class BaseAlgo():
self.val_acc=[]
self.train_acc=[]
# if self.args.method_name == 'dp_erm':
# Differentially Private Noise
if self.args.dp_noise:
self.privacy_engine= self.get_dp_noise()
@ -297,28 +291,16 @@ class BaseAlgo():
MAX_GRAD_NORM = 5.0
DELTA = 1.0/(self.total_domains*self.domain_size)
BATCH_SIZE = self.args.batch_size * self.total_domains
VIRTUAL_BATCH_SIZE = 10*BATCH_SIZE
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
N_ACCUMULATION_STEPS = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
SAMPLE_RATE = BATCH_SIZE /(self.total_domains*self.domain_size)
DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
NOISE_MULTIPLIER = get_noise_multiplier(self.args.dp_epsilon, DELTA, SAMPLE_RATE, self.args.epochs, DEFAULT_ALPHAS)
print("Target Epsilon: ", self.args.dp_epsilon)
print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
# sys.exit(-1)
from opacus import PrivacyEngine
# privacy_engine = PrivacyEngine(
# self.phi,
# sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
# alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
# noise_multiplier=NOISE_MULTIPLIER,
# max_grad_norm=MAX_GRAD_NORM,
# )
privacy_engine = PrivacyEngine(
self.phi,
batch_size= BATCH_SIZE,
@ -326,6 +308,11 @@ class BaseAlgo():
noise_multiplier=NOISE_MULTIPLIER,
max_grad_norm=MAX_GRAD_NORM,
)
privacy_engine.attach(self.opt)
if self.args.dp_attach_opt:
print('Standard DP Training with finite epsilon')
privacy_engine.attach(self.opt)
else:
print('DP Training with infinite epsilon')
return privacy_engine

Просмотреть файл

@ -110,15 +110,16 @@ class ErmMatch(BaseAlgo):
loss_e.backward(retain_graph=False)
if batch_idx % 10 == 9:
if self.args.dp_noise and self.args.dp_attach_opt:
if batch_idx % 10 == 9:
self.opt.step()
self.opt.zero_grad()
else:
self.opt.virtual_step()
else:
self.opt.step()
self.opt.zero_grad()
else:
self.opt.virtual_step()
# self.opt.step()
# self.opt.zero_grad()
#Gradient Norm Computation
# batch_grad_norm=0.0
@ -156,10 +157,6 @@ class ErmMatch(BaseAlgo):
self.max_epoch= epoch
self.save_model()
# Sanity check on the test accuracy
# self.test_method.get_model()
# self.test_method.get_metric_eval()
# print( ' Sanity Check Test Accuracy: ', self.test_method.metric_score )
print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:

Просмотреть файл

@ -178,6 +178,9 @@ class MatchDG(BaseAlgo):
# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
# If no instances of label y_c in the current batch then continue
print(pos_feat_match.shape[0], neg_feat_match.shape[0], y_c)
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
continue
@ -229,6 +232,9 @@ class MatchDG(BaseAlgo):
loss_e += ( ( epoch- self.args.penalty_s )/(self.args.epochs -self.args.penalty_s) )*diff_hinge_loss
if not loss_e.requires_grad:
continue
loss_e.backward(retain_graph=False)
self.opt.step()

231
data/data_gen_mnist.py Normal file
Просмотреть файл

@ -0,0 +1,231 @@
#Common imports
import numpy as np
import sys
import os
import argparse
import random
import copy
import os
#Sklearn
from scipy.stats import bernoulli
#Pillow
from PIL import Image, ImageColor, ImageOps
#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h):
# Get total number of labeled examples
mnist_labels = labels[indices]
mnist_imgs = imgs[indices]
mnist_size = mnist_labels.shape[0]
to_pil= transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((img_w, img_h))
])
to_augment= transforms.Compose([
transforms.RandomResizedCrop(img_w, scale=(0.7,1.0)),
transforms.RandomHorizontalFlip()
])
to_tensor= transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if dataset == 'rot_mnist_spur':
color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']
# Adding color with 70 percent probability
rand_var= bernoulli.rvs(0.7, size=mnist_size)
# Run transforms
if dataset == 'rot_mnist_spur':
mnist_img_rot= torch.zeros((mnist_size, 3, img_w, img_h))
mnist_img_rot_org= torch.zeros((mnist_size, 3, img_w, img_h))
else:
mnist_img_rot= torch.zeros((mnist_size, img_w, img_h))
mnist_img_rot_org= torch.zeros((mnist_size, img_w, img_h))
mnist_idx=[]
for i in range(len(mnist_imgs)):
curr_image= to_pil(mnist_imgs[i])
#Color the image
if dataset == 'rot_mnist_spur':
if rand_var[i]:
# Change colors per label for test domains relative to the train domains
if data_case == 'test':
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
# Choose this for test domain with permuted colors
# curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[(mnist_labels[i].item()+1)%10] )
else:
curr_image = ImageOps.colorize(curr_image, black ="black", white =color_list[mnist_labels[i].item()])
else:
curr_image = ImageOps.colorize(curr_image, black ="black", white ="white")
#Rotation
if domain == '0':
img_rotated= curr_image
else:
img_rotated= transforms.functional.rotate( curr_image, int(domain) )
mnist_img_rot_org[i]= to_tensor(img_rotated)
#Augmentation
mnist_img_rot[i]= to_tensor(to_augment(img_rotated))
if data_case == 'train' or data_case == 'val':
torch.save(mnist_img_rot, save_dir + '_data.pt')
torch.save(mnist_img_rot_org, save_dir + '_org_data.pt')
torch.save(mnist_labels, save_dir + '_label.pt')
if dataset == 'rot_mnist_spur':
np.save(save_dir + '_spur.npy', rand_var)
print('Data Case: ', data_case, ' Source Domain: ', domain, ' Shape: ', mnist_img_rot.shape, mnist_img_rot_org.shape, mnist_labels.shape)
return
# Main Function
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='rot_mnist',
help='Datasets: rot_mnist; fashion_mnist; rot_mnist_spur')
parser.add_argument('--model', type=str, default='resnet18',
help='Base Models: resnet18; lenet')
parser.add_argument('--data_size', type=int, default=60000)
parser.add_argument('--subset_size', type=int, default=2000)
parser.add_argument('--img_w', type=int, default=224)
parser.add_argument('--img_h', type=int, default=224)
args = parser.parse_args()
dataset= args.dataset
model= args.model
img_w= args.img_w
img_h= args.img_h
data_size= args.data_size
subset_size= args.subset_size
val_size= int(args.subset_size/5)
#Generate Dataset for Rotated / Fashion MNIST
#TODO: Manage OS Env from args
os_env=0
if os_env:
base_dir= os.getenv('PT_DATA_DIR') + '/mnist/'
else:
base_dir= 'data/datasets/mnist/'
if not os.path.exists(base_dir):
os.makedirs(base_dir)
data_dir= base_dir + dataset + '_' + model + '/'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if dataset =='rot_mnist' or dataset == 'rot_mnist_spur':
data_obj_train= datasets.MNIST(base_dir,
train=True,
download=True,
transform=transforms.ToTensor()
)
data_obj_test= datasets.MNIST(base_dir,
train=False,
download=True,
transform=transforms.ToTensor()
)
mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))
elif dataset == 'fashion_mnist':
data_obj_train= datasets.FashionMNIST(base_dir,
train=True,
download=True,
transform=transforms.ToTensor()
)
data_obj_test= datasets.FashionMNIST(base_dir,
train=False,
download=True,
transform=transforms.ToTensor()
)
mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))
# For testing over different base objects; seed 9
# Seed 9 only for test data, See 0:3 for train data
seed_list= [0, 1, 2, 9]
domains= [0, 15, 30, 45, 60, 75, 90]
for seed in seed_list:
# Random Seed
random.seed(seed*10)
np.random.seed(seed*10)
torch.manual_seed(seed*10)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed*10)
# Indices
res=np.random.choice(data_size, subset_size+val_size)
print('Seed: ', seed)
for domain in domains:
#Train
data_case= 'train'
if not os.path.exists(data_dir + data_case + '/'):
os.makedirs(data_dir + data_case + '/')
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model in ['lenet']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
#Val
data_case= 'val'
if not os.path.exists(data_dir + data_case + '/'):
os.makedirs(data_dir + data_case + '/')
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[subset_size:]
if model == 'resnet18':
if seed in [0, 1, 2] and domain in [15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model in ['lenet']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
#Test
data_case= 'test'
if not os.path.exists(data_dir + data_case + '/'):
os.makedirs(data_dir + data_case + '/')
save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
indices= res[:subset_size]
if model == 'resnet18':
if seed in [0, 1, 2, 9] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model in ['lenet', 'lenet_mdg']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)

Просмотреть файл

@ -38,8 +38,8 @@ if not os.path.exists(base_dir):
num_samples= 10000
total_slabs= 7
slab_noise_list= [0.0, 0.10, 0.20]
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90, 1.0]
slab_noise_list= [0.0, 0.10]
spur_corr_list= [0.0, 0.10, 0.20, 0.90]
for seed in range(10):
np.random.seed(seed*10)

Просмотреть файл

@ -187,7 +187,7 @@
"Fracture doesn't exist. Adding nans instead.\n",
"Lung Opacity doesn't exist. Adding nans instead.\n",
"Enlarged Cardiomediastinum doesn't exist. Adding nans instead.\n",
"{'Support Devices', 'Pleural Other'} will be dropped\n",
"{'Pleural Other', 'Support Devices'} will be dropped\n",
"Infiltration doesn't exist. Adding nans instead.\n",
"Emphysema doesn't exist. Adding nans instead.\n",
"Fibrosis doesn't exist. Adding nans instead.\n",
@ -565,9 +565,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(112120,)\n",
"Error:\n",
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
"(191010,)\n",
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
"(26684,)\n",
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n"
]
}
],
"source": [
"base_dir=root_dir + '/data/datasets/chestxray/'\n",
" \n",

Просмотреть файл

@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -77,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -86,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -107,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [

Просмотреть файл

@ -0,0 +1,658 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reproducing results\n",
"\n",
"\n",
"The following code reproduces results for Slab dataset, Rotated MNIST and Fashion-MNIST dataset, and PACS dataset corresponding to Tables 1, 2, 3, 4, 5, 6 in the main paper.\n",
"\n",
"## Note regarding hardware requirements\n",
"\n",
"The code requires a GPU device, also the batch size for MatchDG Phase 1 training might need to be adjusted according to the memory limits of the GPU device. In case of CUDA of out of memory issues, try with a smaller batch size."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installing Libraries\n",
"\n",
"List of all the required packages are mentioned in the file 'requirements.txt'\n",
"\n",
"You may install them as follows: `pip install -r requirements.txt`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 1: Slab Dataset\n",
"\n",
"## Prepare Slab Dataset\n",
"\n",
"Run the following command:\n",
"\n",
"`python3 data_gen_syn.py`\n",
"\n",
"## Table 1\n",
"\n",
"Run the following command:\n",
"\n",
"`python3 reproduce_scripts/reproduce_slab.py train`\n",
"\n",
"The results would be stored in the `results/slab/logs/` directory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 2, 3: RotMNIST & Fashion MNIST\n",
"\n",
"## Prepare Data for Rot MNIST & Fashion MNIST\n",
"\n",
"Run the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000` \n",
"\n",
"`python data/data_gen_mnist.py --dataset fashion_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 10000`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table 2\n",
"\n",
"Rotated MNIST dataset with training domains set to [15, 30, 45, 60, 75] and the test domains set to [0, 90]. \n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train`\n",
"\n",
"The results would be present in the `results/rot_mnist/train_logs/` directory\n",
"\n",
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"To obtain results for the different set of training domains in the paper, change the input to the parameter `--train_case` to `train_abl_3` for training with domains [30, 45, 60], and `train_abl_2` for training with domains [30, 45] "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table 3\n",
"\n",
"Run the following commands:\n",
"\n",
"`python test.py --dataset rot_mnist --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --test_metric match_score`\n",
"\n",
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score`\n",
"\n",
"For MDG Perf, run the folllowing command to first train the model:\n",
"\n",
"`python3 reproduce_scripts/mnist_mdg_ctr_run.py --dataset rot_mnist --perf_init 1`\n",
"\n",
"Then run the following commands to evalute match function metrics:\n",
"\n",
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --pos_metric cos --test_metric match_score`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 4, 5: PACS\n",
"\n",
"## Prepare Data for PACS\n",
"\n",
"Download the PACS dataset (https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk) and place it in the directory '/data/datasets/pacs/' "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table 4\n",
"\n",
"* RandMatch: "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/pacs_run.py --method rand --model resnet18`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* MatchDG:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For contrastive phase, we train with the resnet50 model despite the model architecture in Phase 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/pacs_run.py --method matchdg_ctr --model resnet50`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/pacs_run.py --method matchdg_erm --model resnet18`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* MDGHybrid:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"No need to train the contrastive phase again if already done while training MatchDG"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/pacs_run.py --method hybrid --model resnet18`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table 5\n",
"\n",
"Repeat the above commands and replace the argument to flag --model with resnet50 with resnet18 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 6: Chest X-Ray"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Data for Chest X-Ray\n",
"\n",
" -Follow the steps in the Preprocess.ipynb notebook to donwload and process the Chest X-Ray datasets\n",
" -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* NIH: "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric train`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* Chexpert: "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric train`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* RSNA: "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric train`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results would be stored in the `results/cxray/train_logs` directory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 11\n",
"\n",
"Run the following command for data generation:\n",
"\n",
"`python data/data_gen_mnist.py --dataset rot_mnist --model lenet --img_h 32 --img_w 32 --subset_size 1000`\n",
"\n",
"Run the following commands for training models:\n",
"\n",
"`python3 reproduce_rmnist_lenet.py`\n",
"\n",
"The results will be stored in the directory: `results/rmnist_lenet/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 12\n",
"\n",
"Run the following command for data generation:\n",
"\n",
" \n",
"Run the following command for training models:\n",
"\n",
"`python3 reproduce_rmnist_domainbed.py`\n",
"\n",
"The results will be stored in the directory: `results/rmnist_domain_bed/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 13\n",
"\n",
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train --methods approx_25 approx_50 approx_75`\n",
"\n",
"The results will be stored in the directory: `results/rot_mnist/train_logs/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 14\n",
"\n",
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric match_score --data_case train --methods rand perf matchdg`\n",
"\n",
"The results would be stored in the directory: `results/rot_mnist/match_score_train/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 15\n",
"\n",
"Generate data again for the Fashion MNIST 2k sample case by running the following command:\n",
"\n",
"`python data/data_gen_mnist.py --dataset fashion_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000`\n",
"\n",
"Then follow the same commands as mentioned in the Table 2 section"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 16\n",
"\n",
"To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"MatchDG Iterative corresponds to the default MatchDG algorithm, with the same results as in Table 3\n",
"\n",
"For MatchDG Non Iterative, run the folllowing command to first the model\n",
"\n",
"`python3 reproduce_scripts/mnist_mdg_ctr_run.py --dataset rot_mnist --iterative 0`\n",
"\n",
"Then run the following command to evaluate match function metrics:\n",
"\n",
"`python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 0 --pos_metric cos --test_metric match_score`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table 18\n",
"\n",
"Repeat the commands mentioned for PACS ResNet-18 (Table 4) and replace the argument to flag --model with alexnet with resnet18 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Privacy & OOD: ICLR\n",
"\n",
"# Slab Dataset\n",
"\n",
"### Preparing Data\n",
"\n",
"`python3 data_gen_syn.py`\n",
"\n",
"### Training Models\n",
"\n",
"Run the following command to train models with no noise in the prediction mechanism based on slab features\n",
"\n",
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.0`\n",
"\n",
"Run the following command to train models with noise in the prediction mechanism based on slab features\n",
"\n",
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.10`\n",
"\n",
"\n",
"### Evaluating OOD Accuracy, Randomised-AUC, & Privacy Loss Attack\n",
"\n",
"Run the following command for the case of no noise in the prediction mechanism based on slab features\n",
"\n",
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.0 --case test`\n",
"\n",
"Run the following command for the case of noise in the prediction mechanism based on slab features\n",
"\n",
"`python3 reproduce_scripts/slab-run.py --slab_noise 0.10 --case test`\n",
"\n",
"### Plotting Results\n",
"\n",
"`python3 reproduce_scripts/slab-plot.py 0.0`\n",
"\n",
"`python3 reproduce_scripts/slab-plot.py 0.1`\n",
"\n",
"The plots would be stored in the directory: `results/slab/plots/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rotated & Fashion MNIST \n",
"\n",
"For convenience we provide the commands for the Rotated MNIST dataset. To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"### Preparing Data\n",
"\n",
"`python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000` \n",
"\n",
"\n",
"### Training Models\n",
"\n",
"Training Domains: [15, 30, 45, 60, 75]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric train --data_aug 0`\n",
"\n",
"Training Domains: [30, 45, 60]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric train --data_aug 0`\n",
"\n",
"Training Domains: [30, 45]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric train --data_aug 0`\n",
"\n",
"The results would be present in the results/rot_mnist/train_logs/ directory\n",
"\n",
"\n",
"### Evaluating OOD Accuracy\n",
"\n",
"Training Domains: [15, 30, 45, 60, 75]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric acc --data_case train --data_aug 0 `\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric acc --data_case test --data_aug 0 `\n",
"\n",
"Training Domains: [30, 45, 60]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric acc --data_case train --data_aug 0`\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric acc --data_case test --data_aug 0`\n",
"\n",
"Training Domains: [30, 45]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric acc --data_case train --data_aug 0`\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric acc --data_case test --data_aug 0`\n",
"\n",
"\n",
"\n",
"### Evaluating MI Attack Accuracy\n",
"\n",
"Training Domains: [15, 30, 45, 60, 75]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric privacy_loss_attack --data_aug 0 `\n",
"\n",
"Training Domains: [30, 45, 60]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric privacy_loss_attack --data_aug 0`\n",
"\n",
"Training Domains: [30, 45]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric privacy_loss_attack --data_aug 0`\n",
"\n",
"\n",
"### Evaluating Mean Rank\n",
"\n",
"Training Domains: [15, 30, 45, 60, 75]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_all --metric match_score --data_case test --data_aug 0 `\n",
"\n",
"Training Domains: [30, 45, 60]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_3 --metric match_score --data_case test --data_aug 0`\n",
"\n",
"Training Domains: [30, 45]\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --train_case train_abl_2 --metric match_score --data_case test --data_aug 0`\n",
"\n",
"### Plotting Results\n",
"\n",
"`python3 reproduce_scripts/mnist_plot.py rot_mnist`\n",
"\n",
"The plots would be stored in the directory: `results/rot_mnist/plots/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Differentially Private Noise\n",
"\n",
"For convenience we provide the commands for the Rotated MNIST dataset. To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.\n",
"\n",
"The command below produces results for the case of epsilon 1.0; repeat the same command by changing the input to the paramter `--dp_epsilon` to the other values from the list: [1, 2, 5, 10]. \n",
"\n",
"\n",
"### Training Models\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf`\n",
"\n",
"### Evaluating OOD Accuracy\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf --metric acc --data_case test `\n",
"\n",
"\n",
"### Evaluating MI Attack Accuracy\n",
"\n",
"`python3 reproduce_scripts/mnist_run.py --dataset rot_mnist --dp_noise 1 --dp_epsilon 1.0 --data_aug 0 --methods erm perf --metric privacy_loss_attack`\n",
"\n",
"### Infinite Epsilon Case\n",
"\n",
"Append this extra parameter ` --dp_attach_opt 0 ` to all the commands above. This does not attach the differential privacy engine with the optimizer. Also, change the epsilon value to the parameter ` --dp_epsilon ` to any random value as it does not matter since the privacy engine is not attached to the optimizer\n",
"\n",
"### Plotting Results\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChestXRay Dataset\n",
"\n",
"### Prepare Data for Chest X-Ray\n",
"\n",
" -Follow the steps in the Preprocess.ipynb notebook to donwload and process the Chest X-Ray datasets\n",
" -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations\n",
"\n",
"### Training Models\n",
"\n",
"Test Domain NIH\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric train`\n",
"\n",
"Test Domain Chexpert\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric train`\n",
"\n",
"Test Domain RSNA\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric train`\n",
"\n",
"### Evaluating OOD Accuracy\n",
"\n",
"Test Domain NIH\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric acc --data_case train`\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric acc --data_case test`\n",
"\n",
"Test Domain Chexpert\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric acc --data_case train`\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric acc --data_case test`\n",
"\n",
"Test Domain RSNA\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric acc --data_case train`\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric acc --data_case test`\n",
"\n",
"\n",
"### Evaluating MI Attack Accuracy\n",
"\n",
"Test Domain NIH\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain nih --metric privacy_loss_attack`\n",
"\n",
"Test Domain Chexpert\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain chex --metric privacy_loss_attack`\n",
"\n",
"Test Domain RSNA\n",
"\n",
"`python3 reproduce_scripts/cxray_run.py --test_domain kaggle --metric privacy_loss_attack`\n",
"\n",
"\n",
"### Plotting Results\n",
"\n",
"`python3 reproduce_scripts/cxray_plot.py`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Attribute Attack\n",
"\n",
"python data/data_gen_mnist.py --dataset rot_mnist_spur --model resnet18 --img_h 224 --img_w 224 --subset_size 2000"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -351,7 +351,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
"version": "3.7.9"
}
},
"nbformat": 4,

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -16,7 +16,11 @@ metric= args.metric
test_domain= args.test_domain
data_case= args.data_case
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
if metric == 'train':
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
else:
methods=['erm', 'irm', 'csd', 'rand', 'matchdg_erm', 'hybrid']
domains= ['nih', 'chex', 'kaggle']
dataset= 'chestxray'

Просмотреть файл

@ -1,17 +1,23 @@
import os
#rot_mnist, fashion_mnist
dataset=sys.argv[1]
import argparse
#TODO
#1) Add another argparse arugment for deciding between perfect and non-iterative
#2) Script for evaluating the match function metrics
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='rot_mnist',
help='Datasets: rot_mnist; fashion_mnist')
parser.add_argument('--iterative', type=int, default=1,
help='Iterative updates to positive matches')
parser.add_argument('--perf_init', type=int, default=0,
help='Positive matches Initialization to perfect matches')
args = parser.parse_args()
dataset= args.dataset
iterative= args.iterative
perf_init= args.perf_init
base_script= 'python train.py --epochs 50 --batch_size 64 --dataset ' + str(dataset)
#Perf MDG
script= base_script + ' --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --pos_metric cos '
script= base_script + ' --method_name matchdg_ctr --pos_metric cos --match_case ' + str(perf_init) + ' --match_flag ' + str(iterative)
os.system(script)
#Non Iterative MDG
script= base_script + ' --method_name matchdg_ctr --match_case 0.0 --match_flag 0 --pos_metric cos '
os.system(script)
#2) Script for evaluating the match function metrics

Просмотреть файл

@ -61,7 +61,7 @@ dataset=sys.argv[1]
test_case=['test_diff']
matplotlib.rcParams.update({'errorbar.capsize': 2})
fig, ax = plt.subplots(1, 3, figsize=(33, 8))
fig, ax = plt.subplots(1, 4, figsize=(33, 8))
fontsize=35
fontsize_lgd= fontsize/1.2
x=['ERM', 'Rand', 'MatchDG', 'CSD', 'IRM', 'Perf']
@ -70,7 +70,7 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
metrics= ['acc:train', 'acc:test', 'privacy_loss_attack', 'match_score:test']
for idx in range(3):
for idx in range(4):
marker_list = ['o', '^', '*']
legend_count = 0
@ -154,7 +154,8 @@ for idx in range(3):
if idx == 0:
ax[idx].errorbar(x, acc_test, yerr=acc_test_err, label=legend_label, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
ax[idx].set_ylabel('OOD Accuracy', fontsize=fontsize)
if idx == 1:
ax[idx].errorbar(x, loss, yerr=loss_err, label=legend_label, marker= marker_list[legend_count], markersize= fontsize_lgd, linewidth=4, fmt='o--')
@ -165,10 +166,8 @@ for idx in range(3):
ax[idx].set_ylabel('Mean Rank', fontsize=fontsize)
if idx == 3:
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, label=legend_label, fmt='o--')
# ax.set_xlabel('Models', fontsize=fontsize)
ax[idx].set_ylabel('Train-Test Accuracy Gap of ML Model', fontsize=fontsize)
# ax[idx].legend(fontsize=fontsize_lgd)
ax[idx].errorbar(x, np.array(acc_train) - np.array(acc_test), yerr=acc_train_err, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--')
ax[idx].set_ylabel('Generalization Gap', fontsize=fontsize)
legend_count+= 1

Просмотреть файл

@ -14,10 +14,17 @@ parser.add_argument('--data_case', type=str, default='test',
help='train: Evaluate the acc/match_score metrics on the train dataset; test: Evaluate the acc/match_score metrics on the test dataset')
parser.add_argument('--data_aug', type=int, default=1,
help='0: No data augmentation for fashion mnist; 1: Data augmentation for fashion mnist')
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg'],
help='List of methods: erm, irm, csd, rand, approx_25, approx_50, approx_75, perf, matchdg')
parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0,
parser.add_argument('--dp_epsilon', type=float, default=100.0,
help='Epsilon value for Differential Privacy')
parser.add_argument('--dp_attach_opt', type=int, default=1,
help='0: Infinite Epsilon; 1: Finite Epsilion')
args = parser.parse_args()
@ -27,14 +34,13 @@ train_case= args.train_case
metric= args.metric
data_case= args.data_case
data_aug= args.data_aug
methods= args.methods
# test_diff, test_common
test_case=['test_diff']
# List of methods to train/evaluate
# methods=['erm', 'irm', 'csd', 'rand', 'approx_25', 'approx_50', 'approx_75', 'perf', 'matchdg']
# methods=['erm', 'irm', 'csd', 'rand', 'perf', 'matchdg']
methods=['erm', 'perf']
# methods=[]
if metric == 'train':
if dataset in ['rot_mnist', 'rot_mnist_spur']:
@ -46,7 +52,6 @@ if metric == 'train':
elif metric == 'mia':
if dataset in ['rot_mnist', 'rot_mnist_spur']:
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 ' + ' --dataset ' + str(dataset)
elif dataset in ['fashion_mnist']:
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 ' + ' --dataset ' + str(dataset)
res_dir= 'results/'+str(dataset)+'/privacy_clf/'
@ -107,7 +112,7 @@ if test_case == 'test_common':
#Differential Privacy
if args.dp_noise:
base_script += ' --dp_noise ' + str(args.dp_noise) + ' --dp_epsilon ' + str(args.dp_epsilon) + ' '
base_script += ' --dp_noise ' + str(args.dp_noise) + ' --dp_epsilon ' + str(args.dp_epsilon) + ' --dp_attach_opt ' + str(args.dp_attach_opt) + ' '
res_dir= res_dir[:-1] + '_epsilon_' + str(args.dp_epsilon) + '/'
if not os.path.exists(res_dir):

Просмотреть файл

@ -1,33 +1,40 @@
import os
import sys
import argparse
# method: rand, matchdg, perf
method= sys.argv[1]
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--methods', nargs='+', type=str, default=['rand', 'matchdg_ctr', 'matchdg_erm'],
help='List of methods')
args = parser.parse_args()
methods= args.methods
domains= [0, 15, 30, 45, 60, 75]
if method == 'rand':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name erm_match --perfect_match 0 --match_case 0.0 --penalty_ws 1.0 --epochs 25 --batch_size 128 --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
elif method == 'matchdg_ctr':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_ctr --perfect_match 0 --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name domain_bed_mnist --img_h 28 --img_w 28 --match_func_aug_case 1 '
for method in methods:
elif method == 'matchdg_erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_erm --perfect_match 0 --match_case -1 --penalty_ws 1.0 --epochs 25 --batch_size 128 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name domain_bed_mnist --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
if method == 'rand':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name erm_match --perfect_match 0 --match_case 0.0 --penalty_ws 1.0 --epochs 25 --batch_size 128 --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
for test_domain in domains:
train_domains=''
for d in domains:
if d != test_domain:
train_domains+= str(d) + ' '
print(train_domains)
elif method == 'matchdg_ctr':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_ctr --perfect_match 0 --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name domain_bed_mnist --img_h 28 --img_w 28 --match_func_aug_case 1 '
res_dir= 'results/rmnist_domain_bed/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
print('Method: ', method, ' Test Domain: ', test_domain)
os.system(script)
elif method == 'matchdg_erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case domain_bed --method_name matchdg_erm --perfect_match 0 --match_case -1 --penalty_ws 1.0 --epochs 25 --batch_size 128 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name domain_bed_mnist --model_name domain_bed_mnist --img_h 28 --img_w 28 --total_matches_per_point 1000 '
for test_domain in domains:
train_domains=''
for d in domains:
if d != test_domain:
train_domains+= str(d) + ' '
print(train_domains)
res_dir= 'results/rmnist_domain_bed/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
print('Method: ', method, ' Test Domain: ', test_domain)
os.system(script)

Просмотреть файл

@ -1,40 +1,46 @@
import os
import sys
import argparse
# method: rand, matchdg, perf
method= sys.argv[1]
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'rand', 'matchdg_ctr', 'matchdg_erm', 'perf'],
help='List of methods')
args = parser.parse_args()
methods= args.methods
domains= [0, 15, 30, 45, 60, 75]
for method in methods:
if method == 'perf':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
if method == 'perf':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
elif method == 'erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
elif method == 'erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 '
elif method == 'rand':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
elif method == 'rand':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name erm_match --match_case 0.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
elif method == 'matchdg_ctr':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name lenet --img_h 32 --img_w 32 --match_func_aug_case 1 '
elif method == 'matchdg_erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --epochs 100 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name lenet --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
elif method == 'matchdg_ctr':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 512 --pos_metric cos --model_name lenet --img_h 32 --img_w 32 --match_func_aug_case 1 '
for test_domain in domains:
train_domains=''
for d in domains:
if d != test_domain:
train_domains+= str(d) + ' '
print(train_domains)
elif method == 'matchdg_erm':
base_script= 'python train.py --dataset rot_mnist --mnist_case lenet --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --epochs 100 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name lenet --model_name lenet --img_h 32 --img_w 32 --total_matches_per_point 100 '
res_dir= 'results/rmnist_lenet/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
print('Method: ', method, ' Test Domain: ', test_domain)
os.system(script)
for test_domain in domains:
train_domains=''
for d in domains:
if d != test_domain:
train_domains+= str(d) + ' '
print(train_domains)
res_dir= 'results/rmnist_lenet/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
print('Method: ', method, ' Test Domain: ', test_domain)
os.system(script)

Просмотреть файл

@ -2,13 +2,9 @@ import os
import sys
'''
argv1: Allowed Values (train, test)
argv1: Allowed Values (train, evaluate)
'''
## TODO
# Choose a better name for train, test as in the test stage we do not really evalute the test accuracy
# Maybe make the train set accuracy also as part of the logs in the train.py phase (Check once if it changes any results from the main paper)
case= sys.argv[1]
methods=['erm', 'mmd', 'coral', 'dann', 'c-mmd', 'c-coral', 'c-dann', 'rand', 'perf']
total_seed= 10
@ -16,7 +12,7 @@ total_seed= 10
if case == 'train':
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
elif case == 'test':
elif case == 'evaluate':
base_script= 'python test.py --test_metric per_domain_acc --acc_data_case train --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)

Просмотреть файл

@ -1,14 +1,15 @@
import matplotlib
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
slab_noise= float(sys.argv[1])
base_dir= 'slab_res/slab_noise_' + str(slab_noise) + '/'
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg', 'perf', 'mask_linear']
base_dir= 'results/slab/slab_noise_' + str(slab_noise) + '/'
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
# x=['ERM', 'IRM', 'CSD', 'Rand', 'MatchDG', 'Perf', 'Mask']
x=['ERM', 'IRM', 'CSD', 'Rand', 'Perf', 'Mask']
x=['ERM', 'IRM', 'CSD', 'Rand', 'Perf', 'Oracle']
matplotlib.rcParams.update({'errorbar.capsize': 2})
fig, ax = plt.subplots(1, 2, figsize=(24, 8))
fontsize=40
@ -16,10 +17,12 @@ fontsize_lgd= fontsize/1.2
marker_list = ['o', '^', '*']
count= 0
for test_domain in [0.3, 0.9]:
for test_domain in [0.2, 0.9]:
acc=[]
acc_err=[]
train_acc =[]
train_acc_err =[]
auc=[]
auc_err=[]
s_auc=[]
@ -40,6 +43,11 @@ for test_domain in [0.3, 0.9]:
sc_auc.append( float(data[-1].replace('\n', '').split(' ')[-2] ))
sc_auc_err.append( float( data[-1].replace('\n', '').split(' ')[-1] ) )
f= open(base_dir + method + '-train-auc-' + str(test_domain) + '.txt')
data= f.readlines()
train_acc.append( float( data[-4].replace('\n', '').split(' ')[-2] ))
train_acc_err.append( float( data[-4].replace('\n', '').split(' ')[-1] ))
#Privacy Metrics
mia=[]
@ -81,6 +89,10 @@ for test_domain in [0.3, 0.9]:
ax[count].errorbar(x, acc, yerr=acc_err, marker= marker_list[0], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='OOD Acc')
ax[count].errorbar(x, s_auc, yerr=s_auc_err, marker= marker_list[1], markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Linear-RAUC')
ax[count].errorbar(x, loss, yerr=loss_err, marker= marker_list[2], markersize= fontsize_lgd, linewidth=4, label='Loss Attack', fmt='o--')
gen_gap= np.array(train_acc) - np.array(acc)
ax[count].errorbar(x, gen_gap, yerr=0*gen_gap, marker= 's', markersize= fontsize_lgd, linewidth=4, fmt='o--', label='Gen Gap')
ax[count].set_ylabel('Metric Score', fontsize=fontsize)
ax[count].set_title('Test Domain: ' + str(test_domain), fontsize=fontsize)
@ -88,7 +100,11 @@ for test_domain in [0.3, 0.9]:
lines, labels = fig.axes[-1].get_legend_handles_labels()
lgd= fig.legend(lines, labels, loc="lower center", bbox_to_anchor=(0.5, -0.15), fontsize=fontsize, ncol=3)
plt.tight_layout()
plt.savefig('results/privacy_slab_' + str(slab_noise) + '.pdf', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=600)
save_dir= 'results/slab/plots/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.tight_layout()
plt.savefig( save_dir + 'privacy_slab_' + str(slab_noise) + '.pdf', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=600)

Просмотреть файл

@ -1,21 +1,25 @@
import os
import sys
import argparse
'''
argv1: Case (train, test)
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--case', type=str, default='train',
help='Case: train; test')
parser.add_argument('--slab_noise', type=float, default=0.0,
help='Probability of corrupting slab features')
parser.add_argument('--methods', nargs='+', type=str, default=['erm', 'irm', 'csd', 'rand', 'matchdg', 'perf', 'mask_linear'],
help='List of methods')
argv2: Noise in the slab feature (Flip probability)
'''
args = parser.parse_args()
case= sys.argv[1]
slab_noise= float(sys.argv[2])
case= args.case
slab_noise= args.slab_noise
methods= args.methods
metrics= ['train-auc', 'auc', 'loss']
total_seed= 3
methods=['erm', 'irm', 'csd', 'rand', 'perf', 'mask_linear']
# methods=['matchdg']
# metrics= ['auc', 'mi', 'entropy', 'loss']
metrics= ['auc', 'loss']
if case == 'train':
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 0.90 --slab_data_dim 2 '
@ -40,6 +44,7 @@ if case == 'train':
#CTR Phase
script = base_script + ' --method_name matchdg_ctr --batch_size 512 --match_case 0.0 --match_flag 1 --match_interrupt 5 --pos_metric cos '
os.system(script)
#ERM Phase
script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
@ -51,7 +56,9 @@ elif case == 'test':
if metric == 'auc':
base_script= 'python test_slab.py --train_domains 0.0 0.10 '
if metric == 'train-auc':
base_script= 'python test_slab.py --train_domains 0.0 0.10 --acc_data_case train '
elif metric == 'mi':
base_script= 'python test.py --test_metric mia --mia_logit 1 --mia_sample_size 400 --dataset slab --model_name slab --out_classes 2 --train_domains 0.0 0.10 '
elif metric == 'entropy':
@ -62,7 +69,7 @@ elif case == 'test':
base_script= 'python test.py --test_metric attribute_attack --mia_logit 1 --attribute_domain 0 --dataset slab --model_name slab --out_classes 2 --train_domains 0.0 0.10 '
base_script= base_script + ' --slab_noise ' + str(slab_noise) + ' --n_runs ' + str(total_seed)
res_dir= 'slab_res/slab_noise_' + str(slab_noise) + '/'
res_dir= 'results/slab/slab_noise_' + str(slab_noise) + '/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
@ -85,36 +92,6 @@ elif case == 'test':
elif method == 'matchdg':
upd_script = base_script + ' --method_name matchdg_erm --match_case -1 --penalty_ws 1.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name slab '
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
for test_domain in [0.3, 0.9]:
for test_domain in [0.2, 0.9]:
script= upd_script + ' --test_domains ' + str(test_domain) + ' > ' + res_dir + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
os.system(script)
# elif case == 'train_plot':
# for metric in metrics:
# if metric == 'auc':
# base_script= 'python logit_plot_slab.py --train_domains 0.0 0.10 '
# for method in methods:
# if method == 'erm':
# upd_script= base_script + ' --method_name perf_match --penalty_ws 0.0 '
# elif method == 'irm':
# upd_script= base_script + ' --method_name irm_slab --penalty_irm 10.0 --penalty_s 2 '
# elif method == 'csd':
# upd_script= base_script + ' --method_name csd_slab --penalty_ws 0.0 --rep_dim 100 '
# elif method == 'rand':
# upd_script= base_script + ' --method_name rand_match --penalty_ws 1.0 '
# elif method == 'perf':
# upd_script= base_script + ' --method_name perf_match --penalty_ws 1.0 '
# elif method == 'mask_linear':
# upd_script= base_script + ' --method_name mask_linear --penalty_ws 0.0 '
# for test_domain in [0.05, 0.15, 0.3, 0.5, 0.7, 0.9]:
# script= upd_script + ' --test_domains ' + str(test_domain) + ' > slab_temp/' + str(method) + '-' + str(metric) + '-' + str(test_domain) + '.txt'
# os.system(script)

Просмотреть файл

@ -119,6 +119,9 @@ parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0,
help='Epsilon value for Differential Privacy')
parser.add_argument('--dp_attach_opt', type=int, default=1,
help='0: Infinite Epsilon; 1: Finite Epsilion')
#MMD, DANN

Просмотреть файл

@ -276,7 +276,10 @@ for run in range(args.n_runs):
model= test_method.phi
test_method.get_metric_eval()
std_acc= test_method.metric_score['test accuracy']
if args.acc_data_case == 'train':
std_acc= test_method.metric_score['train accuracy']
elif args.acc_data_case == 'test':
std_acc= test_method.metric_score['test accuracy']
print('Test Accuracy: ', std_acc)
spur_prob= float(test_domains[0])

Просмотреть файл

@ -119,6 +119,9 @@ parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0,
help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1,
help='0: Infinite Epsilon; 1: Finite Epsilion')
#MMD, DANN

0
utils/__init__.py Normal file
Просмотреть файл

Просмотреть файл

@ -286,18 +286,13 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
data_obj= ChestXRay(args, domains, '/chestxray_spur/', data_case=data_case, match_func=match_func)
elif args.dataset_name in ['rot_mnist', 'fashion_mnist', 'rot_mnist_spur']:
if data_case == 'test' and args.mnist_case not in ['lenet', 'lenet_mdg']:
if data_case == 'test' and args.mnist_case not in ['lenet']:
#TODO: Infer this based on the total number of seed values for the mnist case
# Actually by default the seeds 0, 1, 2 are for training and seed 9 is for test; mention that properly in comments
mnist_subset= 9
else:
mnist_subset= run
#TODO: Only Temporary, in order to see if it changes results on MNIST
# if eval_case:
# if args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
# mnist_subset=run
print('MNIST Subset: ', mnist_subset)
data_obj= MnistRotated(args, domains, mnist_subset, '/mnist/', data_case=data_case, match_func=match_func)

Просмотреть файл

@ -229,7 +229,9 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
if perfect_match == 1:
## Find all instances among the curr_domain with same object as obj_base[idx]
## .nonzero() converts True matche to match indexes; [0, 0] takes into the first match of same base object in the curr domain
perfect_match_rank.append( (obj_curr[sort_idx] == obj_base[idx]).nonzero()[0,0].item() )
if obj_base[idx] in obj_curr[sort_idx]:
perfect_match_rank.append( (obj_curr[sort_idx] == obj_base[idx]).nonzero()[0,0].item() )
# print('Time Taken in CTR Loop: ', time.time()-start_time)
elif inferred_match == 0 and perfect_match == 1:
@ -242,6 +244,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
# Select random matches with perm_prob probability
if rand_vars[idx]:
rand_indices = np.arange(ordered_curr_indices.size()[0])
np.random.shuffle(rand_indices)
curr_indices= ordered_curr_indices[rand_indices][:total_matches_per_point]

Просмотреть файл