DP ERM changes; code reorganization
This commit is contained in:
Родитель
d4faf8c445
Коммит
04db49e237
|
@ -19,6 +19,59 @@ import torch.utils.data as data_utils
|
|||
|
||||
from utils.match_function import get_matched_pairs
|
||||
|
||||
|
||||
def get_noise_multiplier(
|
||||
target_epsilon: float,
|
||||
target_delta: float,
|
||||
sample_rate: float,
|
||||
epochs: int,
|
||||
alphas: [float],
|
||||
sigma_min: float = 0.01,
|
||||
sigma_max: float = 10.0,
|
||||
) -> float:
|
||||
r"""
|
||||
Computes the noise level sigma to reach a total budget of (target_epsilon, target_delta)
|
||||
at the end of epochs, with a given sample_rate
|
||||
|
||||
Args:
|
||||
target_epsilon: the privacy budget's epsilon
|
||||
target_delta: the privacy budget's delta
|
||||
sample_rate: the sampling rate (usually batch_size / n_data)
|
||||
epochs: the number of epochs to run
|
||||
alphas: the list of orders at which to compute RDP
|
||||
|
||||
Returns:
|
||||
The noise level sigma to ensure privacy budget of (target_epsilon, target_delta)
|
||||
|
||||
"""
|
||||
|
||||
from opacus import privacy_analysis
|
||||
|
||||
eps = float("inf")
|
||||
while eps > target_epsilon:
|
||||
sigma_max = 2 * sigma_max
|
||||
rdp = privacy_analysis.compute_rdp(
|
||||
sample_rate, sigma_max, epochs / sample_rate, alphas
|
||||
)
|
||||
eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]
|
||||
if sigma_max > 2000:
|
||||
raise ValueError("The privacy budget is too low.")
|
||||
|
||||
while sigma_max - sigma_min > 0.01:
|
||||
sigma = (sigma_min + sigma_max) / 2
|
||||
rdp = privacy_analysis.compute_rdp(
|
||||
sample_rate, sigma, epochs / sample_rate, alphas
|
||||
)
|
||||
eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]
|
||||
|
||||
if eps < target_epsilon:
|
||||
sigma_max = sigma
|
||||
else:
|
||||
sigma_min = sigma
|
||||
|
||||
return sigma
|
||||
|
||||
|
||||
class BaseAlgo():
|
||||
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
|
||||
self.args= args
|
||||
|
@ -48,7 +101,8 @@ class BaseAlgo():
|
|||
self.val_acc=[]
|
||||
self.train_acc=[]
|
||||
|
||||
if self.args.method_name == 'dp_erm':
|
||||
# if self.args.method_name == 'dp_erm':
|
||||
if self.args.dp_noise:
|
||||
self.privacy_engine= self.get_dp_noise()
|
||||
|
||||
def get_model(self):
|
||||
|
@ -97,7 +151,7 @@ class BaseAlgo():
|
|||
else:
|
||||
fc_layer= self.args.fc_layer
|
||||
phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer,
|
||||
self.args.img_c, self.args.pre_trained, self.args.os_env)
|
||||
self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
|
||||
|
||||
if 'densenet' in self.args.model_name:
|
||||
from models.densenet import get_densenet
|
||||
|
@ -179,6 +233,11 @@ class BaseAlgo():
|
|||
return data_match_tensor, label_match_tensor, curr_batch_size
|
||||
|
||||
def get_test_accuracy(self, case):
|
||||
import opacus
|
||||
|
||||
if self.args.dp_noise:
|
||||
opacus.autograd_grad_sample.disable_hooks()
|
||||
#self.privacy_engine.module.disable_hooks()
|
||||
|
||||
#Test Env Code
|
||||
test_acc= 0.0
|
||||
|
@ -190,6 +249,10 @@ class BaseAlgo():
|
|||
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
|
||||
with torch.no_grad():
|
||||
|
||||
self.opt.zero_grad()
|
||||
# print(x_e.shape)
|
||||
# print(torch.cuda.memory_allocated())
|
||||
x_e= x_e.to(self.cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
|
||||
|
||||
|
@ -198,9 +261,15 @@ class BaseAlgo():
|
|||
|
||||
test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
|
||||
test_size+= y_e.shape[0]
|
||||
|
||||
# To avoid CUDA memory issues
|
||||
if self.args.dp_noise:
|
||||
self.opt.zero_grad()
|
||||
|
||||
print(' Accuracy: ', case, 100*test_acc/test_size )
|
||||
|
||||
|
||||
#self.privacy_engine.module.enable_hooks()
|
||||
opacus.autograd_grad_sample.enable_hooks()
|
||||
return 100*test_acc/test_size
|
||||
|
||||
def get_dp_noise(self):
|
||||
|
@ -212,23 +281,32 @@ class BaseAlgo():
|
|||
from opacus.utils import module_modification
|
||||
|
||||
inspector = DPModelInspector()
|
||||
print(self.phi)
|
||||
self.phi = module_modification.convert_batchnorm_modules(self.phi)
|
||||
print(self.phi)
|
||||
# print(self.phi)
|
||||
# self.phi = module_modification.convert_batchnorm_modules(self.phi)
|
||||
inspector.validate(self.phi)
|
||||
|
||||
MAX_GRAD_NORM = 1.2
|
||||
EPSILON = 50.0
|
||||
NOISE_MULTIPLIER = .38
|
||||
MAX_GRAD_NORM = 10.0
|
||||
# NOISE_MULTIPLIER = 0.8
|
||||
# NOISE_MULTIPLIER = 1.46
|
||||
# NOISE_MULTIPLIER = 1.15
|
||||
# NOISE_MULTIPLIER = 0.7
|
||||
NOISE_MULTIPLIER = 0.0
|
||||
DELTA = 1.0/(self.total_domains*self.domain_size)
|
||||
BATCH_SIZE = self.args.batch_size
|
||||
VIRTUAL_BATCH_SIZE = 2*BATCH_SIZE
|
||||
BATCH_SIZE = self.args.batch_size * self.total_domains
|
||||
VIRTUAL_BATCH_SIZE = 10*BATCH_SIZE
|
||||
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
|
||||
N_ACCUMULATION_STEPS = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
|
||||
SAMPLE_RATE = BATCH_SIZE /(self.total_domains*self.domain_size)
|
||||
DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
|
||||
|
||||
print(BATCH_SIZE, SAMPLE_RATE, N_ACCUMULATION_STEPS, SAMPLE_RATE*N_ACCUMULATION_STEPS)
|
||||
|
||||
print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
|
||||
|
||||
|
||||
# epsilon=20.0
|
||||
# print('Value of Noise Multiplier Needed')
|
||||
# print(get_noise_multiplier(epsilon, DELTA, SAMPLE_RATE, self.args.epochs, DEFAULT_ALPHAS))
|
||||
# sys.exit(-1)
|
||||
from opacus import PrivacyEngine
|
||||
# privacy_engine = PrivacyEngine(
|
||||
# self.phi,
|
||||
|
@ -236,10 +314,17 @@ class BaseAlgo():
|
|||
# alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
# noise_multiplier=NOISE_MULTIPLIER,
|
||||
# max_grad_norm=MAX_GRAD_NORM,
|
||||
# )
|
||||
# privacy_engine = PrivacyEngine(
|
||||
# self.phi,
|
||||
# sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
|
||||
# noise_multiplier=NOISE_MULTIPLIER,
|
||||
# max_grad_norm=MAX_GRAD_NORM,
|
||||
# )
|
||||
privacy_engine = PrivacyEngine(
|
||||
self.phi,
|
||||
sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
|
||||
batch_size= BATCH_SIZE,
|
||||
sample_size= self.total_domains*self.domain_size,
|
||||
noise_multiplier=NOISE_MULTIPLIER,
|
||||
max_grad_norm=MAX_GRAD_NORM,
|
||||
)
|
||||
|
|
|
@ -26,8 +26,8 @@ class ErmMatch(BaseAlgo):
|
|||
|
||||
def train(self):
|
||||
|
||||
self.max_epoch=-1
|
||||
self.max_val_acc=0.0
|
||||
self.max_epoch= -1
|
||||
self.max_val_acc= 0.0
|
||||
for epoch in range(self.args.epochs):
|
||||
|
||||
if epoch ==0:
|
||||
|
@ -44,8 +44,6 @@ class ErmMatch(BaseAlgo):
|
|||
|
||||
#Batch iteration over single epoch
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
|
||||
# print('Batch Idx: ', batch_idx)
|
||||
|
||||
self.opt.zero_grad()
|
||||
loss_e= torch.tensor(0.0).to(self.cuda)
|
||||
|
||||
|
@ -53,8 +51,6 @@ class ErmMatch(BaseAlgo):
|
|||
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
|
||||
d_e= torch.argmax(d_e, dim=1).numpy()
|
||||
|
||||
#Forward Pass
|
||||
out= self.phi(x_e)
|
||||
|
||||
wasserstein_loss=torch.tensor(0.0).to(self.cuda)
|
||||
erm_loss= torch.tensor(0.0).to(self.cuda)
|
||||
|
@ -70,6 +66,7 @@ class ErmMatch(BaseAlgo):
|
|||
data_match= data_match_tensor.to(self.cuda)
|
||||
data_match= data_match.flatten(start_dim=0, end_dim=1)
|
||||
feat_match= self.phi( data_match )
|
||||
# print(feat_match.shape)
|
||||
|
||||
label_match= label_match_tensor.to(self.cuda)
|
||||
label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
|
||||
|
@ -111,7 +108,9 @@ class ErmMatch(BaseAlgo):
|
|||
|
||||
loss_e.backward(retain_graph=False)
|
||||
self.opt.step()
|
||||
# self.opt.zero_grad()
|
||||
|
||||
# del out
|
||||
del erm_loss
|
||||
del wasserstein_loss
|
||||
del loss_e
|
||||
|
|
|
@ -74,7 +74,7 @@ class MatchDG(BaseAlgo):
|
|||
if 'resnet' in self.args.ctr_model_name:
|
||||
from models.resnet import get_resnet
|
||||
fc_layer=0
|
||||
ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)
|
||||
ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env).to(self.cuda)
|
||||
|
||||
if 'densenet' in self.args.ctr_model_name:
|
||||
from models.densenet import get_densenet
|
||||
|
|
|
@ -248,7 +248,7 @@ for seed in seed_list:
|
|||
indices= res[:subset_size]
|
||||
|
||||
if model == 'resnet18':
|
||||
if seed in [9] and domain in [0, 90]:
|
||||
if seed in [0, 1, 2, 9] and domain in [0, 90]:
|
||||
generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h)
|
||||
elif model in ['lenet', 'lenet_mdg']:
|
||||
if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -36,11 +36,12 @@ class SpurCorrDataLoader(data_utils.Dataset):
|
|||
def __init__(self, dataloader):
|
||||
super(SpurCorrDataLoader, self).__init__()
|
||||
|
||||
self.x= dataloader.train_data
|
||||
self.y= dataloader.train_labels
|
||||
self.d= dataloader.train_domain
|
||||
self.indices= dataloader.train_indices
|
||||
self.spur_corr= dataloader.train_spur
|
||||
self.x= dataloader.data
|
||||
self.y= dataloader.labels
|
||||
self.d= dataloader.domains
|
||||
self.indices= dataloader.indices
|
||||
self.objects= dataloader.objects
|
||||
self.spur_corr= dataloader.spur
|
||||
|
||||
|
||||
def __len__(self):
|
||||
|
@ -51,9 +52,10 @@ class SpurCorrDataLoader(data_utils.Dataset):
|
|||
batch_y = self.y[index]
|
||||
batch_d = self.d[index]
|
||||
batch_idx = self.indices[index]
|
||||
batch_obj= self.objects[index]
|
||||
batch_spur= self.spur_corr[index]
|
||||
|
||||
return batch_x, batch_y, batch_d, batch_idx, batch_spur
|
||||
return batch_x, batch_y, batch_d, batch_idx, batch_obj, batch_spur
|
||||
|
||||
class AttributeAttack(BaseEval):
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ class BaseEval():
|
|||
else:
|
||||
fc_layer= self.args.fc_layer
|
||||
phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer,
|
||||
self.args.img_c, self.args.pre_trained, self.args.os_env)
|
||||
self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
|
||||
|
||||
if 'densenet' in self.args.model_name:
|
||||
from models.densenet import get_densenet
|
||||
|
|
|
@ -112,6 +112,7 @@ class PrivacyEntropy(BaseEval):
|
|||
print(case, attack_data['logits'].shape, attack_data['labels'].shape, attack_data['members'].shape)
|
||||
|
||||
return attack_data
|
||||
|
||||
|
||||
def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='train'):
|
||||
|
||||
|
@ -141,6 +142,8 @@ class PrivacyEntropy(BaseEval):
|
|||
# print('F_y, F_i', F_y.shape, F_i.shape)
|
||||
# print('Neg term: ', (F_i*torch.log(1.0-F_i)).shape, F_i[0])
|
||||
metric= -1*(1.0 - F_y)*torch.log(F_y) -1*torch.sum( F_i*torch.log(1.0-F_i), dim=1 )
|
||||
# metric= -1*(1.0 - F_y)*torch.log(F_y)
|
||||
# metric= -1*(1.0)*torch.log(F_y)
|
||||
|
||||
# threshold_data[y_c]= torch.max(metric)
|
||||
threshold_data[y_c]= torch.mean(metric)
|
||||
|
@ -166,6 +169,8 @@ class PrivacyEntropy(BaseEval):
|
|||
F_y= torch.sum( logits*labels, dim=1)
|
||||
F_i= logits*(1.0-labels)
|
||||
metric= -1*(1.0 - F_y)*torch.log(F_y) -1*torch.sum( F_i*torch.log(1.0-F_i), dim=1 )
|
||||
# metric= -1*(1.0 - F_y)*torch.log(F_y)
|
||||
# metric= -1*(1.0)*torch.log(F_y)
|
||||
|
||||
mem_predict= 1.0*(metric < (threshold_data[y_c]/scale))
|
||||
acc+= torch.sum( mem_predict == members ).item()
|
||||
|
|
|
@ -48,12 +48,13 @@ class PrivacyLossAttack(BaseEval):
|
|||
train_data={}
|
||||
train_data['loss']=[]
|
||||
train_data['labels']=[]
|
||||
train_data['obj']= []
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset['data_loader']):
|
||||
#Random Shuffling along the batch axis
|
||||
rand_indices= torch.randperm(x_e.size()[0])
|
||||
x_e= x_e[rand_indices]
|
||||
y_e= y_e[rand_indices]
|
||||
|
||||
obj_e= obj_e[rand_indices]
|
||||
with torch.no_grad():
|
||||
x_e= x_e.to(self.cuda)
|
||||
y_e= y_e.to(self.cuda)
|
||||
|
@ -62,19 +63,24 @@ class PrivacyLossAttack(BaseEval):
|
|||
loss= cross_entropy(out, torch.argmax(y_e, dim=1).long()).to(self.cuda)
|
||||
train_data['loss'].append(loss)
|
||||
train_data['labels'].append(y_e)
|
||||
train_data['obj'].append(obj_e)
|
||||
|
||||
train_data['loss']= torch.cat(train_data['loss'], dim=0)
|
||||
train_data['labels']= torch.cat(train_data['labels'], dim=0)
|
||||
train_data['obj']= torch.cat(train_data['obj'], dim=0)
|
||||
|
||||
#Test Environment Data
|
||||
test_data={}
|
||||
test_data['loss']=[]
|
||||
test_data['labels']=[]
|
||||
test_data['obj']=[]
|
||||
test_data['free']=[]
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.test_dataset['data_loader']):
|
||||
#Random Shuffling along the batch axis
|
||||
rand_indices= torch.randperm(x_e.size()[0])
|
||||
x_e= x_e[rand_indices]
|
||||
y_e= y_e[rand_indices]
|
||||
obj_e= obj_e[rand_indices]
|
||||
|
||||
with torch.no_grad():
|
||||
x_e= x_e.to(self.cuda)
|
||||
|
@ -85,13 +91,20 @@ class PrivacyLossAttack(BaseEval):
|
|||
|
||||
test_data['loss'].append(loss)
|
||||
test_data['labels'].append(y_e)
|
||||
test_data['obj'].append(obj_e)
|
||||
test_data['free']= test_data['free'] + [1]*out.shape[0]
|
||||
|
||||
test_data['loss']= torch.cat(test_data['loss'], dim=0)
|
||||
test_data['labels']= torch.cat(test_data['labels'], dim=0)
|
||||
test_data['obj']= torch.cat(test_data['obj'], dim=0)
|
||||
test_data['free']= torch.tensor(test_data['free'])
|
||||
|
||||
print('Train Logits: ', train_data['loss'].shape, 'Train Labels: ', train_data['labels'].shape, ' Train Objs: ', train_data['obj'].shape )
|
||||
print(torch.unique(train_data['obj']))
|
||||
print('Test Logits: ', test_data['loss'].shape, 'Test Labels: ', test_data['labels'].shape, ' Test Objs: ', test_data['obj'].shape )
|
||||
print(torch.unique(test_data['obj']))
|
||||
print('Test Free: ', test_data['free'].shape, test_data['free'])
|
||||
|
||||
print('Train Logits: ', train_data['loss'].shape, 'Train Labels: ', train_data['labels'].shape )
|
||||
print('Test Logits: ', test_data['loss'].shape, 'Test Labels: ', test_data['labels'].shape )
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
def create_attack_data(self, train_data, test_data, sample_size, case='train'):
|
||||
|
@ -102,6 +115,7 @@ class PrivacyLossAttack(BaseEval):
|
|||
|
||||
test_loss= test_data['loss'][:sample_size]
|
||||
test_labels= test_data['labels'][:sample_size]
|
||||
|
||||
elif case == 'test':
|
||||
train_loss= train_data['loss'][-1-sample_size:-1]
|
||||
train_labels= train_data['labels'][-1-sample_size:-1]
|
||||
|
@ -116,6 +130,67 @@ class PrivacyLossAttack(BaseEval):
|
|||
print(case, attack_data['loss'].shape, attack_data['labels'].shape, attack_data['members'].shape)
|
||||
|
||||
return attack_data
|
||||
|
||||
|
||||
# def create_attack_data(self, train_data, test_data, sample_size, case='train'):
|
||||
|
||||
# if case == 'train':
|
||||
# train_loss= train_data['loss'][:sample_size]
|
||||
# train_labels= train_data['labels'][:sample_size]
|
||||
# train_obj= train_data['obj'][:sample_size]
|
||||
|
||||
# test_loss= []
|
||||
# test_labels= []
|
||||
# for idx in range(sample_size):
|
||||
# obj= train_obj[idx]
|
||||
# indice= (test_data['obj'] == obj).nonzero()
|
||||
|
||||
# for idx_obj in range(indice.shape[0]):
|
||||
# curr_indice= indice[idx_obj, 0].item()
|
||||
# if test_data['free'][curr_indice] == 1:
|
||||
# test_loss.append(test_data['loss'][curr_indice].view(1))
|
||||
# #TODO: Change 10 to num_classes
|
||||
# test_labels.append(test_data['labels'][curr_indice].view(1, 10))
|
||||
# test_data['free'][curr_indice]= 0
|
||||
# break
|
||||
|
||||
# test_loss= torch.cat(test_loss, dim=0)
|
||||
# test_labels= torch.cat(test_labels, dim=0)
|
||||
|
||||
# elif case == 'test':
|
||||
# train_loss= train_data['loss'][-1-sample_size:-1]
|
||||
# train_labels= train_data['labels'][-1-sample_size:-1]
|
||||
# train_obj= train_data['obj'][-1-sample_size:-1]
|
||||
|
||||
# test_loss= []
|
||||
# test_labels= []
|
||||
# for idx in range(sample_size):
|
||||
# obj= train_obj[idx]
|
||||
# indice= (test_data['obj'] == obj).nonzero()
|
||||
|
||||
# for idx_obj in range(indice.shape[0]):
|
||||
# curr_indice= indice[idx_obj, 0].item()
|
||||
# if test_data['free'][curr_indice] == 1:
|
||||
# test_loss.append(test_data['loss'][curr_indice].view(1))
|
||||
# #TODO: Change 10 to num_classes
|
||||
# test_labels.append(test_data['labels'][curr_indice].view(1, 10))
|
||||
# test_data['free'][curr_indice]= 0
|
||||
# break
|
||||
|
||||
# test_loss= torch.cat(test_loss, dim=0)
|
||||
# test_labels= torch.cat(test_labels, dim=0)
|
||||
|
||||
# print('Attack Dataset Members: ', train_loss.shape, train_labels.shape)
|
||||
# print('Attack Dataset Non Members: ', test_loss.shape, test_labels.shape)
|
||||
# attack_data={}
|
||||
# attack_data['loss']= torch.cat( (train_loss, test_loss), dim=0 )
|
||||
# attack_data['labels']= torch.cat( (train_labels, test_labels), dim=0 )
|
||||
# # attack_data['members']= torch.cat( (torch.ones((sample_size,1)), torch.zeros((sample_size,1))), dim=0).to(self.cuda)
|
||||
# attack_data['members']= torch.cat( (torch.ones((train_loss.shape[0], 1)), torch.zeros((test_loss.shape[0],1))), dim=0).to(self.cuda)
|
||||
|
||||
# print(case, attack_data['loss'].shape, attack_data['labels'].shape, attack_data['members'].shape)
|
||||
|
||||
# return attack_data
|
||||
|
||||
def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='train'):
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import argparse
|
||||
import copy
|
||||
import random
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data as data_utils
|
||||
|
||||
from .base_eval import BaseEval
|
||||
from utils.match_function import get_matched_pairs, perfect_match_score
|
||||
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
|
||||
|
||||
|
||||
def sim_matrix(a, b, eps=1e-8):
|
||||
"""
|
||||
added eps for numerical stability
|
||||
"""
|
||||
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
||||
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
||||
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
||||
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
||||
return sim_mt
|
||||
|
||||
class SlabFeatEval(BaseEval):
|
||||
|
||||
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
|
||||
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda)
|
||||
|
||||
def get_metric_eval(self):
|
||||
|
||||
if self.args.match_func_data_case=='train':
|
||||
dataset= self.train_dataset['data_loader']
|
||||
total_domains= self.train_dataset['total_domains']
|
||||
domain_list= self.train_dataset['domain_list']
|
||||
base_domain_size= self.train_dataset['base_domain_size']
|
||||
domain_size_list= self.train_dataset['domain_size_list']
|
||||
|
||||
elif self.args.match_func_data_case== 'val':
|
||||
dataset= self.val_dataset['data_loader']
|
||||
total_domains= self.val_dataset['total_domains']
|
||||
domain_list= self.val_dataset['domain_list']
|
||||
base_domain_size= self.val_dataset['base_domain_size']
|
||||
domain_size_list= self.val_dataset['domain_size_list']
|
||||
|
||||
elif self.args.match_func_data_case== 'test':
|
||||
dataset= self.test_dataset['data_loader']
|
||||
total_domains= self.test_dataset['total_domains']
|
||||
domain_list= self.test_dataset['domain_list']
|
||||
base_domain_size= self.test_dataset['base_domain_size']
|
||||
domain_size_list= self.test_dataset['domain_size_list']
|
||||
|
||||
pos_metric= 'cos'
|
||||
with torch.no_grad():
|
||||
|
||||
penalty_ws=0
|
||||
batch_size_counter=0
|
||||
#Batch iteration over single epoch
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(dataset):
|
||||
|
||||
x_e= x_e.to(self.cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
|
||||
#Forward Pass
|
||||
out= self.phi(x_e)
|
||||
|
||||
match_objs= np.unique(idx_e)
|
||||
feat= self.phi.feat_net(x_e)
|
||||
for obj in match_objs:
|
||||
indices= idx_e == obj
|
||||
feat_obj= feat[indices]
|
||||
d_obj= d_e[indices]
|
||||
|
||||
match_domains= torch.unique(d_obj)
|
||||
|
||||
if len(match_domains) != len(torch.unique(d_e)):
|
||||
# print('Error: Positivty Violation, objects not present in all the domains')
|
||||
continue
|
||||
|
||||
for d_i in range(len(match_domains)):
|
||||
for d_j in range(len(match_domains)):
|
||||
if d_j <= d_i:
|
||||
continue
|
||||
x1= feat_obj[ d_obj == d_i ]
|
||||
x2= feat_obj[ d_obj == d_j ]
|
||||
|
||||
if pos_metric == 'l2':
|
||||
x1= x1.view(x1.shape[0], 1, x1.shape[1])
|
||||
penalty_ws+= float( torch.sum( torch.sum( torch.sum( (x1 -x2)**2, dim=2), dim=1 )) )
|
||||
elif pos_metric == 'l1':
|
||||
x1= x1.view(x1.shape[0], 1, x1.shape[1])
|
||||
penalty_ws+= float( torch.sum( torch.sum( torch.sum( torch.abs(x1 -x2), dim=2), dim=1 )) )
|
||||
elif pos_metric == 'cos':
|
||||
penalty_ws+= float( torch.sum( torch.sum( sim_matrix(x1, x2), dim=1)) )
|
||||
|
||||
batch_size_counter+= x1.shape[0]*x2.shape[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.metric_score['Perfect Match Distance']= penalty_ws/batch_size_counter
|
||||
print('Perfect Match Distance: ', self.metric_score['Perfect Match Distance'])
|
||||
return
|
|
@ -0,0 +1,255 @@
|
|||
#Common imports
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import argparse
|
||||
import copy
|
||||
import random
|
||||
import json
|
||||
import pickle
|
||||
|
||||
#Pytorch
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data as data_utils
|
||||
|
||||
#Sklearn
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
#robustdg
|
||||
from utils.helper import *
|
||||
from utils.match_function import *
|
||||
|
||||
#slab
|
||||
from utils.slab_data import *
|
||||
import utils.scripts.utils as slab_utils
|
||||
import utils.scripts.lms_utils as slab_lms_utils
|
||||
|
||||
def get_logits(model, loader, device, label=1):
|
||||
X, Y = slab_utils.extract_tensors_from_loader(loader)
|
||||
L = slab_utils.get_logits_given_tensor(X, model, device=device).detach()
|
||||
L = L[Y==label].cpu().numpy()
|
||||
S = L[:, 1] - L[:, 0] # compute score / difference to get scalar
|
||||
return S
|
||||
|
||||
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_name', type=str, default='slab',
|
||||
help='Datasets: rot_mnist; fashion_mnist; pacs')
|
||||
parser.add_argument('--method_name', type=str, default='erm_match',
|
||||
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
|
||||
parser.add_argument('--model_name', type=str, default='slab',
|
||||
help='Architecture of the model to be trained')
|
||||
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
|
||||
help='List of train domains')
|
||||
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
|
||||
help='List of test domains')
|
||||
parser.add_argument('--out_classes', type=int, default=2,
|
||||
help='Total number of classes in the dataset')
|
||||
parser.add_argument('--img_c', type=int, default= 1,
|
||||
help='Number of channels of the image in dataset')
|
||||
parser.add_argument('--img_h', type=int, default= 224,
|
||||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
|
||||
parser.add_argument('--pos_metric', type=str, default='l2',
|
||||
help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
|
||||
parser.add_argument('--rep_dim', type=int, default=250,
|
||||
help='Representation dimension for contrsative learning')
|
||||
parser.add_argument('--pre_trained',type=int, default=0,
|
||||
help='0: No Pretrained Architecture; 1: Pretrained Architecture')
|
||||
parser.add_argument('--perfect_match', type=int, default=1,
|
||||
help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
|
||||
parser.add_argument('--opt', type=str, default='sgd',
|
||||
help='Optimizer Choice: sgd; adam')
|
||||
parser.add_argument('--lr', type=float, default=0.01,
|
||||
help='Learning rate for training the model')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='Batch size foe training the model')
|
||||
parser.add_argument('--epochs', type=int, default=15,
|
||||
help='Total number of epochs for training the model')
|
||||
parser.add_argument('--penalty_w', type=float, default=0.0,
|
||||
help='Penalty weight for IRM invariant classifier loss')
|
||||
parser.add_argument('--penalty_s', type=int, default=-1,
|
||||
help='Epoch threshold over which Matching Loss to be optimised')
|
||||
parser.add_argument('--penalty_irm', type=float, default=0.0,
|
||||
help='Penalty weight for IRM invariant classifier loss')
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.1,
|
||||
help='Penalty weight for Matching Loss')
|
||||
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
|
||||
help='Penalty weight for Contrastive Loss')
|
||||
parser.add_argument('--tau', type=float, default=0.05,
|
||||
help='Temperature hyper param for NTXent contrastive loss ')
|
||||
parser.add_argument('--match_flag', type=int, default=0,
|
||||
help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
|
||||
parser.add_argument('--match_case', type=float, default=1.0,
|
||||
help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
|
||||
parser.add_argument('--match_interrupt', type=int, default=5,
|
||||
help='Number of epochs before inferring the match strategy')
|
||||
parser.add_argument('--n_runs', type=int, default=3,
|
||||
help='Number of iterations to repeat the training process')
|
||||
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1,
|
||||
help='Number of iterations to repeat training process for matchdg_erm')
|
||||
parser.add_argument('--ctr_model_name', type=str, default='resnet18',
|
||||
help='(For matchdg_ctr phase) Architecture of the model to be trained')
|
||||
parser.add_argument('--ctr_match_layer', type=str, default='logit_match',
|
||||
help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
|
||||
parser.add_argument('--ctr_match_flag', type=int, default=1,
|
||||
help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
|
||||
parser.add_argument('--ctr_match_case', type=float, default=0.01,
|
||||
help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
|
||||
parser.add_argument('--ctr_match_interrupt', type=int, default=5,
|
||||
help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
|
||||
parser.add_argument('--mnist_seed', type=int, default=0,
|
||||
help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
|
||||
parser.add_argument('--retain', type=float, default=0,
|
||||
help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
|
||||
parser.add_argument('--test_metric', type=str, default='acc',
|
||||
help='Evaluation Metrics: acc; match_score, t_sne, mia')
|
||||
parser.add_argument('--acc_data_case', type=str, default='test',
|
||||
help='Dataset Train/Val/Test for the accuracy evaluation metric')
|
||||
parser.add_argument('--top_k', type=int, default=10,
|
||||
help='Top K matches to consider for the match score evaluation metric')
|
||||
parser.add_argument('--match_func_aug_case', type=int, default=0,
|
||||
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
|
||||
parser.add_argument('--match_func_data_case', type=str, default='train',
|
||||
help='Dataset Train/Val/Test for the match score evaluation metric')
|
||||
parser.add_argument('--mia_batch_size', default=64, type=int,
|
||||
help='batch size')
|
||||
parser.add_argument('--mia_dnn_steps', default=5000, type=int,
|
||||
help='number of training steps')
|
||||
parser.add_argument('--mia_sample_size', default=1000, type=int,
|
||||
help='number of samples from train/test dataset logits')
|
||||
parser.add_argument('--mia_logit', default=1, type=int,
|
||||
help='0: Softmax applied to logits; 1: No Softmax applied to logits')
|
||||
parser.add_argument('--attribute_domain', default=1, type=int,
|
||||
help='0: spur correlations as attribute; 1: domain as attribute')
|
||||
parser.add_argument('--adv_eps', default=0.3, type=float,
|
||||
help='Epsilon ball dimension for PGD attacks')
|
||||
parser.add_argument('--logit_plot_path', default='', type=str,
|
||||
help='File name to save logit/loss plots')
|
||||
parser.add_argument('--ctr_abl', type=int, default=0,
|
||||
help='0: Randomization til class level ; 1: Randomization completely')
|
||||
parser.add_argument('--match_abl', type=int, default=0,
|
||||
help='0: Randomization til class level ; 1: Randomization completely')
|
||||
parser.add_argument('--cuda_device', type=int, default=0,
|
||||
help='Select the cuda device by id among the avaliable devices' )
|
||||
parser.add_argument('--os_env', type=int, default=0,
|
||||
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#GPU
|
||||
cuda= torch.device("cuda:" + str(args.cuda_device))
|
||||
if cuda:
|
||||
kwargs = {'num_workers': 1, 'pin_memory': False}
|
||||
else:
|
||||
kwargs= {}
|
||||
|
||||
args.kwargs= kwargs
|
||||
|
||||
#List of Train; Test domains
|
||||
train_domains= args.train_domains
|
||||
test_domains= args.test_domains
|
||||
|
||||
#Initialize
|
||||
final_acc= []
|
||||
final_auc= []
|
||||
final_s_auc= []
|
||||
final_sc_auc= []
|
||||
base_res_dir=(
|
||||
"results/" + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
)
|
||||
if not os.path.exists(base_res_dir):
|
||||
os.makedirs(base_res_dir)
|
||||
|
||||
#Checks
|
||||
if args.method_name == 'matchdg_ctr' and args.test_metric == 'acc':
|
||||
raise ValueError('Match DG during the contrastive learning phase cannot be evaluted for test accuracy metric')
|
||||
sys.exit()
|
||||
|
||||
if args.perfect_match == 0 and args.test_metric == 'match_score' and args.match_func_aug_case==0:
|
||||
raise ValueError('Cannot evalute match function metrics when perfect match is not known')
|
||||
sys.exit()
|
||||
|
||||
#Execute the method for multiple runs ( total args.n_runs )
|
||||
for run in range(args.n_runs):
|
||||
|
||||
#Seed for reproducability
|
||||
np.random.seed(run*10)
|
||||
torch.manual_seed(run*10)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(run*10)
|
||||
|
||||
#DataLoader
|
||||
train_dataset= torch.empty(0)
|
||||
val_dataset= torch.empty(0)
|
||||
test_dataset= torch.empty(0)
|
||||
# if args.test_metric == 'match_score':
|
||||
# if args.match_func_data_case== 'train':
|
||||
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
|
||||
# elif args.match_func_data_case== 'val':
|
||||
# val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
|
||||
# elif args.match_func_data_case== 'test':
|
||||
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
|
||||
# elif args.test_metric == 'acc':
|
||||
# if args.acc_data_case== 'train':
|
||||
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
|
||||
# elif args.acc_data_case== 'test':
|
||||
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
|
||||
# elif args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
|
||||
# train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
|
||||
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
|
||||
# elif args.test_metric == 'attribute_attack':
|
||||
# train_dataset= get_dataloader( args, run, train_domains + test_domains, 'train', 1, kwargs )
|
||||
# test_dataset= get_dataloader( args, run, train_domains + test_domains, 'test', 1, kwargs )
|
||||
# else:
|
||||
# test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
|
||||
|
||||
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
|
||||
|
||||
#Import the testing module
|
||||
from evaluation.base_eval import BaseEval
|
||||
test_method= BaseEval(
|
||||
args, train_dataset, val_dataset,
|
||||
test_dataset, base_res_dir,
|
||||
run, cuda
|
||||
)
|
||||
|
||||
test_method.get_model()
|
||||
model= test_method.phi
|
||||
|
||||
for train_domain in train_domains:
|
||||
spur_prob= float(train_domain)
|
||||
data, temp1, _, _= get_data(args.slab_num_samples, spur_prob, args.slab_total_slabs)
|
||||
|
||||
# compute logit scores
|
||||
std_log = get_logits(model, data['te_dl'], cuda)
|
||||
break
|
||||
|
||||
# plot logit distributions
|
||||
kw = dict(kde=False, bins=20, norm_hist=True,
|
||||
hist_kws={"histtype": "step", "linewidth": 2,
|
||||
"alpha": 0.8, "ls": '-'})
|
||||
|
||||
fig, ax = plt.subplots(1,1,figsize=(6,4))
|
||||
ax = sns.distplot(std_log, label='Standard Logits', **kw)
|
||||
|
||||
slab_utils.update_ax(ax, 'Logit Distributions of Positive Data', 'Logits', 'Density',
|
||||
ticks_fs=13, label_fs=13, title_fs=16, legend_fs=14, legend_loc='upper left')
|
||||
plt.savefig('results/slab_train_logit_plot/' + str(args.method_name)+ '_' + str(args.penalty_ws) + '_' + str(run) + '.jpg')
|
|
@ -0,0 +1,36 @@
|
|||
#Entorpy Attack
|
||||
|
||||
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 10 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric privacy_entropy --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
#MIA Attack
|
||||
|
||||
python test.py --test_metric mia --mia_sample_size 1000 --mia_logit 1 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 10 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric mia --mia_sample_size 1000 --batch_size 64 --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
|
||||
#Mean Rank
|
||||
|
||||
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name erm_match --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name erm_match --match_case 0.01 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name csd --match_case 0.01 --penalty_ws 0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
python test.py --test_metric match_score --batch_size 64 --dataset chestxray --match_func_data_case test --match_func_aug_case 1 --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
|
|
@ -0,0 +1,104 @@
|
|||
Table 1:
|
||||
|
||||
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos
|
||||
|
||||
python train.py --dataset rot_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18
|
||||
|
||||
|
||||
Table 2: (Match func evaluation on train domains)
|
||||
|
||||
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --test_metric match_score
|
||||
|
||||
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score
|
||||
|
||||
Table 6: (LeNet)
|
||||
|
||||
#MatchDG
|
||||
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 1.0 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --epochs 100 --lr 0.01 --model_name lenet --img_h 32 --img_w 32 --train_domains 15 30 45 60 75 --test_domains 0
|
||||
|
||||
#Perf
|
||||
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --epochs 100 --model_name lenet --img_h 32 --img_w 32 --train_domains 0 15 30 45 60 --test_domains 75
|
||||
|
||||
Table 7: (DomainBed)
|
||||
|
||||
#MatchDG
|
||||
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 1.0 --match_case -1 --batch_size 64 --epochs 25 --model_name domain_bed_mnist --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --img_h 28 --img_w 28 --train_domains 15 30 45 60 75 --test_domains 0
|
||||
|
||||
#Perf
|
||||
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 1.0 --batch_size 64 --epochs 25 --model_name domain_bed_mnist --img_h 28 --img_w 28 --train_domains 15 30 45 60 75 --test_domains 0
|
||||
|
||||
Table 8: (Phase 2 Rankings)
|
||||
|
||||
#RandMatch
|
||||
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --test_metric match_score
|
||||
|
||||
#MatchDG
|
||||
python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --n_runs 3 --test_metric match_score
|
||||
|
||||
#PerfMatch
|
||||
python test.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1 --test_metric match_score
|
||||
|
||||
|
||||
Chexpert: (Acc Evaluation)
|
||||
|
||||
#Matching Methods
|
||||
python test.py --dataset chestxray --method_name erm_match --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains nih_trans
|
||||
|
||||
#CSD
|
||||
python test.py --dataset chestxray --method_name csd --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
#IRM
|
||||
python test.py --dataset chestxray --method_name irm --match_case 0.01 --penalty_s 5 --penalty_ws 10.0 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
|
||||
python test.py --dataset chestxray --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 0.0 --match_case 0.01 --train_domains nih_trans chex_trans --test_domains chex_tran
|
||||
|
||||
#MatchDG
|
||||
python test.py --dataset chestxray --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 50.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle
|
||||
|
||||
#Hybrid
|
||||
python test.py --dataset chestxray --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --test_metric acc --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --penalty_ws 1.0 --match_case -1 --train_domains nih_trans chex_trans --test_domains kaggle_trans
|
||||
|
||||
PACS: (Acc Evaluation)
|
||||
|
||||
#Matching Methods
|
||||
python test.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --train_domains art_painting cartoon sketch --test_domains photo
|
||||
|
||||
#MatchDG
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 1.0 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws
|
||||
1.0 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
|
||||
|
||||
python test.py --dataset pacs --method_name matchdg_erm --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws
|
||||
0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
|
||||
|
||||
#Hybrid
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.1 --match_case -1 --train_domains art_painting cartoon sketch --test_domains photo
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.01 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo cartoon sketch --test_domains art_painting
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.01
|
||||
--match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo art_painting sketch --test_domains cartoon
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --penalty_ws 0.5 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
|
||||
|
||||
python test.py --dataset pacs --method_name hybrid --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --test_metric acc --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --penalty_ws 0.01 --match_case -1 --train_domains photo art_painting cartoon --test_domains sketch
|
|
@ -0,0 +1,3 @@
|
|||
python3 train.py --dataset adult --model fc --out_classes 2 --train_domains male female --test_domains male female --penalty_irm 100.0 --method_name irm --penalty_s -1
|
||||
|
||||
python test.py --test_metric attribute_attack --mia_logit 1 --batch_size 64 --dataset adult --model fc --out_classes 2 --train_domains male female --test_domains male female --penalty_ws 10.0
|
|
@ -0,0 +1,44 @@
|
|||
M2ExZ
|
||||
|
||||
|
||||
In the domain generalization literature, a common objective is to learn representations independent of the domain after conditioning on the class label. We show that this objective is not sufficient: there exist counter-examples where a model fails to generalize to unseen domains even after satisfying class-conditional domain invariance. We formalize this observation through a structural causal model and show the importance of modeling within-class variations for generalization. Specifically, classes contain objects that characterize specific causal features, and domains can be interpreted as interventions on these objects that change non-causal features. We highlight an alternative condition: inputs across domains should have the same representation if they are derived from the same object. Based on this objective, we propose matching-based algorithms when base objects are observed (e.g., through data augmentation) and approximate the objective when objects are not observed (MatchDG). Our simple matching-based algorithms are competitive to prior work on out-of-domain accuracy for rotated MNIST, Fashion-MNIST, PACS, and Chest-Xray datasets. Our method MatchDG also recovers ground-truth object matches: on MNIST and Fashion-MNIST, top-10 matches from MatchDG have over 50% overlap with ground-truth matches.
|
||||
|
||||
Table 11:
|
||||
|
||||
88.2 + 98.6 + 97.7 + 97.5 + 97.0 + 85.6
|
||||
91.0 + 99.7 + 99.6 + 99.4 + 99.7 + 93.1
|
||||
93.0 + 99.5 + 99.9 + 99.4 + 99.7 + 93.3
|
||||
96.5 + 99.1 + 99.2 + 98.6 + 98.6 + 94.9
|
||||
|
||||
Table 12:
|
||||
|
||||
95.4 + 98.2 + 97.9 + 98.5 + 98.1 + 94.3
|
||||
95.9 + 98.4 + 98.6 + 98.9 + 98.7 + 95.1
|
||||
|
||||
Table 18:
|
||||
|
||||
95.38 + 77.68 + 78.98 + 74.75
|
||||
|
||||
95.37 + 78.16 + 78.83 + 75.13
|
||||
95.93 + 79.77 + 80.03 + 77.11
|
||||
96.15 + 81.71 + 80.75 + 78.79
|
||||
|
||||
95.57 + 79.09 + 79.37 + 77.60
|
||||
96.53 + 81.32 + 80.70 + 79.72
|
||||
96.67 + 82.80 + 81.61 + 81.05
|
||||
|
||||
97.89 + 82.16 + 81.68 + 80.45
|
||||
97.94 + 85.61 + 82.12 + 78.76
|
||||
98.36 + 86.74 + 82.32 + 82.66
|
||||
|
||||
Table 19:
|
||||
|
||||
85.29 + 64.23 + 66.61 + 59.25
|
||||
|
||||
85.42 + 65.54 + 68.41 + 59.46
|
||||
85.41 + 66.21 + 68.47 + 59.56
|
||||
85.67 + 66.89 + 68.89 + 60.39
|
||||
|
||||
86.04 + 67.35 + 69.71 + 64.66
|
||||
86.52 + 67.99 + 69.92 + 65.64
|
||||
87.03 + 67.97 + 71.06 + 67.19
|
|
@ -5,6 +5,10 @@ import torchvision
|
|||
from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck
|
||||
import os
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
def __init__(self, num_channels, num_groups=32, **kwargs):
|
||||
super().__init__(num_groups, num_channels, **kwargs)
|
||||
|
||||
# bypass layer
|
||||
class Identity(nn.Module):
|
||||
def __init__(self,n_inputs):
|
||||
|
@ -15,14 +19,20 @@ class Identity(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, os_env):
|
||||
def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, dp_noise, os_env):
|
||||
if model_name == 'resnet18':
|
||||
if os_env:
|
||||
model= torchvision.models.resnet18()
|
||||
if pre_trained:
|
||||
model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet18-5c106cde.pth' ))
|
||||
else:
|
||||
model= torchvision.models.resnet18(pre_trained)
|
||||
if dp_noise:
|
||||
#TODO
|
||||
#This cannot work with pre trained model, as the batch norm weights have been replaced by group norm
|
||||
model= torchvision.models.resnet18(pre_trained, norm_layer=GroupNorm)
|
||||
# model= torchvision.models.resnet18(pre_trained)
|
||||
else:
|
||||
model= torchvision.models.resnet18(pre_trained)
|
||||
|
||||
n_inputs = model.fc.in_features
|
||||
n_outputs= classes
|
||||
|
@ -33,7 +43,10 @@ def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, os_env):
|
|||
if pre_trained:
|
||||
model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet50-19c8e357.pth' ))
|
||||
else:
|
||||
model= torchvision.models.resnet50(pre_trained)
|
||||
if dp_noise:
|
||||
model= torchvision.models.resnet50(pre_trained, norm_layer=GroupNorm)
|
||||
else:
|
||||
model= torchvision.models.resnet50(pre_trained)
|
||||
|
||||
n_inputs = model.fc.in_features
|
||||
n_outputs= classes
|
||||
|
|
|
@ -88,7 +88,7 @@ for method in methods:
|
|||
# To train and test on the same domains (In Distribution Generalization)
|
||||
# train_domains= curr_test_domain
|
||||
|
||||
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
|
||||
print('Metric', metric, ' Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
|
||||
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + curr_test_domain
|
||||
|
||||
# 5 seeds to formally define the trends for the privacy part
|
|
@ -65,7 +65,8 @@ methods=['erm', 'rand', 'matchdg', 'csd', 'irm', 'perf']
|
|||
|
||||
# metrics= ['acc:train', 'acc:test', 'mia', 'privacy_entropy', 'privacy_loss_attack', 'match_score:train', 'match_score:test', 'feat_eval:train', 'feat_eval:test']
|
||||
|
||||
metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
|
||||
# metrics= ['acc:train', 'acc:test', 'privacy_entropy', 'privacy_loss_attack', 'match_score:test']
|
||||
metrics= ['privacy_loss_attack', 'privacy_entropy']
|
||||
|
||||
acc_train=[]
|
||||
acc_train_err=[]
|
||||
|
@ -133,7 +134,7 @@ for metric in metrics:
|
|||
feat_eval_test.append(mean)
|
||||
feat_eval_test_err.append(sd)
|
||||
|
||||
for idx in range(4):
|
||||
for idx in range(1,2):
|
||||
|
||||
matplotlib.rcParams.update({'errorbar.capsize': 2})
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
||||
|
@ -183,4 +184,4 @@ for idx in range(4):
|
|||
os.makedirs(save_dir)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)
|
||||
plt.savefig(save_dir + 'privacy_' + str(dataset)+'_' + str(idx) + '.pdf', dpi=600)
|
|
@ -92,7 +92,7 @@ if dataset == 'rot_mnist_spur':
|
|||
for method in methods:
|
||||
case= res_dir + str(method)
|
||||
|
||||
print('Method: ', method, ' Train Domains: ', train_case)
|
||||
print('Metric', metric, 'Method: ', method, ' Train Domains: ', train_case)
|
||||
|
||||
if method == 'erm':
|
||||
script= base_script + ' --method_name erm_match --penalty_ws 0.0 --match_case 0.0 --epochs 25 ' + ' > ' + case + '.txt'
|
16
test.py
16
test.py
|
@ -135,6 +135,8 @@ parser.add_argument('--cuda_device', type=int, default=0,
|
|||
help='Select the cuda device by id among the avaliable devices' )
|
||||
parser.add_argument('--os_env', type=int, default=0,
|
||||
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
|
||||
parser.add_argument('--dp_noise', type=int, default=0,
|
||||
help='0: No DP noise; 1: Add DP noise')
|
||||
|
||||
#MMD, DANN
|
||||
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
|
||||
|
@ -174,9 +176,17 @@ test_domains= args.test_domains
|
|||
|
||||
#Initialize
|
||||
final_metric_score=[]
|
||||
base_res_dir=(
|
||||
"results/" + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
|
||||
res_dir= 'results/'
|
||||
if args.dp_noise:
|
||||
base_res_dir=(
|
||||
res_dir + args.dataset_name + '/' + 'dp_' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
)
|
||||
else:
|
||||
base_res_dir=(
|
||||
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
)
|
||||
|
||||
#TODO: Handle slab noise case in helper functions
|
||||
|
|
9
train.py
9
train.py
|
@ -162,7 +162,14 @@ if args.os_env:
|
|||
res_dir= os.getenv('PT_OUTPUT_DIR') + '/'
|
||||
else:
|
||||
res_dir= 'results/'
|
||||
base_res_dir=(
|
||||
|
||||
if args.dp_noise:
|
||||
base_res_dir=(
|
||||
res_dir + args.dataset_name + '/' + 'dp_' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
)
|
||||
else:
|
||||
base_res_dir=(
|
||||
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
|
||||
+ '/' + 'train_' + str(args.train_domains)
|
||||
)
|
||||
|
|
|
@ -247,7 +247,12 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
match_func=False
|
||||
# Can select a higher batch size for val and test domains
|
||||
## TODO: If condition for test batch size less than total size
|
||||
batch_size= 512
|
||||
|
||||
#Don't try higher batch size in the case of dp-noise trained models to avoid CUDA errors
|
||||
if args.dp_noise:
|
||||
batch_size= args.batch_size*5
|
||||
else:
|
||||
batch_size= 512
|
||||
|
||||
# Set match_func to True in case of test metric as match_score
|
||||
try:
|
||||
|
@ -285,6 +290,12 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
mnist_subset=9
|
||||
else:
|
||||
mnist_subset=run
|
||||
|
||||
#TODO: Only Temporary, in order to see if it changes results on MNIST
|
||||
# if eval_case:
|
||||
# if args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
|
||||
# mnist_subset=run
|
||||
|
||||
print('MNIST Subset: ', mnist_subset)
|
||||
data_obj= MnistRotated(args, domains, mnist_subset, '/mnist/', data_case=data_case, match_func=match_func)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче