Issue with perfect math with resolved
This commit is contained in:
Родитель
7929766bc7
Коммит
687be2e829
|
@ -101,16 +101,19 @@ class BaseAlgo():
|
|||
|
||||
|
||||
def get_match_function(self, epoch):
|
||||
|
||||
perfect_match= self.args.perfect_match
|
||||
|
||||
#Start initially with randomly defined batch; else find the local approximate batch
|
||||
if epoch > 0:
|
||||
inferred_match=1
|
||||
if self.args.match_flag:
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
|
||||
else:
|
||||
temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
|
||||
temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
|
||||
else:
|
||||
inferred_match=0
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
|
||||
|
||||
return data_match_tensor, label_match_tensor
|
||||
|
||||
|
|
|
@ -46,11 +46,14 @@ class MatchEval(BaseEval):
|
|||
domain_size_list= self.test_dataset['domain_size_list']
|
||||
|
||||
inferred_match=1
|
||||
|
||||
# Self Augmentation Match Function evaluation will always follow perfect matches
|
||||
if self.args.match_func_aug_case:
|
||||
self.args.perfect_match= 1
|
||||
perfect_match= 1
|
||||
else:
|
||||
perfect_match= self.args.perfect_match
|
||||
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
|
||||
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, perfect_match, inferred_match )
|
||||
|
||||
score= perfect_match_score(indices_matched)
|
||||
perfect_match_rank= np.array(perfect_match_rank)
|
||||
|
|
|
@ -38,7 +38,7 @@ def init_data_match_dict(args, keys, vals, variation):
|
|||
data[key]['idx']=torch.randint(0, 1, (val_dim, 1))
|
||||
return data
|
||||
|
||||
def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, inferred_match):
|
||||
def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, perfect_match, inferred_match):
|
||||
|
||||
#Making Data Matched pairs
|
||||
data_matched= init_data_match_dict( args, range(domain_size), total_domains, 0 )
|
||||
|
@ -128,7 +128,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
curr_size= ordered_curr_indices.shape[0]
|
||||
|
||||
# Sanity check for perfect match case:
|
||||
if args.perfect_match:
|
||||
if perfect_match:
|
||||
if not torch.equal(ordered_base_indices, ordered_curr_indices):
|
||||
print('Issue: Different indices across domains for perfect match' )
|
||||
|
||||
|
@ -179,9 +179,15 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
if domain_idx == base_domain_idx:
|
||||
curr_indice= perfect_indice
|
||||
else:
|
||||
if args.perfect_match:
|
||||
if perfect_match:
|
||||
if inferred_match:
|
||||
curr_indice= ordered_curr_indices[match_idx[idx]].item()
|
||||
# print('Curr Indice, Idx: ', curr_indice, idx)
|
||||
# print('Sort, OrdIndices: ', sort_idx.shape, ordered_curr_indices.shape, type(ordered_curr_indices))
|
||||
# print('Perfect Indice: ', perfect_indice )
|
||||
# print('Unique OrdIndices: ', len(torch.unique(ordered_curr_indices[sort_idx[idx]])))
|
||||
# print( perfect_indice in ordered_curr_indices)
|
||||
# print( perfect_indice in ordered_curr_indices[sort_idx[idx]] )
|
||||
#Find where does the perfect match lies in the sorted order of matches
|
||||
#In the situations where the perfect match is known; the ordered_curr_indices and ordered_base_indices are the same
|
||||
perfect_match_rank.append( (ordered_curr_indices[sort_idx[idx]] == perfect_indice).nonzero()[0,0].item() )
|
||||
|
@ -210,7 +216,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
|
|||
if total_data_idx != domain_size:
|
||||
print('Issue: Some data points left from data_matched dictionary', total_data_idx, domain_size)
|
||||
|
||||
if args.perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
|
||||
if perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
|
||||
print('Issue: Total random changes made are less than perm_size for domain', domain_idx, total_rand_counter, perm_size)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче