DP ERM changes; code reorganization

This commit is contained in:
divyat09 2021-07-14 11:27:45 +00:00
Родитель d4faf8c445
Коммит 04db49e237
33 изменённых файлов: 5286 добавлений и 45 удалений

Просмотреть файл

@ -19,6 +19,59 @@ import torch.utils.data as data_utils
from utils.match_function import get_matched_pairs
def get_noise_multiplier(
target_epsilon: float,
target_delta: float,
sample_rate: float,
epochs: int,
alphas: [float],
sigma_min: float = 0.01,
sigma_max: float = 10.0,
) -> float:
r"""
Computes the noise level sigma to reach a total budget of (target_epsilon, target_delta)
at the end of epochs, with a given sample_rate
Args:
target_epsilon: the privacy budget's epsilon
target_delta: the privacy budget's delta
sample_rate: the sampling rate (usually batch_size / n_data)
epochs: the number of epochs to run
alphas: the list of orders at which to compute RDP
Returns:
The noise level sigma to ensure privacy budget of (target_epsilon, target_delta)
"""
from opacus import privacy_analysis
eps = float("inf")
while eps > target_epsilon:
sigma_max = 2 * sigma_max
rdp = privacy_analysis.compute_rdp(
sample_rate, sigma_max, epochs / sample_rate, alphas
)
eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]
if sigma_max > 2000:
raise ValueError("The privacy budget is too low.")
while sigma_max - sigma_min > 0.01:
sigma = (sigma_min + sigma_max) / 2
rdp = privacy_analysis.compute_rdp(
sample_rate, sigma, epochs / sample_rate, alphas
)
eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]
if eps < target_epsilon:
sigma_max = sigma
else:
sigma_min = sigma
return sigma
class BaseAlgo():
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
self.args= args
@ -48,7 +101,8 @@ class BaseAlgo():
self.val_acc=[]
self.train_acc=[]
if self.args.method_name == 'dp_erm':
# if self.args.method_name == 'dp_erm':
if self.args.dp_noise:
self.privacy_engine= self.get_dp_noise()
def get_model(self):
@ -97,7 +151,7 @@ class BaseAlgo():
else:
fc_layer= self.args.fc_layer
phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer,
self.args.img_c, self.args.pre_trained, self.args.os_env)
self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
if 'densenet' in self.args.model_name:
from models.densenet import get_densenet
@ -179,6 +233,11 @@ class BaseAlgo():
return data_match_tensor, label_match_tensor, curr_batch_size
def get_test_accuracy(self, case):
import opacus
if self.args.dp_noise:
opacus.autograd_grad_sample.disable_hooks()
#self.privacy_engine.module.disable_hooks()
#Test Env Code
test_acc= 0.0
@ -190,6 +249,10 @@ class BaseAlgo():
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
with torch.no_grad():
self.opt.zero_grad()
# print(x_e.shape)
# print(torch.cuda.memory_allocated())
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
@ -198,9 +261,15 @@ class BaseAlgo():
test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
test_size+= y_e.shape[0]
# To avoid CUDA memory issues
if self.args.dp_noise:
self.opt.zero_grad()
print(' Accuracy: ', case, 100*test_acc/test_size )
#self.privacy_engine.module.enable_hooks()
opacus.autograd_grad_sample.enable_hooks()
return 100*test_acc/test_size
def get_dp_noise(self):
@ -212,23 +281,32 @@ class BaseAlgo():
from opacus.utils import module_modification
inspector = DPModelInspector()
print(self.phi)
self.phi = module_modification.convert_batchnorm_modules(self.phi)
print(self.phi)
# print(self.phi)
# self.phi = module_modification.convert_batchnorm_modules(self.phi)
inspector.validate(self.phi)
MAX_GRAD_NORM = 1.2
EPSILON = 50.0
NOISE_MULTIPLIER = .38
MAX_GRAD_NORM = 10.0
# NOISE_MULTIPLIER = 0.8
# NOISE_MULTIPLIER = 1.46
# NOISE_MULTIPLIER = 1.15
# NOISE_MULTIPLIER = 0.7
NOISE_MULTIPLIER = 0.0
DELTA = 1.0/(self.total_domains*self.domain_size)
BATCH_SIZE = self.args.batch_size
VIRTUAL_BATCH_SIZE = 2*BATCH_SIZE
BATCH_SIZE = self.args.batch_size * self.total_domains
VIRTUAL_BATCH_SIZE = 10*BATCH_SIZE
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
N_ACCUMULATION_STEPS = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
SAMPLE_RATE = BATCH_SIZE /(self.total_domains*self.domain_size)
DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
print(BATCH_SIZE, SAMPLE_RATE, N_ACCUMULATION_STEPS, SAMPLE_RATE*N_ACCUMULATION_STEPS)
print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
# epsilon=20.0
# print('Value of Noise Multiplier Needed')
# print(get_noise_multiplier(epsilon, DELTA, SAMPLE_RATE, self.args.epochs, DEFAULT_ALPHAS))
# sys.exit(-1)
from opacus import PrivacyEngine
# privacy_engine = PrivacyEngine(
# self.phi,
@ -236,10 +314,17 @@ class BaseAlgo():
# alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
# noise_multiplier=NOISE_MULTIPLIER,
# max_grad_norm=MAX_GRAD_NORM,
# )
# privacy_engine = PrivacyEngine(
# self.phi,
# sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
# noise_multiplier=NOISE_MULTIPLIER,
# max_grad_norm=MAX_GRAD_NORM,
# )
privacy_engine = PrivacyEngine(
self.phi,
sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
batch_size= BATCH_SIZE,
sample_size= self.total_domains*self.domain_size,
noise_multiplier=NOISE_MULTIPLIER,
max_grad_norm=MAX_GRAD_NORM,
)

Просмотреть файл

@ -26,8 +26,8 @@ class ErmMatch(BaseAlgo):
def train(self):
self.max_epoch=-1
self.max_val_acc=0.0
self.max_epoch= -1
self.max_val_acc= 0.0
for epoch in range(self.args.epochs):
if epoch ==0:
@ -44,8 +44,6 @@ class ErmMatch(BaseAlgo):
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)
self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)
@ -53,8 +51,6 @@ class ErmMatch(BaseAlgo):
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e= torch.argmax(d_e, dim=1).numpy()
#Forward Pass
out= self.phi(x_e)
wasserstein_loss=torch.tensor(0.0).to(self.cuda)
erm_loss= torch.tensor(0.0).to(self.cuda)
@ -70,6 +66,7 @@ class ErmMatch(BaseAlgo):
data_match= data_match_tensor.to(self.cuda)
data_match= data_match.flatten(start_dim=0, end_dim=1)
feat_match= self.phi( data_match )
# print(feat_match.shape)
label_match= label_match_tensor.to(self.cuda)
label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
@ -111,7 +108,9 @@ class ErmMatch(BaseAlgo):
loss_e.backward(retain_graph=False)
self.opt.step()
# self.opt.zero_grad()
# del out
del erm_loss
del wasserstein_loss
del loss_e

Просмотреть файл

@ -74,7 +74,7 @@ class MatchDG(BaseAlgo):
if 'resnet' in self.args.ctr_model_name:
from models.resnet import get_resnet
fc_layer=0
ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)
ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env).to(self.cuda)
if 'densenet' in self.args.ctr_model_name:
from models.densenet import get_densenet

Просмотреть файл

@ -248,7 +248,7 @@ for seed in seed_list:
indices= res[:subset_size]
if model == 'resnet18':
if seed in [9] and domain in [0, 90]:
if seed in [0, 1, 2, 9] and domain in [0, 90]:
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
elif model in ['lenet', 'lenet_mdg']:
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -36,11 +36,12 @@ class SpurCorrDataLoader(data_utils.Dataset):
def __init__(self, dataloader):
super(SpurCorrDataLoader, self).__init__()
self.x= dataloader.train_data
self.y= dataloader.train_labels
self.d= dataloader.train_domain
self.indices= dataloader.train_indices
self.spur_corr= dataloader.train_spur
self.x= dataloader.data
self.y= dataloader.labels
self.d= dataloader.domains
self.indices= dataloader.indices
self.objects= dataloader.objects
self.spur_corr= dataloader.spur
def __len__(self):
@ -51,9 +52,10 @@ class SpurCorrDataLoader(data_utils.Dataset):
batch_y = self.y[index]
batch_d = self.d[index]
batch_idx = self.indices[index]
batch_obj= self.objects[index]
batch_spur= self.spur_corr[index]
return batch_x, batch_y, batch_d, batch_idx, batch_spur
return batch_x, batch_y, batch_d, batch_idx, batch_obj, batch_spur
class AttributeAttack(BaseEval):

