import sys import numpy as np import argparse import copy import random import json import os import torch from torch.autograd import grad from torch import nn, optim from torch.nn import functional as F from torchvision import datasets, transforms from torchvision.utils import save_image from torch.autograd import Variable import torch.utils.data as data_utils from .algo import BaseAlgo from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, get_dataloader from utils.match_function import get_matched_pairs class Hybrid(BaseAlgo): def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda): super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) self.ctr_save_post_string= str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.model_name self.ctr_load_post_string= str(self.args.ctr_match_case) + '_' + str(self.args.ctr_match_interrupt) + '_' + str(self.args.ctr_match_flag) + '_' + str(self.run) + '_' + self.args.ctr_model_name def save_model_erm_phase(self, run): if not os.path.exists(self.base_res_dir + '/' + self.ctr_load_post_string): os.makedirs(self.base_res_dir + '/' + self.ctr_load_post_string) # Store the weights of the model torch.save(self.phi.state_dict(), self.base_res_dir + '/' + self.ctr_load_post_string + '/Model_' + self.post_string + '_' + str(run) + '.pth') def init_erm_phase(self): if self.args.ctr_model_name == 'lenet': from models.lenet import LeNet5 ctr_phi= LeNet5().to(self.cuda) if self.args.model_name == 'slab': from models.slab import SlabClf fc_layer=0 ctr_phi= SlabClf(self.args.slab_data_dim, self.args.out_classes, fc_layer).to(self.cuda) if self.args.ctr_model_name == 'alexnet': from models.alexnet import alexnet ctr_phi= alexnet(self.args.out_classes, self.args.pre_trained, 'matchdg_ctr').to(self.cuda) if self.args.ctr_model_name == 'fc': from models.fc import FC fc_layer=0 ctr_phi= FC(self.args.out_classes, fc_layer).to(self.cuda) if 'resnet' in self.args.ctr_model_name: from models.resnet import get_resnet fc_layer=0 ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda) if 'densenet' in self.args.ctr_model_name: from models.densenet import get_densenet fc_layer=0 ctr_phi= get_densenet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda) # Load MatchDG CTR phase model from the saved weights if self.args.os_env: base_res_dir=os.getenv('PT_DATA_DIR') + '/' + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains) else: base_res_dir="results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains) #TODO: Handle slab noise case in helper functions if self.args.dataset_name == 'slab': base_res_dir= base_res_dir + '/slab_noise_' + str(self.args.slab_noise) save_path= base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth' ctr_phi.load_state_dict( torch.load(save_path) ) ctr_phi.eval() #Inferred Match Case if self.args.match_case == -1: inferred_match=1 # x% percentage match initial strategy else: inferred_match=0 data_matched, domain_data= self.get_match_function(inferred_match, ctr_phi) return data_matched, domain_data def train(self): for run_erm in range(self.args.n_runs_matchdg_erm): self.max_epoch=-1 self.max_val_acc=0.0 for epoch in range(self.args.epochs): if epoch ==0: self.data_matched, self.domain_data= self.init_erm_phase() elif epoch % self.args.match_interrupt == 0 and self.args.match_flag: inferred_match= 1 self.data_match_tensor, self.label_match_tensor= self.get_match_function(inferred_match, self.phi) penalty_erm=0 penalty_erm_extra=0 penalty_ws=0 penalty_aug=0 train_acc= 0.0 train_size=0 #Batch iteration over single epoch for batch_idx, (x_e, x_org_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset): # print('Batch Idx: ', batch_idx) self.opt.zero_grad() loss_e= torch.tensor(0.0).to(self.cuda) x_e= x_e.to(self.cuda) x_org_e= x_org_e.to(self.cuda) y_e= torch.argmax(y_e, dim=1).to(self.cuda) d_e= torch.argmax(d_e, dim=1).numpy() #Forward Pass out= self.phi(x_e) erm_loss_extra= F.cross_entropy(out, y_e.long()).to(self.cuda) penalty_erm_extra += float(erm_loss_extra) #Perfect Match on Augmentations out_org= self.phi(x_org_e) # diff_indices= out != out_org # out= out[diff_indices] # out_org= out_org[diff_indices] augmentation_loss=torch.tensor(0.0).to(self.cuda) if self.args.pos_metric == 'l2': augmentation_loss+= torch.sum( torch.sum( (out -out_org)**2, dim=1 ) ) elif self.args.pos_metric == 'l1': augmentation_loss+= torch.sum( torch.sum( torch.abs(out -out_org), dim=1 ) ) elif self.args.pos_metric == 'cos': augmentation_loss+= torch.sum( cosine_similarity( out, out_org ) ) augmentation_loss = augmentation_loss / out.shape[0] # print('Augmented Images Fraction: ', out.shape, self.args.batch_size, augmentation_loss) penalty_aug+= float(augmentation_loss) wasserstein_loss=torch.tensor(0.0).to(self.cuda) erm_loss= torch.tensor(0.0).to(self.cuda) if epoch > self.args.penalty_s: # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split total_batch_size= len(self.data_matched) if batch_idx >= total_batch_size: break # Sample batch from matched data points data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx) data_match= data_match_tensor.to(self.cuda) data_match= data_match.flatten(start_dim=0, end_dim=1) feat_match= self.phi( data_match ) label_match= label_match_tensor.to(self.cuda) label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) ) erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda) penalty_erm+= float(erm_loss) train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item() train_size+= label_match.shape[0] # Creating tensor of shape ( domain size, total domains, feat size ) feat_match= torch.stack(torch.split(feat_match, len(self.train_domains))) label_match= torch.stack(torch.split(label_match, len(self.train_domains))) #Positive Match Loss pos_match_counter=0 for d_i in range(feat_match.shape[1]): # if d_i != base_domain_idx: # continue for d_j in range(feat_match.shape[1]): if d_j > d_i: if self.args.pos_metric == 'l2': wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) elif self.args.pos_metric == 'l1': wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) ) elif self.args.pos_metric == 'cos': wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) ) pos_match_counter += feat_match.shape[0] wasserstein_loss = wasserstein_loss / pos_match_counter penalty_ws+= float(wasserstein_loss) loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss loss_e += self.args.penalty_aug*augmentation_loss loss_e += erm_loss loss_e += erm_loss_extra loss_e.backward(retain_graph=False) self.opt.step() del erm_loss_extra del erm_loss del wasserstein_loss del loss_e torch.cuda.empty_cache() print('Train Loss Basic : ', penalty_erm_extra, penalty_aug, penalty_erm, penalty_ws ) print('Train Acc Env : ', 100*train_acc/train_size ) print('Done Training for epoch: ', epoch) #Val Dataset Accuracy self.val_acc.append( self.get_test_accuracy('val') ) #Test Dataset Accuracy self.final_acc.append( self.get_test_accuracy('test') ) #Save the model if current best epoch as per validation loss if self.val_acc[-1] > self.max_val_acc: self.max_val_acc= self.val_acc[-1] self.max_epoch= epoch self.save_model_erm_phase(run_erm) # if epoch > 0: # #GPU # cuda= torch.device("cuda:" + str(self.args.cuda_device)) # if cuda: # kwargs = {'num_workers': 1, 'pin_memory': False} # else: # kwargs= {} # train_dataset_temp= get_dataloader( self.args, self.run, self.args.train_domains, 'train', 1, kwargs ) # val_dataset_temp= get_dataloader( self.args, self.run, self.args.train_domains, 'val', 1, kwargs ) # test_dataset_temp= get_dataloader( self.args, self.run, self.args.test_domains, 'test', 1, kwargs ) # from evaluation.match_eval import MatchEval # test_method= MatchEval( # self.args, train_dataset_temp, val_dataset_temp, # test_dataset_temp, self.base_res_dir, # self.run, self.cuda # ) # #Compute test metrics: Mean Rank # test_method.phi= self.phi # test_method.get_metric_eval() # print('Match Function: ', test_method.metric_score) # from evaluation.privacy_attack import PrivacyAttack # test_method= PrivacyAttack( # self.args, train_dataset_temp, val_dataset_temp, # test_dataset_temp, self.base_res_dir, # self.run, self.cuda # ) # #Compute test metrics: Mean Rank # test_method.phi= self.phi # test_method.get_metric_eval() # print('MIA: ', test_method.metric_score) # from evaluation.privacy_entropy import PrivacyEntropy # test_method= PrivacyEntropy( # self.args, train_dataset_temp, val_dataset_temp, # test_dataset_temp, self.base_res_dir, # self.run, self.cuda # ) # #Compute test metrics: Mean Rank # test_method.phi= self.phi # test_method.get_metric_eval() # print('Entropy: ', test_method.metric_score) print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])