296 строки
14 KiB
Python
296 строки
14 KiB
Python
#Common imports
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import argparse
|
|
import copy
|
|
import random
|
|
import json
|
|
import sklearn
|
|
|
|
#Pytorch
|
|
import torch
|
|
from torch.autograd import grad
|
|
from torch import nn, optim
|
|
from torch.nn import functional as F
|
|
from torchvision import datasets, transforms
|
|
from torchvision.utils import save_image
|
|
from torch.autograd import Variable
|
|
import torch.utils.data as data_utils
|
|
|
|
#robustdg
|
|
from utils.helper import *
|
|
from utils.match_function import *
|
|
|
|
# Input Parsing
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset_name', type=str, default='rot_mnist',
|
|
help='Datasets: rot_mnist; fashion_mnist; pacs')
|
|
parser.add_argument('--method_name', type=str, default='erm_match',
|
|
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
|
|
parser.add_argument('--model_name', type=str, default='resnet18',
|
|
help='Architecture of the model to be trained')
|
|
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
|
|
help='List of train domains')
|
|
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
|
|
help='List of test domains')
|
|
parser.add_argument('--out_classes', type=int, default=10,
|
|
help='Total number of classes in the dataset')
|
|
parser.add_argument('--img_c', type=int, default= 1,
|
|
help='Number of channels of the image in dataset')
|
|
parser.add_argument('--img_h', type=int, default= 224,
|
|
help='Height of the image in dataset')
|
|
parser.add_argument('--img_w', type=int, default= 224,
|
|
help='Width of the image in dataset')
|
|
parser.add_argument('--fc_layer', type=int, default= 1,
|
|
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
|
parser.add_argument('--match_layer', type=str, default='logit_match',
|
|
help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
|
|
parser.add_argument('--pos_metric', type=str, default='l2',
|
|
help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
|
|
parser.add_argument('--rep_dim', type=int, default=250,
|
|
help='Representation dimension for contrsative learning')
|
|
parser.add_argument('--pre_trained',type=int, default=0,
|
|
help='0: No Pretrained Architecture; 1: Pretrained Architecture')
|
|
parser.add_argument('--perfect_match', type=int, default=1,
|
|
help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
|
|
parser.add_argument('--opt', type=str, default='sgd',
|
|
help='Optimizer Choice: sgd; adam')
|
|
parser.add_argument('--weight_decay', type=float, default=5e-4,
|
|
help='Weight Decay in SGD')
|
|
parser.add_argument('--lr', type=float, default=0.01,
|
|
help='Learning rate for training the model')
|
|
parser.add_argument('--batch_size', type=int, default=16,
|
|
help='Batch size foe training the model')
|
|
parser.add_argument('--epochs', type=int, default=15,
|
|
help='Total number of epochs for training the model')
|
|
parser.add_argument('--penalty_s', type=int, default=-1,
|
|
help='Epoch threshold over which Matching Loss to be optimised')
|
|
parser.add_argument('--penalty_irm', type=float, default=0.0,
|
|
help='Penalty weight for IRM invariant classifier loss')
|
|
parser.add_argument('--penalty_aug', type=float, default=1.0,
|
|
help='Penalty weight for Augmentation in Hybrid approach loss')
|
|
parser.add_argument('--penalty_ws', type=float, default=0.1,
|
|
help='Penalty weight for Matching Loss')
|
|
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
|
|
help='Penalty weight for Contrastive Loss')
|
|
parser.add_argument('--tau', type=float, default=0.05,
|
|
help='Temperature hyper param for NTXent contrastive loss ')
|
|
parser.add_argument('--match_flag', type=int, default=0,
|
|
help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
|
|
parser.add_argument('--match_case', type=float, default=1.0,
|
|
help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
|
|
parser.add_argument('--match_interrupt', type=int, default=5,
|
|
help='Number of epochs before inferring the match strategy')
|
|
parser.add_argument('--ctr_abl', type=int, default=0,
|
|
help='0: Randomization til class level ; 1: Randomization completely')
|
|
parser.add_argument('--match_abl', type=int, default=0,
|
|
help='0: Randomization til class level ; 1: Randomization completely')
|
|
parser.add_argument('--n_runs', type=int, default=3,
|
|
help='Number of iterations to repeat the training process')
|
|
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1,
|
|
help='Number of iterations to repeat training process for matchdg_erm')
|
|
parser.add_argument('--ctr_model_name', type=str, default='resnet18',
|
|
help='(For matchdg_ctr phase) Architecture of the model to be trained')
|
|
parser.add_argument('--ctr_match_layer', type=str, default='logit_match',
|
|
help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
|
|
parser.add_argument('--ctr_match_flag', type=int, default=1,
|
|
help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
|
|
parser.add_argument('--ctr_match_case', type=float, default=0.01,
|
|
help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
|
|
parser.add_argument('--ctr_match_interrupt', type=int, default=5,
|
|
help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
|
|
parser.add_argument('--mnist_seed', type=int, default=0,
|
|
help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
|
|
parser.add_argument('--retain', type=float, default=0,
|
|
help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
|
|
parser.add_argument('--cuda_device', type=int, default=0,
|
|
help='Select the cuda device by id among the avaliable devices' )
|
|
parser.add_argument('--os_env', type=int, default=0,
|
|
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
|
|
parser.add_argument('--dp_noise', type=int, default=0,
|
|
help='0: No DP noise; 1: Add DP noise')
|
|
|
|
#MMD, DANN
|
|
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
|
|
parser.add_argument('--grad_penalty', type=float, default=0.0)
|
|
parser.add_argument('--conditional', type=int, default=1)
|
|
parser.add_argument('--gaussian', type=int, default=1)
|
|
|
|
#Slab Dataset
|
|
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
|
help='Number of features in the slab dataset')
|
|
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
|
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
|
parser.add_argument('--slab_noise', type=float, default=0.1)
|
|
|
|
#Test Based Args
|
|
parser.add_argument('--test_metric', type=str, default='match_score',
|
|
help='Evaluation Metrics: acc; match_score, t_sne, mia')
|
|
parser.add_argument('--top_k', type=int, default=10,
|
|
help='Top K matches to consider for the match score evaluation metric')
|
|
parser.add_argument('--match_func_aug_case', type=int, default=0,
|
|
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
|
|
parser.add_argument('--match_func_data_case', type=str, default='val',
|
|
help='Dataset Train/Val/Test for the match score evaluation metric')
|
|
|
|
#Differentiate between resnet, lenet, domainbed cases of mnist
|
|
parser.add_argument('--mnist_case', type=str, default='resnet18',
|
|
help='MNIST Dataset Case: resnet18; lenet, domainbed')
|
|
|
|
#Multiple random matches
|
|
parser.add_argument('--total_matches_per_point', type=int, default=1,
|
|
help='Multiple random matches')
|
|
|
|
args = parser.parse_args()
|
|
|
|
#GPU
|
|
cuda= torch.device("cuda:" + str(args.cuda_device))
|
|
if cuda:
|
|
kwargs = {'num_workers': 0, 'pin_memory': False}
|
|
else:
|
|
kwargs= {}
|
|
|
|
#List of Train; Test domains
|
|
train_domains= args.train_domains
|
|
test_domains= args.test_domains
|
|
|
|
#Initialize
|
|
final_accuracy_target_val=[]
|
|
final_accuracy_source_val=[]
|
|
if args.os_env:
|
|
res_dir= os.getenv('PT_OUTPUT_DIR') + '/'
|
|
else:
|
|
res_dir= 'results/'
|
|
|
|
if args.dp_noise:
|
|
base_res_dir=(
|
|
res_dir + args.dataset_name + '/' + 'dp_' + args.method_name + '/' + args.match_layer
|
|
+ '/' + 'train_' + str(args.train_domains)
|
|
)
|
|
else:
|
|
base_res_dir=(
|
|
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
|
|
+ '/' + 'train_' + str(args.train_domains)
|
|
)
|
|
|
|
#TODO: Handle slab noise case in helper functions
|
|
if args.dataset_name == 'slab':
|
|
base_res_dir= base_res_dir + '/slab_noise_' + str(args.slab_noise)
|
|
|
|
if not os.path.exists(base_res_dir):
|
|
os.makedirs(base_res_dir)
|
|
|
|
#Execute the method for multiple runs ( total args.n_runs )
|
|
for run in range(args.n_runs):
|
|
print('Run', run)
|
|
#Seed for repoduability
|
|
random.seed(run*10)
|
|
np.random.seed(run*10)
|
|
torch.manual_seed(run*10)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(run*10)
|
|
|
|
#DataLoader
|
|
train_dataset= get_dataloader( args, run, train_domains, 'train', 0, kwargs )
|
|
if args.method_name == 'matchdg_ctr':
|
|
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
|
|
else:
|
|
val_dataset= get_dataloader( args, run, train_domains, 'val', 0, kwargs )
|
|
test_dataset= get_dataloader( args, run, test_domains, 'test', 0, kwargs )
|
|
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
|
|
|
|
#Import the module as per the current training method
|
|
if args.method_name == 'erm_match' or args.method_name == 'mask_linear' or args.method_name == 'dp_erm':
|
|
from algorithms.erm_match import ErmMatch
|
|
train_method= ErmMatch(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'matchdg_ctr':
|
|
from algorithms.match_dg import MatchDG
|
|
ctr_phase=1
|
|
train_method= MatchDG(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda, ctr_phase
|
|
)
|
|
elif args.method_name == 'matchdg_erm':
|
|
from algorithms.match_dg import MatchDG
|
|
ctr_phase=0
|
|
train_method= MatchDG(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda, ctr_phase
|
|
)
|
|
elif args.method_name == 'hybrid':
|
|
from algorithms.hybrid import Hybrid
|
|
train_method= Hybrid(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'erm':
|
|
from algorithms.erm import Erm
|
|
train_method= Erm(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'irm':
|
|
from algorithms.irm import Irm
|
|
train_method= Irm(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'dro':
|
|
from algorithms.dro import DRO
|
|
train_method= DRO(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'csd':
|
|
from algorithms.csd import CSD
|
|
train_method= CSD(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'mmd':
|
|
from algorithms.mmd import MMD
|
|
train_method= MMD(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
elif args.method_name == 'dann':
|
|
from algorithms.dann import DANN
|
|
train_method= DANN(
|
|
args, train_dataset, val_dataset,
|
|
test_dataset, base_res_dir,
|
|
run, cuda
|
|
)
|
|
|
|
#Train the method: It will save the model's weights post training and evalute it on test accuracy
|
|
train_method.train()
|
|
|
|
# Final Report Accuacy
|
|
if args.method_name != 'matchdg_ctr':
|
|
final_acc= np.max(train_method.final_acc)
|
|
final_accuracy_target_val.append( final_acc )
|
|
|
|
idx= np.argmax(train_method.val_acc)
|
|
final_acc= train_method.final_acc[idx]
|
|
final_accuracy_source_val.append( final_acc )
|
|
|
|
|
|
if args.method_name != 'matchdg_ctr':
|
|
print('\n')
|
|
print('Done for the Model..')
|
|
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
|
|
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
|
|
print('\n')
|