Просмотреть файл

@ -103,7 +103,7 @@ class BaseEval():
else:
fc_layer= self.args.fc_layer
phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer,
self.args.img_c, self.args.pre_trained, self.args.os_env)
self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
if 'densenet' in self.args.model_name:
from models.densenet import get_densenet

Просмотреть файл

@ -112,6 +112,7 @@ class PrivacyEntropy(BaseEval):
print(case, attack_data['logits'].shape, attack_data['labels'].shape, attack_data['members'].shape)
return attack_data
def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='train'):
@ -141,6 +142,8 @@ class PrivacyEntropy(BaseEval):
# print('F_y, F_i', F_y.shape, F_i.shape)
# print('Neg term: ', (F_i*torch.log(1.0-F_i)).shape, F_i[0])
metric= -1*(1.0 - F_y)*torch.log(F_y) -1*torch.sum( F_i*torch.log(1.0-F_i), dim=1 )
# metric= -1*(1.0 - F_y)*torch.log(F_y)
# metric= -1*(1.0)*torch.log(F_y)
# threshold_data[y_c]= torch.max(metric)
threshold_data[y_c]= torch.mean(metric)
@ -166,6 +169,8 @@ class PrivacyEntropy(BaseEval):
F_y= torch.sum( logits*labels, dim=1)
F_i= logits*(1.0-labels)
metric= -1*(1.0 - F_y)*torch.log(F_y) -1*torch.sum( F_i*torch.log(1.0-F_i), dim=1 )
# metric= -1*(1.0 - F_y)*torch.log(F_y)
# metric= -1*(1.0)*torch.log(F_y)
mem_predict= 1.0*(metric < (threshold_data[y_c]/scale))
acc+= torch.sum( mem_predict == members ).item()

Просмотреть файл

@ -48,12 +48,13 @@ class PrivacyLossAttack(BaseEval):
train_data={}
train_data['loss']=[]
train_data['labels']=[]
train_data['obj']= []
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset['data_loader']):
#Random Shuffling along the batch axis
rand_indices= torch.randperm(x_e.size()[0])
x_e= x_e[rand_indices]
y_e= y_e[rand_indices]
obj_e= obj_e[rand_indices]
with torch.no_grad():
x_e= x_e.to(self.cuda)
y_e= y_e.to(self.cuda)
@ -62,19 +63,24 @@ class PrivacyLossAttack(BaseEval):
loss= cross_entropy(out, torch.argmax(y_e, dim=1).long()).to(self.cuda)
train_data['loss'].append(loss)
train_data['labels'].append(y_e)
train_data['obj'].append(obj_e)
train_data['loss']= torch.cat(train_data['loss'], dim=0)
train_data['labels']= torch.cat(train_data['labels'], dim=0)
train_data['obj']= torch.cat(train_data['obj'], dim=0)
#Test Environment Data
test_data={}
test_data['loss']=[]
test_data['labels']=[]
test_data['obj']=[]
test_data['free']=[]
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.test_dataset['data_loader']):
#Random Shuffling along the batch axis
rand_indices= torch.randperm(x_e.size()[0])
x_e= x_e[rand_indices]
y_e= y_e[rand_indices]
obj_e= obj_e[rand_indices]
with torch.no_grad():
x_e= x_e.to(self.cuda)
@ -85,13 +91,20 @@ class PrivacyLossAttack(BaseEval):
test_data['loss'].append(loss)
test_data['labels'].append(y_e)
test_data['obj'].append(obj_e)
test_data['free']= test_data['free'] + [1]*out.shape[0]
test_data['loss']= torch.cat(test_data['loss'], dim=0)
test_data['labels']= torch.cat(test_data['labels'], dim=0)
test_data['obj']= torch.cat(test_data['obj'], dim=0)
test_data['free']= torch.tensor(test_data['free'])
print('Train Logits: ', train_data['loss'].shape, 'Train Labels: ', train_data['labels'].shape, ' Train Objs: ', train_data['obj'].shape )
print(torch.unique(train_data['obj']))
print('Test Logits: ', test_data['loss'].shape, 'Test Labels: ', test_data['labels'].shape, ' Test Objs: ', test_data['obj'].shape )
print(torch.unique(test_data['obj']))
print('Test Free: ', test_data['free'].shape, test_data['free'])
print('Train Logits: ', train_data['loss'].shape, 'Train Labels: ', train_data['labels'].shape )
print('Test Logits: ', test_data['loss'].shape, 'Test Labels: ', test_data['labels'].shape )
return train_data, test_data
def create_attack_data(self, train_data, test_data, sample_size, case='train'):
@ -102,6 +115,7 @@ class PrivacyLossAttack(BaseEval):
test_loss= test_data['loss'][:sample_size]
test_labels= test_data['labels'][:sample_size]
elif case == 'test':
train_loss= train_data['loss'][-1-sample_size:-1]
train_labels= train_data['labels'][-1-sample_size:-1]
@ -116,6 +130,67 @@ class PrivacyLossAttack(BaseEval):
print(case, attack_data['loss'].shape, attack_data['labels'].shape, attack_data['members'].shape)
return attack_data
# def create_attack_data(self, train_data, test_data, sample_size, case='train'):
# if case == 'train':
# train_loss= train_data['loss'][:sample_size]
# train_labels= train_data['labels'][:sample_size]
# train_obj= train_data['obj'][:sample_size]
# test_loss= []
# test_labels= []
# for idx in range(sample_size):
# obj= train_obj[idx]
# indice= (test_data['obj'] == obj).nonzero()
# for idx_obj in range(indice.shape[0]):
# curr_indice= indice[idx_obj, 0].item()
# if test_data['free'][curr_indice] == 1:
# test_loss.append(test_data['loss'][curr_indice].view(1))
# #TODO: Change 10 to num_classes
# test_labels.append(test_data['labels'][curr_indice].view(1, 10))
# test_data['free'][curr_indice]= 0
# break
# test_loss= torch.cat(test_loss, dim=0)
# test_labels= torch.cat(test_labels, dim=0)
# elif case == 'test':
# train_loss= train_data['loss'][-1-sample_size:-1]
# train_labels= train_data['labels'][-1-sample_size:-1]
# train_obj= train_data['obj'][-1-sample_size:-1]
# test_loss= []
# test_labels= []
# for idx in range(sample_size):
# obj= train_obj[idx]
# indice= (test_data['obj'] == obj).nonzero()
# for idx_obj in range(indice.shape[0]):
# curr_indice= indice[idx_obj, 0].item()
# if test_data['free'][curr_indice] == 1:
# test_loss.append(test_data['loss'][curr_indice].view(1))
# #TODO: Change 10 to num_classes
# test_labels.append(test_data['labels'][curr_indice].view(1, 10))
# test_data['free'][curr_indice]= 0
# break
# test_loss= torch.cat(test_loss, dim=0)
# test_labels= torch.cat(test_labels, dim=0)
# print('Attack Dataset Members: ', train_loss.shape, train_labels.shape)
# print('Attack Dataset Non Members: ', test_loss.shape, test_labels.shape)
# attack_data={}
# attack_data['loss']= torch.cat( (train_loss, test_loss), dim=0 )
# attack_data['labels']= torch.cat( (train_labels, test_labels), dim=0 )
# # attack_data['members']= torch.cat( (torch.ones((sample_size,1)), torch.zeros((sample_size,1))), dim=0).to(self.cuda)
# attack_data['members']= torch.cat( (torch.ones((train_loss.shape[0], 1)), torch.zeros((test_loss.shape[0],1))), dim=0).to(self.cuda)
# print(case, attack_data['loss'].shape, attack_data['labels'].shape, attack_data['members'].shape)
# return attack_data
def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='train'):

Просмотреть файл

@ -0,0 +1,108 @@
import sys
import numpy as np
import argparse
import copy
import random
import json
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
from .base_eval import BaseEval
from utils.match_function import get_matched_pairs, perfect_match_score
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
def sim_matrix(a, b, eps=1e-8):
"""
added eps for numerical stability
"""
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt
class SlabFeatEval(BaseEval):
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda)
def get_metric_eval(self):
if self.args.match_func_data_case=='train':
dataset= self.train_dataset['data_loader']
total_domains= self.train_dataset['total_domains']
domain_list= self.train_dataset['domain_list']
base_domain_size= self.train_dataset['base_domain_size']
domain_size_list= self.train_dataset['domain_size_list']
elif self.args.match_func_data_case== 'val':
dataset= self.val_dataset['data_loader']
total_domains= self.val_dataset['total_domains']
domain_list= self.val_dataset['domain_list']
base_domain_size= self.val_dataset['base_domain_size']
domain_size_list= self.val_dataset['domain_size_list']
elif self.args.match_func_data_case== 'test':
dataset= self.test_dataset['data_loader']
total_domains= self.test_dataset['total_domains']
domain_list= self.test_dataset['domain_list']
base_domain_size= self.test_dataset['base_domain_size']
domain_size_list= self.test_dataset['domain_size_list']
pos_metric= 'cos'
with torch.no_grad():
penalty_ws=0
batch_size_counter=0
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(dataset):
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
#Forward Pass
out= self.phi(x_e)
match_objs= np.unique(idx_e)
feat= self.phi.feat_net(x_e)
for obj in match_objs:
indices= idx_e == obj
feat_obj= feat[indices]
d_obj= d_e[indices]
match_domains= torch.unique(d_obj)
if len(match_domains) != len(torch.unique(d_e)):
# print('Error: Positivty Violation, objects not present in all the domains')
continue
for d_i in range(len(match_domains)):
for d_j in range(len(match_domains)):
if d_j <= d_i:
continue
x1= feat_obj[ d_obj == d_i ]
x2= feat_obj[ d_obj == d_j ]
if pos_metric == 'l2':
x1= x1.view(x1.shape[0], 1, x1.shape[1])
penalty_ws+= float( torch.sum( torch.sum( torch.sum( (x1 -x2)**2, dim=2), dim=1 )) )
elif pos_metric == 'l1':
x1= x1.view(x1.shape[0], 1, x1.shape[1])
penalty_ws+= float( torch.sum( torch.sum( torch.sum( torch.abs(x1 -x2), dim=2), dim=1 )) )
elif pos_metric == 'cos':
penalty_ws+= float( torch.sum( torch.sum( sim_matrix(x1, x2), dim=1)) )
batch_size_counter+= x1.shape[0]*x2.shape[0]
torch.cuda.empty_cache()
self.metric_score['Perfect Match Distance']= penalty_ws/batch_size_counter
print('Perfect Match Distance: ', self.metric_score['Perfect Match Distance'])
return

Просмотреть файл

