robustdg/test.py

402 строки
19 KiB
Python

#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle
#Sklearn
import sklearn
from sklearn.manifold import TSNE
#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
#robustdg
from utils.helper import *
from utils.match_function import *
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='rot_mnist',
help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match',
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18',
help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10,
help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1,
help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2',
help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250,
help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0,
help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1,
help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd',
help='Optimizer Choice: sgd; adam')
parser.add_argument('--weight_decay', type=float, default=5e-4,
help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16,
help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=15,
help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1,
help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0,
help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1,
help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05,
help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0,
help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0,
help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5,
help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=3,
help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1,
help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18',
help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match',
help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1,
help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01,
help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5,
help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0,
help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0,
help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=0,
help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0,
help='Epsilon value for Differential Privacy')
parser.add_argument('--dp_attach_opt', type=int, default=1,
help='0: Infinite Epsilon; 1: Finite Epsilion')
#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)
#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)
#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18',
help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0,
help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')
#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1,
help='Multiple random matches')
# Evaluation specific
parser.add_argument('--test_metric', type=str, default='acc',
help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test',
help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10,
help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0,
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='train',
help='Dataset Train/Val/Test for the match score evaluation metric')
parser.add_argument('--mia_batch_size', default=64, type=int,
help='batch size')
parser.add_argument('--mia_dnn_steps', default=5000, type=int,
help='number of training steps')
parser.add_argument('--mia_sample_size', default=1000, type=int,
help='number of samples from train/test dataset logits')
parser.add_argument('--mia_logit', default=1, type=int,
help='0: Softmax applied to logits; 1: No Softmax applied to logits')
parser.add_argument('--attribute_domain', default=1, type=int,
help='0: spur correlations as attribute; 1: domain as attribute')
parser.add_argument('--adv_eps', default=0.3, type=float,
help='Epsilon ball dimension for PGD attacks')
parser.add_argument('--logit_plot_path', default='', type=str,
help='File name to save logit/loss plots')
args = parser.parse_args()
#GPU
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
kwargs = {'num_workers': 1, 'pin_memory': False}
else:
kwargs= {}
args.kwargs= kwargs
#List of Train; Test domains
train_domains= args.train_domains
test_domains= args.test_domains
#Initialize
final_metric_score=[]
res_dir= 'results/'
if args.dp_noise:
base_res_dir=(
res_dir + args.dataset_name + '/' + 'dp_' + str(args.dp_epsilon) + '_' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
else:
base_res_dir=(
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
print('Result Base Dir: ', base_res_dir)
#TODO: Handle slab noise case in helper functions
if args.dataset_name == 'slab':
base_res_dir= base_res_dir + '/slab_noise_' + str(args.slab_noise)
if not os.path.exists(base_res_dir):
os.makedirs(base_res_dir)
#Checks
if args.method_name == 'matchdg_ctr' and args.test_metric == 'acc':
raise ValueError('Match DG during the contrastive learning phase cannot be evaluted for test accuracy metric')
sys.exit()
if args.perfect_match == 0 and args.test_metric == 'match_score' and args.match_func_aug_case==0:
raise ValueError('Cannot evalute match function metrics when perfect match is not known')
sys.exit()
#Execute the method for multiple runs ( total args.n_runs )
for run in range(args.n_runs):
#Seed for repoduability
random.seed(run*10)
np.random.seed(run*10)
torch.manual_seed(run*10)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(run*10)
#DataLoader
train_dataset= torch.empty(0)
val_dataset= torch.empty(0)
test_dataset= torch.empty(0)
if args.test_metric in ['match_score', 'feat_eval', 'slab_feat_eval']:
if args.match_func_data_case== 'train':
train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
elif args.match_func_data_case== 'val':
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
elif args.match_func_data_case== 'test':
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
elif args.test_metric in ['acc', 'per_domain_acc']:
if args.acc_data_case== 'train':
train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
elif args.acc_data_case== 'val':
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
elif args.acc_data_case== 'test':
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
elif args.test_metric in ['mia', 'privacy_entropy', 'privacy_loss_attack']:
train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
elif args.test_metric == 'attribute_attack':
print( train_domains + test_domains)
train_dataset= get_dataloader( args, run, train_domains + test_domains, 'train', 1, kwargs )
test_dataset= get_dataloader( args, run, train_domains + test_domains, 'test', 1, kwargs )
else:
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
#Import the testing module
if args.test_metric == 'acc':
from evaluation.base_eval import BaseEval
test_method= BaseEval(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'per_domain_acc':
from evaluation.per_domain_acc import PerDomainAcc
test_method= PerDomainAcc(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'match_score':
from evaluation.match_eval import MatchEval
test_method= MatchEval(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'feat_eval':
from evaluation.feat_eval import FeatEval
test_method= FeatEval(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'slab_feat_eval':
from evaluation.slab_feat_eval import SlabFeatEval
test_method= SlabFeatEval(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 't_sne':
from evaluation.t_sne import TSNE
test_method= TSNE(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'mia':
from evaluation.privacy_attack import PrivacyAttack
test_method= PrivacyAttack(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'attribute_attack':
from evaluation.attribute_attack import AttributeAttack
test_method= AttributeAttack(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'privacy_loss_attack':
from evaluation.privacy_loss_attack import PrivacyLossAttack
test_method= PrivacyLossAttack(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'privacy_entropy':
from evaluation.privacy_entropy import PrivacyEntropy
test_method= PrivacyEntropy(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'logit_hist':
from evaluation.logit_hist import LogitHist
test_method= LogitHist(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'adv_attack':
from evaluation.adv_attack import AdvAttack
test_method= AdvAttack(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
#Testing Phase
with torch.no_grad():
if args.test_metric == 'mia':
for mia_run in range(2):
if args.method_name in ['matchdg_erm', 'hybrid']:
for run_matchdg_erm in range(args.n_runs_matchdg_erm):
test_method.get_model(run_matchdg_erm)
test_method.get_metric_eval()
final_metric_score.append( test_method.metric_score )
else:
test_method.get_model()
test_method.get_metric_eval()
final_metric_score.append( test_method.metric_score )
else:
if args.method_name in ['matchdg_erm', 'hybrid']:
for run_matchdg_erm in range(args.n_runs_matchdg_erm):
test_method.get_model(run_matchdg_erm)
test_method.get_metric_eval()
final_metric_score.append( test_method.metric_score )
else:
test_method.get_model()
test_method.get_metric_eval()
final_metric_score.append( test_method.metric_score )
if args.test_metric not in ['t_sne', 'logit_hist']:
print('\n')
print('Done for Model..')
keys=final_metric_score[0].keys()
for key in keys:
curr_metric_score=[]
for item in final_metric_score:
curr_metric_score.append( item[key] )
curr_metric_score= np.array(curr_metric_score)
print(key, ' : ', np.mean(curr_metric_score), np.std(curr_metric_score)/np.sqrt(curr_metric_score.shape[0]))
print('\n')