Issue with perfect math with resolved

This commit is contained in:
divyat09 2020-09-28 17:26:30 +00:00
Родитель 7929766bc7
Коммит 687be2e829
3 изменённых файлов: 22 добавлений и 10 удалений

Просмотреть файл

@ -101,16 +101,19 @@ class BaseAlgo():
def get_match_function(self, epoch):
perfect_match= self.args.perfect_match
#Start initially with randomly defined batch; else find the local approximate batch
if epoch > 0:
inferred_match=1
if self.args.match_flag:
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
else:
temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
else:
inferred_match=0
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match )
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
return data_match_tensor, label_match_tensor

Просмотреть файл

@ -46,11 +46,14 @@ class MatchEval(BaseEval):
domain_size_list= self.test_dataset['domain_size_list']
inferred_match=1
# Self Augmentation Match Function evaluation will always follow perfect matches
if self.args.match_func_aug_case:
self.args.perfect_match= 1
perfect_match= 1
else:
perfect_match= self.args.perfect_match
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match )
data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, perfect_match, inferred_match )
score= perfect_match_score(indices_matched)
perfect_match_rank= np.array(perfect_match_rank)

Просмотреть файл

@ -38,7 +38,7 @@ def init_data_match_dict(args, keys, vals, variation):
data[key]['idx']=torch.randint(0, 1, (val_dim, 1))
return data
def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, inferred_match):
def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, perfect_match, inferred_match):
#Making Data Matched pairs
data_matched= init_data_match_dict( args, range(domain_size), total_domains, 0 )
@ -128,7 +128,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
curr_size= ordered_curr_indices.shape[0]
# Sanity check for perfect match case:
if args.perfect_match:
if perfect_match:
if not torch.equal(ordered_base_indices, ordered_curr_indices):
print('Issue: Different indices across domains for perfect match' )
@ -179,9 +179,15 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
if domain_idx == base_domain_idx:
curr_indice= perfect_indice
else:
if args.perfect_match:
if perfect_match:
if inferred_match:
curr_indice= ordered_curr_indices[match_idx[idx]].item()
# print('Curr Indice, Idx: ', curr_indice, idx)
# print('Sort, OrdIndices: ', sort_idx.shape, ordered_curr_indices.shape, type(ordered_curr_indices))
# print('Perfect Indice: ', perfect_indice )
# print('Unique OrdIndices: ', len(torch.unique(ordered_curr_indices[sort_idx[idx]])))
# print( perfect_indice in ordered_curr_indices)
# print( perfect_indice in ordered_curr_indices[sort_idx[idx]] )
#Find where does the perfect match lies in the sorted order of matches
#In the situations where the perfect match is known; the ordered_curr_indices and ordered_base_indices are the same
perfect_match_rank.append( (ordered_curr_indices[sort_idx[idx]] == perfect_indice).nonzero()[0,0].item() )
@ -210,7 +216,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra
if total_data_idx != domain_size:
print('Issue: Some data points left from data_matched dictionary', total_data_idx, domain_size)
if args.perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
if perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size:
print('Issue: Total random changes made are less than perm_size for domain', domain_idx, total_rand_counter, perm_size)