@ -0,0 +1,255 @@
#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle
#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
#Sklearn
from sklearn.manifold import TSNE
#robustdg
from utils.helper import *
from utils.match_function import *
#slab
from utils.slab_data import *
import utils.scripts.utils as slab_utils
import utils.scripts.lms_utils as slab_lms_utils
def get_logits(model, loader, device, label=1):
X, Y = slab_utils.extract_tensors_from_loader(loader)
L = slab_utils.get_logits_given_tensor(X, model, device=device).detach()
L = L[Y==label].cpu().numpy()
S = L[:, 1] - L[:, 0] # compute score / difference to get scalar
return S
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='slab',
help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match',
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='slab',
help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
help='List of test domains')
parser.add_argument('--out_classes', type=int, default=2,
help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1,
help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2',
help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250,
help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0,
help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1,
help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd',
help='Optimizer Choice: sgd; adam')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16,
help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=15,
help='Total number of epochs for training the model')
parser.add_argument('--penalty_w', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_s', type=int, default=-1,
help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_ws', type=float, default=0.1,
help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05,
help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0,
help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0,
help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5,
help='Number of epochs before inferring the match strategy')
parser.add_argument('--n_runs', type=int, default=3,
help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1,
help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18',
help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match',
help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1,
help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01,
help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5,
help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0,
help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0,
help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--test_metric', type=str, default='acc',
help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test',
help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10,
help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0,
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='train',
help='Dataset Train/Val/Test for the match score evaluation metric')
parser.add_argument('--mia_batch_size', default=64, type=int,
help='batch size')
parser.add_argument('--mia_dnn_steps', default=5000, type=int,
help='number of training steps')
parser.add_argument('--mia_sample_size', default=1000, type=int,
help='number of samples from train/test dataset logits')
parser.add_argument('--mia_logit', default=1, type=int,
help='0: Softmax applied to logits; 1: No Softmax applied to logits')
parser.add_argument('--attribute_domain', default=1, type=int,
help='0: spur correlations as attribute; 1: domain as attribute')
parser.add_argument('--adv_eps', default=0.3, type=float,
help='Epsilon ball dimension for PGD attacks')
parser.add_argument('--logit_plot_path', default='', type=str,
help='File name to save logit/loss plots')
parser.add_argument('--ctr_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--cuda_device', type=int, default=0,
help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
args = parser.parse_args()
#GPU
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
kwargs = {'num_workers': 1, 'pin_memory': False}
else:
kwargs= {}
args.kwargs= kwargs
#List of Train; Test domains
train_domains= args.train_domains
test_domains= args.test_domains
#Initialize
final_acc= []
final_auc= []
final_s_auc= []
final_sc_auc= []
base_res_dir=(
"results/" + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
if not os.path.exists(base_res_dir):
os.makedirs(base_res_dir)
#Checks
if args.method_name == 'matchdg_ctr' and args.test_metric == 'acc':
raise ValueError('Match DG during the contrastive learning phase cannot be evaluted for test accuracy metric')
sys.exit()
if args.perfect_match == 0 and args.test_metric == 'match_score' and args.match_func_aug_case==0:
raise ValueError('Cannot evalute match function metrics when perfect match is not known')
sys.exit()
#Execute the method for multiple runs ( total args.n_runs )
for run in range(args.n_runs):
#Seed for reproducability
np.random.seed(run*10)
torch.manual_seed(run*10)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(run*10)
#DataLoader
train_dataset= torch.empty(0)
val_dataset= torch.empty(0)
test_dataset= torch.empty(0)
# if args.test_metric == 'match_score':
# if args.match_func_data_case== 'train':
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
# elif args.match_func_data_case== 'val':
# val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
# elif args.match_func_data_case== 'test':
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
# elif args.test_metric == 'acc':
# if args.acc_data_case== 'train':
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
# elif args.acc_data_case== 'test':
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
# elif args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
# elif args.test_metric == 'attribute_attack':
# train_dataset= get_dataloader( args, run, train_domains + test_domains, 'train', 1, kwargs )
# test_dataset= get_dataloader( args, run, train_domains + test_domains, 'test', 1, kwargs )
# else:
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
#Import the testing module
from evaluation.base_eval import BaseEval
test_method= BaseEval(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
test_method.get_model()
model= test_method.phi
for train_domain in train_domains:
spur_prob= float(train_domain)
data, temp1, _, _= get_data(args.slab_num_samples, spur_prob, args.slab_total_slabs)
# compute logit scores
std_log = get_logits(model, data['te_dl'], cuda)
break
# plot logit distributions
kw = dict(kde=False, bins=20, norm_hist=True,
hist_kws={"histtype": "step", "linewidth": 2,
"alpha": 0.8, "ls": '-'})
fig, ax = plt.subplots(1,1,figsize=(6,4))
ax = sns.distplot(std_log, label='Standard Logits', **kw)
slab_utils.update_ax(ax, 'Logit Distributions of Positive Data', 'Logits', 'Density',
ticks_fs=13, label_fs=13, title_fs=16, legend_fs=14, legend_loc='upper left')
plt.savefig('results/slab_train_logit_plot/' + str(args.method_name)+ '_' + str(args.penalty_ws) + '_' + str(run) + '.jpg')

36
misc_scripts/privacy.txt Normal file
Просмотреть файл

@ -0,0 +1,36 @@
#Entorpy Attack
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 10 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
#MIA Attack
python test.py --test_metric mia --mia_sample_size 1000 --mia_logit 1 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 10 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
#Mean Rank
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name erm_match --match_case 0.01 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle

104
misc_scripts/reproduce.txt Normal file
Просмотреть файл

@ -0,0 +1,104 @@
Table 1:
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos
python train.py --dataset rot_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18
Table 2: (Match func evaluation on train domains)
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --test_metric match_score
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score
Table 6: (LeNet)
#MatchDG
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 1.0 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --epochs 100 --lr 0.01 --model_name lenet --img_h 32 --img_w 32 --train_domains 15 30 45 60 75 --test_domains 0
#Perf
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --train_domains 0 15 30 45 60 --test_domains 75
Table 7: (DomainBed)
#MatchDG
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 1.0 --match_case -1 --batch_size 64 --epochs 25 --model_name domain_bed_mnist --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --img_h 28 --img_w 28 --train_domains 15 30 45 60 75 --test_domains 0
#Perf
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --batch_size 64 --epochs 25 --model_name domain_bed_mnist --img_h 28 --img_w 28 --train_domains 15 30 45 60 75 --test_domains 0
Table 8: (Phase 2 Rankings)
#RandMatch
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --test_metric match_score
#MatchDG
python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --test_metric match_score
#PerfMatch
python test.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1 --test_metric match_score
Chexpert: (Acc Evaluation)
#Matching Methods
python test.py --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains nih_trans
#CSD
python test.py --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
#IRM
python test.py --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
python test.py --dataset chestxray --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 0.0 --match_case 0.01 --train_domains nih_trans chex_trans --test_domains chex_tran
#MatchDG
python test.py --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
#Hybrid
python test.py --dataset chestxray --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 1.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle_trans
PACS: (Acc Evaluation)
#Matching Methods
python test.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --train_domains art_painting cartoon sketch --test_domains photo
#MatchDG
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 1.0 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws
1.0 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws
0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
#Hybrid
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.01 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.01
--match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch

Просмотреть файл

@ -0,0 +1,3 @@
python3 train.py --dataset adult --model fc --out_classes 2 --train_domains male female --test_domains male female --penalty_irm 100.0 --method_name irm --penalty_s -1
python test.py --test_metric attribute_attack --mia_logit 1 --batch_size 64 --dataset adult --model fc --out_classes 2 --train_domains male female --test_domains male female --penalty_ws 10.0

Просмотреть файл

@ -0,0 +1,44 @@
M2ExZ
In the domain generalization literature, a common objective is to learn representations independent of the domain after conditioning on the class label. We show that this objective is not sufficient: there exist counter-examples where a model fails to generalize to unseen domains even after satisfying class-conditional domain invariance. We formalize this observation through a structural causal model and show the importance of modeling within-class variations for generalization. Specifically, classes contain objects that characterize specific causal features, and domains can be interpreted as interventions on these objects that change non-causal features. We highlight an alternative condition: inputs across domains should have the same representation if they are derived from the same object. Based on this objective, we propose matching-based algorithms when base objects are observed (e.g., through data augmentation) and approximate the objective when objects are not observed (MatchDG). Our simple matching-based algorithms are competitive to prior work on out-of-domain accuracy for rotated MNIST, Fashion-MNIST, PACS, and Chest-Xray datasets. Our method MatchDG also recovers ground-truth object matches: on MNIST and Fashion-MNIST, top-10 matches from MatchDG have over 50% overlap with ground-truth matches.
Table 11:
88.2 + 98.6 + 97.7 + 97.5 + 97.0 + 85.6
91.0 + 99.7 + 99.6 + 99.4 + 99.7 + 93.1
93.0 + 99.5 + 99.9 + 99.4 + 99.7 + 93.3
96.5 + 99.1 + 99.2 + 98.6 + 98.6 + 94.9
Table 12:
95.4 + 98.2 + 97.9 + 98.5 + 98.1 + 94.3
95.9 + 98.4 + 98.6 + 98.9 + 98.7 + 95.1
Table 18:
95.38 + 77.68 + 78.98 + 74.75
95.37 + 78.16 + 78.83 + 75.13
95.93 + 79.77 + 80.03 + 77.11
96.15 + 81.71 + 80.75 + 78.79
95.57 + 79.09 + 79.37 + 77.60
96.53 + 81.32 + 80.70 + 79.72
96.67 + 82.80 + 81.61 + 81.05
97.89 + 82.16 + 81.68 + 80.45
97.94 + 85.61 + 82.12 + 78.76
98.36 + 86.74 + 82.32 + 82.66
Table 19:
85.29 + 64.23 + 66.61 + 59.25
85.42 + 65.54 + 68.41 + 59.46
85.41 + 66.21 + 68.47 + 59.56
85.67 + 66.89 + 68.89 + 60.39
86.04 + 67.35 + 69.71 + 64.66
86.52 + 67.99 + 69.92 + 65.64
87.03 + 67.97 + 71.06 + 67.19

Просмотреть файл

@ -5,6 +5,10 @@ import torchvision
from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck
import os
class GroupNorm(torch.nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, **kwargs):
super().__init__(num_groups, num_channels, **kwargs)
# bypass layer
class Identity(nn.Module):
def __init__(self,n_inputs):
@ -15,14 +19,20 @@ class Identity(nn.Module):
return x
def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, os_env):
def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, dp_noise, os_env):
if model_name == 'resnet18':
if os_env:
model= torchvision.models.resnet18()
if pre_trained:
model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet18-5c106cde.pth' ))
else:
model= torchvision.models.resnet18(pre_trained)
if dp_noise:
#TODO
#This cannot work with pre trained model, as the batch norm weights have been replaced by group norm
model= torchvision.models.resnet18(pre_trained, norm_layer=GroupNorm)
# model= torchvision.models.resnet18(pre_trained)
else:
model= torchvision.models.resnet18(pre_trained)
n_inputs = model.fc.in_features
n_outputs= classes
@ -33,7 +43,10 @@ def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, os_env):
if pre_trained:
model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet50-19c8e357.pth' ))
else:
model= torchvision.models.resnet50(pre_trained)
if dp_noise:
model= torchvision.models.resnet50(pre_trained, norm_layer=GroupNorm)
else:
model= torchvision.models.resnet50(pre_trained)
n_inputs = model.fc.in_features
n_outputs= classes

