Fix for perfect match condition in self augmentation match score
This commit is contained in:
Родитель
ce271e0930
Коммит
7929766bc7
|
@ -28,11 +28,11 @@ class CSD(BaseAlgo):
|
|||
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes
|
||||
num_domains = self.total_domains
|
||||
|
||||
self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
|
||||
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||
|
||||
self.opt= optim.SGD([
|
||||
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
|
||||
|
@ -48,7 +48,7 @@ class CSD(BaseAlgo):
|
|||
x = self.phi(x)
|
||||
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
|
||||
logits_common = torch.matmul(x, w_c) + b_c
|
||||
|
||||
|
||||
if eval_case:
|
||||
return logits_common
|
||||
|
||||
|
@ -70,7 +70,7 @@ class CSD(BaseAlgo):
|
|||
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
|
||||
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)
|
||||
|
||||
loss = 0.5*class_loss + 0.5*specific_loss + orth_loss
|
||||
loss = class_loss + specific_loss + orth_loss
|
||||
return loss, logits_common
|
||||
|
||||
def epoch_callback(self, nepoch, final=False):
|
||||
|
|
|
@ -78,7 +78,8 @@ class Hybrid(BaseAlgo):
|
|||
for epoch in range(self.args.epochs):
|
||||
|
||||
if epoch ==0:
|
||||
data_match_tensor, label_match_tensor= self.init_erm_phase()
|
||||
# data_match_tensor, label_match_tensor= self.init_erm_phase()
|
||||
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
|
||||
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
|
||||
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
|
||||
|
||||
|
|
|
@ -154,8 +154,8 @@ class MatchDG(BaseAlgo):
|
|||
pos_feat_match= feat_match[pos_indices]
|
||||
neg_feat_match= feat_match[neg_indices]
|
||||
|
||||
if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
|
||||
print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
||||
# if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
|
||||
# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
||||
|
||||
# If no instances of label y_c in the current batch then continue
|
||||
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
|
||||
|
|
|
@ -45,7 +45,11 @@ class MatchEval(BaseEval):
|
|||
base_domain_size= self.test_dataset['base_domain_size']
|
||||
domain_size_list= self.test_dataset['domain_size_list']
|
||||
|
||||
inferred_match=1
|
||||
inferred_match=1
|
||||
# Self Augmentation Match Function evaluation will always follow perfect matches
|
||||
if self.args.match_func_aug_case:
|
||||
self.args.perfect_match= 1
|
||||
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
|
||||
|
||||
score= perfect_match_score(indices_matched)
|
||||
|
|
Загрузка…
Ссылка в новой задаче