Slab Per Domain Validation Acc Metric
This commit is contained in:
Родитель
ab71903364
Коммит
c2e6c2dcc3
|
@ -175,6 +175,7 @@ class BaseAlgo():
|
|||
data_match_tensor= torch.stack( data_match_tensor )
|
||||
label_match_tensor= torch.stack( label_match_tensor )
|
||||
# print('Shape: ', data_match_tensor.shape, label_match_tensor.shape)
|
||||
|
||||
return data_match_tensor, label_match_tensor, curr_batch_size
|
||||
|
||||
def get_test_accuracy(self, case):
|
||||
|
|
|
@ -2,7 +2,8 @@ import os
|
|||
import sys
|
||||
|
||||
# methods=['erm', 'irm', 'csd', 'rand', 'matchdg_ctr', 'matchdg_erm', 'hybrid']
|
||||
methods=['matchdg_ctr', 'hybrid']
|
||||
# methods=['erm', 'csd', 'matchdg_ctr', 'hybrid', 'matchdg_erm']
|
||||
methods=['erm', 'csd']
|
||||
domains= ['nih', 'chex', 'kaggle']
|
||||
dataset= 'chestxray'
|
||||
|
||||
|
@ -74,7 +75,7 @@ for method in methods:
|
|||
script= base_script + ' --method_name matchdg_erm --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
|
||||
|
||||
elif method == 'hybrid':
|
||||
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 1.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
|
||||
script= base_script + ' --method_name hybrid --epochs 40 --lr 0.001 --match_case -1 --penalty_ws 0.0 --penalty_aug 50.0 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121'
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
|
@ -83,5 +84,7 @@ for method in methods:
|
|||
|
||||
print('Method: ', method, ' Train Domains: ', train_domains, ' Test Domains: ', curr_test_domain)
|
||||
script= script + ' --train_domains ' + train_domains + ' --test_domains ' + curr_test_domain
|
||||
|
||||
script= script + ' --n_runs 2 '
|
||||
script= script + ' > ' + res_dir + str(method) + '.txt'
|
||||
os.system(script)
|
|
@ -33,10 +33,8 @@ if not os.path.exists(base_dir):
|
|||
|
||||
num_samples= 10000
|
||||
total_slabs= 7
|
||||
# slab_noise_list= [0.0, 0.05, 0.10, 0.15, 0.20, 0.25]
|
||||
# spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90]
|
||||
slab_noise_list= [0.0, 0.10]
|
||||
spur_corr_list= [0.20, 0.90]
|
||||
slab_noise_list= [0.0, 0.10, 0.20]
|
||||
spur_corr_list= [0.0, 0.05, 0.10, 0.15, 0.30, 0.50, 0.70, 0.90. 1.0]
|
||||
|
||||
for seed in range(10):
|
||||
np.random.seed(seed*10)
|
||||
|
|
|
@ -123,7 +123,7 @@ class BaseEval():
|
|||
|
||||
def load_model(self, run_matchdg_erm):
|
||||
|
||||
if self.args.method_name in ['erm_match', 'csd', 'csd_slab', 'irm', 'irm_slab', 'perf_match', 'rand_match', 'mask_linear']:
|
||||
if self.args.method_name in ['erm_match', 'csd', 'irm', 'perf_match', 'rand_match', 'mask_linear', 'mmd', 'dann']:
|
||||
self.save_path= self.base_res_dir + '/Model_' + self.post_string
|
||||
|
||||
elif self.args.method_name == 'matchdg_ctr':
|
||||
|
|
|
@ -25,6 +25,9 @@ elif method == 'matchdg_erm':
|
|||
|
||||
for test_domain in domains:
|
||||
|
||||
if test_domain not in [75]:
|
||||
continue
|
||||
|
||||
train_domains=''
|
||||
for d in domains:
|
||||
if d != test_domain:
|
||||
|
@ -36,7 +39,7 @@ for test_domain in domains:
|
|||
os.makedirs(res_dir)
|
||||
|
||||
script= base_script + ' --train_domains ' + str(train_domains) + ' --test_domains ' + str(test_domain)
|
||||
script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
# script= script + ' > ' + res_dir + method + '_' + str(test_domain) + '.txt'
|
||||
|
||||
print('Method: ', method, ' Test Domain: ', test_domain)
|
||||
os.system(script)
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
method= sys.argv[1]
|
||||
case= sys.argv[2]
|
||||
total_seed= 10
|
||||
|
||||
if case == 'train':
|
||||
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
|
||||
|
||||
elif case == 'test':
|
||||
base_script= 'python test.py --test_metric per_domain_acc --acc_data_case train --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
|
||||
|
||||
|
||||
res_dir= 'results/slab/logs/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
if method == 'mmd':
|
||||
script= base_script + ' --method_name mmd --gaussian 1 --conditional 0 --penalty_ws 0.1 '
|
||||
|
||||
elif method == 'c-mmd':
|
||||
script= base_script + ' --method_name mmd --gaussian 1 --conditional 1 --penalty_ws 0.1 '
|
||||
|
||||
elif method == 'coral':
|
||||
script= base_script + ' --method_name mmd --gaussian 0 --conditional 0 --penalty_ws 0.1 '
|
||||
|
||||
elif method == 'c-coral':
|
||||
script= base_script + ' --method_name mmd --gaussian 0 --conditional 1 --penalty_ws 0.1 '
|
||||
|
||||
elif method == 'dann':
|
||||
script= base_script + ' --method_name dann --conditional 0 --penalty_ws 0.01 --grad_penalty 0.1 --d_steps_per_g_step 4 '
|
||||
|
||||
elif method == 'c-dann':
|
||||
script= base_script + ' --method_name dann --conditional 1 --penalty_ws 0.01 --grad_penalty 1.0 --d_steps_per_g_step 2 '
|
||||
|
||||
elif method == 'erm':
|
||||
script= base_script + ' --method_name erm_match --match_case 0.0 --penalty_ws 0.0 '
|
||||
|
||||
elif method == 'rand':
|
||||
script= base_script + ' --method_name erm_match --match_case 0.0 --penalty_ws 1.0 '
|
||||
|
||||
elif method == 'perf':
|
||||
script= base_script + ' --method_name erm_match --match_case 1.0 --penalty_ws 1.0 '
|
||||
|
||||
# script= script + ' > ' + res_dir + str(method) + '.txt'
|
||||
os.system(script)
|
118
slab-tune.py
118
slab-tune.py
|
@ -2,13 +2,11 @@ import os
|
|||
import sys
|
||||
|
||||
method= sys.argv[1]
|
||||
slab_noise= float(sys.argv[2])
|
||||
case= sys.argv[3]
|
||||
total_seed= 3
|
||||
|
||||
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 0.90 --slab_data_dim 2 ' + ' --slab_noise ' + str(slab_noise) + ' --n_runs ' + str(total_seed)
|
||||
base_script= 'python train.py --dataset slab --model_name slab --batch_size 128 --lr 0.1 --epochs 100 --out_classes 2 --train_domains 0.0 0.10 --test_domains 1.0 --slab_data_dim 2 --slab_noise 0.1 ' + ' --n_runs ' + str(total_seed)
|
||||
|
||||
res_dir= 'results/slab/htune/slab_noise_' + str(slab_noise) + '/'
|
||||
res_dir= 'results/slab/htune/'
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
|
@ -30,61 +28,73 @@ elif method == 'dann':
|
|||
elif method == 'c-dann':
|
||||
base_script= base_script + ' --method_name dann --conditional 1 '
|
||||
|
||||
if method in ['dann', 'c-dann']:
|
||||
elif method == 'erm':
|
||||
script= base_script + ' --method_name erm_match --match_case 0.0 '
|
||||
|
||||
elif method == 'rand':
|
||||
script= base_script + ' --method_name erm_match --match_case 0.0 '
|
||||
|
||||
elif method == 'perf':
|
||||
script= base_script + ' --method_name erm_match --match_case 1.0 '
|
||||
|
||||
|
||||
if method in ['erm']:
|
||||
penalty_glist= [0.0]
|
||||
elif method in ['rand', 'perf', 'mmd', 'coral', 'c-mmd', 'c-coral']:
|
||||
penalty_glist= [0.1, 1.0, 10.0]
|
||||
elif method in ['dann', 'c-dann']:
|
||||
penalty_glist= [0.01, 0.1, 1.0, 10.0, 100.0]
|
||||
else:
|
||||
penalty_glist= [0.1, 1.0, 10.0]
|
||||
|
||||
grad_penalty_glist= [0.01, 0.1, 1.0, 10.0]
|
||||
disc_steps_glist= [1, 2, 4, 8]
|
||||
|
||||
if case == 'train':
|
||||
if method in ['dann', 'c-dann']:
|
||||
for penalty in penalty_glist:
|
||||
for grad_penalty in grad_penalty_glist:
|
||||
for disc_steps in disc_steps_glist:
|
||||
script= base_script + ' --penalty_ws ' + str(penalty) + ' --grad_penalty ' + str(grad_penalty) + ' --d_steps_per_g_step ' + str(disc_steps)
|
||||
script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
|
||||
os.system(script)
|
||||
# Train for different hyper configurations
|
||||
# if method in ['dann', 'c-dann']:
|
||||
# for penalty in penalty_glist:
|
||||
# for grad_penalty in grad_penalty_glist:
|
||||
# for disc_steps in disc_steps_glist:
|
||||
# script= base_script + ' --penalty_ws ' + str(penalty) + ' --grad_penalty ' + str(grad_penalty) + ' --d_steps_per_g_step ' + str(disc_steps)
|
||||
# script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
|
||||
# os.system(script)
|
||||
|
||||
else:
|
||||
for penalty in penalty_glist:
|
||||
script= base_script + ' --penalty_ws ' + str(penalty)
|
||||
script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '.txt'
|
||||
os.system(script)
|
||||
# else:
|
||||
# for penalty in penalty_glist:
|
||||
# script= base_script + ' --penalty_ws ' + str(penalty)
|
||||
# script= script + ' > ' + res_dir + str(method) + '_' + str(penalty) + '.txt'
|
||||
# os.system(script)
|
||||
|
||||
elif case == 'test':
|
||||
# Search over different hyper configurations and infer the best values
|
||||
|
||||
best_acc= -1
|
||||
best_err= -1
|
||||
best_case= ''
|
||||
|
||||
if method in ['dann', 'c-dann']:
|
||||
for penalty in penalty_glist:
|
||||
for grad_penalty in grad_penalty_glist:
|
||||
for disc_steps in disc_steps_glist:
|
||||
f_name= res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
|
||||
f= open(f_name)
|
||||
data= f.readlines()
|
||||
# Source validation accuracy
|
||||
mean= float(data[-4].replace('\n', '').split(' ')[-2])
|
||||
err= float(data[-4].replace('\n', '').split(' ')[-1])
|
||||
if mean > best_acc:
|
||||
best_acc= mean
|
||||
best_err= err
|
||||
best_case= f_name
|
||||
else:
|
||||
for penalty in penalty_glist:
|
||||
f_name= res_dir + str(method) + '_' + str(penalty) + '.txt'
|
||||
f= open(f_name)
|
||||
data= f.readlines()
|
||||
# Source validation accuracy
|
||||
mean= float(data[-4].replace('\n', '').split(' ')[-2])
|
||||
err= float(data[-4].replace('\n', '').split(' ')[-1])
|
||||
if mean > best_acc:
|
||||
best_acc= mean
|
||||
best_err= err
|
||||
best_case= f_name
|
||||
|
||||
print('Best Hparam for method: ', method, best_case)
|
||||
print('Best Accuracy', best_acc, best_err)
|
||||
best_acc= -1
|
||||
best_err= -1
|
||||
best_case= ''
|
||||
|
||||
if method in ['dann', 'c-dann']:
|
||||
for penalty in penalty_glist:
|
||||
for grad_penalty in grad_penalty_glist:
|
||||
for disc_steps in disc_steps_glist:
|
||||
f_name= res_dir + str(method) + '_' + str(penalty) + '_' + str(grad_penalty) + '_' + str(disc_steps) + '.txt'
|
||||
f= open(f_name)
|
||||
data= f.readlines()
|
||||
# Source validation accuracy
|
||||
mean= float(data[-4].replace('\n', '').split(' ')[-2])
|
||||
err= float(data[-4].replace('\n', '').split(' ')[-1])
|
||||
if mean > best_acc:
|
||||
best_acc= mean
|
||||
best_err= err
|
||||
best_case= f_name
|
||||
else:
|
||||
for penalty in penalty_glist:
|
||||
f_name= res_dir + str(method) + '_' + str(penalty) + '.txt'
|
||||
f= open(f_name)
|
||||
data= f.readlines()
|
||||
# Source validation accuracy
|
||||
mean= float(data[-4].replace('\n', '').split(' ')[-2])
|
||||
err= float(data[-4].replace('\n', '').split(' ')[-1])
|
||||
if mean > best_acc:
|
||||
best_acc= mean
|
||||
best_err= err
|
||||
best_case= f_name
|
||||
|
||||
print('Best Hparam for method: ', method, best_case)
|
||||
print('Best Accuracy', best_acc, best_err)
|
28
test.py
28
test.py
|
@ -45,11 +45,6 @@ parser.add_argument('--img_h', type=int, default= 224,
|
|||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.0)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
|
@ -141,6 +136,19 @@ parser.add_argument('--cuda_device', type=int, default=0,
|
|||
parser.add_argument('--os_env', type=int, default=0,
|
||||
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
|
||||
|
||||
#MMD, DANN
|
||||
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
|
||||
parser.add_argument('--grad_penalty', type=float, default=0.0)
|
||||
parser.add_argument('--conditional', type=int, default=1)
|
||||
parser.add_argument('--gaussian', type=int, default=1)
|
||||
|
||||
#Slab Dataset
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.1)
|
||||
|
||||
#Differentiate between resnet, lenet, domainbed cases of mnist
|
||||
parser.add_argument('--mnist_case', type=str, default='resnet18',
|
||||
help='MNIST Dataset Case: resnet18; lenet, domainbed')
|
||||
|
@ -208,7 +216,7 @@ for run in range(args.n_runs):
|
|||
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
|
||||
elif args.match_func_data_case== 'test':
|
||||
test_dataset= get_dataloader( args, run, test_domains, 'test', 1, kwargs )
|
||||
elif args.test_metric == 'acc':
|
||||
elif args.test_metric in ['acc', 'per_domain_acc']:
|
||||
if args.acc_data_case== 'train':
|
||||
train_dataset= get_dataloader( args, run, train_domains, 'train', 1, kwargs )
|
||||
elif args.acc_data_case== 'test':
|
||||
|
@ -233,6 +241,14 @@ for run in range(args.n_runs):
|
|||
test_dataset, base_res_dir,
|
||||
run, cuda
|
||||
)
|
||||
|
||||
elif args.test_metric == 'per_domain_acc':
|
||||
from evaluation.per_domain_acc import PerDomainAcc
|
||||
test_method= PerDomainAcc(
|
||||
args, train_dataset, val_dataset,
|
||||
test_dataset, base_res_dir,
|
||||
run, cuda
|
||||
)
|
||||
|
||||
elif args.test_metric == 'match_score':
|
||||
from evaluation.match_eval import MatchEval
|
||||
|
|
28
test_slab.py
28
test_slab.py
|
@ -40,17 +40,17 @@ def get_logits(model, loader, device, label=1):
|
|||
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_name', type=str, default='slab',
|
||||
parser.add_argument('--dataset_name', type=str, default='rot_mnist',
|
||||
help='Datasets: rot_mnist; fashion_mnist; pacs')
|
||||
parser.add_argument('--method_name', type=str, default='erm_match',
|
||||
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
|
||||
parser.add_argument('--model_name', type=str, default='slab',
|
||||
parser.add_argument('--model_name', type=str, default='resnet18',
|
||||
help='Architecture of the model to be trained')
|
||||
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
|
||||
help='List of train domains')
|
||||
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
|
||||
help='List of test domains')
|
||||
parser.add_argument('--out_classes', type=int, default=2,
|
||||
parser.add_argument('--out_classes', type=int, default=10,
|
||||
help='Total number of classes in the dataset')
|
||||
parser.add_argument('--img_c', type=int, default= 1,
|
||||
help='Number of channels of the image in dataset')
|
||||
|
@ -58,11 +58,6 @@ parser.add_argument('--img_h', type=int, default= 224,
|
|||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.0)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
|
@ -89,6 +84,8 @@ parser.add_argument('--penalty_s', type=int, default=-1,
|
|||
help='Epoch threshold over which Matching Loss to be optimised')
|
||||
parser.add_argument('--penalty_irm', type=float, default=0.0,
|
||||
help='Penalty weight for IRM invariant classifier loss')
|
||||
parser.add_argument('--penalty_aug', type=float, default=1.0,
|
||||
help='Penalty weight for Augmentation in Hybrid approach loss')
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.1,
|
||||
help='Penalty weight for Matching Loss')
|
||||
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
|
||||
|
@ -152,6 +149,21 @@ parser.add_argument('--cuda_device', type=int, default=0,
|
|||
parser.add_argument('--os_env', type=int, default=0,
|
||||
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
|
||||
|
||||
#Slab Dataset
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.1)
|
||||
|
||||
#Differentiate between resnet, lenet, domainbed cases of mnist
|
||||
parser.add_argument('--mnist_case', type=str, default='resnet18',
|
||||
help='MNIST Dataset Case: resnet18; lenet, domainbed')
|
||||
|
||||
#Multiple random matches
|
||||
parser.add_argument('--total_matches_per_point', type=int, default=1,
|
||||
help='Multiple random matches')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#GPU
|
||||
|
|
11
train.py
11
train.py
|
@ -42,11 +42,6 @@ parser.add_argument('--img_h', type=int, default= 224,
|
|||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.0)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
|
@ -122,6 +117,12 @@ parser.add_argument('--grad_penalty', type=float, default=0.0)
|
|||
parser.add_argument('--conditional', type=int, default=1)
|
||||
parser.add_argument('--gaussian', type=int, default=1)
|
||||
|
||||
#Slab Dataset
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--slab_noise', type=float, default=0.1)
|
||||
|
||||
#Test Based Args
|
||||
parser.add_argument('--test_metric', type=str, default='match_score',
|
||||
|
|
Загрузка…
Ссылка в новой задаче