Added helper functions, data_loader and base_model classes
This commit is contained in:
Родитель
e0ff521887
Коммит
965c78e5a0
|
@ -8,19 +8,13 @@ import torch
|
|||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
class MnistRotated(data_utils.Dataset):
|
||||
class MnistRotated(BaseDataLoader):
|
||||
def __init__(self, dataset_name, list_train_domains, mnist_subset, root, transform=None, data_case='train', download=True):
|
||||
self.dataset_name= dataset_name
|
||||
self.list_train_domains = list_train_domains
|
||||
|
||||
super().__init__(dataset_name, list_train_domains, root, transform, data_case)
|
||||
self.mnist_subset = mnist_subset
|
||||
self.root = os.path.expanduser(root)
|
||||
self.transform = transform
|
||||
self.data_case = data_case
|
||||
self.download = download
|
||||
self.base_domain_idx= -1
|
||||
self.base_domain_size= 0
|
||||
self.training_list_size=[]
|
||||
|
||||
|
||||
self.train_data, self.train_labels, self.train_domain, self.train_indices = self._get_data()
|
||||
|
||||
def load_inds(self):
|
||||
|
@ -37,17 +31,8 @@ class MnistRotated(data_utils.Dataset):
|
|||
else:
|
||||
res= np.random.choice(60000, 1000)
|
||||
return res
|
||||
# res=[]
|
||||
# for subset in range(10):
|
||||
# temp= np.load(self.root + '/supervised_inds_' + str(subset) + '.npy' )
|
||||
# res.append(temp)
|
||||
# res= np.array(res)
|
||||
# res= np.reshape( res, (res.shape[0]*res.shape[1]) )
|
||||
# print(res.shape)
|
||||
# return res
|
||||
else:
|
||||
if self.dataset_name == 'rot_mnist':
|
||||
# data_dir= self.root + '/rot_mnist_lenet_indices'
|
||||
data_dir= self.root + '/rot_mnist_indices'
|
||||
elif self.dataset_name == 'fashion_mnist':
|
||||
data_dir= self.root + '/fashion_mnist_indices'
|
||||
|
@ -193,75 +178,3 @@ class MnistRotated(data_utils.Dataset):
|
|||
|
||||
print(train_imgs.shape, train_labels.shape, train_domains.shape, train_indices.shape)
|
||||
return train_imgs.unsqueeze(1), train_labels, train_domains, train_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.train_data[index]
|
||||
y = self.train_labels[index]
|
||||
d = self.train_domain[index]
|
||||
idx = self.train_indices[index]
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
return x, y, d, idx
|
||||
|
||||
def get_size(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
|
||||
seed = 1
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(seed)
|
||||
|
||||
list_train_domains = ['0', '15', '30', '45', '60']
|
||||
num_supervised = 1000
|
||||
|
||||
train_loader = data_utils.DataLoader(
|
||||
MnistRotated(list_train_domains, num_supervised, seed, '../dataset/', train=True),
|
||||
batch_size=100,
|
||||
shuffle=False)
|
||||
|
||||
y_array = np.zeros(10)
|
||||
d_array = np.zeros(5)
|
||||
|
||||
for i, (x, y, d) in enumerate(train_loader):
|
||||
y_array += y.sum(dim=0).cpu().numpy()
|
||||
d_array += d.sum(dim=0).cpu().numpy()
|
||||
|
||||
if i == 0:
|
||||
print(y)
|
||||
print(d)
|
||||
n = min(x.size(0), 8)
|
||||
comparison = x[:n].view(-1, 1, 16, 16)
|
||||
save_image(comparison.cpu(),
|
||||
'reconstruction_rotation_train.png', nrow=n)
|
||||
|
||||
print(y_array, d_array)
|
||||
|
||||
test_loader = data_utils.DataLoader(
|
||||
MnistRotated(list_train_domains, seed, '../dataset/', train=False),
|
||||
batch_size=100,
|
||||
shuffle=False)
|
||||
|
||||
y_array = np.zeros(10)
|
||||
d_array = np.zeros(5)
|
||||
|
||||
for i, (x, y, d) in enumerate(test_loader):
|
||||
y_array += y.sum(dim=0).cpu().numpy()
|
||||
d_array += d.sum(dim=0).cpu().numpy()
|
||||
|
||||
if i == 0:
|
||||
print(y)
|
||||
print(d)
|
||||
n = min(x.size(0), 8)
|
||||
comparison = x[:n].view(-1, 1, 16, 16)
|
||||
save_image(comparison.cpu(),
|
||||
'reconstruction_rotation_test.png', nrow=n)
|
||||
|
||||
print(y_array, d_array)
|
|
@ -8,18 +8,12 @@ import torch
|
|||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
class MnistRotated(data_utils.Dataset):
|
||||
class MnistRotated(BaseDataLoader):
|
||||
def __init__(self, dataset_name, list_train_domains, mnist_subset, root, transform=None, data_case='train', download=True):
|
||||
self.dataset_name= dataset_name
|
||||
self.list_train_domains = list_train_domains
|
||||
|
||||
super().__init__(dataset_name, list_train_domains, root, transform, data_case)
|
||||
self.mnist_subset = mnist_subset
|
||||
self.root = os.path.expanduser(root)
|
||||
self.transform = transform
|
||||
self.data_case = data_case
|
||||
self.download = download
|
||||
self.base_domain_idx= -1
|
||||
self.base_domain_size= 0
|
||||
self.training_list_size=[]
|
||||
|
||||
self.train_data, self.train_labels, self.train_domain, self.train_indices = self._get_data()
|
||||
|
||||
|
@ -32,18 +26,9 @@ class MnistRotated(data_utils.Dataset):
|
|||
else:
|
||||
res= np.random.choice(60000, 100)
|
||||
return res
|
||||
# res=[]
|
||||
# for subset in range(10):
|
||||
# temp= np.load(self.root + '/supervised_inds_' + str(subset) + '.npy' )
|
||||
# res.append(temp)
|
||||
# res= np.array(res)
|
||||
# res= np.reshape( res, (res.shape[0]*res.shape[1]) )
|
||||
# print(res.shape)
|
||||
# return res
|
||||
else:
|
||||
if self.dataset_name == 'rot_mnist':
|
||||
data_dir= self.root + '/rot_mnist_lenet_indices'
|
||||
# data_dir= self.root + '/rot_mnist_indices'
|
||||
|
||||
if self.data_case != 'val':
|
||||
return np.load(data_dir + '/supervised_inds_' + str(self.mnist_subset) + '.npy')
|
||||
|
@ -180,75 +165,3 @@ class MnistRotated(data_utils.Dataset):
|
|||
|
||||
print(train_imgs.shape, train_labels.shape, train_domains.shape, train_indices.shape)
|
||||
return train_imgs.unsqueeze(1), train_labels, train_domains, train_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.train_data[index]
|
||||
y = self.train_labels[index]
|
||||
d = self.train_domain[index]
|
||||
idx = self.train_indices[index]
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
return x, y, d, idx
|
||||
|
||||
def get_size(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
|
||||
seed = 1
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(seed)
|
||||
|
||||
list_train_domains = ['0', '15', '30', '45', '60']
|
||||
num_supervised = 1000
|
||||
|
||||
train_loader = data_utils.DataLoader(
|
||||
MnistRotated(list_train_domains, num_supervised, seed, '../dataset/', train=True),
|
||||
batch_size=100,
|
||||
shuffle=False)
|
||||
|
||||
y_array = np.zeros(10)
|
||||
d_array = np.zeros(5)
|
||||
|
||||
for i, (x, y, d) in enumerate(train_loader):
|
||||
y_array += y.sum(dim=0).cpu().numpy()
|
||||
d_array += d.sum(dim=0).cpu().numpy()
|
||||
|
||||
if i == 0:
|
||||
print(y)
|
||||
print(d)
|
||||
n = min(x.size(0), 8)
|
||||
comparison = x[:n].view(-1, 1, 16, 16)
|
||||
save_image(comparison.cpu(),
|
||||
'reconstruction_rotation_train.png', nrow=n)
|
||||
|
||||
print(y_array, d_array)
|
||||
|
||||
test_loader = data_utils.DataLoader(
|
||||
MnistRotated(list_train_domains, seed, '../dataset/', train=False),
|
||||
batch_size=100,
|
||||
shuffle=False)
|
||||
|
||||
y_array = np.zeros(10)
|
||||
d_array = np.zeros(5)
|
||||
|
||||
for i, (x, y, d) in enumerate(test_loader):
|
||||
y_array += y.sum(dim=0).cpu().numpy()
|
||||
d_array += d.sum(dim=0).cpu().numpy()
|
||||
|
||||
if i == 0:
|
||||
print(y)
|
||||
print(d)
|
||||
n = min(x.size(0), 8)
|
||||
comparison = x[:n].view(-1, 1, 16, 16)
|
||||
save_image(comparison.cpu(),
|
||||
'reconstruction_rotation_test.png', nrow=n)
|
||||
|
||||
print(y_array, d_array)
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
class BaseDataLoader(data_utils.Dataset):
|
||||
def __init__(self, dataset_name, list_train_domains, root, transform=None, data_case='train'):
|
||||
self.dataset_name= dataset_name
|
||||
self.list_train_domains = list_train_domains
|
||||
self.root = os.path.expanduser(root)
|
||||
self.transform = transform
|
||||
self.data_case = data_case
|
||||
self.base_domain_idx= -1
|
||||
self.base_domain_size= 0
|
||||
self.training_list_size=[]
|
||||
self.train_data= []
|
||||
self.train_labels= []
|
||||
self.train_domain= []
|
||||
self.train_indices= []
|
||||
|
||||
def __len__(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.train_data[index]
|
||||
y = self.train_labels[index]
|
||||
d = self.train_domain[index]
|
||||
idx = self.train_indices[index]
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
return x, y, d, idx
|
||||
|
||||
def get_size(self):
|
||||
return self.train_labels.shape[0]
|
||||
|
|
@ -15,336 +15,7 @@ import torch.utils.data as data_utils
|
|||
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
def t_sne_plot(X):
|
||||
# X= X.view(X.shape[0], X.shape[1]*X.shape[2]*X.shape[3])
|
||||
X= X.detach().cpu().numpy()
|
||||
X= TSNE(n_components=2).fit_transform(X)
|
||||
return X
|
||||
|
||||
def classifier(x_e, phi, w):
|
||||
return torch.matmul(phi(x_e), w)
|
||||
|
||||
def erm_loss(temp_logits, target_label):
|
||||
#mse= torch.nn.MSELoss(reduction="none")
|
||||
#print(torch.argmax(temp_logits, dim=1), target_label)
|
||||
loss= F.cross_entropy(temp_logits, target_label.long()).to(cuda)
|
||||
return loss
|
||||
|
||||
def cosine_similarity( x1, x2 ):
|
||||
cos= torch.nn.CosineSimilarity(dim=1, eps=1e-08)
|
||||
return 1.0 - cos(x1, x2)
|
||||
|
||||
def l1_dist(x1, x2):
|
||||
|
||||
#Broadcasting
|
||||
if len(x1.shape) == len(x2.shape) - 1:
|
||||
x1=x1.unsqueeze(1)
|
||||
if len(x2.shape) == len(x1.shape) - 1:
|
||||
x2=x2.unsqueeze(1)
|
||||
|
||||
if len(x1.shape) == 3 and len(x2.shape) ==3:
|
||||
# Tensor shapes: (N,1,D) and (N,K,D) so x1-x2 would result in (N,K,D)
|
||||
return torch.sum( torch.sum(torch.abs(x1 - x2), dim=2) , dim=1 )
|
||||
elif len(x1.shape) ==2 and len(x2.shape) ==2:
|
||||
return torch.sum( torch.abs(x1 - x2), dim=1 )
|
||||
elif len(x1.shape) ==1 and len(x2.shape) ==1:
|
||||
return torch.sum( torch.abs(x1 - x2), dim=0 )
|
||||
else:
|
||||
print('Error: Expect 1, 2 or 3 rank tensors to compute L1 Norm')
|
||||
return
|
||||
|
||||
def l2_dist(x1, x2):
|
||||
|
||||
#Broadcasting
|
||||
if len(x1.shape) == len(x2.shape) - 1:
|
||||
x1=x1.unsqueeze(1)
|
||||
if len(x2.shape) == len(x1.shape) - 1:
|
||||
x2=x2.unsqueeze(1)
|
||||
|
||||
if len(x1.shape) == 3 and len(x2.shape) ==3:
|
||||
# Tensor shapes: (N,1,D) and (N,K,D) so x1-x2 would result in (N,K,D)
|
||||
return torch.sum( torch.sum((x1 - x2)**2, dim=2) , dim=1 )
|
||||
elif len(x1.shape) ==2 and len(x2.shape) ==2:
|
||||
return torch.sum( (x1 - x2)**2, dim=1 )
|
||||
elif len(x1.shape) ==1 and len(x2.shape) ==1:
|
||||
return torch.sum( (x1 - x2)**2, dim=0 )
|
||||
else:
|
||||
print('Error: Expect 1, 2 or 3 rank tensors to compute L2 Norm')
|
||||
return
|
||||
|
||||
def embedding_dist(x1, x2):
|
||||
if args.pos_metric == 'l1':
|
||||
return l1_dist(x1, x2)
|
||||
elif args.pos_metric == 'l2':
|
||||
return l2_dist(x1, x2)
|
||||
elif args.pos_metric == 'cos':
|
||||
return cosine_similarity( x1, x2 )
|
||||
|
||||
|
||||
# def wasserstein_penalty( ):
|
||||
|
||||
def compute_penalty( model, feature, target_label, domain_label):
|
||||
curr_domains= np.unique(domain_label)
|
||||
ret= torch.tensor(0.).to(cuda)
|
||||
for domain in curr_domains:
|
||||
indices= domain_label == domain
|
||||
temp_logits= model(feature[indices])
|
||||
labels= target_label[indices]
|
||||
scale = torch.tensor(1.).to(cuda).requires_grad_()
|
||||
loss = F.cross_entropy(temp_logits*scale, labels.long()).to(cuda)
|
||||
g = grad(loss, [scale], create_graph=True)[0].to(cuda)
|
||||
# Since g is scalar output, do we need torch.sum?
|
||||
ret+= torch.sum(g**2)
|
||||
|
||||
return ret
|
||||
|
||||
def init_data_match_dict(keys, vals, variation):
|
||||
data={}
|
||||
for key in keys:
|
||||
data[key]={}
|
||||
if variation:
|
||||
val_dim= vals[key]
|
||||
else:
|
||||
val_dim= vals
|
||||
|
||||
if args.dataset == 'color_mnist':
|
||||
data[key]['data']=torch.rand((val_dim, 2, 28, 28))
|
||||
elif args.dataset == 'rot_mnist' or args.dataset == 'fashion_mnist':
|
||||
data[key]['data']=torch.rand((val_dim, 1, 224, 224))
|
||||
elif args.dataset == 'pacs':
|
||||
data[key]['data']=torch.rand((val_dim, 3, 227, 227))
|
||||
|
||||
data[key]['label']=torch.rand((val_dim, 1))
|
||||
data[key]['idx']=torch.randint(0, 1, (val_dim, 1))
|
||||
return data
|
||||
|
||||
def perfect_match_score(indices_matched):
|
||||
counter=0
|
||||
score=0
|
||||
for key in indices_matched:
|
||||
for match in indices_matched[key]:
|
||||
if key == match:
|
||||
score+=1
|
||||
counter+=1
|
||||
if counter:
|
||||
return 100*score/counter
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_dataloader(train_data_obj, val_data_obj, test_data_obj):
|
||||
# Load supervised training
|
||||
train_dataset = data_utils.DataLoader(train_data_obj, batch_size=args.batch_size, shuffle=True, **kwargs )
|
||||
|
||||
# Can select a higher batch size for val and test domains
|
||||
test_batch=512
|
||||
val_dataset = data_utils.DataLoader(val_data_obj, batch_size=test_batch, shuffle=True, **kwargs )
|
||||
test_dataset = data_utils.DataLoader(test_data_obj, batch_size=test_batch, shuffle=True, **kwargs )
|
||||
|
||||
# elif args.dataset == 'color_mnist' or args.dataset =='rot_color_mnist':
|
||||
# # Load supervised training
|
||||
# train_dataset = data_utils.DataLoader( MnistRotated(train_domains, -1, 0.25, 1, 'data/rot_mnist', train=True), batch_size=args.batch_size, shuffle=True, **kwargs )
|
||||
|
||||
# test_dataset = data_utils.DataLoader( MnistRotated(test_domains, -1, 0.25, 1, 'data/rot_mnist', train=True), batch_size=args.batch_size, shuffle=True, **kwargs )
|
||||
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
|
||||
def get_matched_pairs (train_dataset, domain_size, total_domains, training_list_size, phi, match_case, inferred_match):
|
||||
|
||||
#Making Data Matched pairs
|
||||
data_matched= init_data_match_dict( range(domain_size), total_domains, 0 )
|
||||
domain_data= init_data_match_dict( range(total_domains), training_list_size, 1)
|
||||
indices_matched={}
|
||||
for key in range(domain_size):
|
||||
indices_matched[key]=[]
|
||||
perfect_match_rank=[]
|
||||
|
||||
domain_count={}
|
||||
for domain in range(total_domains):
|
||||
domain_count[domain]= 0
|
||||
|
||||
# Create dictionary: class label -> list of ordered indices
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(train_dataset):
|
||||
x_e= x_e
|
||||
y_e= torch.argmax(y_e, dim=1)
|
||||
d_e= torch.argmax(d_e, dim=1).numpy()
|
||||
|
||||
domain_indices= np.unique(d_e)
|
||||
for domain_idx in domain_indices:
|
||||
indices= d_e == domain_idx
|
||||
ordered_indices= idx_e[indices]
|
||||
for idx in range(ordered_indices.shape[0]):
|
||||
#Matching points across domains
|
||||
perfect_indice= ordered_indices[idx].item()
|
||||
domain_data[domain_idx]['data'][perfect_indice]= x_e[indices][idx]
|
||||
domain_data[domain_idx]['label'][perfect_indice]= y_e[indices][idx]
|
||||
domain_data[domain_idx]['idx'][perfect_indice]= idx_e[indices][idx]
|
||||
domain_count[domain_idx]+= 1
|
||||
|
||||
#Sanity Check: To check if the domain_data was updated for all the data points
|
||||
# for domain in range(total_domains):
|
||||
# if domain_count[domain] != training_list_size[domain]:
|
||||
# print('Error: Some data points are missing from domain_data dictionary')
|
||||
|
||||
# Creating the random permutation tensor for each domain
|
||||
perm_size= int(domain_size*(1-match_case))
|
||||
|
||||
#Determine the base_domain_idx as the domain with the max samples of the current class
|
||||
base_domain_dict={}
|
||||
for y_c in range(args.out_classes):
|
||||
base_domain_size=0
|
||||
base_domain_idx=-1
|
||||
for domain_idx in range(total_domains):
|
||||
class_idx= domain_data[domain_idx]['label'] == y_c
|
||||
curr_size= domain_data[domain_idx]['label'][class_idx].shape[0]
|
||||
if base_domain_size < curr_size:
|
||||
base_domain_size= curr_size
|
||||
base_domain_idx= domain_idx
|
||||
|
||||
base_domain_dict[y_c]= base_domain_idx
|
||||
#print('Base Domain: ', base_domain_size, base_domain_idx, y_c )
|
||||
|
||||
# Applying the random permutation tensor
|
||||
for domain_idx in range(total_domains):
|
||||
total_rand_counter=0
|
||||
total_data_idx=0
|
||||
for y_c in range(args.out_classes):
|
||||
base_domain_idx= base_domain_dict[y_c]
|
||||
|
||||
indices_base= domain_data[base_domain_idx]['label'] == y_c
|
||||
indices_base= indices_base[:,0]
|
||||
ordered_base_indices= domain_data[base_domain_idx]['idx'][indices_base]
|
||||
|
||||
indices_curr= domain_data[domain_idx]['label'] == y_c
|
||||
indices_curr= indices_curr[:,0]
|
||||
ordered_curr_indices= domain_data[domain_idx]['idx'][indices_curr]
|
||||
curr_size= ordered_curr_indices.shape[0]
|
||||
|
||||
# Sanity check for perfect match case:
|
||||
# if args.perfect_match:
|
||||
# if not torch.equal(ordered_base_indices, ordered_curr_indices):
|
||||
# print('Error: Different indices across domains for perfect match' )
|
||||
|
||||
# Only for the perfect match case to generate x% correct match strategy
|
||||
rand_base_indices= ordered_base_indices[ ordered_base_indices < perm_size ]
|
||||
idx_perm= torch.randperm( rand_base_indices.shape[0] )
|
||||
rand_base_indices= rand_base_indices[idx_perm]
|
||||
rand_counter=0
|
||||
|
||||
base_feat_data=domain_data[base_domain_idx]['data'][indices_base]
|
||||
base_feat_data_split= torch.split( base_feat_data, args.batch_size, dim=0 )
|
||||
base_feat=[]
|
||||
for batch_feat in base_feat_data_split:
|
||||
with torch.no_grad():
|
||||
batch_feat=batch_feat.to(cuda)
|
||||
out= phi(batch_feat)
|
||||
base_feat.append(out.cpu())
|
||||
base_feat= torch.cat(base_feat)
|
||||
|
||||
if inferred_match:
|
||||
feat_x_data= domain_data[domain_idx]['data'][indices_curr]
|
||||
feat_x_data_split= torch.split(feat_x_data, args.batch_size, dim=0)
|
||||
feat_x=[]
|
||||
for batch_feat in feat_x_data_split:
|
||||
with torch.no_grad():
|
||||
batch_feat= batch_feat.to(cuda)
|
||||
out= phi(batch_feat)
|
||||
feat_x.append(out.cpu())
|
||||
feat_x= torch.cat(feat_x)
|
||||
|
||||
base_feat= base_feat.unsqueeze(1)
|
||||
base_feat_split= torch.split(base_feat, args.batch_size, dim=0)
|
||||
|
||||
data_idx=0
|
||||
for batch_feat in base_feat_split:
|
||||
|
||||
if inferred_match:
|
||||
# Need to compute over batches of base_fear due ot CUDA Memory out errors
|
||||
# Else no ned for loop over base_feat_split; could have simply computed feat_x - base_feat
|
||||
ws_dist= torch.sum( (feat_x - batch_feat)**2, dim=2)
|
||||
match_idx= torch.argmin( ws_dist, dim=1 )
|
||||
sort_val, sort_idx= torch.sort( ws_dist, dim=1 )
|
||||
del ws_dist
|
||||
|
||||
for idx in range(batch_feat.shape[0]):
|
||||
perfect_indice= ordered_base_indices[data_idx].item()
|
||||
|
||||
if domain_idx == base_domain_idx:
|
||||
curr_indice= perfect_indice
|
||||
else:
|
||||
if args.perfect_match:
|
||||
if inferred_match:
|
||||
curr_indice= ordered_curr_indices[match_idx[idx]].item()
|
||||
#Find where does the perfect match lies in the sorted order of matches
|
||||
#In the situations where the perfect match is known; the ordered_curr_indices and ordered_base_indices are the same
|
||||
perfect_match_rank.append( (ordered_curr_indices[sort_idx[idx]] == perfect_indice).nonzero()[0,0].item() )
|
||||
else:
|
||||
# To allow x% match case type permutations for datasets where the perfect match is known
|
||||
# In perfect match settings; same ordered indice implies perfect match across domains
|
||||
if perfect_indice < perm_size:
|
||||
curr_indice= rand_base_indices[rand_counter].item()
|
||||
rand_counter+=1
|
||||
total_rand_counter+=1
|
||||
else:
|
||||
curr_indice= perfect_indice
|
||||
indices_matched[perfect_indice].append(curr_indice)
|
||||
|
||||
else:
|
||||
if inferred_match:
|
||||
curr_indice= ordered_curr_indices[match_idx[idx]].item()
|
||||
else:
|
||||
curr_indice= ordered_curr_indices[data_idx%curr_size].item()
|
||||
|
||||
data_matched[total_data_idx]['data'][domain_idx]= domain_data[domain_idx]['data'][curr_indice]
|
||||
data_matched[total_data_idx]['label'][domain_idx]= domain_data[domain_idx]['label'][curr_indice]
|
||||
data_idx+=1
|
||||
total_data_idx+=1
|
||||
|
||||
# if total_data_idx != domain_size:
|
||||
# print('Error: Some data points left from data_matched dictionary', total_data_idx, domain_size)
|
||||
|
||||
# if args.perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
|
||||
# print('Error: Total random changes made are less than perm_size for domain', domain_idx, total_rand_counter, perm_size)
|
||||
|
||||
|
||||
# Sanity Check: N keys; K vals per key
|
||||
# for key in data_matched.keys():
|
||||
# if data_matched[key]['label'].shape[0] != total_domains:
|
||||
# print('Error with data matching')
|
||||
|
||||
#Sanity Check: Ensure paired points have the same class label
|
||||
wrong_case=0
|
||||
for key in data_matched.keys():
|
||||
for d_i in range(data_matched[key]['label'].shape[0]):
|
||||
for d_j in range(data_matched[key]['label'].shape[0]):
|
||||
if d_j > d_i:
|
||||
if data_matched[key]['label'][d_i] != data_matched[key]['label'][d_j]:
|
||||
wrong_case+=1
|
||||
#print('Total Label MisMatch across pairs: ', wrong_case )
|
||||
|
||||
data_match_tensor=[]
|
||||
label_match_tensor=[]
|
||||
for key in data_matched.keys():
|
||||
data_match_tensor.append( data_matched[key]['data'] )
|
||||
label_match_tensor.append(data_matched[key]['label'] )
|
||||
|
||||
data_match_tensor= torch.stack( data_match_tensor )
|
||||
label_match_tensor= torch.stack( label_match_tensor )
|
||||
|
||||
# Creating tensor of shape ( domain_size * total domains, feat size )
|
||||
# data_match_tensor= data_match_tensor.view( data_match_tensor.shape[0]*data_match_tensor.shape[1], data_match_tensor.shape[2], data_match_tensor.shape[3], data_match_tensor.shape[4] )
|
||||
# label_match_tensor= label_match_tensor.view( label_match_tensor.shape[0]*label_match_tensor.shape[1] )
|
||||
|
||||
#print(data_match_tensor.shape, label_match_tensor.shape)
|
||||
|
||||
del domain_data
|
||||
del data_matched
|
||||
|
||||
return data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank
|
||||
|
||||
|
||||
|
||||
|
||||
def test(test_dataset, phi, epoch):
|
||||
#Test Env Code
|
||||
test_acc= 0.0
|
1368
main-train.py
1368
main-train.py
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,122 +0,0 @@
|
|||
import torch
|
||||
import torch.utils.data
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import torch.nn.init as init
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
__all__ = ['AlexNet', 'alexnet']
|
||||
|
||||
model_urls = {
|
||||
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
|
||||
}
|
||||
|
||||
|
||||
class Id(nn.Module):
|
||||
def __init__(self):
|
||||
super(Id, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
def __init__(self, num_classes=1000, dropout=True):
|
||||
super(AlexNet, self).__init__()
|
||||
|
||||
self.features = nn.Sequential(OrderedDict([
|
||||
("conv1", nn.Conv2d(3, 96, kernel_size=11, stride=4)),
|
||||
("relu1", nn.ReLU(inplace=True)),
|
||||
("pool1", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
|
||||
("norm1", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
|
||||
("conv2", nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2)),
|
||||
("relu2", nn.ReLU(inplace=True)),
|
||||
("pool2", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
|
||||
("norm2", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
|
||||
("conv3", nn.Conv2d(256, 384, kernel_size=3, padding=1)),
|
||||
("relu3", nn.ReLU(inplace=True)),
|
||||
("conv4", nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2)),
|
||||
("relu4", nn.ReLU(inplace=True)),
|
||||
("conv5", nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2)),
|
||||
("relu5", nn.ReLU(inplace=True)),
|
||||
("pool5", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
|
||||
]))
|
||||
|
||||
self.classifier = nn.Sequential(OrderedDict([
|
||||
("fc6", nn.Linear(256 * 6 * 6, 4096)),
|
||||
("relu6", nn.ReLU(inplace=True)),
|
||||
("drop6", nn.Dropout()),
|
||||
("fc7", nn.Linear(4096, 4096)),
|
||||
("relu7", nn.ReLU(inplace=True)),
|
||||
("drop7", nn.Dropout()),
|
||||
("fc8", nn.Linear(4096, num_classes))
|
||||
]))
|
||||
|
||||
self.initialize_params()
|
||||
|
||||
|
||||
def initialize_params(self):
|
||||
|
||||
for layer in self.modules():
|
||||
#if isinstance(layer, torch.nn.Conv2d):
|
||||
#init.kaiming_normal_(layer.weight, a=0, mode='fan_out')
|
||||
#layer.bias.data.zero_()
|
||||
if isinstance(layer, torch.nn.Linear):
|
||||
init.xavier_uniform_(layer.weight, 0.1)
|
||||
layer.bias.data.zero_()
|
||||
#elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d):
|
||||
#layer.weight.data.fill_(1)
|
||||
#layer.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x*57.6)
|
||||
x = x.view(x.size(0), 256 * 6 * 6)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def alexnet(classes, pretrained=False, erm_base=1):
|
||||
r"""AlexNet model architecture from the
|
||||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = AlexNet(classes, erm_base)
|
||||
|
||||
if pretrained:
|
||||
state_dict = torch.load("models/alexnet_caffe.pth.tar")
|
||||
del state_dict["classifier.fc8.weight"]
|
||||
del state_dict["classifier.fc8.bias"]
|
||||
model.load_state_dict(state_dict, strict = False)
|
||||
|
||||
module=[]
|
||||
if erm_base==0:
|
||||
for idx in range(4):
|
||||
layer= model.classifier[idx]
|
||||
module.append(layer)
|
||||
model.classifier= nn.Sequential( *module )
|
||||
print(model)
|
||||
return model
|
||||
|
||||
|
||||
class ClfNet(nn.Module):
|
||||
def __init__(self, rep_net, rep_dim, out_dim):
|
||||
super(ClfNet, self).__init__()
|
||||
self.rep_net= rep_net
|
||||
self.erm_net=nn.Sequential(
|
||||
nn.Linear(rep_dim, out_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(out_dim, out_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out= self.rep_net(x)
|
||||
return self.erm_net(out)
|
|
@ -7,7 +7,6 @@ from torchvision.utils import save_image
|
|||
from torch.autograd import Variable
|
||||
from torchvision.models.resnet import ResNet, BasicBlock
|
||||
|
||||
|
||||
# Defining the network (LeNet-5)
|
||||
class LeNet5(torch.nn.Module):
|
||||
|
||||
|
@ -26,9 +25,10 @@ class LeNet5(torch.nn.Module):
|
|||
# Max-pooling
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
self.predict_fc_net= nn.Sequential(
|
||||
# Fully connected layer
|
||||
# convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns)
|
||||
# convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns)
|
||||
nn.Linear(16*5*5, 120),
|
||||
nn.ReLU(),
|
||||
# convert matrix with 120 features to a matrix of 84 features (columns)
|
||||
|
@ -43,25 +43,4 @@ class LeNet5(torch.nn.Module):
|
|||
out= self.predict_conv_net(x)
|
||||
out= out.view(-1,out.shape[1]*out.shape[2]*out.shape[3])
|
||||
out= self.predict_fc_net(out)
|
||||
return out
|
||||
|
||||
class ClfNet(nn.Module):
|
||||
def __init__(self, rep_net, rep_dim, out_dim):
|
||||
super(ClfNet, self).__init__()
|
||||
self.rep_net= rep_net
|
||||
self.erm_net=nn.Sequential(
|
||||
nn.Linear(rep_dim, out_dim)
|
||||
# nn.Linear(rep_dim, 200),
|
||||
# nn.BatchNorm1d(200),
|
||||
# nn.Dropout(),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(200, 100),
|
||||
# nn.BatchNorm1d(100),
|
||||
# nn.Dropout(),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(100, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out= self.rep_net(x)
|
||||
return self.erm_net(out)
|
||||
return out
|
107
models/ResNet.py
107
models/ResNet.py
|
@ -3,88 +3,6 @@ from torch.utils import model_zoo
|
|||
import torchvision
|
||||
from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, jigsaw_classes=1000, classes=100, domains=3):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.jigsaw_classifier = nn.Linear(512 * block.expansion, jigsaw_classes)
|
||||
self.class_classifier = nn.Linear(512 * block.expansion, classes)
|
||||
#self.domain_classifier = nn.Linear(512 * block.expansion, domains)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def is_patch_based(self):
|
||||
return False
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.jigsaw_classifier(x), self.class_classifier(x)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
|
||||
return model
|
||||
|
||||
def resnet50(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
# bypass layer
|
||||
class Identity(nn.Module):
|
||||
def __init__(self,n_inputs):
|
||||
|
@ -100,8 +18,6 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained):
|
|||
model= torchvision.models.resnet18(pre_trained)
|
||||
n_inputs = model.fc.in_features
|
||||
n_outputs= classes
|
||||
#print(n_inputs)
|
||||
# model.fc = nn.Linear(n_inputs, n_outputs)
|
||||
if erm_base:
|
||||
model.fc = nn.Linear(n_inputs, n_outputs)
|
||||
else:
|
||||
|
@ -117,25 +33,4 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained):
|
|||
padding=(3, 3),
|
||||
bias=False)
|
||||
|
||||
return model
|
||||
|
||||
class ClfNet(nn.Module):
|
||||
def __init__(self, rep_net, rep_dim, out_dim):
|
||||
super(ClfNet, self).__init__()
|
||||
self.rep_net= rep_net
|
||||
self.erm_net=nn.Sequential(
|
||||
nn.Linear(rep_dim, out_dim)
|
||||
# nn.Linear(rep_dim, 200),
|
||||
# nn.BatchNorm1d(200),
|
||||
# nn.Dropout(),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(200, 100),
|
||||
# nn.BatchNorm1d(100),
|
||||
# nn.Dropout(),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(100, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out= self.rep_net(x)
|
||||
return self.erm_net(out)
|
||||
return model
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser(description='Evaluation')
|
||||
parser.add_argument('--dataset', type=str, default='rot_mnist')
|
||||
parser.add_argument('--domain_abl', type=int, default=0, help='0: No Abl; x: Train with x domains only')
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.1)
|
||||
parser.add_argument('--match_case', type=float, default=1, help='0.01: Random Match; 1: Perfect Match')
|
||||
parser.add_argument('--match_dg', type=int, default=0, help='0: ERM, ERM+Match; 1: MatcDG Phase 1; 2: MatchDG Phase 2')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get results for ERM, ERM_RandomMatch, ERM_PerfectMatch
|
||||
if args.match_dg == 0:
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 15 --batch_size 16 --penalty_w 0.0 --penalty_s -1 --penalty_same_ctr 0.0 --penalty_diff_ctr 0.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 100.0 --save_logs 0 --test_domain 1 --seed -1 --match_flag 0 --match_interrupt 35 --pre_trained 1 --method_name phi_match --out_classes 10 --n_runs 3 --pos_metric l2 --model mnist --perfect_match 1 --erm_base 1 --ctr_phase 0 --erm_phase 0 --penalty_ws_erm 0.0"
|
||||
|
||||
script= base_script + ' --match_case ' + str(args.match_case) + ' --dataset ' + str(args.dataset) + ' --domain_abl ' + str(args.domain_abl) + ' --penalty_ws ' + str(args.penalty_ws)
|
||||
os.system(script)
|
||||
|
||||
#Train MatchDG
|
||||
else:
|
||||
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 30 --batch_size 64 --penalty_w 0.0 --penalty_s -1 --penalty_ws 0.0 --penalty_same_ctr 0.0 --penalty_diff_ctr 1.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 1.5 --save_logs 0 --test_domain 1 --seed -1 --match_flag 1 --match_interrupt 5 --pre_trained 1 --method_name phi_match --pos_metric cos --out_classes 10 --n_runs 2 --model mnist --perfect_match 1 --erm_base 0 --ctr_phase 1 --erm_phase 0 --epochs_erm 25 --penalty_ws_erm 0.1 --match_case_erm -1 --opt sgd"
|
||||
|
||||
script = base_script + ' --domain_abl ' + str(args.domain_abl) + ' --dataset ' + str(args.dataset) + ' --match_case ' + str(args.match_case)
|
||||
os.system(script)
|
||||
|
||||
# Don't need Match DG Phase 2 for Perfect Match Seed
|
||||
if args.match_case == 0.01:
|
||||
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 30 --batch_size 16 --penalty_w 0.0 --penalty_s -1 --penalty_ws 0.0 --penalty_same_ctr 0.0 --penalty_diff_ctr 1.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 1.5 --save_logs 0 --test_domain 1 --seed -1 --match_case 0.01 --match_flag 1 --match_interrupt 5 --pre_trained 1 --method_name phi_match --pos_metric cos --out_classes 10 --n_runs 2 --model mnist --perfect_match 1 --erm_base 0 --ctr_phase 0 --erm_phase 1 --epochs_erm 15 --penalty_ws_erm 0.1 --match_case_erm -1 --opt sgd"
|
||||
|
||||
script= base_script + ' --dataset ' + str(args.dataset) + ' --domain_abl ' + str(args.domain_abl)
|
||||
os.system(script)
|
829
train.py
829
train.py
|
@ -1,35 +1,822 @@
|
|||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import argparse
|
||||
import copy
|
||||
import random
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data as data_utils
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
def train( train_dataset, data_match_tensor, label_match_tensor, phi, opt, opt_ws, scheduler, epoch, base_domain_idx, bool_erm, bool_ws, bool_ctr):
|
||||
|
||||
penalty_erm=0
|
||||
penalty_erm_2=0
|
||||
penalty_irm=0
|
||||
penalty_ws=0
|
||||
penalty_same_ctr=0
|
||||
penalty_diff_ctr=0
|
||||
penalty_same_hinge=0
|
||||
penalty_diff_hinge=0
|
||||
train_acc= 0.0
|
||||
train_size=0
|
||||
|
||||
perm = torch.randperm(data_match_tensor.size(0))
|
||||
data_match_tensor_split= torch.split(data_match_tensor[perm], args.batch_size, dim=0)
|
||||
label_match_tensor_split= torch.split(label_match_tensor[perm], args.batch_size, dim=0)
|
||||
print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split))
|
||||
|
||||
#Loop Over One Environment
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(train_dataset):
|
||||
# print('Batch Idx: ', batch_idx)
|
||||
|
||||
opt.zero_grad()
|
||||
opt_ws.zero_grad()
|
||||
|
||||
x_e= x_e.to(cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(cuda)
|
||||
d_e= torch.argmax(d_e, dim=1).numpy()
|
||||
|
||||
#Forward Pass
|
||||
out= phi(x_e)
|
||||
loss_e= torch.tensor(0.0).to(cuda)
|
||||
penalty_e= torch.tensor(0.0).to(cuda)
|
||||
|
||||
if bool_erm:
|
||||
# torch.mean not really required here since the reduction mean is set by default in ERM loss
|
||||
if args.method_name in ['erm', 'irm'] or ( args.erm_phase==1 and args.match_case_erm == -1 ):
|
||||
loss_e= erm_loss(out, y_e)
|
||||
else:
|
||||
#####
|
||||
## Experimenting for now to keep standard erm loss for all the different methods
|
||||
####
|
||||
loss_e= 0*erm_loss(out, y_e)
|
||||
|
||||
# penalty_e= compute_penalty(phi, x_e, y_e, d_e)
|
||||
# penalty_erm+= float(loss_e)
|
||||
# penalty_irm+= float(penalty_e)
|
||||
|
||||
# weight_norm = torch.tensor(0.).to(cuda)
|
||||
# for w in phi.erm_net[-1].parameters():
|
||||
# weight_norm += w.norm().pow(2)
|
||||
|
||||
if epoch > anneal_iter:
|
||||
loss_e+= lmd*penalty_e
|
||||
if lmd > 1.0:
|
||||
# Rescale the entire loss to keep gradients in a reasonable range
|
||||
loss_e /= lmd
|
||||
|
||||
# loss_e+=0.001*weight_norm
|
||||
|
||||
wasserstein_loss=torch.tensor(0.0).to(cuda)
|
||||
erm_loss_2= torch.tensor(0.0).to(cuda)
|
||||
same_ctr_loss = torch.tensor(0.0).to(cuda)
|
||||
diff_ctr_loss = torch.tensor(0.0).to(cuda)
|
||||
same_hinge_loss = torch.tensor(0.0).to(cuda)
|
||||
diff_hinge_loss = torch.tensor(0.0).to(cuda)
|
||||
if epoch > anneal_iter and args.method_name in ['rep_match', 'phi_match', 'phi_match_abl']:
|
||||
# sample_size= args.batch_size
|
||||
# perm = torch.randperm(data_match_tensor.size(0))
|
||||
# idx = perm[:sample_size]
|
||||
|
||||
|
||||
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
|
||||
total_batch_size= len(data_match_tensor_split)
|
||||
if batch_idx >= total_batch_size:
|
||||
break
|
||||
curr_batch_size= data_match_tensor_split[batch_idx].shape[0]
|
||||
|
||||
# data_match= data_match_tensor[idx].to(cuda)
|
||||
data_match= data_match_tensor_split[batch_idx].to(cuda)
|
||||
data_match= data_match.view( data_match.shape[0]*data_match.shape[1], data_match.shape[2], data_match.shape[3], data_match.shape[4] )
|
||||
feat_match= phi( data_match )
|
||||
|
||||
# label_match= label_match_tensor[idx].to(cuda)
|
||||
label_match= label_match_tensor_split[batch_idx].to(cuda)
|
||||
label_match= label_match.view( label_match.shape[0]*label_match.shape[1] )
|
||||
|
||||
if bool_erm:
|
||||
erm_loss_2+= erm_loss(feat_match, label_match)
|
||||
penalty_erm_2+= float(erm_loss_2)
|
||||
|
||||
if args.method_name=="rep_match":
|
||||
temp_out= phi.predict_conv_net( data_match )
|
||||
temp_out= temp_out.view(-1, temp_out.shape[1]*temp_out.shape[2]*temp_out.shape[3])
|
||||
feat_match= phi.predict_fc_net(temp_out)
|
||||
del temp_out
|
||||
|
||||
# Creating tensor of shape ( domain size, total domains, feat size )
|
||||
if len(feat_match.shape) == 4:
|
||||
feat_match= feat_match.view( curr_batch_size, len(train_domains), feat_match.shape[1]*feat_match.shape[2]*feat_match.shape[3] )
|
||||
else:
|
||||
feat_match= feat_match.view( curr_batch_size, len(train_domains), feat_match.shape[1] )
|
||||
|
||||
label_match= label_match.view( curr_batch_size, len(train_domains) )
|
||||
|
||||
# print(feat_match.shape)
|
||||
data_match= data_match.view( curr_batch_size, len(train_domains), data_match.shape[1], data_match.shape[2], data_match.shape[3] )
|
||||
|
||||
#Positive Match Loss
|
||||
if bool_ws:
|
||||
pos_match_counter=0
|
||||
for d_i in range(feat_match.shape[1]):
|
||||
# if d_i != base_domain_idx:
|
||||
# continue
|
||||
for d_j in range(feat_match.shape[1]):
|
||||
if d_j > d_i:
|
||||
|
||||
if args.erm_phase:
|
||||
wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) )
|
||||
else:
|
||||
if args.pos_metric == 'l2':
|
||||
wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) )
|
||||
elif args.pos_metric == 'l1':
|
||||
wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )
|
||||
elif args.pos_metric == 'cos':
|
||||
wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )
|
||||
|
||||
pos_match_counter += feat_match.shape[0]
|
||||
|
||||
wasserstein_loss = wasserstein_loss / pos_match_counter
|
||||
penalty_ws+= float(wasserstein_loss)
|
||||
|
||||
# Contrastive Loss
|
||||
if bool_ctr:
|
||||
same_neg_counter=1
|
||||
diff_neg_counter=1
|
||||
for y_c in range(args.out_classes):
|
||||
|
||||
pos_indices= label_match[:, 0] == y_c
|
||||
neg_indices= label_match[:, 0] != y_c
|
||||
pos_feat_match= feat_match[pos_indices]
|
||||
neg_feat_match= feat_match[neg_indices]
|
||||
|
||||
if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
|
||||
print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
||||
|
||||
# If no instances of label y_c in the current batch then continue
|
||||
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# Iterating over anchors from different domains
|
||||
for d_i in range(pos_feat_match.shape[1]):
|
||||
if torch.sum( torch.isnan(neg_feat_match) ):
|
||||
print('Non Reshaped X2 is Nan')
|
||||
sys.exit()
|
||||
|
||||
diff_neg_feat_match= neg_feat_match.view( neg_feat_match.shape[0]*neg_feat_match.shape[1], neg_feat_match.shape[2] )
|
||||
|
||||
if torch.sum( torch.isnan(diff_neg_feat_match) ):
|
||||
print('Reshaped X2 is Nan')
|
||||
sys.exit()
|
||||
|
||||
neg_dist= embedding_dist( pos_feat_match[:, d_i, :], diff_neg_feat_match[:, :], args.tau, xent=True)
|
||||
if torch.sum(torch.isnan(neg_dist)):
|
||||
print('Neg Dist Nan')
|
||||
sys.exit()
|
||||
|
||||
# Iterating pos dist for current anchor
|
||||
for d_j in range(pos_feat_match.shape[1]):
|
||||
if d_i != d_j:
|
||||
pos_dist= 1.0 - embedding_dist( pos_feat_match[:, d_i, :], pos_feat_match[:, d_j, :] )
|
||||
pos_dist= pos_dist / args.tau
|
||||
if torch.sum(torch.isnan(neg_dist)):
|
||||
print('Pos Dist Nan')
|
||||
sys.exit()
|
||||
|
||||
if torch.sum( torch.isnan( torch.log( torch.exp(pos_dist) + neg_dist ) ) ):
|
||||
print('Xent Nan')
|
||||
sys.exit()
|
||||
|
||||
# print( 'Pos Dist', pos_dist )
|
||||
# print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist ))
|
||||
diff_hinge_loss+= -1*torch.sum( pos_dist - torch.log( torch.exp(pos_dist) + neg_dist ) )
|
||||
diff_ctr_loss+= torch.sum(neg_dist)
|
||||
diff_neg_counter+= pos_dist.shape[0]
|
||||
|
||||
same_ctr_loss = same_ctr_loss / same_neg_counter
|
||||
diff_ctr_loss = diff_ctr_loss / diff_neg_counter
|
||||
same_hinge_loss = same_hinge_loss / same_neg_counter
|
||||
diff_hinge_loss = diff_hinge_loss / diff_neg_counter
|
||||
|
||||
penalty_same_ctr+= float(same_ctr_loss)
|
||||
penalty_diff_ctr+= float(diff_ctr_loss)
|
||||
penalty_same_hinge+= float(same_hinge_loss)
|
||||
penalty_diff_hinge+= float(diff_hinge_loss)
|
||||
|
||||
if args.erm_base:
|
||||
if epoch >= args.match_interrupt and args.match_flag==1:
|
||||
loss_e += ( args.penalty_ws*( epoch - anneal_iter - args.match_interrupt )/(args.epochs - anneal_iter - args.match_interrupt) )*wasserstein_loss
|
||||
loss_e += ( args.penalty_same_ctr*( epoch - anneal_iter - args.match_interrupt )/(args.epochs - anneal_iter - args.match_interrupt) )*same_hinge_loss
|
||||
loss_e += ( args.penalty_diff_ctr*( epoch - anneal_iter - args.match_interrupt )/(args.epochs - anneal_iter - args.match_interrupt) )*diff_hinge_loss
|
||||
else:
|
||||
loss_e += ( args.penalty_ws*( epoch-anneal_iter )/(args.epochs -anneal_iter) )*wasserstein_loss
|
||||
loss_e += ( args.penalty_same_ctr*( epoch-anneal_iter )/(args.epochs -anneal_iter) )*same_hinge_loss
|
||||
loss_e += ( args.penalty_diff_ctr*( epoch-anneal_iter )/(args.epochs -anneal_iter) )*diff_hinge_loss
|
||||
|
||||
loss_e += args.penalty_erm*erm_loss_2
|
||||
|
||||
# No CTR and No Match Update Case here
|
||||
elif args.erm_phase:
|
||||
loss_e += ( args.penalty_ws_erm*( epoch-anneal_iter )/(args.epochs_erm -anneal_iter) )*wasserstein_loss
|
||||
loss_e += args.penalty_erm*erm_loss_2
|
||||
|
||||
elif args.ctr_phase:
|
||||
if epoch >= args.match_interrupt:
|
||||
loss_e += ( args.penalty_ws*( epoch - anneal_iter - args.match_interrupt )/(args.epochs - anneal_iter - args.match_interrupt) )*wasserstein_loss
|
||||
# loss_e += ( args.penalty_same_ctr*( epoch - anneal_iter - args.match_interrupt )/(args.epochs - anneal_iter - args.match_interrupt) )*same_hinge_loss
|
||||
loss_e += ( args.penalty_diff_ctr*( epoch-anneal_iter )/(args.epochs -anneal_iter) )*diff_hinge_loss
|
||||
|
||||
loss_e.backward(retain_graph=False)
|
||||
opt.step()
|
||||
|
||||
# opt.zero_grad()
|
||||
# opt_ws.zero_grad()
|
||||
# wasserstein_loss.backward()
|
||||
# opt_ws.step()
|
||||
|
||||
del penalty_e
|
||||
del erm_loss_2
|
||||
del wasserstein_loss
|
||||
del same_ctr_loss
|
||||
del diff_ctr_loss
|
||||
del same_hinge_loss
|
||||
del diff_hinge_loss
|
||||
del loss_e
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if bool_erm:
|
||||
train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
|
||||
train_size+= y_e.shape[0]
|
||||
|
||||
|
||||
print('Train Loss Basic : ', penalty_erm, penalty_irm, penalty_ws, penalty_erm_2 )
|
||||
print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr, penalty_same_hinge, penalty_diff_hinge)
|
||||
if bool_erm:
|
||||
print('Train Acc Env : ', 100*train_acc/train_size )
|
||||
print('Done Training for epoch: ', epoch)
|
||||
|
||||
# scheduler.step()
|
||||
|
||||
return penalty_erm_2, penalty_irm, penalty_ws, penalty_same_ctr, penalty_diff_ctr
|
||||
|
||||
|
||||
def test(test_dataset, phi, epoch, case='Test'):
|
||||
#Test Env Code
|
||||
test_acc= 0.0
|
||||
test_size=0
|
||||
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(test_dataset):
|
||||
with torch.no_grad():
|
||||
x_e= x_e.to(cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(cuda)
|
||||
d_e = torch.argmax(d_e, dim=1).numpy()
|
||||
#print(type(x_e), x_e.shape, y_e.shape, d_e.shape)
|
||||
|
||||
#Forward Pass
|
||||
out= phi(x_e)
|
||||
loss_e= torch.mean(erm_loss(out, y_e))
|
||||
|
||||
test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
|
||||
test_size+= y_e.shape[0]
|
||||
#print('Test Loss Env : ', loss_e)
|
||||
|
||||
print( case + ' Accuracy: Epoch ', epoch, 100*test_acc/test_size )
|
||||
|
||||
return (100*test_acc/test_size)
|
||||
|
||||
# Input Parsing
|
||||
parser = argparse.ArgumentParser(description='Evaluation')
|
||||
parser = argparse.ArgumentParser(description='PACS')
|
||||
parser.add_argument('--dataset', type=str, default='rot_mnist')
|
||||
parser.add_argument('--method_name', type=str, default='erm')
|
||||
parser.add_argument('--pos_metric', type=str, default='l2')
|
||||
parser.add_argument('--model_name', type=str, default='alexnet')
|
||||
parser.add_argument('--opt', type=str, default='sgd')
|
||||
parser.add_argument('--out_classes', type=int, default=10)
|
||||
parser.add_argument('--rep_dim', type=int, default=250)
|
||||
parser.add_argument('--test_domain', type=int, default=0, help='0: In angles; 1: out angles')
|
||||
parser.add_argument('--perfect_match', type=int, default=0, help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--epochs', type=int, default=25)
|
||||
parser.add_argument('--lr', type=float, default=0.001)
|
||||
parser.add_argument('--penalty_w', type=float, default=100)
|
||||
parser.add_argument('--penalty_s', type=int, default=5)
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.001)
|
||||
parser.add_argument('--penalty_same_ctr',type=float, default=0.001)
|
||||
parser.add_argument('--penalty_diff_ctr',type=float, default=0.001)
|
||||
parser.add_argument('--penalty_erm',type=float, default=1)
|
||||
parser.add_argument('--same_margin', type=float, default=1.0)
|
||||
parser.add_argument('--diff_margin', type=float, default=1.0)
|
||||
parser.add_argument('--epochs_erm', type=int, default=25)
|
||||
parser.add_argument('--n_runs_erm', type=int, default=2)
|
||||
parser.add_argument('--penalty_ws_erm', type=float, default=0.1)
|
||||
parser.add_argument('--match_case_erm', type=float, default=1.0)
|
||||
parser.add_argument('--pre_trained',type=int, default=1)
|
||||
parser.add_argument('--match_flag', type=int, default=1, help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
|
||||
parser.add_argument('--match_case', type=float, default=1, help='0: Random Match; 1: Perfect Match')
|
||||
parser.add_argument('--match_interrupt', type=int, default=10)
|
||||
parser.add_argument('--base_domain_idx', type=int, default=1)
|
||||
parser.add_argument('--ctr_abl', type=int, default=0, help='0: Randomization til class level ; 1: Randomization completely')
|
||||
parser.add_argument('--match_abl', type=int, default=0, help='0: Randomization til class level ; 1: Randomization completely')
|
||||
parser.add_argument('--domain_abl', type=int, default=0, help='0: No Abl; x: Train with x domains only')
|
||||
parser.add_argument('--penalty_ws', type=float, default=0.1)
|
||||
parser.add_argument('--match_case', type=float, default=1, help='0.01: Random Match; 1: Perfect Match')
|
||||
parser.add_argument('--match_dg', type=int, default=0, help='0: ERM, ERM+Match; 1: MatcDG Phase 1; 2: MatchDG Phase 2')
|
||||
|
||||
parser.add_argument('--erm_base', type=int, default=1, help='0: ERM loss added gradually; 1: ERM weight constant')
|
||||
parser.add_argument('--ctr_phase', type=int, default=1, help='0: No Metric Learning; 1: Metric Learning')
|
||||
parser.add_argument('--erm_phase', type=int, default=1, help='0: No ERM Learning; 1: ERM Learning')
|
||||
parser.add_argument('--domain_gen', type=int, default=1)
|
||||
parser.add_argument('--save_logs', type=int, default=0)
|
||||
parser.add_argument('--n_runs', type=int, default=3)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--tau', type=float, default=0.05)
|
||||
parser.add_argument('--retain', type=float, default=0, help='0: Train from scratch in ERM Phase; 1: Finetune from CTR Phase in ERM Phase')
|
||||
parser.add_argument('--cuda_device', type=int, default=0 )
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get results for ERM, ERM_RandomMatch, ERM_PerfectMatch
|
||||
if args.match_dg == 0:
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 15 --batch_size 16 --penalty_w 0.0 --penalty_s -1 --penalty_same_ctr 0.0 --penalty_diff_ctr 0.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 100.0 --save_logs 0 --test_domain 1 --seed -1 --match_flag 0 --match_interrupt 35 --pre_trained 1 --method_name phi_match --out_classes 10 --n_runs 3 --pos_metric l2 --model mnist --perfect_match 1 --erm_base 1 --ctr_phase 0 --erm_phase 0 --penalty_ws_erm 0.0"
|
||||
if args.dataset == 'color_mnist':
|
||||
from models.color_mnist import *
|
||||
from models.ResNet import *
|
||||
from data.color_mnist.mnist_loader import MnistRotated
|
||||
if args.dataset == 'rot_color_mnist':
|
||||
from models.rot_mnist import *
|
||||
from data.rot_color_mnist.mnist_loader import MnistRotated
|
||||
# elif args.dataset == 'fashion_mnist':
|
||||
# from models.rot_mnist import *
|
||||
# from data.rot_fashion_mnist.fashion_mnist_loader import MnistRotated
|
||||
elif args.dataset == 'rot_mnist' or args.dataset == 'fashion_mnist':
|
||||
# from models.rot_mnist import *
|
||||
# from models.metric_rot_mnist import *
|
||||
if args.model_name == 'lenet':
|
||||
from models.LeNet import *
|
||||
from models.ResNet import *
|
||||
from data.rot_mnist.mnist_loader_lenet import MnistRotated
|
||||
# from data.rot_mnist.mnist_loader import MnistRotated
|
||||
else:
|
||||
from models.ResNet import *
|
||||
from data.rot_mnist.mnist_loader import MnistRotated
|
||||
elif args.dataset == 'pacs':
|
||||
from models.AlexNet import *
|
||||
from models.ResNet import *
|
||||
from data.pacs.pacs_loader import PACS
|
||||
|
||||
#GPU
|
||||
cuda= torch.device("cuda:" + str(args.cuda_device))
|
||||
|
||||
script= base_script + ' --match_case ' + str(args.match_case) + ' --dataset ' + str(args.dataset) + ' --domain_abl ' + str(args.domain_abl) + ' --penalty_ws ' + str(args.penalty_ws)
|
||||
os.system(script)
|
||||
# Environments
|
||||
base_data_dir='data/' + args.dataset +'/'
|
||||
base_logs_dir="results/" + args.dataset +'/'
|
||||
base_res_dir="results/" + args.dataset + '/'
|
||||
|
||||
#Train MatchDG
|
||||
if cuda:
|
||||
kwargs = {'num_workers': 1, 'pin_memory': False}
|
||||
else:
|
||||
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 30 --batch_size 64 --penalty_w 0.0 --penalty_s -1 --penalty_ws 0.0 --penalty_same_ctr 0.0 --penalty_diff_ctr 1.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 1.5 --save_logs 0 --test_domain 1 --seed -1 --match_flag 1 --match_interrupt 5 --pre_trained 1 --method_name phi_match --pos_metric cos --out_classes 10 --n_runs 2 --model mnist --perfect_match 1 --erm_base 0 --ctr_phase 1 --erm_phase 0 --epochs_erm 25 --penalty_ws_erm 0.1 --match_case_erm -1 --opt sgd"
|
||||
kwargs= {}
|
||||
|
||||
script = base_script + ' --domain_abl ' + str(args.domain_abl) + ' --dataset ' + str(args.dataset) + ' --match_case ' + str(args.match_case)
|
||||
os.system(script)
|
||||
if args.dataset == 'rot_mnist' or args.dataset == 'fashion_mnist':
|
||||
|
||||
# Don't need Match DG Phase 2 for Perfect Match Seed
|
||||
if args.match_case == 0.01:
|
||||
if args.model_name == 'lenet':
|
||||
#Train and Test Domains
|
||||
if args.test_domain==0:
|
||||
test_domains= ["0"]
|
||||
elif args.test_domain==1:
|
||||
test_domains= ["15"]
|
||||
elif args.test_domain==2:
|
||||
test_domains=["30"]
|
||||
elif args.test_domain==3:
|
||||
test_domains=["45"]
|
||||
elif args.test_domain==4:
|
||||
test_domains=["60"]
|
||||
elif args.test_domain==5:
|
||||
test_domains=["75"]
|
||||
|
||||
if args.domain_abl == 0:
|
||||
train_domains= ["0", "15", "30", "45", "60", "75"]
|
||||
elif args.domain_abl == 2:
|
||||
train_domains= ["30", "45"]
|
||||
elif args.domain_abl == 3:
|
||||
train_domains= ["30", "45", "60"]
|
||||
|
||||
for angle in test_domains:
|
||||
if angle in train_domains:
|
||||
train_domains.remove(angle)
|
||||
|
||||
base_script= "python3 main-train.py --lr 0.01 --epochs 30 --batch_size 16 --penalty_w 0.0 --penalty_s -1 --penalty_ws 0.0 --penalty_same_ctr 0.0 --penalty_diff_ctr 1.0 --penalty_erm 1.0 --same_margin 1.0 --diff_margin 1.5 --save_logs 0 --test_domain 1 --seed -1 --match_case 0.01 --match_flag 1 --match_interrupt 5 --pre_trained 1 --method_name phi_match --pos_metric cos --out_classes 10 --n_runs 2 --model mnist --perfect_match 1 --erm_base 0 --ctr_phase 0 --erm_phase 1 --epochs_erm 15 --penalty_ws_erm 0.1 --match_case_erm -1 --opt sgd"
|
||||
else:
|
||||
#Train and Test Domains
|
||||
if args.test_domain==0:
|
||||
test_domains= ["30", "45"]
|
||||
elif args.test_domain==1:
|
||||
test_domains= ["0", "90"]
|
||||
elif args.test_domain==2:
|
||||
test_domains=["45"]
|
||||
elif args.test_domain==3:
|
||||
test_domains=["0"]
|
||||
|
||||
script= base_script + ' --dataset ' + str(args.dataset) + ' --domain_abl ' + str(args.domain_abl)
|
||||
os.system(script)
|
||||
if args.domain_abl == 0:
|
||||
train_domains= ["0", "15", "30", "45", "60", "75", "90"]
|
||||
elif args.domain_abl == 2:
|
||||
train_domains= ["30", "45"]
|
||||
elif args.domain_abl == 3:
|
||||
train_domains= ["30", "45", "60"]
|
||||
# train_domains= ["0", "30", "60", "90"]
|
||||
for angle in test_domains:
|
||||
if angle in train_domains:
|
||||
train_domains.remove(angle)
|
||||
|
||||
elif args.dataset == 'color_mnist' or args.dataset == 'rot_color_mnist':
|
||||
train_domains= [0.1, 0.2]
|
||||
test_domains= [0.9]
|
||||
|
||||
elif args.dataset == 'pacs':
|
||||
#Train and Test Domains
|
||||
if args.test_domain==0:
|
||||
test_domains= ["photo"]
|
||||
elif args.test_domain==1:
|
||||
test_domains=["art_painting"]
|
||||
elif args.test_domain==2:
|
||||
test_domains=["cartoon"]
|
||||
elif args.test_domain==3:
|
||||
test_domains=["sketch"]
|
||||
elif args.test_domain==-1:
|
||||
test_domains=["sketch"]
|
||||
|
||||
train_domains= ["photo", "art_painting", "cartoon", "sketch"]
|
||||
for angle in test_domains:
|
||||
if angle in train_domains:
|
||||
train_domains.remove(angle)
|
||||
|
||||
if args.test_domain==-1:
|
||||
train_domains=test_domains
|
||||
|
||||
final_report_accuracy=[]
|
||||
for run in range(args.n_runs):
|
||||
|
||||
#Seed for repoduability
|
||||
torch.manual_seed(run*10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(run*10)
|
||||
|
||||
# Path to save results
|
||||
post_string= str(args.penalty_erm) + '_' + str(args.penalty_ws) + '_' + str(args.penalty_same_ctr) + '_' + str(args.penalty_diff_ctr) + '_' + str(args.rep_dim) + '_' + str(args.match_case) + '_' + str(args.match_interrupt) + '_' + str(args.match_flag) + '_' + str(args.test_domain) + '_' + str(run) + '_' + args.pos_metric + '_' + args.model_name
|
||||
|
||||
|
||||
# Parameters
|
||||
feature_dim= 28*28
|
||||
rep_dim= args.rep_dim
|
||||
num_classes= args.out_classes
|
||||
pre_trained= args.pre_trained
|
||||
|
||||
if args.dataset in ['rot_mnist', 'color_mnist', 'fashion_mnist']:
|
||||
feature_dim= 28*28
|
||||
num_ch=1
|
||||
pre_trained=0
|
||||
# phi= RepNet( feature_dim, rep_dim)
|
||||
if args.erm_base:
|
||||
if args.model_name == 'lenet':
|
||||
phi= LeNet5().to(cuda)
|
||||
else:
|
||||
phi= get_resnet('resnet18', num_classes, args.erm_base, num_ch, pre_trained).to(cuda)
|
||||
else:
|
||||
rep_dim=512
|
||||
phi= get_resnet('resnet18', rep_dim, args.erm_base, num_ch, pre_trained).to(cuda)
|
||||
elif args.dataset in ['pacs', 'vlcs']:
|
||||
if args.model_name == 'alexnet':
|
||||
if args.erm_base:
|
||||
phi= alexnet(num_classes, pre_trained, args.erm_base ).to(cuda)
|
||||
else:
|
||||
rep_dim= 4096
|
||||
phi= alexnet(rep_dim, pre_trained, args.erm_base).to(cuda)
|
||||
elif args.model_name == 'resnet18':
|
||||
num_ch=3
|
||||
if args.erm_base:
|
||||
phi= get_resnet('resnet18', num_classes, args.erm_base, num_ch, pre_trained).to(cuda)
|
||||
else:
|
||||
rep_dim= 512
|
||||
phi= get_resnet('resnet18', rep_dim, args.erm_base, num_ch, pre_trained).to(cuda)
|
||||
print('Model Archtecture: ', args.model_name)
|
||||
|
||||
# Ensure that the rep_dim and the architecture matches
|
||||
# Like for alexnet, resnet the rep dim would be pre determined to be the second last layer
|
||||
phi_erm= ClfNet(phi, rep_dim, num_classes).to(cuda)
|
||||
|
||||
#Main Code
|
||||
epochs=args.epochs
|
||||
batch_size=args.batch_size
|
||||
learning_rate= args.lr
|
||||
lmd=args.penalty_w
|
||||
anneal_iter= args.penalty_s
|
||||
|
||||
if args.opt == 'sgd':
|
||||
opt= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.parameters()) },
|
||||
], lr= learning_rate, weight_decay= 5e-4, momentum= 0.9, nesterov=True )
|
||||
elif args.opt == 'adam':
|
||||
opt= optim.Adam([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.parameters())},
|
||||
], lr= learning_rate)
|
||||
|
||||
patience= 25
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=patience)
|
||||
|
||||
if args.model_name=='alexnet':
|
||||
opt_ws= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.classifier[-1].parameters()), 'lr': learning_rate, 'weight_decay': 1e-5, 'momentum': 0.9 },
|
||||
] )
|
||||
|
||||
opt_all= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.classifier[-1].parameters()), 'lr': learning_rate, 'weight_decay': 1e-5, 'momentum': 0.9 },
|
||||
] )
|
||||
elif args.model_name=='resnet18' or args.method_name=='resnet50':
|
||||
opt_ws= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.fc.parameters()), 'lr': learning_rate, 'weight_decay': 1e-5, 'momentum': 0.9 },
|
||||
] )
|
||||
|
||||
opt_all= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi.fc.parameters()), 'lr': learning_rate, 'weight_decay': 1e-5, 'momentum': 0.9 },
|
||||
] )
|
||||
else:
|
||||
opt_ws=opt
|
||||
opt_all=opt
|
||||
|
||||
# opt_all= optim.SGD([
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.features.parameters()), 'lr': learning_rate/100, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-1].parameters()), 'lr': learning_rate, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-4].parameters()), 'lr': learning_rate, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-7].parameters()), 'lr': learning_rate, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-10].parameters()), 'lr': learning_rate, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-12].parameters()), 'lr': learning_rate/100, 'weight_decay': 5e-4, 'momentum': 0.9 },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.classifier[-15].parameters()), 'lr': learning_rate/100, 'weight_decay': 5e-4, 'momentum': 0.9 }
|
||||
# ] )
|
||||
|
||||
# opt= optim.Adam([
|
||||
# # {'params': filter(lambda p: p.requires_grad, phi.features.parameters()) },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi_erm.rep_net.classifier[-1].parameters()) },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi_erm.erm_net.parameters()) }
|
||||
# ], lr=learning_rate)
|
||||
|
||||
|
||||
loss_erm=[]
|
||||
loss_irm=[]
|
||||
loss_ws=[]
|
||||
loss_same_ctr=[]
|
||||
loss_diff_ctr=[]
|
||||
match_diff=[]
|
||||
match_acc=[]
|
||||
match_rank=[]
|
||||
match_top_k=[]
|
||||
final_acc=[]
|
||||
val_acc=[]
|
||||
|
||||
match_flag=args.match_flag
|
||||
match_interrupt=args.match_interrupt
|
||||
base_domain_idx= args.base_domain_idx
|
||||
match_counter=0
|
||||
|
||||
# DataLoader
|
||||
if args.dataset in ['pacs', 'vlcs']:
|
||||
## TODO: Change the dataloader of PACS to incoporate the val indices
|
||||
train_data_obj= PACS(train_domains, '/pacs/train_val_splits/', data_case='train')
|
||||
val_data_obj= PACS(train_domains, '/pacs/train_val_splits/', data_case='val')
|
||||
test_data_obj= PACS(test_domains, '/pacs/train_val_splits/', data_case='test')
|
||||
elif args.dataset in ['rot_mnist', 'fashion_mnist']:
|
||||
train_data_obj= MnistRotated(args.dataset, train_domains, 3+run, 'data/rot_mnist', data_case='train')
|
||||
val_data_obj= MnistRotated(args.dataset, train_domains, 3+run, 'data/rot_mnist', data_case='val')
|
||||
test_data_obj= MnistRotated(args.dataset, test_domains, 3+run, 'data/rot_mnist', data_case='test')
|
||||
|
||||
train_dataset, val_dataset, test_dataset= get_dataloader( train_data_obj, val_data_obj, test_data_obj )
|
||||
|
||||
total_domains= len(train_domains)
|
||||
domain_size= train_data_obj.base_domain_size
|
||||
base_domain_idx= train_data_obj.base_domain_idx
|
||||
training_list_size= train_data_obj.training_list_size
|
||||
print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, domain_size, base_domain_idx, total_domains, training_list_size)
|
||||
|
||||
|
||||
# Either end to end training fashion (erm_base) or contrastive learning rep phase (ctr_phase)
|
||||
if args.erm_base or args.ctr_phase:
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
if epoch % match_interrupt == 0:
|
||||
#Start with initially defined batch; else find the local approximate batch
|
||||
if epoch > 0:
|
||||
inferred_match=1
|
||||
if args.match_flag and match_counter <100:
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( args, train_dataset, domain_size, total_domains, training_list_size, phi, args.match_case, inferred_match )
|
||||
match_counter+=1
|
||||
#Reset the weights after very match strategy update
|
||||
# phi= RotMNIST( feature_dim, num_classes ).to(cuda)
|
||||
# opt= optim.Adam([
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.predict_conv_net.parameters()) },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.predict_fc_net.parameters()) },
|
||||
# {'params': filter(lambda p: p.requires_grad, phi.predict_final_net.parameters()) }
|
||||
# ], lr=learning_rate)
|
||||
|
||||
elif args.match_flag ==0 or match_counter>=1:
|
||||
temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( args, train_dataset, domain_size, total_domains, training_list_size, phi, args.match_case, inferred_match )
|
||||
|
||||
perfect_match_rank= np.array(perfect_match_rank)
|
||||
if args.perfect_match:
|
||||
print('Mean Perfect Match Score: ', np.mean(perfect_match_rank), 100*np.sum(perfect_match_rank < 10)/perfect_match_rank.shape[0] )
|
||||
match_rank.append( np.mean(perfect_match_rank) )
|
||||
match_top_k.append( 100*np.sum( perfect_match_rank < 10 )/perfect_match_rank.shape[0] )
|
||||
|
||||
else:
|
||||
inferred_match=0
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( args, train_dataset, domain_size, total_domains, training_list_size, phi, args.match_case, inferred_match )
|
||||
|
||||
## To ensure a random match keeps happening after every match interrupt
|
||||
# data_match_tensor, label_match_tensor, indices_matched= get_matched_pairs( args, train_dataset, domain_size, total_domains, base_domain_idx, args.match_case )
|
||||
if args.perfect_match:
|
||||
score= perfect_match_score(indices_matched)
|
||||
print('Perfect Match Score: ', score)
|
||||
match_acc.append(score)
|
||||
|
||||
# Decide which losses to optimizer depending on the end to end case (erm_base==1) or block wise (erm_base=0)
|
||||
if args.erm_base:
|
||||
bool_erm=1
|
||||
bool_ws=1
|
||||
bool_ctr=0
|
||||
else:
|
||||
bool_erm=0
|
||||
bool_ws=1
|
||||
bool_ctr=1
|
||||
|
||||
# To decide which till which layer to finetune
|
||||
if epoch > -1:
|
||||
penalty_erm, penalty_irm, penalty_ws, penalty_same_ctr, penalty_diff_ctr = train( train_dataset, data_match_tensor, label_match_tensor, phi, opt, opt_ws, scheduler, epoch, base_domain_idx, bool_erm, bool_ws, bool_ctr )
|
||||
else:
|
||||
penalty_erm, penalty_irm, penalty_ws, penalty_same_ctr, penalty_diff_ctr= train( train_dataset, data_match_tensor, label_match_tensor, phi, opt_all, opt_ws, epoch, base_domain_idx, bool_erm, bool_ws, bool_ctr )
|
||||
|
||||
loss_erm.append( penalty_erm )
|
||||
loss_irm.append( penalty_irm )
|
||||
loss_ws.append( penalty_ws )
|
||||
loss_same_ctr.append( penalty_same_ctr )
|
||||
loss_diff_ctr.append( penalty_diff_ctr )
|
||||
|
||||
if bool_erm:
|
||||
#Validation Phase
|
||||
test_acc= test( val_dataset, phi, epoch, 'Val' )
|
||||
val_acc.append( test_acc )
|
||||
#Testing Phase
|
||||
test_acc= test( test_dataset, phi, epoch, 'Test' )
|
||||
final_acc.append( test_acc )
|
||||
|
||||
loss_erm= np.array(loss_erm)
|
||||
loss_irm= np.array(loss_irm)
|
||||
loss_ws= np.array(loss_ws)
|
||||
loss_same_ctr= np.array(loss_same_ctr)
|
||||
loss_diff_ctr= np.array(loss_diff_ctr)
|
||||
final_acc= np.array(final_acc)
|
||||
val_acc= np.array(val_acc)
|
||||
match_rank= np.array(match_rank)
|
||||
match_top_k= np.array(match_top_k)
|
||||
|
||||
if args.erm_base:
|
||||
if args.domain_abl==0:
|
||||
sub_dir= '/ERM_Base'
|
||||
elif args.domain_abl ==2:
|
||||
sub_dir= '/ERM_Base/' + train_domains[0] + '_' + train_domains[1]
|
||||
elif args.domain_abl ==3:
|
||||
sub_dir= '/ERM_Base/' + train_domains[0] + '_' + train_domains[1] + '_' + train_domains[2]
|
||||
elif args.ctr_phase:
|
||||
if args.domain_abl==0:
|
||||
sub_dir= '/CTR'
|
||||
elif args.domain_abl ==2:
|
||||
sub_dir= '/CTR/' + train_domains[0] + '_' + train_domains[1]
|
||||
elif args.domain_abl ==3:
|
||||
sub_dir= '/CTR/' + train_domains[0] + '_' + train_domains[1] + '_' + train_domains[2]
|
||||
|
||||
np.save( base_res_dir + args.method_name + sub_dir + '/ERM_' + post_string + '.npy' , loss_erm )
|
||||
np.save( base_res_dir + args.method_name + sub_dir + '/WS_' + post_string + '.npy', loss_ws )
|
||||
np.save( base_res_dir + args.method_name + sub_dir + '/S_CTR_' + post_string + '.npy', loss_same_ctr )
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/D_CTR_' + post_string + '.npy', loss_diff_ctr )
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/ACC_' + post_string + '.npy', final_acc )
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/Val_' + post_string + '.npy', val_acc )
|
||||
|
||||
|
||||
if args.perfect_match:
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/Match_Acc_' + post_string + '.npy', match_acc )
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/Match_Rank_' + post_string + '.npy', match_rank )
|
||||
np.save( base_res_dir + args.method_name + sub_dir +'/Match_TopK_' + post_string + '.npy', match_top_k )
|
||||
|
||||
# Store the weights of the model
|
||||
torch.save(phi.state_dict(), base_res_dir + args.method_name + sub_dir + '/Model_' + post_string + '.pth')
|
||||
|
||||
# Final Report Accuacy
|
||||
if args.erm_base:
|
||||
final_report_accuracy.append( final_acc[-1] )
|
||||
|
||||
|
||||
if args.erm_phase:
|
||||
|
||||
for run_erm in range(args.n_runs_erm):
|
||||
|
||||
# Load RepNet from save weights
|
||||
sub_dir='/CTR'
|
||||
save_path= base_res_dir + args.method_name + sub_dir + '/Model_' + post_string + '.pth'
|
||||
phi.load_state_dict( torch.load(save_path) )
|
||||
phi.eval()
|
||||
|
||||
|
||||
#Inferred Match Case
|
||||
if args.match_case_erm == -1:
|
||||
inferred_match=1
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( args, train_dataset, domain_size, total_domains, training_list_size, phi, args.match_case_erm, inferred_match )
|
||||
|
||||
if args.perfect_match:
|
||||
score= perfect_match_score(indices_matched)
|
||||
print('Perfect Match Score: ', score)
|
||||
perfect_match_rank= np.array(perfect_match_rank)
|
||||
print('Mean Perfect Match Score: ', np.mean(perfect_match_rank), 100*np.sum(perfect_match_rank < 10)/perfect_match_rank.shape[0] )
|
||||
|
||||
else:
|
||||
inferred_match=0
|
||||
# x% percentage match initial strategy
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( args, train_dataset, domain_size, total_domains, training_list_size, phi, args.match_case_erm, inferred_match )
|
||||
|
||||
if args.perfect_match:
|
||||
score= perfect_match_score(indices_matched)
|
||||
print('Perfect Match Score: ', score)
|
||||
perfect_match_rank= np.array(perfect_match_rank)
|
||||
print('Mean Perfect Match Score: ', np.mean(perfect_match_rank), 100*np.sum(perfect_match_rank < 10)/perfect_match_rank.shape[0] )
|
||||
|
||||
|
||||
# Model and parameters
|
||||
if args.retain:
|
||||
phi_erm= ClfNet( phi, rep_dim, num_classes ).to(cuda)
|
||||
else:
|
||||
if args.dataset in ['rot_mnist', 'color_mnist', 'fashion_mnist']:
|
||||
feature_dim= 28*28
|
||||
num_ch=1
|
||||
pre_trained=0
|
||||
if args.model_name == 'lenet':
|
||||
phi_erm= LeNet5().to(cuda)
|
||||
else:
|
||||
phi_erm= get_resnet('resnet18', num_classes, 1, num_ch, pre_trained).to(cuda)
|
||||
|
||||
elif args.dataset in ['pacs', 'vlcs']:
|
||||
if args.model_name == 'alexnet':
|
||||
phi_erm= alexnet(num_classes, pre_trained, 1 ).to(cuda)
|
||||
elif args.model_name == 'resnet18':
|
||||
num_ch=3
|
||||
phi_erm= get_resnet('resnet18', num_classes, 1, num_ch, pre_trained).to(cuda)
|
||||
|
||||
learning_rate=args.lr
|
||||
opt= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, phi_erm.parameters()) }
|
||||
], lr=learning_rate, weight_decay=5e-4, momentum=0.9)
|
||||
|
||||
|
||||
#Training and Evaludation
|
||||
final_acc=[]
|
||||
val_acc=[]
|
||||
for epoch in range(args.epochs_erm):
|
||||
|
||||
#Train Specifications
|
||||
bool_erm=1
|
||||
bool_ws=1
|
||||
bool_ctr=0
|
||||
#Train
|
||||
penalty_erm, penalty_irm, penalty_ws, penalty_same_ctr, penalty_diff_ctr = train( train_dataset, data_match_tensor, label_match_tensor, phi_erm, opt, opt_ws, scheduler, epoch, base_domain_idx, bool_erm, bool_ws, bool_ctr )
|
||||
#Test
|
||||
#Validation Phase
|
||||
test_acc= test( val_dataset, phi_erm, epoch, 'Val' )
|
||||
val_acc.append( test_acc )
|
||||
#Testing Phase
|
||||
test_acc= test( test_dataset, phi_erm, epoch, 'Test' )
|
||||
final_acc.append(test_acc)
|
||||
|
||||
post_string_erm= str(args.penalty_erm) + '_' + str(args.penalty_ws) + '_' + str(args.penalty_same_ctr) + '_' + str(args.penalty_diff_ctr) + '_' + str(args.rep_dim) + '_' + str(args.match_case) + '_' + str(args.match_interrupt) + '_' + str(args.match_flag) + '_' + str(args.test_domain) + '_' + str(run) + '_' + args.pos_metric + '_' + args.model_name + '_' + str(args.penalty_ws_erm) + '_' + str(args.match_case_erm) + '_' + str(run_erm)
|
||||
|
||||
final_acc= np.array(final_acc)
|
||||
|
||||
if args.domain_abl==0:
|
||||
sub_dir= '/ERM'
|
||||
elif args.domain_abl ==2:
|
||||
sub_dir= '/ERM/' + train_domains[0] + '_' + train_domains[1]
|
||||
elif args.domain_abl ==3:
|
||||
sub_dir= '/ERM/' + train_domains[0] + '_' + train_domains[1] + '_' + train_domains[2]
|
||||
|
||||
np.save( base_res_dir + args.method_name + sub_dir + '/ACC_' + post_string_erm + '.npy' , final_acc )
|
||||
|
||||
# Store the weights of the model
|
||||
torch.save(phi_erm.state_dict(), base_res_dir + args.method_name + sub_dir + '/Model_' + post_string_erm + '.pth')
|
||||
|
||||
# Final Report Accuracy
|
||||
if args.erm_phase:
|
||||
final_report_accuracy.append( final_acc[-1] )
|
||||
|
||||
if args.erm_base or args.erm_phase:
|
||||
print('\n')
|
||||
print('Done for the Model..')
|
||||
|
||||
print('Final Test Accuracy', np.mean(final_report_accuracy), np.std(final_report_accuracy) )
|
||||
|
||||
print('\n')
|
|
@ -0,0 +1,149 @@
|
|||
def t_sne_plot(X):
|
||||
X= X.detach().cpu().numpy()
|
||||
X= TSNE(n_components=2).fit_transform(X)
|
||||
return X
|
||||
|
||||
def classifier(x_e, phi, w):
|
||||
return torch.matmul(phi(x_e), w)
|
||||
|
||||
def erm_loss(temp_logits, target_label):
|
||||
loss= F.cross_entropy(temp_logits, target_label.long()).to(cuda)
|
||||
return loss
|
||||
|
||||
def cosine_similarity( x1, x2 ):
|
||||
cos= torch.nn.CosineSimilarity(dim=1, eps=1e-08)
|
||||
return 1.0 - cos(x1, x2)
|
||||
|
||||
def l1_dist(x1, x2):
|
||||
|
||||
#Broadcasting
|
||||
if len(x1.shape) == len(x2.shape) - 1:
|
||||
x1=x1.unsqueeze(1)
|
||||
if len(x2.shape) == len(x1.shape) - 1:
|
||||
x2=x2.unsqueeze(1)
|
||||
|
||||
if len(x1.shape) == 3 and len(x2.shape) ==3:
|
||||
# Tensor shapes: (N,1,D) and (N,K,D) so x1-x2 would result in (N,K,D)
|
||||
return torch.sum( torch.sum(torch.abs(x1 - x2), dim=2) , dim=1 )
|
||||
elif len(x1.shape) ==2 and len(x2.shape) ==2:
|
||||
return torch.sum( torch.abs(x1 - x2), dim=1 )
|
||||
elif len(x1.shape) ==1 and len(x2.shape) ==1:
|
||||
return torch.sum( torch.abs(x1 - x2), dim=0 )
|
||||
else:
|
||||
print('Error: Expect 1, 2 or 3 rank tensors to compute L1 Norm')
|
||||
return
|
||||
|
||||
def l2_dist(x1, x2):
|
||||
|
||||
#Broadcasting
|
||||
if len(x1.shape) == len(x2.shape) - 1:
|
||||
x1=x1.unsqueeze(1)
|
||||
if len(x2.shape) == len(x1.shape) - 1:
|
||||
x2=x2.unsqueeze(1)
|
||||
|
||||
if len(x1.shape) == 3 and len(x2.shape) ==3:
|
||||
# Tensor shapes: (N,1,D) and (N,K,D) so x1-x2 would result in (N,K,D)
|
||||
return torch.sum( torch.sum((x1 - x2)**2, dim=2) , dim=1 )
|
||||
elif len(x1.shape) ==2 and len(x2.shape) ==2:
|
||||
return torch.sum( (x1 - x2)**2, dim=1 )
|
||||
elif len(x1.shape) ==1 and len(x2.shape) ==1:
|
||||
return torch.sum( (x1 - x2)**2, dim=0 )
|
||||
else:
|
||||
print('Error: Expect 1, 2 or 3 rank tensors to compute L2 Norm')
|
||||
return
|
||||
|
||||
def embedding_dist(x1, x2, tau=0.05, xent=False):
|
||||
|
||||
if xent:
|
||||
#X1 denotes the batch of anchors while X2 denotes all the negative matches
|
||||
#Broadcasting to compute loss for each anchor over all the negative matches
|
||||
|
||||
#Only implemnted if x1, x2 are 2 rank tensors
|
||||
if len(x1.shape) != 2 or len(x2.shape) !=2:
|
||||
print('Error: both should be rank 2 tensors for NT-Xent loss computation')
|
||||
|
||||
#Normalizing each vector
|
||||
## Take care to reshape the norm: For a (N*D) vector; the norm would be (N) which needs to be shaped to (N,1) to ensure row wise l2 normalization takes place
|
||||
if torch.sum( torch.isnan( x1 ) ):
|
||||
print('X1 is nan')
|
||||
sys.exit()
|
||||
|
||||
if torch.sum( torch.isnan( x2 ) ):
|
||||
print('X1 is nan')
|
||||
sys.exit()
|
||||
|
||||
eps=1e-8
|
||||
|
||||
norm= x1.norm(dim=1)
|
||||
norm= norm.view(norm.shape[0], 1)
|
||||
temp= eps*torch.ones_like(norm)
|
||||
|
||||
x1= x1/torch.max(norm, temp)
|
||||
|
||||
if torch.sum( torch.isnan( x1 ) ):
|
||||
print('X1 Norm is nan')
|
||||
sys.exit()
|
||||
|
||||
norm= x2.norm(dim=1)
|
||||
norm= norm.view(norm.shape[0], 1)
|
||||
temp= eps*torch.ones_like(norm)
|
||||
|
||||
x2= x2/torch.max(norm, temp)
|
||||
|
||||
if torch.sum( torch.isnan( x2 ) ):
|
||||
print('Norm: ', norm, x2 )
|
||||
print('X2 Norm is nan')
|
||||
sys.exit()
|
||||
|
||||
|
||||
# Boradcasting the anchors vector to compute loss over all negative matches
|
||||
x1=x1.unsqueeze(1)
|
||||
cos_sim= torch.sum( x1*x2, dim=2)
|
||||
cos_sim= cos_sim / args.tau
|
||||
|
||||
if torch.sum( torch.isnan( cos_sim ) ):
|
||||
print('Cos is nan')
|
||||
sys.exit()
|
||||
|
||||
loss= torch.sum( torch.exp(cos_sim), dim=1)
|
||||
|
||||
if torch.sum( torch.isnan( loss ) ):
|
||||
print('Loss is nan')
|
||||
sys.exit()
|
||||
|
||||
return loss
|
||||
|
||||
else:
|
||||
if args.pos_metric == 'l1':
|
||||
return l1_dist(x1, x2)
|
||||
elif args.pos_metric == 'l2':
|
||||
return l2_dist(x1, x2)
|
||||
elif args.pos_metric == 'cos':
|
||||
return cosine_similarity( x1, x2 )
|
||||
|
||||
|
||||
def compute_penalty( model, feature, target_label, domain_label):
|
||||
curr_domains= np.unique(domain_label)
|
||||
ret= torch.tensor(0.).to(cuda)
|
||||
for domain in curr_domains:
|
||||
indices= domain_label == domain
|
||||
temp_logits= model(feature[indices])
|
||||
labels= target_label[indices]
|
||||
scale = torch.tensor(1.).to(cuda).requires_grad_()
|
||||
loss = F.cross_entropy(temp_logits*scale, labels.long()).to(cuda)
|
||||
g = grad(loss, [scale], create_graph=True)[0].to(cuda)
|
||||
# Since g is scalar output, do we need torch.sum?
|
||||
ret+= torch.sum(g**2)
|
||||
|
||||
return ret
|
||||
|
||||
def get_dataloader(train_data_obj, val_data_obj, test_data_obj):
|
||||
# Load supervised training
|
||||
train_dataset = data_utils.DataLoader(train_data_obj, batch_size=args.batch_size, shuffle=True, **kwargs )
|
||||
|
||||
# Can select a higher batch size for val and test domains
|
||||
test_batch=512
|
||||
val_dataset = data_utils.DataLoader(val_data_obj, batch_size=test_batch, shuffle=True, **kwargs )
|
||||
test_dataset = data_utils.DataLoader(test_data_obj, batch_size=test_batch, shuffle=True, **kwargs )
|
||||
|
||||
return train_dataset, val_dataset, test_dataset
|
|
@ -0,0 +1,224 @@
|
|||
def perfect_match_score(indices_matched):
|
||||
counter=0
|
||||
score=0
|
||||
for key in indices_matched:
|
||||
for match in indices_matched[key]:
|
||||
if key == match:
|
||||
score+=1
|
||||
counter+=1
|
||||
if counter:
|
||||
return 100*score/counter
|
||||
else:
|
||||
return 0
|
||||
|
||||
def init_data_match_dict(args, keys, vals, variation):
|
||||
data={}
|
||||
for key in keys:
|
||||
data[key]={}
|
||||
if variation:
|
||||
val_dim= vals[key]
|
||||
else:
|
||||
val_dim= vals
|
||||
|
||||
if args.dataset == 'color_mnist':
|
||||
data[key]['data']=torch.rand((val_dim, 2, 28, 28))
|
||||
elif args.dataset == 'rot_mnist' or args.dataset == 'fashion_mnist':
|
||||
if args.model_name == 'lenet':
|
||||
data[key]['data']=torch.rand((val_dim, 1, 32, 32))
|
||||
elif args.model_name == 'resnet18':
|
||||
data[key]['data']=torch.rand((val_dim, 1, 224, 224))
|
||||
elif args.dataset == 'pacs':
|
||||
data[key]['data']=torch.rand((val_dim, 3, 227, 227))
|
||||
|
||||
data[key]['label']=torch.rand((val_dim, 1))
|
||||
data[key]['idx']=torch.randint(0, 1, (val_dim, 1))
|
||||
return data
|
||||
|
||||
def get_matched_pairs(args, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, inferred_match):
|
||||
|
||||
#Making Data Matched pairs
|
||||
data_matched= init_data_match_dict( args, range(domain_size), total_domains, 0 )
|
||||
domain_data= init_data_match_dict( args, range(total_domains), training_list_size, 1)
|
||||
indices_matched={}
|
||||
for key in range(domain_size):
|
||||
indices_matched[key]=[]
|
||||
perfect_match_rank=[]
|
||||
|
||||
domain_count={}
|
||||
for domain in range(total_domains):
|
||||
domain_count[domain]= 0
|
||||
|
||||
# Create dictionary: class label -> list of ordered indices
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(train_dataset):
|
||||
x_e= x_e
|
||||
y_e= torch.argmax(y_e, dim=1)
|
||||
d_e= torch.argmax(d_e, dim=1).numpy()
|
||||
|
||||
domain_indices= np.unique(d_e)
|
||||
for domain_idx in domain_indices:
|
||||
indices= d_e == domain_idx
|
||||
ordered_indices= idx_e[indices]
|
||||
for idx in range(ordered_indices.shape[0]):
|
||||
#Matching points across domains
|
||||
perfect_indice= ordered_indices[idx].item()
|
||||
domain_data[domain_idx]['data'][perfect_indice]= x_e[indices][idx]
|
||||
domain_data[domain_idx]['label'][perfect_indice]= y_e[indices][idx]
|
||||
domain_data[domain_idx]['idx'][perfect_indice]= idx_e[indices][idx]
|
||||
domain_count[domain_idx]+= 1
|
||||
|
||||
#Sanity Check: To check if the domain_data was updated for all the data points
|
||||
for domain in range(total_domains):
|
||||
if domain_count[domain] != training_list_size[domain]:
|
||||
print('Issue: Some data points are missing from domain_data dictionary')
|
||||
|
||||
# Creating the random permutation tensor for each domain
|
||||
perm_size= int(domain_size*(1-match_case))
|
||||
|
||||
#Determine the base_domain_idx as the domain with the max samples of the current class
|
||||
base_domain_dict={}
|
||||
for y_c in range(args.out_classes):
|
||||
base_domain_size=0
|
||||
base_domain_idx=-1
|
||||
for domain_idx in range(total_domains):
|
||||
class_idx= domain_data[domain_idx]['label'] == y_c
|
||||
curr_size= domain_data[domain_idx]['label'][class_idx].shape[0]
|
||||
if base_domain_size < curr_size:
|
||||
base_domain_size= curr_size
|
||||
base_domain_idx= domain_idx
|
||||
|
||||
base_domain_dict[y_c]= base_domain_idx
|
||||
print('Base Domain: ', base_domain_size, base_domain_idx, y_c )
|
||||
|
||||
# Applying the random permutation tensor
|
||||
for domain_idx in range(total_domains):
|
||||
total_rand_counter=0
|
||||
total_data_idx=0
|
||||
for y_c in range(args.out_classes):
|
||||
base_domain_idx= base_domain_dict[y_c]
|
||||
|
||||
indices_base= domain_data[base_domain_idx]['label'] == y_c
|
||||
indices_base= indices_base[:,0]
|
||||
ordered_base_indices= domain_data[base_domain_idx]['idx'][indices_base]
|
||||
|
||||
indices_curr= domain_data[domain_idx]['label'] == y_c
|
||||
indices_curr= indices_curr[:,0]
|
||||
ordered_curr_indices= domain_data[domain_idx]['idx'][indices_curr]
|
||||
curr_size= ordered_curr_indices.shape[0]
|
||||
|
||||
# Sanity check for perfect match case:
|
||||
if args.perfect_match:
|
||||
if not torch.equal(ordered_base_indices, ordered_curr_indices):
|
||||
print('Issue: Different indices across domains for perfect match' )
|
||||
|
||||
# Only for the perfect match case to generate x% correct match strategy
|
||||
rand_base_indices= ordered_base_indices[ ordered_base_indices < perm_size ]
|
||||
idx_perm= torch.randperm( rand_base_indices.shape[0] )
|
||||
rand_base_indices= rand_base_indices[idx_perm]
|
||||
rand_counter=0
|
||||
|
||||
base_feat_data=domain_data[base_domain_idx]['data'][indices_base]
|
||||
base_feat_data_split= torch.split( base_feat_data, args.batch_size, dim=0 )
|
||||
base_feat=[]
|
||||
for batch_feat in base_feat_data_split:
|
||||
with torch.no_grad():
|
||||
batch_feat=batch_feat.to(cuda)
|
||||
out= phi(batch_feat)
|
||||
base_feat.append(out.cpu())
|
||||
base_feat= torch.cat(base_feat)
|
||||
|
||||
if inferred_match:
|
||||
feat_x_data= domain_data[domain_idx]['data'][indices_curr]
|
||||
feat_x_data_split= torch.split(feat_x_data, args.batch_size, dim=0)
|
||||
feat_x=[]
|
||||
for batch_feat in feat_x_data_split:
|
||||
with torch.no_grad():
|
||||
batch_feat= batch_feat.to(cuda)
|
||||
out= phi(batch_feat)
|
||||
feat_x.append(out.cpu())
|
||||
feat_x= torch.cat(feat_x)
|
||||
|
||||
base_feat= base_feat.unsqueeze(1)
|
||||
base_feat_split= torch.split(base_feat, args.batch_size, dim=0)
|
||||
|
||||
data_idx=0
|
||||
for batch_feat in base_feat_split:
|
||||
|
||||
if inferred_match:
|
||||
# Need to compute over batches of base_fear due ot CUDA Memory out errors
|
||||
# Else no ned for loop over base_feat_split; could have simply computed feat_x - base_feat
|
||||
ws_dist= torch.sum( (feat_x - batch_feat)**2, dim=2)
|
||||
match_idx= torch.argmin( ws_dist, dim=1 )
|
||||
sort_val, sort_idx= torch.sort( ws_dist, dim=1 )
|
||||
del ws_dist
|
||||
|
||||
for idx in range(batch_feat.shape[0]):
|
||||
perfect_indice= ordered_base_indices[data_idx].item()
|
||||
|
||||
if domain_idx == base_domain_idx:
|
||||
curr_indice= perfect_indice
|
||||
else:
|
||||
if args.perfect_match:
|
||||
if inferred_match:
|
||||
curr_indice= ordered_curr_indices[match_idx[idx]].item()
|
||||
#Find where does the perfect match lies in the sorted order of matches
|
||||
#In the situations where the perfect match is known; the ordered_curr_indices and ordered_base_indices are the same
|
||||
perfect_match_rank.append( (ordered_curr_indices[sort_idx[idx]] == perfect_indice).nonzero()[0,0].item() )
|
||||
else:
|
||||
# To allow x% match case type permutations for datasets where the perfect match is known
|
||||
# In perfect match settings; same ordered indice implies perfect match across domains
|
||||
if perfect_indice < perm_size:
|
||||
curr_indice= rand_base_indices[rand_counter].item()
|
||||
rand_counter+=1
|
||||
total_rand_counter+=1
|
||||
else:
|
||||
curr_indice= perfect_indice
|
||||
indices_matched[perfect_indice].append(curr_indice)
|
||||
|
||||
else:
|
||||
if inferred_match:
|
||||
curr_indice= ordered_curr_indices[match_idx[idx]].item()
|
||||
else:
|
||||
curr_indice= ordered_curr_indices[data_idx%curr_size].item()
|
||||
|
||||
data_matched[total_data_idx]['data'][domain_idx]= domain_data[domain_idx]['data'][curr_indice]
|
||||
data_matched[total_data_idx]['label'][domain_idx]= domain_data[domain_idx]['label'][curr_indice]
|
||||
data_idx+=1
|
||||
total_data_idx+=1
|
||||
|
||||
if total_data_idx != domain_size:
|
||||
print('Issue: Some data points left from data_matched dictionary', total_data_idx, domain_size)
|
||||
|
||||
if args.perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
|
||||
print('Issue: Total random changes made are less than perm_size for domain', domain_idx, total_rand_counter, perm_size)
|
||||
|
||||
|
||||
# Sanity Check: N keys; K vals per key
|
||||
for key in data_matched.keys():
|
||||
if data_matched[key]['label'].shape[0] != total_domains:
|
||||
print('Issue with data matching')
|
||||
|
||||
#Sanity Check: Ensure paired points have the same class label
|
||||
wrong_case=0
|
||||
for key in data_matched.keys():
|
||||
for d_i in range(data_matched[key]['label'].shape[0]):
|
||||
for d_j in range(data_matched[key]['label'].shape[0]):
|
||||
if d_j > d_i:
|
||||
if data_matched[key]['label'][d_i] != data_matched[key]['label'][d_j]:
|
||||
wrong_case+=1
|
||||
print('Total Label MisMatch across pairs: ', wrong_case )
|
||||
|
||||
data_match_tensor=[]
|
||||
label_match_tensor=[]
|
||||
for key in data_matched.keys():
|
||||
data_match_tensor.append( data_matched[key]['data'] )
|
||||
label_match_tensor.append(data_matched[key]['label'] )
|
||||
|
||||
data_match_tensor= torch.stack( data_match_tensor )
|
||||
label_match_tensor= torch.stack( label_match_tensor )
|
||||
|
||||
print(data_match_tensor.shape, label_match_tensor.shape)
|
||||
|
||||
del domain_data
|
||||
del data_matched
|
||||
|
||||
return data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank
|
Загрузка…
Ссылка в новой задаче