Fix for perfect match condition in self augmentation match score

This commit is contained in:
divyat09 2020-09-28 12:20:10 +00:00
Родитель ce271e0930
Коммит 7929766bc7
4 изменённых файлов: 14 добавлений и 9 удалений

Просмотреть файл

@ -28,11 +28,11 @@ class CSD(BaseAlgo):
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes self.K, m, self.num_classes = 1, H_dim, self.args.out_classes
num_domains = self.total_domains num_domains = self.total_domains
self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True) self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True) self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.opt= optim.SGD([ self.opt= optim.SGD([
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
@ -48,7 +48,7 @@ class CSD(BaseAlgo):
x = self.phi(x) x = self.phi(x)
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :] w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
logits_common = torch.matmul(x, w_c) + b_c logits_common = torch.matmul(x, w_c) + b_c
if eval_case: if eval_case:
return logits_common return logits_common
@ -70,7 +70,7 @@ class CSD(BaseAlgo):
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0) cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2) orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)
loss = 0.5*class_loss + 0.5*specific_loss + orth_loss loss = class_loss + specific_loss + orth_loss
return loss, logits_common return loss, logits_common
def epoch_callback(self, nepoch, final=False): def epoch_callback(self, nepoch, final=False):

Просмотреть файл

@ -78,7 +78,8 @@ class Hybrid(BaseAlgo):
for epoch in range(self.args.epochs): for epoch in range(self.args.epochs):
if epoch ==0: if epoch ==0:
data_match_tensor, label_match_tensor= self.init_erm_phase() # data_match_tensor, label_match_tensor= self.init_erm_phase()
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag: elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
data_match_tensor, label_match_tensor= self.get_match_function(epoch) data_match_tensor, label_match_tensor= self.get_match_function(epoch)

Просмотреть файл

@ -154,8 +154,8 @@ class MatchDG(BaseAlgo):
pos_feat_match= feat_match[pos_indices] pos_feat_match= feat_match[pos_indices]
neg_feat_match= feat_match[neg_indices] neg_feat_match= feat_match[neg_indices]
if pos_feat_match.shape[0] > neg_feat_match.shape[0]: # if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0]) # print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
# If no instances of label y_c in the current batch then continue # If no instances of label y_c in the current batch then continue
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0: if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:

Просмотреть файл

@ -45,7 +45,11 @@ class MatchEval(BaseEval):
base_domain_size= self.test_dataset['base_domain_size'] base_domain_size= self.test_dataset['base_domain_size']
domain_size_list= self.test_dataset['domain_size_list'] domain_size_list= self.test_dataset['domain_size_list']
inferred_match=1 inferred_match=1
# Self Augmentation Match Function evaluation will always follow perfect matches
if self.args.match_func_aug_case:
self.args.perfect_match= 1
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match ) data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
score= perfect_match_score(indices_matched) score= perfect_match_score(indices_matched)