diff --git a/algorithms/algo.py b/algorithms/algo.py index df6b6e8..5632f65 100644 --- a/algorithms/algo.py +++ b/algorithms/algo.py @@ -101,16 +101,19 @@ class BaseAlgo(): def get_match_function(self, epoch): + + perfect_match= self.args.perfect_match + #Start initially with randomly defined batch; else find the local approximate batch if epoch > 0: inferred_match=1 if self.args.match_flag: - data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match ) + data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) else: - temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match ) + temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) else: inferred_match=0 - data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match ) + data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) return data_match_tensor, label_match_tensor diff --git a/evaluation/match_eval.py b/evaluation/match_eval.py index 7f0f10f..5bbcf12 100644 --- a/evaluation/match_eval.py +++ b/evaluation/match_eval.py @@ -46,11 +46,14 @@ class MatchEval(BaseEval): domain_size_list= self.test_dataset['domain_size_list'] inferred_match=1 + # Self Augmentation Match Function evaluation will always follow perfect matches if self.args.match_func_aug_case: - self.args.perfect_match= 1 - - data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, inferred_match ) + perfect_match= 1 + else: + perfect_match= self.args.perfect_match + + data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, perfect_match, inferred_match ) score= perfect_match_score(indices_matched) perfect_match_rank= np.array(perfect_match_rank) diff --git a/utils/match_function.py b/utils/match_function.py index e6a36ad..04194cd 100644 --- a/utils/match_function.py +++ b/utils/match_function.py @@ -38,7 +38,7 @@ def init_data_match_dict(args, keys, vals, variation): data[key]['idx']=torch.randint(0, 1, (val_dim, 1)) return data -def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, inferred_match): +def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, training_list_size, phi, match_case, perfect_match, inferred_match): #Making Data Matched pairs data_matched= init_data_match_dict( args, range(domain_size), total_domains, 0 ) @@ -128,7 +128,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra curr_size= ordered_curr_indices.shape[0] # Sanity check for perfect match case: - if args.perfect_match: + if perfect_match: if not torch.equal(ordered_base_indices, ordered_curr_indices): print('Issue: Different indices across domains for perfect match' ) @@ -179,9 +179,15 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra if domain_idx == base_domain_idx: curr_indice= perfect_indice else: - if args.perfect_match: + if perfect_match: if inferred_match: curr_indice= ordered_curr_indices[match_idx[idx]].item() +# print('Curr Indice, Idx: ', curr_indice, idx) +# print('Sort, OrdIndices: ', sort_idx.shape, ordered_curr_indices.shape, type(ordered_curr_indices)) +# print('Perfect Indice: ', perfect_indice ) +# print('Unique OrdIndices: ', len(torch.unique(ordered_curr_indices[sort_idx[idx]]))) +# print( perfect_indice in ordered_curr_indices) +# print( perfect_indice in ordered_curr_indices[sort_idx[idx]] ) #Find where does the perfect match lies in the sorted order of matches #In the situations where the perfect match is known; the ordered_curr_indices and ordered_base_indices are the same perfect_match_rank.append( (ordered_curr_indices[sort_idx[idx]] == perfect_indice).nonzero()[0,0].item() ) @@ -210,7 +216,7 @@ def get_matched_pairs(args, cuda, train_dataset, domain_size, total_domains, tra if total_data_idx != domain_size: print('Issue: Some data points left from data_matched dictionary', total_data_idx, domain_size) - if args.perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size: + if perfect_match and inferred_match ==0 and domain_idx != base_domain_idx and total_rand_counter < perm_size: print('Issue: Total random changes made are less than perm_size for domain', domain_idx, total_rand_counter, perm_size)