Просмотреть файл

Просмотреть файл

@ -88,7 +88,7 @@ for method in methods:
# To train and test on the same domains (In Distribution Generalization)
# train_domains= curr_test_domain
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
print('Metric', metric, ' Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + curr_test_domain
# 5 seeds to formally define the trends for the privacy part

Просмотреть файл

Просмотреть файл

@ -65,7 +65,8 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train', 'feat_eval:test']
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
# metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
metrics= ['privacy_loss_attack', 'privacy_entropy']
acc_train=[]
acc_train_err=[]
@ -133,7 +134,7 @@ for metric in metrics:
feat_eval_test.append(mean)
feat_eval_test_err.append(sd)
for idx in range(4):
for idx in range(1,2):
matplotlib.rcParams.update({'errorbar.capsize': 2})
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
@ -183,4 +184,4 @@ for idx in range(4):
os.makedirs(save_dir)
plt.tight_layout()
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)

Просмотреть файл

@ -92,7 +92,7 @@ if dataset == 'rot_mnist_spur':
for method in methods:
case= res_dir + str(method)
print('Method: ', method, ' Train Domains: ', train_case)
print('Metric', metric, 'Method: ', method, ' Train Domains: ', train_case)
if method == 'erm':
script= base_script + ' --method_name erm_match --penalty_ws 0.0 --match_case 0.0 --epochs 25 ' + ' > ' + case + '.txt'

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

16
test.py
Просмотреть файл

@ -135,6 +135,8 @@ parser.add_argument('--cuda_device', type=int, default=0,
help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
@ -174,9 +176,17 @@ test_domains= args.test_domains
#Initialize
final_metric_score=[]
base_res_dir=(
"results/" + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
res_dir= 'results/'
if args.dp_noise:
base_res_dir=(
res_dir + args.dataset_name + '/' + 'dp_' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
else:
base_res_dir=(
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
#TODO: Handle slab noise case in helper functions

Просмотреть файл

@ -162,7 +162,14 @@ if args.os_env:
res_dir= os.getenv('PT_OUTPUT_DIR') + '/'
else:
res_dir= 'results/'
base_res_dir=(
if args.dp_noise:
base_res_dir=(
res_dir + args.dataset_name + '/' + 'dp_' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
else:
base_res_dir=(
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)

Просмотреть файл

@ -247,7 +247,12 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
match_func=False
# Can select a higher batch size for val and test domains
## TODO: If condition for test batch size less than total size
batch_size= 512
#Don't try higher batch size in the case of dp-noise trained models to avoid CUDA errors
if args.dp_noise:
batch_size= args.batch_size*5
else:
batch_size= 512
# Set match_func to True in case of test metric as match_score
try:
@ -285,6 +290,12 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
mnist_subset=9
else:
mnist_subset=run
#TODO: Only Temporary, in order to see if it changes results on MNIST
# if eval_case:
# if args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
# mnist_subset=run
print('MNIST Subset: ', mnist_subset)
data_obj= MnistRotated(args, domains, mnist_subset, '/mnist/', data_case=data_case, match_func=match_func)