robustdg/algorithms/match_dg.py

393 строки
20 KiB
Python

import sys
import numpy as np
import argparse
import copy
import random
import json
import os
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
from utils.match_function import get_matched_pairs
class MatchDG(BaseAlgo):
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda, ctr_phase=1):
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda)
self.ctr_phase= ctr_phase
self.ctr_save_post_string= str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.model_name
self.ctr_load_post_string= str(self.args.ctr_match_case) + '_' + str(self.args.ctr_match_interrupt) + '_' + str(self.args.ctr_match_flag) + '_' + str(self.run) + '_' + self.args.ctr_model_name
def train(self):
# Initialise and call train functions depending on the method's phase
if self.ctr_phase:
self.train_ctr_phase()
else:
self.train_erm_phase()
def save_model_ctr_phase(self, epoch):
# Store the weights of the model
torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.ctr_save_post_string + '.pth')
def save_model_erm_phase(self, run):
if not os.path.exists(self.base_res_dir + '/' + self.ctr_load_post_string):
os.makedirs(self.base_res_dir + '/' + self.ctr_load_post_string)
# Store the weights of the model
torch.save(self.phi.state_dict(), self.base_res_dir + '/' + self.ctr_load_post_string + '/Model_' + self.post_string + '_' + str(run) + '.pth')
def init_erm_phase(self):
if self.args.ctr_model_name == 'lenet':
from models.lenet import LeNet5
ctr_phi= LeNet5().to(self.cuda)
if self.args.model_name == 'slab':
from models.slab import SlabClf
fc_layer=0
ctr_phi= SlabClf(self.args.slab_data_dim, self.args.out_classes, fc_layer).to(self.cuda)
if self.args.model_name == 'domain_bed_mnist':
from models.domain_bed_mnist import DomainBed
fc_layer=0
ctr_phi= DomainBed(self.args.img_c, fc_layer).to(self.cuda)
if self.args.ctr_model_name == 'alexnet':
from models.alexnet import alexnet
ctr_phi= alexnet(self.args.out_classes, self.args.pre_trained, 'matchdg_ctr').to(self.cuda)
if self.args.ctr_model_name == 'fc':
from models.fc import FC
fc_layer=0
ctr_phi= FC(self.args.out_classes, fc_layer).to(self.cuda)
if 'resnet' in self.args.ctr_model_name:
from models.resnet import get_resnet
fc_layer=0
ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env).to(self.cuda)
if 'densenet' in self.args.ctr_model_name:
from models.densenet import get_densenet
fc_layer=0
ctr_phi= get_densenet(self.args.ctr_model_name, self.args.out_classes, fc_layer,
self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)
# Load MatchDG CTR phase model from the saved weights
if self.args.os_env:
base_res_dir=os.getenv('PT_DATA_DIR') + '/' + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)
else:
base_res_dir="results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)
#TODO: Handle slab noise case in helper functions
if self.args.dataset_name == 'slab':
base_res_dir= base_res_dir + '/slab_noise_' + str(self.args.slab_noise)
save_path= base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth'
ctr_phi.load_state_dict( torch.load(save_path) )
ctr_phi.eval()
#Inferred Match Case
if self.args.match_case == -1:
inferred_match=1
# x% percentage match initial strategy
else:
inferred_match=0
data_matched, domain_data= self.get_match_function(inferred_match, ctr_phi)
return data_matched, domain_data
def train_ctr_phase(self):
self.max_epoch= -1
self.max_val_score= 0.0
for epoch in range(self.args.epochs):
if epoch ==0:
inferred_match= 0
self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
elif (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
inferred_match= 1
self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
penalty_same_ctr=0
penalty_diff_ctr=0
penalty_same_hinge=0
penalty_diff_hinge=0
train_acc= 0.0
train_size=0
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)
self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e= torch.argmax(d_e, dim=1).numpy()
same_ctr_loss = torch.tensor(0.0).to(self.cuda)
diff_ctr_loss = torch.tensor(0.0).to(self.cuda)
same_hinge_loss = torch.tensor(0.0).to(self.cuda)
diff_hinge_loss = torch.tensor(0.0).to(self.cuda)
if epoch > self.args.penalty_s:
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
total_batch_size= len(self.data_matched)
if batch_idx >= total_batch_size:
break
# Sample batch from matched data points
data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
data_match= data_match_tensor.to(self.cuda)
data_match= data_match.flatten(start_dim=0, end_dim=1)
feat_match= self.phi( data_match )
label_match= label_match_tensor.to(self.cuda)
label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
# Creating tensor of shape ( domain size, total domains, feat size )
feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))
label_match= torch.stack(torch.split(label_match, len(self.train_domains)))
# Contrastive Loss
same_neg_counter=1
diff_neg_counter=1
for y_c in range(self.args.out_classes):
pos_indices= label_match[:, 0] == y_c
neg_indices= label_match[:, 0] != y_c
pos_feat_match= feat_match[pos_indices]
neg_feat_match= feat_match[neg_indices]
# if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
# print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])
# If no instances of label y_c in the current batch then continue
if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
continue
# Iterating over anchors from different domains
for d_i in range(pos_feat_match.shape[1]):
if torch.sum( torch.isnan(neg_feat_match) ):
print('Non Reshaped X2 is Nan')
sys.exit()
diff_neg_feat_match= neg_feat_match.view( neg_feat_match.shape[0]*neg_feat_match.shape[1], neg_feat_match.shape[2] )
if torch.sum( torch.isnan(diff_neg_feat_match) ):
print('Reshaped X2 is Nan')
sys.exit()
neg_dist= embedding_dist( pos_feat_match[:, d_i, :], diff_neg_feat_match[:, :], self.args.pos_metric, self.args.tau, xent=True)
if torch.sum(torch.isnan(neg_dist)):
print('Neg Dist Nan')
sys.exit()
# Iterating pos dist for current anchor
for d_j in range(pos_feat_match.shape[1]):
if d_i != d_j:
pos_dist= 1.0 - embedding_dist( pos_feat_match[:, d_i, :], pos_feat_match[:, d_j, :], self.args.pos_metric )
pos_dist= pos_dist / self.args.tau
if torch.sum(torch.isnan(neg_dist)):
print('Pos Dist Nan')
sys.exit()
if torch.sum( torch.isnan( torch.log( torch.exp(pos_dist) + neg_dist ) ) ):
print('Xent Nan')
sys.exit()
# print( 'Pos Dist', pos_dist )
# print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist ))
diff_hinge_loss+= -1*torch.sum( pos_dist - torch.log( torch.exp(pos_dist) + neg_dist ) )
diff_ctr_loss+= torch.sum(neg_dist)
diff_neg_counter+= pos_dist.shape[0]
same_ctr_loss = same_ctr_loss / same_neg_counter
diff_ctr_loss = diff_ctr_loss / diff_neg_counter
same_hinge_loss = same_hinge_loss / same_neg_counter
diff_hinge_loss = diff_hinge_loss / diff_neg_counter
penalty_same_ctr+= float(same_ctr_loss)
penalty_diff_ctr+= float(diff_ctr_loss)
penalty_same_hinge+= float(same_hinge_loss)
penalty_diff_hinge+= float(diff_hinge_loss)
loss_e += ( ( epoch- self.args.penalty_s )/(self.args.epochs -self.args.penalty_s) )*diff_hinge_loss
loss_e.backward(retain_graph=False)
self.opt.step()
del same_ctr_loss
del diff_ctr_loss
del same_hinge_loss
del diff_hinge_loss
torch.cuda.empty_cache()
print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr, penalty_same_hinge, penalty_diff_hinge)
print('Done Training for epoch: ', epoch)
if (epoch+1)%5 == 0:
from evaluation.match_eval import MatchEval
test_method= MatchEval(
self.args, self.train_dataset, self.val_dataset,
self.test_dataset, self.base_res_dir,
self.run, self.cuda
)
#Compute test metrics: Mean Rank
test_method.phi= self.phi
test_method.get_metric_eval()
# Save the model's weights post training
if test_method.metric_score['TopK Perfect Match Score'] > self.max_val_score:
self.max_val_score= test_method.metric_score['TopK Perfect Match Score']
self.max_epoch= epoch
self.save_model_ctr_phase(epoch)
print('Current Best Epoch: ', self.max_epoch, ' with TopK Overlap: ', self.max_val_score)
def train_erm_phase(self):
for run_erm in range(self.args.n_runs_matchdg_erm):
self.max_epoch= -1
self.max_val_acc= 0.0
for epoch in range(self.args.epochs):
if epoch ==0:
self.data_matched, self.domain_data= self.init_erm_phase()
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
inferred_match= 1
self.data_match_tensor, self.label_match_tensor= self.get_match_function(inferred_match, self.phi)
penalty_erm=0
penalty_erm_extra=0
penalty_ws=0
train_acc= 0.0
train_size=0
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)
self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e= torch.argmax(d_e, dim=1).numpy()
#Forward Pass
out= self.phi(x_e)
erm_loss_extra= F.cross_entropy(out, y_e.long()).to(self.cuda)
penalty_erm_extra += float(erm_loss_extra)
wasserstein_loss=torch.tensor(0.0).to(self.cuda)
erm_loss= torch.tensor(0.0).to(self.cuda)
if epoch > self.args.penalty_s:
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
total_batch_size= len(self.data_matched)
if batch_idx >= total_batch_size:
break
# Sample batch from matched data points
data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
data_match= data_match_tensor.to(self.cuda)
data_match= data_match.flatten(start_dim=0, end_dim=1)
feat_match= self.phi( data_match )
label_match= label_match_tensor.to(self.cuda)
label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
penalty_erm+= float(erm_loss)
train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
train_size+= label_match.shape[0]
# Creating tensor of shape ( domain size, total domains, feat size )
feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))
label_match= torch.stack(torch.split(label_match, len(self.train_domains)))
#Positive Match Loss
pos_match_counter=0
for d_i in range(feat_match.shape[1]):
# if d_i != base_domain_idx:
# continue
for d_j in range(feat_match.shape[1]):
if d_j > d_i:
if self.args.pos_metric == 'l2':
wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) )
elif self.args.pos_metric == 'l1':
wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )
elif self.args.pos_metric == 'cos':
wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )
pos_match_counter += feat_match.shape[0]
wasserstein_loss = wasserstein_loss / pos_match_counter
penalty_ws+= float(wasserstein_loss)
loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss
loss_e += erm_loss
loss_e += erm_loss_extra
loss_e.backward(retain_graph=False)
self.opt.step()
del erm_loss_extra
del erm_loss
del wasserstein_loss
del loss_e
torch.cuda.empty_cache()
print('Train Loss Basic : ', penalty_erm_extra, penalty_erm, penalty_ws )
print('Train Acc Env : ', 100*train_acc/train_size )
print('Done Training for epoch: ', epoch)
#Train Dataset Accuracy
self.train_acc.append( 100*train_acc/train_size )
#Val Dataset Accuracy
self.val_acc.append( self.get_test_accuracy('val') )
#Test Dataset Accuracy
self.final_acc.append( self.get_test_accuracy('test') )
#Save the model if current best epoch as per validation loss
if self.val_acc[-1] > self.max_val_acc:
self.max_val_acc= self.val_acc[-1]
self.max_epoch= epoch
self.save_model_erm_phase(run_erm)
print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:
if self.args.model_name == 'lenet':
lr_schedule_step= 25
elif self.args.model_name == 'domain_bed_mnist':
lr_schedule_step= 10
if epoch % lr_schedule_step==0 :
lr=self.args.lr/(2**(int(epoch/lr_schedule_step)))
print('Learning Rate Scheduling; New LR: ', lr)
self.opt= optim.SGD([
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
], lr= lr, weight_decay= self.args.weight_decay, momentum= 0.9, nesterov=True )