Slab Per Domain Validation Acc Metric

This commit is contained in:
divyat09 2021-06-11 05:23:05 +00:00
Родитель ab71903364
Коммит c2e6c2dcc3
10 изменённых файлов: 172 добавлений и 81 удалений

Просмотреть файл

@ -175,6 +175,7 @@ class BaseAlgo():
data_match_tensor= torch.stack( data_match_tensor )
label_match_tensor= torch.stack( label_match_tensor )
# print('Shape: ', data_match_tensor.shape, label_match_tensor.shape)
return data_match_tensor, label_match_tensor, curr_batch_size
def get_test_accuracy(self, case):

Просмотреть файл

@ -2,7 +2,8 @@ import os
import sys
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
methods=['matchdg_ctr', 'hybrid']
# methods=['erm', 'csd', 'matchdg_ctr', 'hybrid', 'matchdg_erm']
methods=['erm', 'csd']
domains= ['nih', 'chex', 'kaggle']
dataset= 'chestxray'
@ -74,7 +75,7 @@ for method in methods:
script= base_script + ' --method_name matchdg_erm --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
elif method == 'hybrid':
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 1.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 0.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
train_domains=''
for d in domains:
@ -83,5 +84,7 @@ for method in methods:
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + curr_test_domain
script= script + ' --n_runs 2 '
script= script + ' > ' + res_dir + str(method) + '.txt'
os.system(script)

Просмотреть файл

@ -33,10 +33,8 @@ if not os.path.exists(base_dir):
num_samples= 10000
total_slabs= 7
# slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
# spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
slab_noise_list= [0.0, 0.10]
spur_corr_list= [0.20, 0.90]
slab_noise_list= [0.0, 0.10, 0.20]
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90. 1.0]
for seed in range(10):
np.random.seed(seed*10)

Просмотреть файл

@ -123,7 +123,7 @@ class BaseEval():
def load_model(self, run_matchdg_erm):
if self.args.method_name in ['erm_match', 'csd', 'csd_slab', 'irm', 'irm_slab', 'perf_match', 'rand_match', 'mask_linear']:
if self.args.method_name in ['erm_match', 'csd', 'irm', 'perf_match', 'rand_match', 'mask_linear', 'mmd', 'dann']:
self.save_path= self.base_res_dir + '/Model_' + self.post_string
elif self.args.method_name == 'matchdg_ctr':

Просмотреть файл

@ -25,6 +25,9 @@ elif method == 'matchdg_erm':
for test_domain in domains:
if test_domain not in [75]:
continue
train_domains=''
for d in domains:
if d != test_domain:
@ -36,7 +39,7 @@ for test_domain in domains:
os.makedirs(res_dir)
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
# script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
print('Method: ', method, ' Test Domain: ', test_domain)
os.system(script)

47
reproduce_slab.py Normal file
Просмотреть файл

@ -0,0 +1,47 @@
import os
import sys
method= sys.argv[1]
case= sys.argv[2]
total_seed= 10
if case == 'train':
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
elif case == 'test':
base_script= 'python test.py --test_metric per_domain_acc --acc_data_case train --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
res_dir= 'results/slab/logs/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
if method == 'mmd':
script= base_script + ' --method_name mmd --gaussian 1 --conditional 0 --penalty_ws 0.1 '
elif method == 'c-mmd':
script= base_script + ' --method_name mmd --gaussian 1 --conditional 1 --penalty_ws 0.1 '
elif method == 'coral':
script= base_script + ' --method_name mmd --gaussian 0 --conditional 0 --penalty_ws 0.1 '
elif method == 'c-coral':
script= base_script + ' --method_name mmd --gaussian 0 --conditional 1 --penalty_ws 0.1 '
elif method == 'dann':
script= base_script + ' --method_name dann --conditional 0 --penalty_ws 0.01 --grad_penalty 0.1 --d_steps_per_g_step 4 '
elif method == 'c-dann':
script= base_script + ' --method_name dann --conditional 1 --penalty_ws 0.01 --grad_penalty 1.0 --d_steps_per_g_step 2 '
elif method == 'erm':
script= base_script + ' --method_name erm_match --match_case 0.0 --penalty_ws 0.0 '
elif method == 'rand':
script= base_script + ' --method_name erm_match --match_case 0.0 --penalty_ws 1.0 '
elif method == 'perf':
script= base_script + ' --method_name erm_match --match_case 1.0 --penalty_ws 1.0 '
# script= script + ' > ' + res_dir + str(method) + '.txt'
os.system(script)

Просмотреть файл

