robustdg/algorithms/erm_match.py

153 строки
7.1 KiB
Python

import sys
import numpy as np
import argparse
import copy
import random
import json
import time
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
class ErmMatch(BaseAlgo):
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda)
def train(self):
self.max_epoch= -1
self.max_val_acc= 0.0
for epoch in range(self.args.epochs):
if epoch ==0:
inferred_match= 0
self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
inferred_match= 1
self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
penalty_erm=0
penalty_ws=0
train_acc= 0.0
train_size=0
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e= torch.argmax(d_e, dim=1).numpy()
wasserstein_loss=torch.tensor(0.0).to(self.cuda)
erm_loss= torch.tensor(0.0).to(self.cuda)
if epoch > self.args.penalty_s:
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
total_batch_size= len(self.data_matched)
if batch_idx >= total_batch_size:
break
# Sample batch from matched data points
data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
data_match= data_match_tensor.to(self.cuda)
data_match= data_match.flatten(start_dim=0, end_dim=1)
feat_match= self.phi( data_match )
# print(feat_match.shape)
label_match= label_match_tensor.to(self.cuda)
label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
penalty_erm+= float(erm_loss)
loss_e += erm_loss
train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
train_size+= label_match.shape[0]
# Creating tensor of shape ( domain size, total domains, feat size )
feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))
#Positive Match Loss
pos_match_counter=0
for d_i in range(feat_match.shape[1]):
# if d_i != base_domain_idx:
# continue
for d_j in range(feat_match.shape[1]):
if d_j > d_i:
if self.args.pos_metric == 'l2':
wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) )
elif self.args.pos_metric == 'l1':
wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )
elif self.args.pos_metric == 'cos':
wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )
pos_match_counter += feat_match.shape[0]
wasserstein_loss = wasserstein_loss / pos_match_counter
penalty_ws+= float(wasserstein_loss)
if epoch >= self.args.match_interrupt and self.args.match_flag==1:
loss_e += ( self.args.penalty_ws*( epoch - self.args.penalty_s - self.args.match_interrupt )/(self.args.epochs - self.args.penalty_s - self.args.match_interrupt) )*wasserstein_loss
else:
loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss
loss_e.backward(retain_graph=False)
self.opt.step()
# self.opt.zero_grad()
# del out
del erm_loss
del wasserstein_loss
del loss_e
torch.cuda.empty_cache()
print('Train Loss Basic : ', penalty_erm, penalty_ws )
print('Train Acc Env : ', 100*train_acc/train_size )
print('Done Training for epoch: ', epoch)
#Train Dataset Accuracy
self.train_acc.append( 100*train_acc/train_size )
#Val Dataset Accuracy
self.val_acc.append( self.get_test_accuracy('val') )
#Test Dataset Accuracy
self.final_acc.append( self.get_test_accuracy('test') )
#Save the model if current best epoch as per validation loss
if self.val_acc[-1] > self.max_val_acc:
self.max_val_acc=self.val_acc[-1]
self.max_epoch= epoch
self.save_model()
print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:
if self.args.model_name == 'lenet':
lr_schedule_step= 25
elif self.args.model_name == 'domain_bed_mnist':
lr_schedule_step= 10
if epoch % lr_schedule_step==0 :
lr=self.args.lr/(2**(int(epoch/lr_schedule_step)))
print('Learning Rate Scheduling; New LR: ', lr)
self.opt= optim.SGD([
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
], lr= lr, weight_decay= self.args.weight_decay, momentum= 0.9, nesterov=True )