From 7929766bc76555240360707ed31275adcdb457cc Mon Sep 17 00:00:00 2001 From: divyat09 Date: Mon, 28 Sep 2020 12:20:10 +0000 Subject: [PATCH] Fix for perfect match condition in self augmentation match score --- algorithms/csd.py | 10 +++++----- algorithms/hybrid.py | 3 ++- algorithms/match_dg.py | 4 ++-- evaluation/match_eval.py | 6 +++++- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/algorithms/csd.py b/algorithms/csd.py index bbea92a..1db10ee 100644 --- a/algorithms/csd.py +++ b/algorithms/csd.py @@ -28,11 +28,11 @@ class CSD(BaseAlgo): self.K, m, self.num_classes = 1, H_dim, self.args.out_classes num_domains = self.total_domains - self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) - self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) + self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) + self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True) - self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True) + self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True) self.opt= optim.SGD([ {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, @@ -48,7 +48,7 @@ class CSD(BaseAlgo): x = self.phi(x) w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :] logits_common = torch.matmul(x, w_c) + b_c - + if eval_case: return logits_common @@ -70,7 +70,7 @@ class CSD(BaseAlgo): cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0) orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2) - loss = 0.5*class_loss + 0.5*specific_loss + orth_loss + loss = class_loss + specific_loss + orth_loss return loss, logits_common def epoch_callback(self, nepoch, final=False): diff --git a/algorithms/hybrid.py b/algorithms/hybrid.py index 6947d80..d3b4f5e 100644 --- a/algorithms/hybrid.py +++ b/algorithms/hybrid.py @@ -78,7 +78,8 @@ class Hybrid(BaseAlgo): for epoch in range(self.args.epochs): if epoch ==0: - data_match_tensor, label_match_tensor= self.init_erm_phase() +# data_match_tensor, label_match_tensor= self.init_erm_phase() + data_match_tensor, label_match_tensor= self.get_match_function(epoch) elif epoch % self.args.match_interrupt == 0 and self.args.match_flag: data_match_tensor, label_match_tensor= self.get_match_function(epoch) diff --git a/algorithms/match_dg.py b/algorithms/match_dg.py index 11a15e7..9ac2dd1 100644 --- a/algorithms/match_dg.py +++ b/algorithms/match_dg.py @@ -154,8 +154,8 @@ class MatchDG(BaseAlgo): pos_feat_match= feat_match[pos_indices] neg_feat_match= feat_match[neg_indices] - if pos_feat_match.shape[0] > neg_feat_match.shape[0]: - print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0]) +# if pos_feat_match.shape[0] > neg_feat_match.shape[0]: +# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0]) # If no instances of label y_c in the current batch then continue if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0: diff --git a/evaluation/match_eval.py b/evaluation/match_eval.py index 02684be..7f0f10f 100644 --- a/evaluation/match_eval.py +++ b/evaluation/match_eval.py @@ -45,7 +45,11 @@ class MatchEval(BaseEval): base_domain_size= self.test_dataset['base_domain_size'] domain_size_list= self.test_dataset['domain_size_list'] - inferred_match=1 + inferred_match=1 + # Self Augmentation Match Function evaluation will always follow perfect matches + if self.args.match_func_aug_case: + self.args.perfect_match= 1 + data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match ) score= perfect_match_score(indices_matched)