@ -2,13 +2,11 @@ import os
import sys
method= sys.argv[1]
slab_noise= float(sys.argv[2])
case= sys.argv[3]
total_seed= 3
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 0.90 --slab_data_dim 2 ' + ' --slab_noise ' + str(slab_noise) + ' --n_runs ' + str(total_seed)
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
res_dir= 'results/slab/htune/slab_noise_' + str(slab_noise) + '/'
res_dir= 'results/slab/htune/'
if not os.path.exists(res_dir):
os.makedirs(res_dir)
@ -30,61 +28,73 @@ elif method == 'dann':
elif method == 'c-dann':
base_script= base_script + ' --method_name dann --conditional 1 '
if method in ['dann', 'c-dann']:
elif method == 'erm':
script= base_script + ' --method_name erm_match --match_case 0.0 '
elif method == 'rand':
script= base_script + ' --method_name erm_match --match_case 0.0 '
elif method == 'perf':
script= base_script + ' --method_name erm_match --match_case 1.0 '
if method in ['erm']:
penalty_glist= [0.0]
elif method in ['rand', 'perf', 'mmd', 'coral', 'c-mmd', 'c-coral']:
penalty_glist= [0.1, 1.0, 10.0]
elif method in ['dann', 'c-dann']:
penalty_glist= [0.01, 0.1, 1.0, 10.0, 100.0]
else:
penalty_glist= [0.1, 1.0, 10.0]
grad_penalty_glist= [0.01, 0.1, 1.0, 10.0]
disc_steps_glist= [1, 2, 4, 8]
if case == 'train':
if method in ['dann', 'c-dann']:
for penalty in penalty_glist:
for grad_penalty in grad_penalty_glist:
for disc_steps in disc_steps_glist:
script= base_script + ' --penalty_ws ' + str(penalty) + ' --grad_penalty ' + str(grad_penalty) + ' --d_steps_per_g_step ' + str(disc_steps)
script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
os.system(script)
# Train for different hyper configurations
# if method in ['dann', 'c-dann']:
# for penalty in penalty_glist:
# for grad_penalty in grad_penalty_glist:
# for disc_steps in disc_steps_glist:
# script= base_script + ' --penalty_ws ' + str(penalty) + ' --grad_penalty ' + str(grad_penalty) + ' --d_steps_per_g_step ' + str(disc_steps)
# script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
# os.system(script)
else:
for penalty in penalty_glist:
script= base_script + ' --penalty_ws ' + str(penalty)
script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '.txt'
os.system(script)
# else:
# for penalty in penalty_glist:
# script= base_script + ' --penalty_ws ' + str(penalty)
# script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '.txt'
# os.system(script)
elif case == 'test':
# Search over different hyper configurations and infer the best values
best_acc= -1
best_err= -1
best_case= ''
if method in ['dann', 'c-dann']:
for penalty in penalty_glist:
for grad_penalty in grad_penalty_glist:
for disc_steps in disc_steps_glist:
f_name= res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
f= open(f_name)
data= f.readlines()
# Source validation accuracy
mean= float(data[-4].replace('\n', '').split(' ')[-2])
err= float(data[-4].replace('\n', '').split(' ')[-1])
if mean > best_acc:
best_acc= mean
best_err= err
best_case= f_name
else:
for penalty in penalty_glist:
f_name= res_dir + str(method) + '_' + str(penalty) + '.txt'
f= open(f_name)
data= f.readlines()
# Source validation accuracy
mean= float(data[-4].replace('\n', '').split(' ')[-2])
err= float(data[-4].replace('\n', '').split(' ')[-1])
if mean > best_acc:
best_acc= mean
best_err= err
best_case= f_name
print('Best Hparam for method: ', method, best_case)
print('Best Accuracy', best_acc, best_err)
best_acc= -1
best_err= -1
best_case= ''
if method in ['dann', 'c-dann']:
for penalty in penalty_glist:
for grad_penalty in grad_penalty_glist:
for disc_steps in disc_steps_glist:
f_name= res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
f= open(f_name)
data= f.readlines()
# Source validation accuracy
mean= float(data[-4].replace('\n', '').split(' ')[-2])
err= float(data[-4].replace('\n', '').split(' ')[-1])
if mean > best_acc:
best_acc= mean
best_err= err
best_case= f_name
else:
for penalty in penalty_glist:
f_name= res_dir + str(method) + '_' + str(penalty) + '.txt'
f= open(f_name)
data= f.readlines()
# Source validation accuracy
mean= float(data[-4].replace('\n', '').split(' ')[-2])
err= float(data[-4].replace('\n', '').split(' ')[-1])
if mean > best_acc:
best_acc= mean
best_err= err
best_case= f_name
print('Best Hparam for method: ', method, best_case)
print('Best Accuracy', best_acc, best_err)

28
test.py
Просмотреть файл

@ -45,11 +45,6 @@ parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.0)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
@ -141,6 +136,19 @@ parser.add_argument('--cuda_device', type=int, default=0,
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)
#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)
#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18',
help='MNIST Dataset Case: resnet18; lenet, domainbed')
@ -208,7 +216,7 @@ for run in range(args.n_runs):
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
elif args.match_func_data_case== 'test':
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
elif args.test_metric == 'acc':
elif args.test_metric in ['acc', 'per_domain_acc']:
if args.acc_data_case== 'train':
train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
elif args.acc_data_case== 'test':
@ -233,6 +241,14 @@ for run in range(args.n_runs):
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'per_domain_acc':
from evaluation.per_domain_acc import PerDomainAcc
test_method= PerDomainAcc(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.test_metric == 'match_score':
from evaluation.match_eval import MatchEval

Просмотреть файл

@ -40,17 +40,17 @@ def get_logits(model, loader, device, label=1):
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='slab',
parser.add_argument('--dataset_name', type=str, default='rot_mnist',
help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match',
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='slab',
parser.add_argument('--model_name', type=str, default='resnet18',
help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
help='List of test domains')
parser.add_argument('--out_classes', type=int, default=2,
parser.add_argument('--out_classes', type=int, default=10,
help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1,
help='Number of channels of the image in dataset')
@ -58,11 +58,6 @@ parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.0)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
@ -89,6 +84,8 @@ parser.add_argument('--penalty_s', type=int, default=-1,
help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0,
help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1,
help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
@ -152,6 +149,21 @@ parser.add_argument('--cuda_device', type=int, default=0,
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)
#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18',
help='MNIST Dataset Case: resnet18; lenet, domainbed')
#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1,
help='Multiple random matches')
args = parser.parse_args()
#GPU

Просмотреть файл

@ -42,11 +42,6 @@ parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.0)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
@ -122,6 +117,12 @@ parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)
#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)
#Test Based Args
parser.add_argument('--test_metric', type=str, default='match_score',