Fix for perfect match condition in self augmentation match score
This commit is contained in:
Родитель
ce271e0930
Коммит
7929766bc7
|
@ -28,11 +28,11 @@ class CSD(BaseAlgo):
|
||||||
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes
|
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes
|
||||||
num_domains = self.total_domains
|
num_domains = self.total_domains
|
||||||
|
|
||||||
self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||||
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||||
|
|
||||||
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||||
self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)
|
||||||
|
|
||||||
self.opt= optim.SGD([
|
self.opt= optim.SGD([
|
||||||
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
|
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
|
||||||
|
@ -48,7 +48,7 @@ class CSD(BaseAlgo):
|
||||||
x = self.phi(x)
|
x = self.phi(x)
|
||||||
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
|
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
|
||||||
logits_common = torch.matmul(x, w_c) + b_c
|
logits_common = torch.matmul(x, w_c) + b_c
|
||||||
|
|
||||||
if eval_case:
|
if eval_case:
|
||||||
return logits_common
|
return logits_common
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class CSD(BaseAlgo):
|
||||||
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
|
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
|
||||||
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)
|
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)
|
||||||
|
|
||||||
loss = 0.5*class_loss + 0.5*specific_loss + orth_loss
|
loss = class_loss + specific_loss + orth_loss
|
||||||
return loss, logits_common
|
return loss, logits_common
|
||||||
|
|
||||||
def epoch_callback(self, nepoch, final=False):
|
def epoch_callback(self, nepoch, final=False):
|
||||||
|
|
|
@ -78,7 +78,8 @@ class Hybrid(BaseAlgo):
|
||||||
for epoch in range(self.args.epochs):
|
for epoch in range(self.args.epochs):
|
||||||
|
|
||||||
if epoch ==0:
|
if epoch ==0:
|
||||||
data_match_tensor, label_match_tensor= self.init_erm_phase()
|
# data_match_tensor, label_match_tensor= self.init_erm_phase()
|
||||||
|
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
|
||||||
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
|
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
|
||||||
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
|
data_match_tensor, label_match_tensor= self.get_match_function(epoch)
|
||||||
|
|
||||||
|
|
|
@ -154,8 +154,8 @@ class MatchDG(BaseAlgo):
|
||||||
pos_feat_match= feat_match[pos_indices]
|
pos_feat_match= feat_match[pos_indices]
|
||||||
neg_feat_match= feat_match[neg_indices]
|
neg_feat_match= feat_match[neg_indices]
|
||||||
|
|
||||||
if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
|
# if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
|
||||||
print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
|
||||||
|
|
||||||
# If no instances of label y_c in the current batch then continue
|
# If no instances of label y_c in the current batch then continue
|
||||||
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
|
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
|
||||||
|
|
|
@ -45,7 +45,11 @@ class MatchEval(BaseEval):
|
||||||
base_domain_size= self.test_dataset['base_domain_size']
|
base_domain_size= self.test_dataset['base_domain_size']
|
||||||
domain_size_list= self.test_dataset['domain_size_list']
|
domain_size_list= self.test_dataset['domain_size_list']
|
||||||
|
|
||||||
inferred_match=1
|
inferred_match=1
|
||||||
|
# Self Augmentation Match Function evaluation will always follow perfect matches
|
||||||
|
if self.args.match_func_aug_case:
|
||||||
|
self.args.perfect_match= 1
|
||||||
|
|
||||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
|
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
|
||||||
|
|
||||||
score= perfect_match_score(indices_matched)
|
score= perfect_match_score(indices_matched)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче