Slab Synthetic Data integrated with RobustDG
This commit is contained in:
Родитель
514a3d92c8
Коммит
7ee0caa5a9
|
@ -53,6 +53,10 @@ class BaseAlgo():
|
|||
from models.lenet import LeNet5
|
||||
phi= LeNet5()
|
||||
|
||||
if self.args.model_name == 'slab':
|
||||
from models.slab import SlabClf
|
||||
phi= SlabClf(self.args.slab_data_dim, self.args.out_classes)
|
||||
|
||||
if self.args.model_name == 'fc':
|
||||
from models.fc import FC
|
||||
if self.args.method_name in ['csd', 'matchdg_ctr']:
|
||||
|
@ -148,7 +152,6 @@ class BaseAlgo():
|
|||
with torch.no_grad():
|
||||
x_e= x_e.to(self.cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
|
||||
d_e = torch.argmax(d_e, dim=1).numpy()
|
||||
|
||||
#Forward Pass
|
||||
out= self.phi(x_e)
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import argparse
|
||||
import copy
|
||||
import random
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data as data_utils
|
||||
|
||||
from .algo import BaseAlgo
|
||||
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, slab_batch_process
|
||||
|
||||
class PerfMatch(BaseAlgo):
|
||||
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
|
||||
|
||||
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda)
|
||||
|
||||
def train(self):
|
||||
|
||||
self.max_epoch=-1
|
||||
self.max_val_acc=0.0
|
||||
for epoch in range(self.args.epochs):
|
||||
|
||||
penalty_erm=0
|
||||
penalty_ws=0
|
||||
train_acc= 0.0
|
||||
train_size=0
|
||||
|
||||
#Batch iteration over single epoch
|
||||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset):
|
||||
# print('Batch Idx: ', batch_idx)
|
||||
|
||||
#Process current batch as per slab dataset
|
||||
# x_e, y_e ,d_e, idx_e= slab_batch_process(x_e, y_e ,d_e, idx_e)
|
||||
|
||||
self.opt.zero_grad()
|
||||
loss_e= torch.tensor(0.0).to(self.cuda)
|
||||
|
||||
x_e= x_e.to(self.cuda)
|
||||
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
|
||||
|
||||
#Forward Pass
|
||||
out= self.phi(x_e)
|
||||
erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
|
||||
loss_e+= erm_loss
|
||||
penalty_erm += float(loss_e)
|
||||
|
||||
#Perfect Match Penalty
|
||||
ws_loss= torch.tensor(0.0).to(self.cuda)
|
||||
counter=0
|
||||
match_objs= np.unique(idx_e)
|
||||
feat= self.phi.feat_net(x_e)
|
||||
for obj in match_objs:
|
||||
indices= idx_e == obj
|
||||
feat_obj= feat[indices]
|
||||
d_obj= d_e[indices]
|
||||
|
||||
match_domains= torch.unique(d_obj)
|
||||
|
||||
if len(match_domains) != len(torch.unique(d_e)):
|
||||
# print('Error: Positivty Violation, objects not present in all the domains')
|
||||
continue
|
||||
|
||||
for d_i in range(len(match_domains)):
|
||||
for d_j in range(len(match_domains)):
|
||||
if d_j <= d_i:
|
||||
continue
|
||||
x1= feat_obj[ d_obj == d_i ]
|
||||
x2= feat_obj[ d_obj == d_j ]
|
||||
|
||||
#Typecasting
|
||||
# print(x1.shape, x2.shape)
|
||||
x1= x1.view(x1.shape[0], 1, x1.shape[1])
|
||||
ws_loss= torch.sum( torch.sum( torch.sum( (x1 -x2)**2, dim=2), dim=1 ) )
|
||||
# ws_loss= torch.sum( torch.sum( torch.sum( torch.abs(x1 -x2), dim=2), dim=1 ) )
|
||||
counter+= x1.shape[0]*x2.shape[0]
|
||||
|
||||
ws_loss= ws_loss/counter
|
||||
penalty_ws += float(ws_loss)
|
||||
|
||||
#Backprop
|
||||
loss_e+= self.args.penalty_ws*ws_loss*((epoch+1)/self.args.epochs)
|
||||
loss_e.backward(retain_graph=False)
|
||||
self.opt.step()
|
||||
|
||||
del erm_loss
|
||||
del loss_e
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
|
||||
train_size+= y_e.shape[0]
|
||||
|
||||
|
||||
print('Train Loss Basic : ', penalty_erm, penalty_ws )
|
||||
print('Train Acc Env : ', 100*train_acc/train_size )
|
||||
print('Done Training for epoch: ', epoch)
|
||||
|
||||
#Train Dataset Accuracy
|
||||
self.train_acc.append( 100*train_acc/train_size )
|
||||
|
||||
#Val Dataset Accuracy
|
||||
self.val_acc.append( self.get_test_accuracy('val') )
|
||||
|
||||
#Test Dataset Accuracy
|
||||
self.final_acc.append( self.get_test_accuracy('test') )
|
||||
|
||||
#Save the model if current best epoch as per validation loss
|
||||
if self.val_acc[-1] > self.max_val_acc:
|
||||
self.max_val_acc=self.val_acc[-1]
|
||||
self.max_epoch= epoch
|
||||
self.save_model()
|
||||
|
||||
# Save the model's weights post training
|
||||
self.save_model()
|
|
@ -0,0 +1,102 @@
|
|||
#Common imports
|
||||
import os
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import h5py
|
||||
from PIL import Image
|
||||
|
||||
#Pytorch
|
||||
import torch
|
||||
import torch.utils.data as data_utils
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
#Base Class
|
||||
from .data_loader import BaseDataLoader
|
||||
|
||||
#Specific Modules
|
||||
from utils.slab_data import *
|
||||
|
||||
|
||||
class SlabData(BaseDataLoader):
|
||||
def __init__(self, args, list_train_domains, root, transform=None, data_case='train', match_func=False, base_size=10000, freq_ratio=50, data_dim=2, total_slabs=5):
|
||||
|
||||
super().__init__(args, list_train_domains, root, transform, data_case, match_func)
|
||||
|
||||
self.base_size = base_size
|
||||
self.freq_ratio= freq_ratio
|
||||
self.data_dim= data_dim
|
||||
self.total_slabs= total_slabs
|
||||
|
||||
print(list_train_domains)
|
||||
if self.data_case == 'train':
|
||||
self.domain_size = [self.base_size, self.base_size]
|
||||
# Default Train Domains: [0.0, 0.10]
|
||||
self.spur_probs= [ float(domain) for domain in list_train_domains ]
|
||||
|
||||
elif self.data_case == 'val':
|
||||
self.domain_size = [int(self.base_size/4), int(self.base_size/4)]
|
||||
self.spur_probs= [ float(domain) for domain in list_train_domains ]
|
||||
|
||||
elif self.data_case == 'test':
|
||||
self.domain_size = [self.base_size]
|
||||
self.spur_probs= [1.0]
|
||||
|
||||
print('\n')
|
||||
print('Data Case: ', self.data_case)
|
||||
|
||||
self.train_data, self.train_labels, self.train_domain, self.train_indices = self._get_data(self.domain_size, self.data_dim, self.total_slabs, self.spur_probs)
|
||||
|
||||
def _get_data(self, domain_size, data_dim, total_slabs, spur_probs):
|
||||
|
||||
list_data = []
|
||||
list_labels = []
|
||||
list_objs= []
|
||||
total_domains= len(domain_size)
|
||||
|
||||
for idx in range(total_domains):
|
||||
|
||||
num_samples= domain_size[idx]
|
||||
spur_prob= spur_probs[idx]
|
||||
|
||||
_, data, labels, match_obj= get_data(num_samples, spur_prob, total_slabs)
|
||||
print('Source Domain: ', idx, ' Size: ', data.shape, labels.shape, match_obj.shape)
|
||||
list_data.append(torch.tensor(data))
|
||||
list_labels.append(torch.tensor(labels))
|
||||
list_objs.append(match_obj)
|
||||
|
||||
# Stack data from the different domains
|
||||
data_feat = torch.cat(list_data)
|
||||
data_labels = torch.cat(list_labels)
|
||||
data_objs= np.hstack(list_objs)
|
||||
|
||||
# Create domain labels
|
||||
data_domains = torch.zeros(data_labels.size())
|
||||
domain_start=0
|
||||
for idx in range(total_domains):
|
||||
curr_domain_size= domain_size[idx]
|
||||
data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
|
||||
domain_start+= curr_domain_size
|
||||
|
||||
# Shuffle everything one more time
|
||||
# inds = np.arange(data_labels.size()[0])
|
||||
# np.random.shuffle(inds)
|
||||
# data_feat = data_feat[inds]
|
||||
# data_labels = data_labels[inds].long()
|
||||
# data_domains = data_domains[inds].long()
|
||||
|
||||
# Convert to onehot
|
||||
y = torch.eye(2)
|
||||
data_labels = y[data_labels]
|
||||
# # Convert to onehot
|
||||
# d = torch.eye(len(self.list_train_domains))
|
||||
# data_domains = d[data_domains]
|
||||
|
||||
#Type Casting
|
||||
data_feat= data_feat.type(torch.FloatTensor)
|
||||
data_labels = data_labels.long()
|
||||
data_domains = data_domains.long()
|
||||
|
||||
print('Final Dataset: ', data_feat.shape, data_labels.shape, data_domains.shape, data_objs.shape)
|
||||
return data_feat, data_labels, data_domains, data_objs
|
|
@ -63,6 +63,10 @@ class BaseEval():
|
|||
from models.lenet import LeNet5
|
||||
phi= LeNet5()
|
||||
|
||||
if self.args.model_name == 'slab':
|
||||
from models.slab import SlabClf
|
||||
phi= SlabClf(self.args.slab_data_dim, self.args.out_classes)
|
||||
|
||||
if self.args.model_name == 'fc':
|
||||
from models.fc import FC
|
||||
if self.args.method_name in ['csd', 'matchdg_ctr']:
|
||||
|
@ -111,7 +115,7 @@ class BaseEval():
|
|||
|
||||
def load_model(self, run_matchdg_erm):
|
||||
|
||||
if self.args.method_name in ['erm_match', 'csd', 'irm']:
|
||||
if self.args.method_name in ['erm_match', 'csd', 'irm', 'perf_match']:
|
||||
self.save_path= self.base_res_dir + '/Model_' + self.post_string
|
||||
|
||||
elif self.args.method_name == 'matchdg_ctr':
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
import torch.utils.data
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class SlabClf(nn.Module):
|
||||
def __init__(self, inp_shape, out_shape):
|
||||
|
||||
super(SlabClf, self).__init__()
|
||||
self.inp_shape = inp_shape
|
||||
self.out_shape = out_shape
|
||||
self.hidden_dim = 100
|
||||
self.feat_net= nn.Sequential(
|
||||
nn.Linear( self.inp_shape, self.hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.fc= nn.Sequential(
|
||||
nn.Linear( self.hidden_dim, self.hidden_dim),
|
||||
nn.Linear( self.hidden_dim, self.out_shape),
|
||||
)
|
||||
|
||||
self.disc= nn.Sequential(
|
||||
nn.Linear( self.hidden_dim, self.hidden_dim),
|
||||
nn.Linear( self.hidden_dim, 2),
|
||||
)
|
||||
|
||||
self.embedding = nn.Embedding(2, self.hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(self.feat_net(x))
|
4
test.py
4
test.py
|
@ -45,6 +45,10 @@ parser.add_argument('--img_h', type=int, default= 224,
|
|||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
|
|
11
train.py
11
train.py
|
@ -42,6 +42,10 @@ parser.add_argument('--img_h', type=int, default= 224,
|
|||
help='Height of the image in dataset')
|
||||
parser.add_argument('--img_w', type=int, default= 224,
|
||||
help='Width of the image in dataset')
|
||||
parser.add_argument('--slab_data_dim', type=int, default= 2,
|
||||
help='Number of features in the slab dataset')
|
||||
parser.add_argument('--slab_total_slabs', type=int, default=7)
|
||||
parser.add_argument('--slab_num_samples', type=int, default=1000)
|
||||
parser.add_argument('--fc_layer', type=int, default= 1,
|
||||
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
|
||||
parser.add_argument('--match_layer', type=str, default='logit_match',
|
||||
|
@ -172,6 +176,13 @@ for run in range(args.n_runs):
|
|||
test_dataset, base_res_dir,
|
||||
run, cuda
|
||||
)
|
||||
if args.method_name == 'perf_match':
|
||||
from algorithms.perf_match import PerfMatch
|
||||
train_method= PerfMatch(
|
||||
args, train_dataset, val_dataset,
|
||||
test_dataset, base_res_dir,
|
||||
run, cuda
|
||||
)
|
||||
elif args.method_name == 'matchdg_ctr':
|
||||
from algorithms.match_dg import MatchDG
|
||||
ctr_phase=1
|
||||
|
|
|
@ -14,6 +14,23 @@ from torchvision.utils import save_image
|
|||
from torch.autograd import Variable
|
||||
import torch.utils.data as data_utils
|
||||
|
||||
# Slab Dataset: Flatten the tensor along batch and domain axis
|
||||
# Input of the shape (Batch, Domain, Feat)
|
||||
def slab_batch_process(x, y, d, o):
|
||||
if len(x.shape) > 2:
|
||||
x= x.flatten(start_dim=0, end_dim=1)
|
||||
|
||||
if len(y.shape) > 1:
|
||||
y= y.flatten(start_dim=0, end_dim=1)
|
||||
|
||||
if len(d.shape) > 1:
|
||||
d= d.flatten(start_dim=0, end_dim=1)
|
||||
|
||||
if len(o.shape) > 1:
|
||||
o= o.flatten(start_dim=0, end_dim=1)
|
||||
|
||||
return x, y, d, o
|
||||
|
||||
def t_sne_plot(X):
|
||||
X= X.detach().cpu().numpy()
|
||||
X= TSNE(n_components=2).fit_transform(X)
|
||||
|
@ -214,6 +231,8 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
else:
|
||||
from data.pacs_loader import PACS
|
||||
|
||||
elif args.dataset_name == 'slab':
|
||||
from data.slab_loader import SlabData
|
||||
|
||||
if data_case == 'train':
|
||||
match_func=True
|
||||
|
@ -237,7 +256,10 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
|
|||
except AttributeError:
|
||||
batch_size= batch_size
|
||||
|
||||
if args.dataset_name in ['pacs', 'vlcs']:
|
||||
if args.dataset_name == 'slab':
|
||||
data_obj= SlabData(args, domains, '/slab/', data_case=data_case, match_func=match_func, base_size=args.slab_num_samples, freq_ratio=50, data_dim=args.slab_data_dim, total_slabs=args.slab_total_slabs)
|
||||
|
||||
elif args.dataset_name in ['pacs', 'vlcs']:
|
||||
data_obj= PACS(args, domains, '/pacs/train_val_splits/', data_case=data_case, match_func=match_func)
|
||||
|
||||
elif args.dataset_name in ['chestxray']:
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
import sys
|
||||
|
||||
import random, os, copy, pickle, time, random, argparse, itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn import metrics
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
|
||||
import utils.scripts.gpu_utils as gu
|
||||
import utils.scripts.lms_utils as au
|
||||
import utils.scripts.synth_models as synth_models
|
||||
import utils.scripts.utils as utils
|
||||
import matplotlib.pyplot as plt
|
||||
import pathlib
|
||||
|
||||
try:
|
||||
sys.path.append('../../cifar10_models/')
|
||||
import cifar10_models as c10
|
||||
c10_not_found = False
|
||||
except:
|
||||
c10_not_found = True
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
REPO_DIR = pathlib.Path(__file__).parent.parent.absolute()
|
||||
DOWNLOAD_DIR = os.path.join(REPO_DIR, 'datasets')
|
||||
|
||||
def msd(x, r=3):
|
||||
return np.round(np.mean(x), r), np.round(np.std(x), r)
|
||||
|
||||
def _get_dataloaders(trd, ted, bs, pm=True, shuffle=True):
|
||||
train_dl = DataLoader(trd, batch_size=bs, shuffle=shuffle, pin_memory=pm)
|
||||
test_dl = DataLoader(ted, batch_size=bs, pin_memory=pm)
|
||||
return train_dl, test_dl
|
||||
|
||||
def get_cifar10_models(device=None, pretrained=True):
|
||||
if c10_not_found: return {}
|
||||
device = gu.get_device(None) if device is None else device
|
||||
get_lmbda = lambda cls: (lambda: cls(pretrained=pretrained).eval().to(device))
|
||||
return {
|
||||
'vgg11_bn': get_lmbda(c10.vgg11_bn),
|
||||
'vgg13_bn': get_lmbda(c10.vgg13_bn),
|
||||
'vgg16_bn': get_lmbda(c10.vgg16_bn),
|
||||
'vgg19_bn': get_lmbda(c10.vgg19_bn),
|
||||
'resnet18': get_lmbda(c10.resnet18),
|
||||
'resnet34': get_lmbda(c10.resnet34),
|
||||
'resnet50': get_lmbda(c10.resnet50),
|
||||
'densenet121': get_lmbda(c10.densenet121),
|
||||
'densenet161': get_lmbda(c10.densenet161),
|
||||
'densenet169': get_lmbda(c10.densenet169),
|
||||
'mobilenet_v2': get_lmbda(c10.mobilenet_v2),
|
||||
'googlenet': get_lmbda(c10.googlenet),
|
||||
'inception_v3': get_lmbda(c10.inception_v3)
|
||||
}
|
||||
|
||||
def plot_decision_boundary(dl, model, c1, c2, ax=None, print_info=True):
|
||||
if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4))
|
||||
model = model.cpu()
|
||||
deps = sorted(au.get_feature_deps(dl, model).items(), key=lambda t: t[-1])
|
||||
|
||||
if print_info:
|
||||
for k, v in deps: print ('{}:{:.3f}'.format(k,v), end=', ')
|
||||
print ("")
|
||||
|
||||
X, Y = utils.extract_numpy_from_loader(dl)
|
||||
K = 100_000
|
||||
U = np.random.uniform(low=X.min(), high=X.max(), size=(K, X.shape[1])) # copy.deepcopy(X)
|
||||
U[:, c1] = np.random.uniform(low=X[:, c1].min(), high=X[:, c1].max(), size=K)
|
||||
U[:, c2] = np.random.uniform(low=X[:, c2].min(), high=X[:, c2].max(), size=K)
|
||||
U = torch.Tensor(U)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(U)
|
||||
Yu = torch.argmax(out, 1)
|
||||
|
||||
ax.scatter(U[:,c1], U[:,c2], c=Yu, alpha=0.3, s=24)
|
||||
ax.scatter(X[:,c1], X[:,c2], c=Y, cmap='coolwarm', s=12)
|
||||
|
||||
def get_binary_datasets(X, Y, y1, y2, image_width=28, use_cnn=False):
|
||||
assert type(X) is np.ndarray and type(Y) is np.ndarray
|
||||
idx0 = (Y==y1).nonzero()[0]
|
||||
idx1 = (Y==y2).nonzero()[0]
|
||||
idx = np.concatenate((idx0, idx1))
|
||||
X_, Y_ = X[idx,:], (Y[idx]==y2).astype(int)
|
||||
P = np.random.permutation(len(X_))
|
||||
X_, Y_ = X_[P,:], Y_[P]
|
||||
if use_cnn: X_ = X_.reshape(X.shape[0], -1, image_width)[:, None, :, :]
|
||||
return X_[P,:], Y_[P]
|
||||
|
||||
def get_binary_loader(dl, y1, y2):
|
||||
X, Y = utils.extract_numpy_from_loader(dl)
|
||||
X, Y = get_binary_datasets(X, Y, y1, y2)
|
||||
return utils._to_dl(X, Y, bs=dl.batch_size)
|
||||
|
||||
def get_mnist(fpath=DOWNLOAD_DIR, flatten=False, binarize=False, normalize=True, y0={0,1,2,3,4}):
|
||||
"""get preprocessed mnist torch.TensorDataset class"""
|
||||
def _to_torch(d):
|
||||
X, Y = [], []
|
||||
for xb, yb in d:
|
||||
X.append(xb)
|
||||
Y.append(yb)
|
||||
return torch.Tensor(np.stack(X)), torch.LongTensor(np.stack(Y))
|
||||
|
||||
to_tensor = torchvision.transforms.ToTensor()
|
||||
to_flat = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze())
|
||||
to_norm = torchvision.transforms.Normalize((0.5, ), (0.5, ))
|
||||
to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1)
|
||||
|
||||
transforms = [to_tensor]
|
||||
if normalize: transforms.append(to_norm)
|
||||
if flatten: transforms.append(to_flat)
|
||||
tf = torchvision.transforms.Compose(transforms)
|
||||
ttf = to_binary if binarize else None
|
||||
|
||||
X_tr = torchvision.datasets.MNIST(fpath, download=True, transform=tf, target_transform=ttf)
|
||||
X_te = torchvision.datasets.MNIST(fpath, download=True, train=False, transform=tf, target_transform=ttf)
|
||||
|
||||
return _to_torch(X_tr), _to_torch(X_te)
|
||||
|
||||
def get_mnist_dl(fpath=DOWNLOAD_DIR, to_np=False, bs=128, pm=False, shuffle=False,
|
||||
normalize=True, flatten=False, binarize=False, y0={0,1,2,3,4}):
|
||||
(X_tr, Y_tr), (X_te, Y_te) = get_mnist(fpath, normalize=normalize, flatten=flatten, binarize=binarize, y0=y0)
|
||||
tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, shuffle=shuffle, pin_memory=pm)
|
||||
te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, pin_memory=pm)
|
||||
return tr_dl, te_dl
|
||||
|
||||
def get_cifar(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none',
|
||||
means=None, std=None, use_grayscale=False, binarize=False, normalize=True, y0={0,1,2,3,4}):
|
||||
"""get preprocessed cifar torch.Dataset class"""
|
||||
|
||||
if transform_type == 'none':
|
||||
normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
|
||||
tensorize = torchvision.transforms.ToTensor()
|
||||
to_grayscale = torchvision.transforms.Grayscale()
|
||||
flatten = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze())
|
||||
|
||||
transforms = [tensorize]
|
||||
if use_grayscale: transforms = [to_grayscale] + transforms
|
||||
if normalize: transforms.append(normalize_cifar())
|
||||
if flatten_data: transforms.append(flatten)
|
||||
tr_transforms = te_transforms = torchvision.transforms.Compose(transforms)
|
||||
|
||||
if transform_type == 'basic':
|
||||
normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
|
||||
|
||||
tr_transforms= [
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor()
|
||||
]
|
||||
|
||||
te_transforms = [
|
||||
torchvision.transforms.Resize(32),
|
||||
torchvision.transforms.CenterCrop(32),
|
||||
torchvision.transforms.ToTensor(),
|
||||
]
|
||||
|
||||
if normalize:
|
||||
tr_transforms.append(normalize_cifar())
|
||||
te_transforms.append(normalize_cifar())
|
||||
|
||||
tr_transforms = torchvision.transforms.Compose(tr_transforms)
|
||||
te_transforms = torchvision.transforms.Compose(te_transforms)
|
||||
|
||||
to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1)
|
||||
target_transforms = to_binary if binarize else None
|
||||
dset = 'cifar10' if use_cifar10 else 'cifar100'
|
||||
func = torchvision.datasets.CIFAR10 if use_cifar10 else torchvision.datasets.CIFAR100
|
||||
|
||||
X_tr = func(fpath, download=True, transform=tr_transforms, target_transform=target_transforms)
|
||||
X_te = func(fpath, download=True, train=False, transform=te_transforms, target_transform=target_transforms)
|
||||
|
||||
return X_tr, X_te
|
||||
|
||||
def get_cifar_dl(fpath=DOWNLOAD_DIR, use_cifar10=False, bs=128, shuffle=True, transform_type='none',
|
||||
means=None, std=None, normalize=True, flatten_data=False, use_grayscale=False, nw=4, pm=False, binarize=False, y0={0,1,2,3,4}):
|
||||
"""data in dataloaders have has shape (B, C, W, H)"""
|
||||
d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, means=means, std=std, flatten_data=flatten_data, binarize=binarize, y0=y0)
|
||||
tr_dl = DataLoader(d_tr, batch_size=bs, shuffle=shuffle, num_workers=nw, pin_memory=pm)
|
||||
te_dl = DataLoader(d_te, batch_size=bs, num_workers=nw, pin_memory=pm)
|
||||
return tr_dl, te_dl
|
||||
|
||||
def get_cifar_np(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', normalize=True, binarize=False, y0={0,1,2,3,4}, use_grayscale=False):
|
||||
"""get numpy matrices of preprocessed cifar data"""
|
||||
|
||||
def _to_np(d):
|
||||
X, Y = [], []
|
||||
for xb, yb in d:
|
||||
X.append(xb)
|
||||
Y.append(yb)
|
||||
return map(np.stack, [X,Y])
|
||||
|
||||
d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, flatten_data=flatten_data, binarize=binarize, y0=y0)
|
||||
return _to_np(d_tr), _to_np(d_te)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
|
@ -0,0 +1,130 @@
|
|||
import os, copy, pickle, time
|
||||
import random, itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import pandas as pd
|
||||
import torchvision
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
import dill
|
||||
import gpu_utils as gu
|
||||
import data_utils as du
|
||||
import synth_models as sm
|
||||
import utils
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
|
||||
def _get_dummy_classifier(self):
|
||||
def dummy(x):
|
||||
return x
|
||||
return dummy
|
||||
|
||||
def __init__(self, models, num_classes, use_softmax=False):
|
||||
super(Ensemble, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.use_softmax = use_softmax
|
||||
|
||||
# register models as pytorch modules
|
||||
self.models = []
|
||||
for idx, m in enumerate(models,1):
|
||||
setattr(self, 'm{}'.format(idx), m.eval())
|
||||
self.models.append(getattr(self, 'm{}'.format(idx)))
|
||||
|
||||
self.classifier = self._get_dummy_classifier()
|
||||
|
||||
def _forward(self, x):
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
outs = self._forward(x)
|
||||
return self.classifier(outs)
|
||||
|
||||
def get_output_loader(self, dl, device=gu.get_device(None), bs=None):
|
||||
"""return dataloader of model output (logit or softmax prob)"""
|
||||
X, Y = [], []
|
||||
with torch.no_grad():
|
||||
for xb, yb in dl:
|
||||
xb = xb.to(device)
|
||||
out = self._forward(xb).cpu()
|
||||
X.append(out)
|
||||
Y.append(yb)
|
||||
X, Y = torch.cat(X), torch.cat(Y)
|
||||
return DataLoader(TensorDataset(X, Y), batch_size=bs or dl.batch_size)
|
||||
|
||||
def fit_classifier(self, tr_dl, te_dl, lr=0.05, adam=False, wd=5e-5, device=None, **fit_kw):
|
||||
device = gu.get_device(None) if device is None else device
|
||||
self = self.to(device)
|
||||
|
||||
c = dict(gap=1000, epsilon=1e-2, wd=5e-5, is_loss_epsilon=True)
|
||||
c.update(**fit_kw)
|
||||
|
||||
tro_dl = self.get_output_loader(tr_dl, device)
|
||||
teo_dl = self.get_output_loader(te_dl, device)
|
||||
|
||||
if adam: opt = optim.Adam(self.classifier.parameters())
|
||||
else: opt = optim.SGD(self.classifier.parameters(), lr=lr, weight_decay=wd)
|
||||
stats = utils.fit_model(self.classifier, F.cross_entropy, opt, tro_dl, teo_dl, device=device, **c)
|
||||
|
||||
self.classifier = stats['best_model'][-1].to(device)
|
||||
self = self.cpu()
|
||||
return stats
|
||||
|
||||
class EnsembleLinear(Ensemble):
|
||||
|
||||
def _get_classifier(self):
|
||||
# linear with equal weights and zero bias
|
||||
nl = self.num_classes*len(self.models)
|
||||
linear = nn.Linear(nl, self.num_classes, bias=self.use_bias)
|
||||
nn.init.ones_(linear.weight.data)
|
||||
linear.weight.data /= float(nl)
|
||||
if self.use_bias: linear.bias.data.zero_()
|
||||
return linear
|
||||
|
||||
def __init__(self, models, num_classes=2, use_softmax=False, use_bias=True):
|
||||
super(EnsembleLinear, self).__init__(models, num_classes, use_softmax)
|
||||
self.use_bias = use_bias
|
||||
self.classifier = self._get_classifier()
|
||||
|
||||
def _forward(self, x):
|
||||
outs = [m(x) for m in self.models]
|
||||
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
|
||||
outs = torch.stack(outs, dim=2)
|
||||
outs = outs.reshape(outs.shape[0], -1)
|
||||
return outs
|
||||
|
||||
class EnsembleMLP(Ensemble):
|
||||
|
||||
def _get_classifier(self):
|
||||
nl = self.num_classes*len(self.models)
|
||||
fcn = sm.get_fcn(nl, self.hdim or nl, self.num_classes, hl=self.hl)
|
||||
return fcn
|
||||
|
||||
def __init__(self, models, num_classes=2, use_softmax=False, hdim=None, hl=1):
|
||||
super(EnsembleMLP, self).__init__(models, num_classes, use_softmax)
|
||||
self.hdim = hdim
|
||||
self.hl = hl
|
||||
self.classifier = self._get_classifier()
|
||||
|
||||
def _forward(self, x):
|
||||
outs = [m(x) for m in self.models]
|
||||
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
|
||||
outs = torch.stack(outs, dim=2)
|
||||
outs = outs.reshape(outs.shape[0], -1)
|
||||
return outs
|
||||
|
||||
class EnsembleAverage(Ensemble):
|
||||
|
||||
def __init__(self, models, num_classes=2, use_softmax=False):
|
||||
super(EnsembleAverage, self).__init__(models, num_classes, use_softmax)
|
||||
self.classifier = self._get_dummy_classifier()
|
||||
|
||||
def _forward(self, x):
|
||||
outs = [m(x) for m in self.models]
|
||||
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
|
||||
outs = torch.stack(outs)
|
||||
return outs.mean(dim=0)
|
||||
|
||||
def fit_classifier(self, *args, **kwargs):
|
||||
return None
|
|
@ -0,0 +1,316 @@
|
|||
import numpy as np
|
||||
import scipy.stats as scs
|
||||
import random
|
||||
from collections import Counter
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import utils.scripts.utils as utils
|
||||
import utils.scripts.gpu_utils as gu
|
||||
|
||||
def _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=None):
|
||||
X_te, Y_te = torch.Tensor(X[:N_te,:]), torch.Tensor(Y[:N_te])
|
||||
X_tr, Y_tr = torch.Tensor(X[N_te:,:]), torch.Tensor(Y[N_te:])
|
||||
Y_te, Y_tr = map(lambda Z: Z.long(), [Y_te, Y_tr])
|
||||
|
||||
tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=True)
|
||||
te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=False)
|
||||
|
||||
return {
|
||||
'X': torch.tensor(X).float(),
|
||||
'Y': torch.tensor(Y).long(),
|
||||
'w': w,
|
||||
'tr_dl': tr_dl,
|
||||
'te_dl': te_dl,
|
||||
'N': (N_tr, N_te),
|
||||
'W': orth_matrix
|
||||
}
|
||||
|
||||
def _get_random_data(N, dim, scale):
|
||||
X = np.random.uniform(size=(N, dim))
|
||||
X *= scale
|
||||
Y = np.random.choice([0,1], size=N)
|
||||
return X, Y
|
||||
|
||||
def generate_linsep_data_v2(N_tr, dim, eff_margin, width=10., bs=256, scale_noise=True, pm=True, nw=0, no_width=False, N_te=5000): # no unif_max.
|
||||
assert eff_margin < 1, "equal range constraint"
|
||||
margin = eff_margin if no_width else eff_margin*width
|
||||
|
||||
N = N_tr + N_te
|
||||
w = np.zeros(shape=dim)
|
||||
w[0] = 1
|
||||
|
||||
X, Y = _get_random_data(N, dim, width if scale_noise else 1.)
|
||||
|
||||
U = np.random.uniform(size=N)
|
||||
if no_width: X[:,0] = (2*Y-1)*margin
|
||||
else: X[:, 0] = (2*Y-1)*(margin + (width-margin)*U)
|
||||
|
||||
P = np.random.permutation(X.shape[0])
|
||||
X, Y = X[P,:], Y[P]
|
||||
|
||||
return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w)
|
||||
|
||||
def sample_from_unif_union_of_unifs(unifs, size):
|
||||
x = []
|
||||
choices = Counter(np.random.choice(list(range(len(unifs))), size=size))
|
||||
for choice, sz in choices.items():
|
||||
s = np.random.uniform(low=unifs[choice][0], high=unifs[choice][1], size=sz)
|
||||
x.append(s)
|
||||
x = np.concatenate(x)
|
||||
return x
|
||||
|
||||
def generate_ub_linslab_data_diffmargin_v2(N_tr, dim, eff_lin_margins, eff_slab_margins,
|
||||
slabs_per_coord, slab_p_vals, corrupt_lin=0., corrupt_slab=0.,
|
||||
corrupt_slab7=0., scale_noise=True, width=10., lin_coord=0, lin_shift=0.,
|
||||
slab_shift=0., indep_slabs=True, bs=256, pm=True, nw=0, N_te=10000,
|
||||
random_transform=False, corrupt_lin_margin=False, corrupt_5slab_margin=False):
|
||||
get_unif = lambda a: np.random.uniform(size=a)
|
||||
get_bool = lambda a: np.random.choice([0,1], size=a)
|
||||
get_sign = lambda a: 2*get_bool(a)-1.
|
||||
|
||||
def get_slab_width(NS, B, SM):
|
||||
if NS==3: return (2.*B-4.*SM)/3.
|
||||
if NS==5: return (2.*B-8.*SM)/5.
|
||||
if NS==7: return (2.*B-12.*SM)/7.
|
||||
return None
|
||||
|
||||
num_lin, num_slabs = map(len, [eff_lin_margins, eff_slab_margins])
|
||||
assert 0 <= corrupt_lin <= 1, "input is probability"
|
||||
assert num_lin + num_slabs <= dim, "dim constraint, num_lin: {}, num_slabs: {}, dim: {}".format(num_lin, num_slabs, dim)
|
||||
for elm in eff_lin_margins: assert 0 < elm < 1, "equal range constraint (0 < eff_lin_margin={} < 1)".format(elm)
|
||||
for esm in eff_slab_margins: assert 0 < esm < 1, "equal range constraint (0 < eff_slab_margin={} < 0.25)".format(esm)
|
||||
|
||||
lin_margins = list(map(lambda x: x*width, eff_lin_margins))
|
||||
slab_margins = list(map(lambda x: x*width, eff_slab_margins))
|
||||
|
||||
# hyperplane
|
||||
N = N_tr + N_te
|
||||
half_N = N//2
|
||||
w = np.zeros(shape=dim); w[0] = 1
|
||||
|
||||
X, Y = _get_random_data(N, dim, width if scale_noise else 1.)
|
||||
nrange = list(range(N))
|
||||
# linear
|
||||
total_corrupt = int(round(N*corrupt_lin))
|
||||
no_linear = num_lin == 0
|
||||
if not no_linear:
|
||||
for coord, lin_margin in enumerate(lin_margins):
|
||||
if indep_slabs:
|
||||
P = np.random.permutation(N)
|
||||
X, Y = X[P, :], Y[P]
|
||||
X[:, coord] = (2*Y-1)*(lin_margin+(width-lin_margin)*get_unif(N)) + lin_shift*width
|
||||
|
||||
# corrupt linear coordinate
|
||||
if total_corrupt > 0:
|
||||
corrupt_sample = np.random.choice(nrange, size=total_corrupt, replace=False)
|
||||
if corrupt_lin_margin:
|
||||
X[corrupt_sample, 0] = np.random.uniform(low=-lin_margin, high=lin_margin, size=total_corrupt)
|
||||
else:
|
||||
X[corrupt_sample, 0] *= -1
|
||||
|
||||
# slabs
|
||||
i = (num_lin)*int(not no_linear)
|
||||
for idx, coord in enumerate(range(i, i+num_slabs)):
|
||||
slab_per = slabs_per_coord[idx]
|
||||
assert slab_per in [3, 5, 7], "Invalid slabs_per_coord"
|
||||
|
||||
slab_pval = slab_p_vals[idx]
|
||||
slab_margin = slab_margins[idx]
|
||||
slab_width = get_slab_width(slab_per, width, slab_margin)
|
||||
|
||||
if indep_slabs:
|
||||
P = np.random.permutation(N)
|
||||
X, Y = X[P, :], Y[P]
|
||||
|
||||
if slab_per == 3:
|
||||
# positive slabs
|
||||
idx_p = (Y==1).nonzero()[0]
|
||||
offset = 0.5*slab_width + 2*slab_margin
|
||||
X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p)))
|
||||
|
||||
# negative center
|
||||
idx_n = (Y==0).nonzero()[0]
|
||||
X[idx_n, coord] = 0.5*get_sign(len(idx_n))*slab_width*get_unif(len(idx_n))
|
||||
|
||||
if slab_per == 5:
|
||||
# positive slabs
|
||||
idx_p = (Y==1).nonzero()[0]
|
||||
offset = (width+6*slab_margin)/5.
|
||||
X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p)))
|
||||
|
||||
# negative slabs partitioned using p val
|
||||
idx_n = (Y==0).nonzero()[0]
|
||||
in_ctr = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_n))
|
||||
idx_nc, idx_ns = idx_n[(in_ctr==1)], idx_n[(in_ctr==0)]
|
||||
|
||||
# negative center
|
||||
X[idx_nc, coord] = 0.5*get_sign(len(idx_nc))*slab_width*get_unif(len(idx_nc))
|
||||
|
||||
# negative sides
|
||||
offset = (8*slab_margin+3*width)/5.
|
||||
X[idx_ns, coord] = get_sign(len(idx_ns))*(offset+slab_width*get_unif(len(idx_ns)))
|
||||
|
||||
# corrupt slab 5
|
||||
total_corrupt = int(round(N*corrupt_slab))
|
||||
if total_corrupt > 0:
|
||||
if corrupt_5slab_margin:
|
||||
offset1 = (width+6*slab_margin)/5.
|
||||
offset2 = (8*slab_margin+3*width)/5.
|
||||
unifs = [
|
||||
(0.5*slab_width, offset1),
|
||||
(offset1+slab_width, offset2),
|
||||
(-offset1, -0.5*slab_width),
|
||||
(-offset2, -(offset1+slab_width))
|
||||
]
|
||||
|
||||
idx = np.random.choice(range(N), size=total_corrupt, replace=False)
|
||||
X[idx, coord] = sample_from_unif_union_of_unifs(unifs, total_corrupt)
|
||||
else:
|
||||
# get corrupt sample
|
||||
idx = np.random.choice(range(N), size=total_corrupt, replace=False)
|
||||
idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1)
|
||||
idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1)
|
||||
|
||||
# move negative points to random positive slabs
|
||||
offset = (0.5*slab_width+2*slab_margin)
|
||||
X[idx_n, coord] = torch.Tensor(get_sign(len(idx_n))*(offset+slab_width*get_unif(len(idx_n))))
|
||||
|
||||
# pick negative slab for each positve point
|
||||
mv_to_ctr = np.random.choice([0, 1], size=len(idx_p))
|
||||
idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1)
|
||||
idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1)
|
||||
|
||||
# move positive points to negative slabs
|
||||
X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*slab_width*get_unif(len(idx_p_ctr)))
|
||||
|
||||
# move negative points to positve slabs
|
||||
offset = 1.5*slab_width + 4*slab_margin
|
||||
X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid))))
|
||||
|
||||
if slab_per == 7:
|
||||
# positive slabs
|
||||
idx_p = (Y==1).nonzero()[0]
|
||||
in_s0 = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_p))
|
||||
idx_p0, idx_p1 = idx_p[(in_s0==1)], idx_p[(in_s0==0)]
|
||||
|
||||
# positive slab 0 (inner)
|
||||
offset = 0.5*slab_width+2*slab_margin
|
||||
X[idx_p0, coord] = get_sign(len(idx_p0))*(offset+slab_width*get_unif(len(idx_p0)))
|
||||
|
||||
# positive slab 1 (outer)
|
||||
offset = 2.5*slab_width+6*slab_margin
|
||||
X[idx_p1, coord] = get_sign(len(idx_p1))*(offset+slab_width*get_unif(len(idx_p1)))
|
||||
|
||||
# negative slabs
|
||||
idx_n = (Y==0).nonzero()[0]
|
||||
in_s0 = get_bool(len(idx_n))
|
||||
idx_n0, idx_n1 = idx_n[(in_s0==1)], idx_n[(in_s0==0)]
|
||||
|
||||
# negative slab 0 (center)
|
||||
X[idx_n0, coord] = 0.5*get_sign(len(idx_n0))*slab_width*get_unif(len(idx_n0))
|
||||
|
||||
# negative slab 1 (outer)
|
||||
offset = 1.5*slab_width+4*slab_margin
|
||||
X[idx_n1, coord] = get_sign(len(idx_n1))*(offset+slab_width*get_unif(len(idx_n1)))
|
||||
|
||||
# corrupt slab7
|
||||
total_corrupt = int(round(N*corrupt_slab7))
|
||||
if total_corrupt > 0:
|
||||
# corrupt data
|
||||
idx = np.random.choice(range(len(X)), size=total_corrupt, replace=False)
|
||||
idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1)
|
||||
idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1)
|
||||
|
||||
# pick positive slab for each negative slab
|
||||
mv_to_inner = get_bool(len(idx_n))
|
||||
idx_n_inner = idx_n[np.argwhere(mv_to_inner==1)].reshape(-1)
|
||||
idx_n_outer = idx_n[np.argwhere(mv_to_inner==0)].reshape(-1)
|
||||
|
||||
# move to idx_n_inner and outer
|
||||
offset = 0.5*slab_width+2*slab_margin
|
||||
X[idx_n_inner, coord] = torch.Tensor(get_sign(len(idx_n_inner))*(offset+slab_width*get_unif(len(idx_n_inner))))
|
||||
offset = 2.5*slab_width+6*slab_margin
|
||||
X[idx_n_outer, coord] = torch.Tensor(get_sign(len(idx_n_outer))*(offset+slab_width*get_unif(len(idx_n_outer))))
|
||||
|
||||
# pick negative slab for each positive point
|
||||
mv_to_ctr = get_bool(len(idx_p))
|
||||
idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1)
|
||||
idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1)
|
||||
|
||||
# move to idx_n_inner and outer
|
||||
X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*(slab_width*get_unif(len(idx_p_ctr))))
|
||||
offset = 1.5*slab_width+4*slab_margin
|
||||
X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid))))
|
||||
|
||||
# shift
|
||||
X[:, coord] += slab_shift*width
|
||||
|
||||
# reshuffle
|
||||
P = np.random.permutation(N)
|
||||
X, Y = X[P,:], Y[P]
|
||||
|
||||
# lin coord position
|
||||
if not random_transform and lin_coord != 0:
|
||||
X[:, [0, lin_coord]] = X[:, [lin_coord, 0]]
|
||||
|
||||
# transform
|
||||
W = np.eye(dim)
|
||||
if random_transform: W = utils.get_orthonormal_matrix(dim)
|
||||
X = X.dot(W)
|
||||
|
||||
return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=W)
|
||||
|
||||
|
||||
def generate_ub_linslab_data_v2(N_tr, dim, eff_lin_margin, eff_slab_margin=None, lin_coord=0,
|
||||
corrupt_lin=0., corrupt_slab=0., corrupt_slab3=0., corrupt_slab7=0.,
|
||||
scale_noise=True, num_lin=1, lin_shift=0., slab_shift=0., random_transform=False,
|
||||
num_slabs=1, slabs_per_coord=5, width=10., indep_slabs=True, no_linear=False,
|
||||
bs=256, pm=True, nw=0, N_te=10000, corrupt_lin_margin=False, slab5_pval=3/4.,
|
||||
slab3_pval=1/2., slab7_pval=7/8., corrupt_5slab_margin=False):
|
||||
slab_p_map = {5: slab5_pval, 7: slab7_pval, 3: slab3_pval}
|
||||
slabs_per_coord = [slabs_per_coord]*num_slabs if type(slabs_per_coord) is int else slabs_per_coord[:]
|
||||
for x in slabs_per_coord: assert x in slab_p_map
|
||||
slab_p_vals = [slab_p_map[x] for x in slabs_per_coord]
|
||||
lms = [eff_lin_margin]*num_lin
|
||||
sms = eff_slab_margin if type(eff_slab_margin) is list else [eff_slab_margin]*num_slabs
|
||||
return generate_ub_linslab_data_diffmargin_v2(N_tr, dim, lms, sms, slabs_per_coord, slab_p_vals, lin_coord=lin_coord, corrupt_slab=corrupt_slab,
|
||||
corrupt_slab7=corrupt_slab7, corrupt_lin=corrupt_lin, scale_noise=scale_noise, width=width,
|
||||
lin_shift=lin_shift, slab_shift=slab_shift, random_transform=random_transform, indep_slabs=indep_slabs,
|
||||
pm=pm, bs=bs, corrupt_lin_margin=corrupt_lin_margin, nw=nw, N_te=N_te, corrupt_5slab_margin=corrupt_5slab_margin)
|
||||
|
||||
|
||||
def get_lms_data(**kw):
|
||||
|
||||
c = config = {
|
||||
'num_train': 100_000,
|
||||
'dim': 20,
|
||||
'lin_margin': 0.1,
|
||||
'slab_margin': 0.1,
|
||||
'same_margin': False,
|
||||
'random_transform': False,
|
||||
'width': 1, # data width
|
||||
'bs': 256,
|
||||
'corrupt_lin': 0.0,
|
||||
'corrupt_lin_margin': False,
|
||||
'corrupt_slab': 0.0,
|
||||
'num_test': 2_000,
|
||||
'hdim': 200, # model width
|
||||
'hl': 2, # model depth
|
||||
'device': gu.get_device(0),
|
||||
'input_dropout': 0,
|
||||
'num_lin': 1,
|
||||
'num_slabs': 19,
|
||||
'num_slabs7': 0,
|
||||
'num_slabs3': 0,
|
||||
}
|
||||
|
||||
c.update(kw)
|
||||
|
||||
smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin']
|
||||
data_func = generate_ub_linslab_data_v2
|
||||
spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7']
|
||||
data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'],
|
||||
corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'],
|
||||
corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab'])
|
||||
return data, c
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
try: import pycuda.driver as cuda
|
||||
except: print ("pycuda not available")
|
||||
|
||||
import torch
|
||||
import sys, os, glob, subprocess
|
||||
|
||||
def get_gpu_info(print_info=True, get_specs=False):
|
||||
cuda.init()
|
||||
if get_specs: gpu_specs = cuda.Device(0).get_attributes() # assume same for all (dnnx)
|
||||
else: gpu_specs = None
|
||||
|
||||
gpu_info = {
|
||||
'available': torch.cuda.is_available(),
|
||||
'num_devices': torch.cuda.device_count(),
|
||||
'devices': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]),
|
||||
'current device id': torch.cuda.current_device(),
|
||||
'allocated memory': torch.cuda.memory_allocated(),
|
||||
'cached memory': torch.cuda.memory_cached()
|
||||
}
|
||||
|
||||
if print_info:
|
||||
for k,v in gpu_info.items(): print ("{}: {}".format(k, v))
|
||||
|
||||
return gpu_info, gpu_specs
|
||||
|
||||
|
||||
def get_device(device_id=None): # None -> cpu
|
||||
device = 'cuda:{}'.format(device_id) if device_id is not None else 'cpu'
|
||||
device = torch.device(device if torch.cuda.is_available() and device_id is not None else 'cpu')
|
||||
return device
|
||||
|
||||
def get_gpu_name():
|
||||
try:
|
||||
out_str = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE).stdout
|
||||
out_list = out_str.decode("utf-8").split('\n')
|
||||
out_list = out_list[1:-1]
|
||||
return out_list
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def get_cuda_version():
|
||||
"""Get CUDA version"""
|
||||
if sys.platform == 'win32':
|
||||
raise NotImplementedError("Implement this!")
|
||||
# This breaks on linux:
|
||||
#cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA"
|
||||
#path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\version.txt"
|
||||
elif sys.platform == 'linux' or sys.platform == 'darwin':
|
||||
path = '/usr/local/cuda/version.txt'
|
||||
else:
|
||||
raise ValueError("Not in Windows, Linux or Mac")
|
||||
if os.path.isfile(path):
|
||||
with open(path, 'r') as f:
|
||||
data = f.read().replace('\n','')
|
||||
return data
|
||||
else:
|
||||
return "No CUDA in this machine"
|
||||
|
||||
def get_cudnn_version():
|
||||
"""Get CUDNN version"""
|
||||
if sys.platform == 'win32':
|
||||
raise NotImplementedError("Implement this!")
|
||||
# This breaks on linux:
|
||||
#cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA"
|
||||
#candidates = ["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\include\\cudnn.h"]
|
||||
elif sys.platform == 'linux':
|
||||
candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h',
|
||||
'/usr/local/cuda/include/cudnn.h',
|
||||
'/usr/include/cudnn.h']
|
||||
elif sys.platform == 'darwin':
|
||||
candidates = ['/usr/local/cuda/include/cudnn.h',
|
||||
'/usr/include/cudnn.h']
|
||||
else:
|
||||
raise ValueError("Not in Windows, Linux or Mac")
|
||||
for c in candidates:
|
||||
file = glob.glob(c)
|
||||
if file: break
|
||||
if file:
|
||||
with open(file[0], 'r') as f:
|
||||
version = ''
|
||||
for line in f:
|
||||
if "#define CUDNN_MAJOR" in line:
|
||||
version = line.split()[-1]
|
||||
if "#define CUDNN_MINOR" in line:
|
||||
version += '.' + line.split()[-1]
|
||||
if "#define CUDNN_PATCHLEVEL" in line:
|
||||
version += '.' + line.split()[-1]
|
||||
if version:
|
||||
return version
|
||||
else:
|
||||
return "Cannot find CUDNN version"
|
||||
else:
|
||||
return "No CUDNN in this machine"
|
||||
|
||||
if __name__=='__main__':
|
||||
print ('gpu name', get_gpu_name())
|
||||
print ('cuda', get_cuda_version())
|
||||
print ('cudnn', get_cudnn_version())
|
||||
print ('device0', get_device(0))
|
||||
print ('available', torch.cuda.is_available())
|
|
@ -0,0 +1,307 @@
|
|||
import seaborn as sns
|
||||
import utils.scripts.gpu_utils as gu
|
||||
import utils.scripts.data_utils as du
|
||||
import utils
|
||||
import random
|
||||
import os, copy, pickle, time
|
||||
import itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import pandas as pd
|
||||
#import foolbox
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
def parse_data(exps=None, root='/', **funcs_kw):
|
||||
"""
|
||||
main function (parse data files and run added functions on it)
|
||||
"""
|
||||
exps = exps if exps is not None else os.listdir(root)
|
||||
total = len(exps)
|
||||
print ("total: {}".format(total))
|
||||
parsed = defaultdict(dict)
|
||||
if total == 0: return parsed
|
||||
for idx, exp in enumerate(exps):
|
||||
if (idx+1) % 1 == 0: print (idx+1, end=' ', flush=True)
|
||||
# load data
|
||||
fpath = os.path.join(root, exp)
|
||||
try:
|
||||
data = torch.load(fpath, map_location=lambda storage, loc: storage)
|
||||
except:
|
||||
print ("File {} corrupted, skip.".format(fpath))
|
||||
continue
|
||||
config = data['config']
|
||||
|
||||
# make exp config key
|
||||
config['run'] = int(exp.rsplit('.', 1)[0][-1])
|
||||
config['fname'] = exp
|
||||
ckeys = ['exp_name', 'dim', 'num_train', 'lin_margin', 'slab_margin', 'num_slabs',
|
||||
'num_slabs', 'width', 'hdim', 'hl', 'linear', 'use_bn', 'run', 'fname',
|
||||
'weight_decay', 'dropout']
|
||||
ckeys = [c for c in ckeys if c in config]
|
||||
cvals = [config[k] for k in ckeys]
|
||||
ckv = tuple(zip(ckeys, cvals))
|
||||
|
||||
# save config
|
||||
parsed[ckv]['config'] = config
|
||||
|
||||
# save functions
|
||||
for func_name, func in funcs_kw.items():
|
||||
parsed[ckv][func_name] = func(data)
|
||||
|
||||
return parsed
|
||||
|
||||
def parse_exp_stats(data):
|
||||
"""training summary statistics"""
|
||||
stats = data['stats']
|
||||
s = {}
|
||||
|
||||
# loss + accuracy
|
||||
for t1, t2 in itertools.product(['acc', 'loss'], ['tr', 'te']):
|
||||
s['{}_{}'.format(t1,t2)] = stats['{}_{}'.format(t1, t2)][-1]
|
||||
s['orig_stats'] = stats
|
||||
s['acc_gap'] = s['acc_tr']-s['acc_te']
|
||||
s['loss_gap'] = s['loss_te']-s['loss_tr']
|
||||
s['fin_acc_te'] = stats['acc_te'][-1]
|
||||
s['fin_acc_tr'] = stats['acc_tr'][-1]
|
||||
|
||||
# updates
|
||||
s['update_gap'] = stats['update_gap']
|
||||
s['num_updates'] = stats['num_updates']
|
||||
|
||||
# effective number of updates
|
||||
for acc_threshold in [0.96, 0.97, 0.98, 0.99, 1]:
|
||||
eff = np.argmin(np.abs(np.array(stats['acc_tr'])-acc_threshold))*s['update_gap']
|
||||
s['effective_num_updates{}'.format(int(acc_threshold*100))] = eff
|
||||
|
||||
return s
|
||||
|
||||
def parse_exp_model(data):
|
||||
"""model parameter stats"""
|
||||
depth = data['config']['hl']
|
||||
linear = data['config']['linear']
|
||||
mtype = data['config'].get('mtype', 'fcn')
|
||||
if mtype == 'fcn' and depth == 1 and not linear: d = parse_exp_depth1_model(data)
|
||||
if mtype == 'fcn' and depth == 1 and linear: d = parse_exp_linear_model(data)
|
||||
return {}
|
||||
|
||||
def parse_exp_depth1_model(data):
|
||||
"""cosine + w2"""
|
||||
device = gu.get_device()
|
||||
model = data['model'].to(device)
|
||||
p = W1, b1, w2, b2 = list(map(lambda x: x.detach().numpy(), model.parameters()))
|
||||
s = {}
|
||||
s['params'] = p
|
||||
s['cosine'] = W1[:, 0]/np.linalg.norm(W1, axis=1)
|
||||
s['l2'] = np.linalg.norm(W1, axis=1)
|
||||
s['w2'] = w2
|
||||
s['corr0'] = np.corrcoef(s['cosine'], w2[0, :])[0,1]
|
||||
s['corr1'] = np.corrcoef(s['cosine'], w2[1, :])[0,1]
|
||||
s['max_weight_cosine'] = s['cosine'][np.argmax(s['w2'][1,:])]
|
||||
return s
|
||||
|
||||
def parse_exp_linear_model(data):
|
||||
"""cosine"""
|
||||
device = gu.get_device()
|
||||
model = data['model'].to(device)
|
||||
p = W,b = list(map(lambda x: x.detach().numpy(), model.parameters()))
|
||||
s = {}
|
||||
s['cosine0'], s['cosine1'] = W[:, 0]/np.linalg.norm(W, axis=1)
|
||||
return s
|
||||
|
||||
def parse_exp_data(data, load_X=False):
|
||||
s = {}
|
||||
model = data['model'].to(gu.get_device())
|
||||
data = data['data']
|
||||
X, Y = data['X'], data['Y']
|
||||
|
||||
if type(X) != np.ndarray:
|
||||
X = data['X'].detach().cpu()
|
||||
|
||||
if type(X) != np.ndarray:
|
||||
Y = data['Y'].detach().cpu()
|
||||
|
||||
s['Y'] = Y
|
||||
if load_X: s['X'] = X
|
||||
s['Y_'] = get_yhat(model, X)
|
||||
s['model'] = model
|
||||
return s
|
||||
|
||||
def get_yhat(model, data):
|
||||
if type(data)==np.ndarray: data = torch.Tensor(data)
|
||||
return torch.argmax(model(data), 1)
|
||||
|
||||
def get_acc(y,yhat):
|
||||
n = float(len(y))
|
||||
return (y==yhat).sum().item()/n
|
||||
|
||||
def parse_and_get_df(root, prefix, files=None, device_id=None, only_load=False, only_linear=False, sample_pct=0.5, load_X=False, use_model_pred=False):
|
||||
exps = files if files is not None else [f for f in os.listdir(root) if f.startswith(prefix)]
|
||||
|
||||
funcs = {
|
||||
'config': lambda d: d['config'],
|
||||
'stats': parse_exp_stats,
|
||||
'model': parse_exp_model,
|
||||
'data': lambda x: parse_exp_data(x, load_X=load_X),
|
||||
'random_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='random', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id),
|
||||
'swap_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='swap', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id),
|
||||
}
|
||||
|
||||
P = parse_data(root=root, exps=exps, **funcs)
|
||||
if only_load: return P
|
||||
|
||||
D = []
|
||||
for idx, (k,v) in enumerate(P.items(),1):
|
||||
d = OrderedDict()
|
||||
for a,b in k: d[a] = b
|
||||
for vk in ['model', 'data', 'stats', 'config']:
|
||||
for a,b in v[vk].items(): d[a] = b
|
||||
for vk in ['random_dep', 'swap_dep']:
|
||||
for coord, dep in v[vk].items():
|
||||
d[f'{vk[0]}dep_{coord}'] = dep
|
||||
D.append(d)
|
||||
|
||||
df = pd.DataFrame(D)
|
||||
if len(df): df['nd'] = df['num_train']/df['dim']
|
||||
return df
|
||||
|
||||
def viz(d, c1, c2, k=80_000, info=True, plot_dm=True, plot_data=True, use_yhat=False, unif_k=False, width=10, title=None, is_binary=False, dep_type='swap', ax=None):
|
||||
if 'W' not in d['data']: W = np.eye(d['config']['dim'])
|
||||
else: W = d['data']['W']
|
||||
if W is None: W = np.eye(d['config']['dim'])
|
||||
|
||||
z = parse_exp_data(d)
|
||||
X = d['data']['X']
|
||||
|
||||
# visualize un-transformed data...
|
||||
X_ = np.array(X).dot(W.T)
|
||||
Y, Y_ = z['Y'], z['Y_']
|
||||
model = d['model'].cpu()
|
||||
D = X.shape[1]
|
||||
kn = k if unif_k else len(X)
|
||||
K = torch.Tensor(np.random.uniform(size=(k, D)))*width if unif_k else np.array(X_)
|
||||
K[:, c1] = torch.Tensor(np.random.uniform(low=min(X_[:,c1]), high=max(X_[:,c1]), size=kn))
|
||||
K[:, c2] = torch.Tensor(np.random.uniform(low=min(X_[:,c2]), high=max(X_[:,c2]), size=kn))
|
||||
KO = model(torch.Tensor(np.array(K).dot(W)))
|
||||
if is_binary: KY = (KO > 0).squeeze().numpy()
|
||||
else: KY = torch.argmax(KO, 1).numpy()
|
||||
|
||||
if info:
|
||||
deps = get_feature_deps(d['data']['te_dl'], d['model'], W=d['data'].get('W', None), dep_type=dep_type)
|
||||
for k,v in sorted(deps.items(), reverse=False, key=lambda t: t[-1]): print ('{}:{:.3f}'.format(k,v), end=' ')
|
||||
print ("\n")
|
||||
|
||||
if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4))
|
||||
|
||||
if plot_dm: ax.scatter(K[:, c1], K[:, c2], c=KY, cmap='binary', s=8, alpha=.2)
|
||||
if plot_data: ax.scatter(X_[:, c1], X_[:, c2], c=Y_ if use_yhat else Y, cmap='coolwarm', s=8, alpha=.4)
|
||||
|
||||
ax.set_xlabel('e_{}'.format(c1))
|
||||
ax.set_ylabel('e_{}'.format(c2))
|
||||
ax.set_title(title if title else '')
|
||||
plt.tight_layout()
|
||||
return ax
|
||||
|
||||
def visualize_boundary(model, data, c1, c2, dim, ax=None, is_binary=False, use_yhat=False, width=1, unif_k=True, k=100_000, print_info=True, dep_type='random'):
|
||||
agg = {'model': model, 'data': data, 'config': dict(dim=dim)}
|
||||
return viz(agg, c1, c2, unif_k=unif_k, width=width, dep_type=dep_type, is_binary=is_binary, use_yhat=use_yhat, ax=ax, info=print_info)
|
||||
|
||||
def get_randomized_loader(dl, W, coordinates):
|
||||
"""
|
||||
dl: dataloader
|
||||
W: rotation matrix
|
||||
coordinates: list of coordinates to randomize
|
||||
output: randomized dataloader
|
||||
"""
|
||||
|
||||
def _randomize(X, coords):
|
||||
p = torch.randperm(len(X))
|
||||
for c in coords: X[:, c] = X[p, c]
|
||||
return X
|
||||
|
||||
# rotate data
|
||||
X, Y = map(copy.deepcopy, dl.dataset.tensors)
|
||||
dim = X.shape[1]
|
||||
if W is None: W = np.eye(dim)
|
||||
|
||||
rt_X = torch.Tensor(X.numpy().dot(W.T))
|
||||
rand_rt_X = _randomize(rt_X, coordinates)
|
||||
rand_X = torch.Tensor(rand_rt_X.numpy().dot(W))
|
||||
|
||||
return utils._to_dl(rand_X, Y, dl.batch_size)
|
||||
|
||||
|
||||
def get_feature_deps(dl, model, W=None, dep_type='random', only_linear=False, coords=None, metric='accuracy',
|
||||
use_model_pred=False, print_info=False, sample_pct=1.0, device_id=None):
|
||||
"""Compute feature dependencies using randomization or swapping"""
|
||||
def _randomize(X, Y, coords):
|
||||
p = torch.randperm(len(X))
|
||||
for c in coords: X[:, c] = X[p, c]
|
||||
return X
|
||||
|
||||
def _swap(X, Y, coords):
|
||||
idx0, idx1 = map(lambda c: (Y.numpy()==c).nonzero()[0], [0, 1])
|
||||
idx0_new = np.random.choice(idx1, size=len(idx0), replace=True)
|
||||
idx1_new = np.random.choice(idx0, size=len(idx1), replace=True)
|
||||
for c in coords: X[idx0, c], X[idx1, c] = X[idx0_new, c], X[idx1_new, c]
|
||||
return X
|
||||
|
||||
def _get_dep_data(X, Y, coords):
|
||||
return dict(random=_randomize, swap=_swap)[dep_type](X, Y, coords)
|
||||
|
||||
|
||||
assert metric in {'accuracy', 'loss', 'auc'}
|
||||
|
||||
# setup data
|
||||
device = gu.get_device(device_id)
|
||||
model = model.to(device)
|
||||
X, Y = map(lambda Z: Z.to(device), dl.dataset.tensors)
|
||||
Yh = get_yhat(model, X)
|
||||
dim = X.shape[1]
|
||||
if W is None: W = np.eye(dim)
|
||||
W = torch.Tensor(W).to(device)
|
||||
rt_X = torch.mm(X, torch.transpose(W,0,1))
|
||||
|
||||
# subsample data
|
||||
n_samp = int(round(sample_pct*len(rt_X)))
|
||||
perm = torch.randperm(len(rt_X))[:n_samp]
|
||||
rt_X, Y, Yh = rt_X[perm, :], Y[perm], Yh[perm]
|
||||
|
||||
# compute deps
|
||||
deps = {}
|
||||
|
||||
dims = list(range(dim))
|
||||
if coords is None and not only_linear: coords = dims
|
||||
if coords is None and only_linear: coords = [0,1]
|
||||
|
||||
for idx, coord in enumerate(coords):
|
||||
if print_info: print ('{}/{}'.format(idx, len(coords)), end=' ')
|
||||
rt_X_ = copy.deepcopy(rt_X).to(device)
|
||||
rt_X_ = _get_dep_data(rt_X_, Y, coord if type(coord) in (list, tuple) else [coord])
|
||||
X_ = torch.mm(rt_X_, W)
|
||||
Ys = get_yhat(model, X_)
|
||||
|
||||
key = tuple(coord) if type(coord) in (list, tuple) else coord
|
||||
|
||||
if metric == 'auc':
|
||||
L = utils.get_logits_given_tensor(X_, model, device=device, bs=250)
|
||||
S = L[:,1]-L[:,0]
|
||||
auc = roc_auc_score(Y.cpu().numpy(), S.cpu().numpy())
|
||||
deps[key] = auc
|
||||
elif metric == 'accuracy':
|
||||
deps[key] = get_acc(Yh if use_model_pred else Y, Ys)
|
||||
elif metric == 'loss':
|
||||
L = utils.get_logits_given_tensor(X_, model, device=device, bs=250)
|
||||
with torch.no_grad():
|
||||
loss_val = F.cross_entropy(L, Y).item()
|
||||
deps[key] = loss_val
|
||||
|
||||
return deps
|
||||
|
||||
def get_subset_feature_deps(dl, model, coords_set, comb_size, W=None, dep_type='random', sample_pct=0.5, device_id=None, print_info=False):
|
||||
coords = list(itertools.combinations(coords_set, comb_size))
|
||||
return get_feature_deps(dl, model, W=W, dep_type=dep_type, coords=coords, print_info=print_info, sample_pct=sample_pct, device_id=device_id)
|
|
@ -0,0 +1,98 @@
|
|||
import random
|
||||
import os, copy, pickle, time
|
||||
import itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import utils
|
||||
import gpu_utils as gu
|
||||
import data_utils as du
|
||||
|
||||
def get_binary_mnist(y1=0, y2=1, apply_padding=True, repeat_channels=True):
|
||||
|
||||
def _make_cifar_compatible(X):
|
||||
if apply_padding: X = np.stack([np.pad(X[i][0], 2)[None,:] for i in range(len(X))]) # pad
|
||||
if repeat_channels: X = np.repeat(X, 3, axis=1) # add channels
|
||||
return X
|
||||
|
||||
binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2)
|
||||
|
||||
tr_dl, te_dl = du.get_mnist_dl(normalize=False)
|
||||
Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl))
|
||||
Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl))
|
||||
Xtr, Xte = map(_make_cifar_compatible, [Xtr, Xte])
|
||||
return (Xtr, Ytr), (Xte, Yte)
|
||||
|
||||
def get_binary_cifar(y1=3, y2=5, c={0,1,2,3,4}, use_cifar10=True):
|
||||
binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2)
|
||||
binary = False if y1 is not None and y2 is not None else True
|
||||
if binary: print ("grouping cifar classes")
|
||||
tr_dl, te_dl = du.get_cifar_dl(use_cifar10=use_cifar10, shuffle=False, normalize=False, binarize=binary, y0=c)
|
||||
|
||||
Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl))
|
||||
Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl))
|
||||
return (Xtr, Ytr), (Xte, Yte)
|
||||
|
||||
def combine_datasets(Xm, Ym, Xc, Yc, randomize_order=False, randomize_first_block=False, randomize_second_block=False):
|
||||
"""combine two datasets"""
|
||||
|
||||
def partition(X, Y, randomize=False):
|
||||
"""partition randomly or using labels"""
|
||||
if randomize:
|
||||
n = len(Y)
|
||||
p = np.random.permutation(n)
|
||||
ni, pi = p[:n//2], p[n//2:]
|
||||
else:
|
||||
ni, pi = (Y==0).nonzero()[0], (Y==1).nonzero()[0]
|
||||
return X[pi], X[ni]
|
||||
|
||||
def _combine(X1, X2):
|
||||
"""concatenate images from two sources"""
|
||||
X = []
|
||||
for i in range(min(len(X1), len(X2))):
|
||||
x1, x2 = X1[i], X2[i]
|
||||
# randomize order
|
||||
if randomize_order and random.random() < 0.5:
|
||||
x1, x2 = x2, x1
|
||||
x = np.concatenate((x1,x2), axis=1)
|
||||
X.append(x)
|
||||
return np.stack(X)
|
||||
|
||||
Xmp, Xmn = partition(Xm, Ym, randomize=randomize_first_block)
|
||||
Xcp, Xcn = partition(Xc, Yc, randomize=randomize_second_block)
|
||||
n = min(map(len, [Xmp, Xmn, Xcp, Xcn]))
|
||||
Xmp, Xmn, Xcp, Xcn = map(lambda Z: Z[:n], [Xmp, Xmn, Xcp, Xcn])
|
||||
|
||||
Xp = _combine(Xmp, Xcp)
|
||||
Yp = np.ones(len(Xp))
|
||||
|
||||
Xn = _combine(Xmn, Xcn)
|
||||
Yn = np.zeros(len(Xn))
|
||||
|
||||
X = np.concatenate([Xp, Xn], axis=0)
|
||||
Y = np.concatenate([Yp, Yn], axis=0)
|
||||
P = np.random.permutation(len(X))
|
||||
X, Y = X[P], Y[P]
|
||||
return X, Y
|
||||
|
||||
def get_mnist_cifar(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4},
|
||||
randomize_mnist=False, randomize_cifar=False):
|
||||
|
||||
y1, y2 = mnist_classes
|
||||
(Xtrm, Ytrm), (Xtem, Ytem) = get_binary_mnist(y1=y1, y2=y2)
|
||||
|
||||
y1, y2 = (None, None) if cifar_classes is None else cifar_classes
|
||||
(Xtrc, Ytrc), (Xtec, Ytec) = get_binary_cifar(c=c, y1=y1, y2=y2)
|
||||
|
||||
Xtr, Ytr = combine_datasets(Xtrm, Ytrm, Xtrc, Ytrc, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar)
|
||||
Xte, Yte = combine_datasets(Xtem, Ytem, Xtec, Ytec, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar)
|
||||
return (Xtr, Ytr), (Xte, Yte)
|
||||
|
||||
def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, bs=256,
|
||||
randomize_mnist=False, randomize_cifar=False):
|
||||
(Xtr, Ytr), (Xte, Yte) = get_mnist_cifar(mnist_classes=mnist_classes, cifar_classes=cifar_classes,
|
||||
c=c, randomize_mnist=randomize_mnist, randomize_cifar=randomize_cifar)
|
||||
tr_dl = utils._to_dl(Xtr, Ytr, bs=bs, shuffle=True)
|
||||
te_dl = utils._to_dl(Xte, Yte, bs=100, shuffle=False)
|
||||
return tr_dl, te_dl
|
|
@ -0,0 +1,236 @@
|
|||
import seaborn as sns
|
||||
import utils
|
||||
import random
|
||||
import os, copy, pickle, time
|
||||
import itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import pandas as pd
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import utils.scripts.gpu_utils as gu
|
||||
import utils.scripts.data_utils as du
|
||||
import utils.scripts.synth_models as synth_models
|
||||
|
||||
#import foolbox as fb
|
||||
#from autoattack import AutoAttack
|
||||
|
||||
# Misc
|
||||
def get_yhat(model, data): return torch.argmax(model(data), 1)
|
||||
def get_acc(y,yhat): return (y==yhat).sum().item()/float(len(y))
|
||||
|
||||
class PGD_Attack(object):
|
||||
|
||||
def __init__(self, eps, lr, num_iter, loss_type, rand_eps=1e-3,
|
||||
num_classes=2, bounds=(0.,1.), minimal=False, restarts=1, device=None):
|
||||
self.eps = eps
|
||||
self.lr = lr
|
||||
self.num_iter = num_iter
|
||||
self.B = bounds
|
||||
self.restarts = restarts
|
||||
self.rand_eps = rand_eps
|
||||
self.device = device or gu.get_device(None)
|
||||
self.loss_type = loss_type
|
||||
self.num_classes = num_classes
|
||||
self.classes = list(range(self.num_classes))
|
||||
self.delta = None
|
||||
self.minimal = minimal # early stop + no eps
|
||||
self.project = not self.minimal
|
||||
self.loss = -np.inf
|
||||
|
||||
def evaluate_attack(self, dl, model):
|
||||
model = model.to(self.device)
|
||||
Xa, Ya, Yh, P = [], [], [], []
|
||||
|
||||
for xb, yb in dl:
|
||||
xb, yb = xb.to(self.device), yb.to(self.device)
|
||||
delta = self.perturb(xb, yb, model)
|
||||
xba = xb+delta
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(xba).detach()
|
||||
yh = torch.argmax(out, dim=1)
|
||||
xb, yb, yh, xba, delta = xb.cpu(), yb.cpu(), yh.cpu(), xba.cpu(), delta.cpu()
|
||||
|
||||
Ya.append(yb)
|
||||
Yh.append(yh)
|
||||
Xa.append(xba)
|
||||
P.append(delta)
|
||||
|
||||
Xa, Ya, Yh, P = map(torch.cat, [Xa, Ya, Yh, P])
|
||||
ta_dl = utils._to_dl(Xa, Ya, dl.batch_size)
|
||||
acc, loss = utils.compute_loss_and_accuracy_from_dl(ta_dl, model,
|
||||
F.cross_entropy,
|
||||
device=self.device)
|
||||
return {
|
||||
'acc': acc.item(),
|
||||
'loss': loss.item(),
|
||||
'ta_dl': ta_dl,
|
||||
'Xa': Xa.numpy(),
|
||||
'Ya': Ya.numpy(),
|
||||
'Yh': Yh.numpy(),
|
||||
'P': P.numpy()
|
||||
}
|
||||
|
||||
def perturb(self, xb, yb, model, cpu=False):
|
||||
model, xb, yb = model.to(self.device), xb.to(self.device), yb.to(self.device)
|
||||
if self.eps == 0: return torch.zeros_like(xb)
|
||||
|
||||
# compute perturbations and track best perturbations
|
||||
self.loss = -np.inf
|
||||
max_delta = self._perturb_once(xb, yb, model)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(xb+max_delta)
|
||||
max_loss = nn.CrossEntropyLoss(reduction='none')(out, yb)
|
||||
|
||||
for _ in range(self.restarts-1):
|
||||
delta = self._perturb_once(xb, yb, model)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(xb+delta)
|
||||
all_loss = nn.CrossEntropyLoss(reduction='none')(out, yb)
|
||||
|
||||
loss_flag = all_loss >= max_loss
|
||||
max_delta[loss_flag] = delta[loss_flag]
|
||||
max_loss = torch.max(max_loss, all_loss)
|
||||
|
||||
if cpu: max_delta = max_delta.cpu()
|
||||
return max_delta
|
||||
|
||||
def _perturb_once(self, xb, yb, model, track_scores=False, stop_const=1e-5):
|
||||
self.delta = self._init_delta(xb, yb)
|
||||
scores = []
|
||||
|
||||
# (minimal) mask perturbations if model already misclassifies
|
||||
for t in range(self.num_iter):
|
||||
loss, out = self._get_loss(xb, yb, model, get_scores=True)
|
||||
|
||||
if self.minimal:
|
||||
yh = torch.argmax(out, dim=1).detach()
|
||||
not_flipped = yh == yb
|
||||
not_flipped_ratio = not_flipped.sum().item()/float(len(yb))
|
||||
else:
|
||||
not_flipped = None
|
||||
not_flipped_ratio = 1.0
|
||||
|
||||
# stop if almost all examples in the batch misclassified
|
||||
if not_flipped_ratio < stop_const:
|
||||
break
|
||||
|
||||
if track_scores:
|
||||
scores.append(out.detach().cpu().numpy())
|
||||
|
||||
# compute loss, update + clamp delta
|
||||
loss.backward()
|
||||
self.loss = max(self.loss, loss.item())
|
||||
|
||||
self.delta = self._update_delta(xb, yb, update_mask=not_flipped)
|
||||
self.delta = self._clamp_input(xb, yb)
|
||||
|
||||
d = self.delta.detach()
|
||||
|
||||
if track_scores:
|
||||
scores = np.stack(scores).swapaxes(0, 1)
|
||||
return d, scores
|
||||
|
||||
return d
|
||||
|
||||
def _init_delta(self, xb, yb):
|
||||
delta = torch.empty_like(xb)
|
||||
delta = delta.uniform_(-self.rand_eps, self.rand_eps)
|
||||
delta = delta.to(self.device)
|
||||
delta.requires_grad = True
|
||||
return delta
|
||||
|
||||
def _clamp_input(self, xb, yb):
|
||||
# clamp delta s.t. X+delta in valid input range
|
||||
self.delta.data = torch.max(self.B[0]-xb,
|
||||
torch.min(self.B[1]-xb,
|
||||
self.delta.data))
|
||||
return self.delta
|
||||
|
||||
def _get_loss(self, xb, yb, model, get_scores=False):
|
||||
out = model(xb+self.delta)
|
||||
|
||||
if self.loss_type == 'untargeted':
|
||||
L = -1*F.cross_entropy(out, yb)
|
||||
|
||||
elif self.loss_type == 'targeted':
|
||||
L = nn.CrossEntropyLoss()(out, yb)
|
||||
|
||||
elif self.loss_type == 'random_targeted':
|
||||
rand_yb = torch.randint(low=0, high=self.num_classes, size=(len(yb),), device=self.device)
|
||||
#rand_yb[rand_yb==yb] = (yb[rand_yb==yb]+1) % self.num_classes
|
||||
L = nn.CrossEntropyLoss()(out, rand_yb)
|
||||
|
||||
elif self.loss_type == 'plusone_targeted':
|
||||
next_yb = (yb+1) % self.num_classes
|
||||
L = nn.CrossEntropyLoss()(out, next_yb)
|
||||
|
||||
elif self.loss_type == 'binary_targeted':
|
||||
yb_opp = 1-yb
|
||||
L = nn.CrossEntropyLoss()(out, yb_opp)
|
||||
|
||||
elif self.loss_type == 'binary_hybrid':
|
||||
yb_opp = 1-yb
|
||||
L = nn.CrossEntropyLoss()(out, yb_opp) - nn.CrossEntropyLoss()(out, yb)
|
||||
|
||||
else:
|
||||
assert False, "unknown loss type"
|
||||
|
||||
if get_scores: return L, out
|
||||
return L
|
||||
|
||||
class L2_PGD_Attack(PGD_Attack):
|
||||
|
||||
OVERFLOW_CONST = 1e-10
|
||||
|
||||
def get_norms(self, X):
|
||||
nch = len(X.shape)
|
||||
return X.view(X.shape[0], -1).norm(dim=1)[(...,) + (None,)*(nch-1)]
|
||||
|
||||
def _update_delta(self, xb, yb, update_mask=None):
|
||||
# normalize gradients
|
||||
grad = self.delta.grad.detach()
|
||||
norms = self.get_norms(grad)
|
||||
grad = grad/(norms+self.OVERFLOW_CONST) # add const to avoid overflow
|
||||
|
||||
# steepest descent
|
||||
if self.minimal and update_mask is not None:
|
||||
um = update_mask
|
||||
self.delta.data[um] = self.delta.data[um] - self.lr*grad[um]
|
||||
else:
|
||||
self.delta.data = self.delta.data - self.lr*grad
|
||||
|
||||
# l2 ball projection
|
||||
if self.project:
|
||||
delta_norms = self.get_norms(self.delta.data)
|
||||
self.delta.data = self.eps*self.delta.data / (delta_norms.clamp(min=self.eps))
|
||||
|
||||
self.delta.grad.zero_()
|
||||
return self.delta
|
||||
|
||||
def _init_delta(self, xb, yb):
|
||||
# random vector with L2 norm rand_eps
|
||||
delta = torch.zeros_like(xb)
|
||||
delta = delta.uniform_(-self.rand_eps, self.rand_eps)
|
||||
delta_norm = self.get_norms(delta)
|
||||
delta = self.rand_eps*delta/(delta_norm+self.OVERFLOW_CONST)
|
||||
delta = delta.to(self.device)
|
||||
delta.requires_grad = True
|
||||
return delta
|
||||
|
||||
class Linf_PGD_Attack(PGD_Attack):
|
||||
|
||||
def _update_delta(self, xb, yb, **kw):
|
||||
# steepest descent + linf projection (GD)
|
||||
self.delta.data = self.delta.data - self.lr*(self.delta.grad.detach().sign())
|
||||
self.delta.data = self.delta.data.clamp(-self.eps, self.eps)
|
||||
self.delta.grad.zero_()
|
||||
return self.delta
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
import sys, copy
|
||||
import torch, torchvision
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import utils.scripts.gendata as gendata
|
||||
import utils.scripts.utils as utils
|
||||
import numpy as np
|
||||
import utils.scripts.gpu_utils as gu
|
||||
import utils.scripts.ptb_utils as pu
|
||||
|
||||
def kaiming_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight.data)
|
||||
nn.init.kaiming_uniform_(m.bias.data)
|
||||
|
||||
class SequenceClassifier(nn.Module):
|
||||
|
||||
def __init__(self, seq_model, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
|
||||
super(SequenceClassifier, self).__init__()
|
||||
self.seq_model = seq_model
|
||||
self.hdim = hdim
|
||||
self.hl = hl
|
||||
self.input_size = input_size
|
||||
self.idim = idim
|
||||
self.num_classes = num_classes
|
||||
self.unsqueeze_input = unsqueeze_input
|
||||
self.many_to_many = many_to_many
|
||||
|
||||
self.seq_length = self.idim//self.input_size
|
||||
self.seq = self.seq_model(input_size=input_size, hidden_size=hdim, num_layers=hl, batch_first=True)
|
||||
self.lin_idim = hdim*self.seq_length if many_to_many else hdim
|
||||
self.lin = nn.Linear(self.lin_idim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
if self.unsqueeze_input: x = x.unsqueeze(2)
|
||||
bsize, idim, _ = x.shape
|
||||
seq_length = idim//self.input_size
|
||||
x = x.view((bsize, seq_length, self.input_size))
|
||||
out, hidden = self.seq(x)
|
||||
lin_in = out[:,-1,:]
|
||||
if self.many_to_many: lin_in = out.contiguous().view((bsize, -1))
|
||||
lin_out = self.lin(lin_in)
|
||||
return lin_out
|
||||
|
||||
class GRUClassifier(SequenceClassifier):
|
||||
|
||||
def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
|
||||
super(GRUClassifier, self).__init__(nn.GRU, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input)
|
||||
|
||||
class LSTMClassifier(SequenceClassifier):
|
||||
|
||||
def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
|
||||
super(LSTMClassifier, self).__init__(nn.LSTM, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input)
|
||||
|
||||
class CNNClassifier(nn.Module):
|
||||
|
||||
def __init__(self, out_channels, hl, kernel_size, idim, num_classes=2, padding=None, stride=1, maxpool_kernel_size=None, use_maxpool=False):
|
||||
"""
|
||||
Fixed architecture:
|
||||
- default max pool kernel size half of convolution kernel size
|
||||
- default padding = kernel size - 1 // 2 to maintain same dimension
|
||||
- stride = 1
|
||||
- 1 FC layer
|
||||
"""
|
||||
if padding == None: assert kernel_size % 2 == 1, "use odd kernel size, equal padding constraint"
|
||||
super(CNNClassifier, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.num_conv = hl
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding or (self.kernel_size-1)//2
|
||||
self.stride = 1
|
||||
self.num_classes = 2
|
||||
self.idim = idim
|
||||
self.use_maxpool = use_maxpool
|
||||
self.maxpool_kernel_size = maxpool_kernel_size or self.kernel_size//2
|
||||
|
||||
self.maxpool = nn.MaxPool1d(self.maxpool_kernel_size)
|
||||
self.ih_conv = nn.Conv1d(1, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride)
|
||||
|
||||
self.hh_convs = []
|
||||
for _ in range(self.num_conv-1):
|
||||
self.hh_convs.append(nn.Conv1d(self.out_channels, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride))
|
||||
self.hh_convs.append(nn.ReLU())
|
||||
self.hh_convs = nn.Sequential(*self.hh_convs)
|
||||
|
||||
fc_idim = int(self.idim/self.maxpool_kernel_size) if self.use_maxpool else self.idim
|
||||
self.fc_layer = nn.Linear(self.out_channels*fc_idim, self.idim)
|
||||
self.out_layer = nn.Linear(self.idim, self.num_classes)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
x_ = x.unsqueeze(1)
|
||||
|
||||
x_ = self.relu(self.ih_conv(x_))
|
||||
x_ = self.hh_convs(x_)
|
||||
|
||||
if self.use_maxpool: x_ = self.maxpool(x_)
|
||||
x_ = self.relu(self.fc_layer(x_.view(bs, -1)))
|
||||
|
||||
return self.out_layer(x_)
|
||||
|
||||
class CNN2DClassifier(nn.Module):
|
||||
|
||||
def __init__(self, num_filters, filter_size, num_layers, input_shape, input_channels=1, stride=2, padding=None, num_stride2_layers=2, fc_idim=None, fc_odim=None, num_classes=2, use_avgpool=True, avgpool_ksize=5):
|
||||
super(CNN2DClassifier, self).__init__()
|
||||
self.outch = num_filters
|
||||
self.fsize = filter_size
|
||||
self.input_channels = input_channels
|
||||
self.hl = num_layers
|
||||
self.padding = (self.fsize-1)//2 if padding is None else padding
|
||||
self.num_classes = num_classes
|
||||
num_stride2_layers = num_stride2_layers
|
||||
self.strides = iter([stride]*num_stride2_layers+[1]*(num_layers-num_stride2_layers))
|
||||
self.use_avgpool = use_avgpool
|
||||
self.avgpool_ksize = avgpool_ksize
|
||||
|
||||
self.convs = [nn.Conv2d(self.input_channels, self.outch, self.fsize, padding=self.padding, stride=next(self.strides)), nn.ReLU()]
|
||||
if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize))
|
||||
|
||||
for _ in range(self.hl-1):
|
||||
self.convs.append(nn.Conv2d(self.outch, self.outch, self.fsize, stride=next(self.strides), padding=self.padding))
|
||||
self.convs.append(nn.ReLU())
|
||||
if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize))
|
||||
|
||||
self.convs = nn.Sequential(*self.convs) # need to wrap for gpu
|
||||
sl = min(self.hl, num_stride2_layers)
|
||||
self.fc_idim = int(num_filters*input_shape[0]*input_shape[1]/float(4**sl)) if fc_idim is None else fc_idim
|
||||
self.fc_odim = fc_odim if fc_odim is not None else self.fc_idim
|
||||
self.fc = nn.Linear(self.fc_idim, self.fc_odim)
|
||||
self.out = nn.Linear(self.fc_odim, self.num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convs(x)
|
||||
x = x.reshape(x.shape[0], -1)
|
||||
return self.out(F.relu(self.fc(x)))
|
||||
|
||||
def get_linear(input_dim, num_classes):
|
||||
return nn.Sequential(nn.Linear(input_dim, num_classes))
|
||||
|
||||
def get_fcn(idim, hdim, odim, hl=1, init=False, activation=nn.ReLU, use_activation=True, use_bn=False, input_dropout=0, dropout=0):
|
||||
use_dropout = dropout > 0
|
||||
layers = []
|
||||
if input_dropout > 0: layers.append(nn.Dropout(input_dropout))
|
||||
layers.append(nn.Linear(idim, hdim))
|
||||
if use_activation: layers.append(activation())
|
||||
if use_dropout: layers.append(nn.Dropout(dropout))
|
||||
if use_bn: layers.append(nn.BatchNorm1d(hdim))
|
||||
for _ in range(hl-1):
|
||||
l = [nn.Linear(hdim, hdim)]
|
||||
if use_activation: l.append(activation())
|
||||
if use_dropout: l.append(nn.Dropout(dropout))
|
||||
if use_bn: l.append(nn.BatchNorm1d(hdim))
|
||||
layers.extend(l)
|
||||
layers.append(nn.Linear(hdim, odim))
|
||||
model = nn.Sequential(*layers)
|
||||
|
||||
if init: model.apply(kaiming_init)
|
||||
return model
|
|
@ -0,0 +1,717 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pickle
|
||||
import copy
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import time
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.linalg import qr
|
||||
import utils.scripts.lms_utils as au
|
||||
import utils.scripts.ptb_utils as pu
|
||||
import utils.scripts.gpu_utils as gu
|
||||
from sklearn import metrics
|
||||
import collections
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
plt.style.use('seaborn-ticks')
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
def get_orthonormal_matrix(n):
|
||||
H = np.random.randn(n, n)
|
||||
s = np.linalg.svd(H)[1]
|
||||
s = s[s>1e-7]
|
||||
if len(s) != n: return get_orthonormal_matrix(n)
|
||||
Q, R = qr(H)
|
||||
return Q
|
||||
|
||||
def get_dataloader(X, Y, bs, **kw):
|
||||
return DataLoader(TensorDataset(X, Y), batch_size=bs, **kw)
|
||||
|
||||
def split_dataloader(dl, frac=0.5):
|
||||
bs = dl.batch_size
|
||||
X, Y = dl.dataset.tensors
|
||||
p = torch.randperm(len(X))
|
||||
X, Y = X[p, :], Y[p]
|
||||
n = int(round(len(X)*frac))
|
||||
X0, Y0 = X[:n, :], Y[:n]
|
||||
X1, Y1 = X[n:, :], Y[n:]
|
||||
dl0 = DataLoader(TensorDataset(torch.Tensor(X0), torch.LongTensor(Y0)), batch_size=bs, shuffle=True)
|
||||
dl1 = DataLoader(TensorDataset(torch.Tensor(X1), torch.LongTensor(Y1)), batch_size=bs, shuffle=True)
|
||||
return dl0, dl1
|
||||
|
||||
def _to_dl(X, Y, bs, shuffle=True):
|
||||
return DataLoader(TensorDataset(torch.Tensor(X), torch.LongTensor(Y)), batch_size=bs, shuffle=shuffle)
|
||||
|
||||
def extract_tensors_from_loader(dl, repeat=1, transform_fn=None):
|
||||
X, Y = [], []
|
||||
for _ in range(repeat):
|
||||
for xb, yb in dl:
|
||||
if transform_fn:
|
||||
xb, yb = transform_fn(xb, yb)
|
||||
X.append(xb)
|
||||
Y.append(yb)
|
||||
X = torch.FloatTensor(torch.cat(X))
|
||||
Y = torch.LongTensor(torch.cat(Y))
|
||||
return X, Y
|
||||
|
||||
def extract_numpy_from_loader(dl, repeat=1, transform_fn=None):
|
||||
X, Y = extract_tensors_from_loader(dl, repeat=repeat, transform_fn=transform_fn)
|
||||
return X.numpy(), Y.numpy()
|
||||
|
||||
def _to_tensor_dl(dl, repeat=1, bs=None):
|
||||
X, Y = extract_numpy_from_loader(dl, repeat=repeat)
|
||||
dl = _to_dl(X, Y, bs if bs else dl.batch_size)
|
||||
return dl
|
||||
|
||||
def flatten_loader(dl, bs=None):
|
||||
X, Y = extract_numpy_from_loader(dl)
|
||||
X = X.reshape(X.shape[0], -1)
|
||||
return _to_dl(X, Y, bs=bs if bs else dl.batch_size)
|
||||
|
||||
def merge_loaders(dla, dlb):
|
||||
bs = dla.batch_size
|
||||
Xa, Ya = extract_numpy_from_loader(dla)
|
||||
Xb, Yb = extract_numpy_from_loader(dlb)
|
||||
return _to_dl(np.concatenate([Xa, Xb]), np.concatenate([Ya, Yb]), bs)
|
||||
|
||||
def transform_loader(dl, func, shuffle=True):
|
||||
#assert type(dl.sampler) is torch.utils.data.sampler.SequentialSampler
|
||||
X, Y = extract_numpy_from_loader(dl, transform_fn=func)
|
||||
return _to_dl(X, Y, dl.batch_size, shuffle=shuffle)
|
||||
|
||||
def visualize_tensors(P, size=8, normalize=True, scale_each=False, permute=True, ax=None, pad_value=0.):
|
||||
if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4))
|
||||
if permute:
|
||||
s = np.random.choice(len(P), size=size, replace=False)
|
||||
p = P[s]
|
||||
else:
|
||||
p = P[:size]
|
||||
g = torchvision.utils.make_grid(torch.FloatTensor(p), nrow=size, normalize=normalize, scale_each=scale_each, pad_value=pad_value)
|
||||
g = g.permute(1,2,0).numpy()
|
||||
ax.imshow(g)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
return ax
|
||||
|
||||
def visualize_loader(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None):
|
||||
if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4))
|
||||
for xb, yb in dl: break
|
||||
if reshape: xb = xb.reshape(len(xb), *reshape)
|
||||
return visualize_tensors(xb, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax)
|
||||
|
||||
def visualize_loader_by_class(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None):
|
||||
for xb, yb in dl: break
|
||||
if reshape: xb = xb.reshape(len(xb), *reshape)
|
||||
|
||||
classes = list(set(list(yb.numpy())))
|
||||
fig, axs = plt.subplots(len(classes), 1, figsize=(15, 3*len(classes)))
|
||||
|
||||
for y, ax in zip(classes, axs):
|
||||
xb_ = xb[yb==y]
|
||||
ax = visualize_tensors(xb_, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax)
|
||||
ax.set_title('Class: {}'.format(y))
|
||||
|
||||
return fig
|
||||
|
||||
def visualize_perturbations(P, transform_fn=None):
|
||||
if transform_fn is not None:
|
||||
P = transform_fn(P)
|
||||
plt.figure(figsize=(20,4))
|
||||
s = np.random.choice(len(P), size=8, replace=False)
|
||||
p = P[s]
|
||||
g = torchvision.utils.make_grid(torch.FloatTensor(p))
|
||||
g = g.permute(1,2,0).numpy()
|
||||
g = (g-g.min())/g.max()
|
||||
plt.imshow(g)
|
||||
|
||||
def get_logits_given_tensor(X, model, device=None, bs=250, softmax=False):
|
||||
if device is None: device = gu.get_device(None)
|
||||
sampler = torch.utils.data.SequentialSampler(X)
|
||||
sampler = torch.utils.data.BatchSampler(sampler, bs, False)
|
||||
|
||||
logits = []
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(device)
|
||||
for idx in sampler:
|
||||
xb = X[idx].to(device)
|
||||
out = model(xb)
|
||||
logits.append(out)
|
||||
|
||||
L = torch.cat(logits)
|
||||
if softmax: return F.softmax(L, 1)
|
||||
return L
|
||||
|
||||
def get_predictions_given_tensor(X, model, device=None, bs=250):
|
||||
out = get_logits_given_tensor(X, model, device=device, bs=bs)
|
||||
return torch.argmax(out, 1)
|
||||
|
||||
def get_accuracy_given_tensor(X, Y, model, device=None, bs=250):
|
||||
if device is None: device = gu.get_device(None)
|
||||
Y = torch.LongTensor(Y).to(device)
|
||||
yhat = get_predictions_given_tensor(X, model, device=device, bs=bs)
|
||||
return (Y==yhat).float().mean().item()
|
||||
|
||||
def compute_accuracy(X, Y, model):
|
||||
with torch.no_grad():
|
||||
pred = torch.argmax(model(X),1)
|
||||
correct = (pred == Y).sum().item()
|
||||
accuracy = correct/float(len(Y))
|
||||
return accuracy
|
||||
|
||||
def compute_loss_and_accuracy_from_dl(dl, model, loss_fn, sample_pct=1.0, device=None, transform_fn=None):
|
||||
in_tr_mode = model.training
|
||||
model = model.eval()
|
||||
data_size = float(len(dl.dataset))
|
||||
samp_size = int(np.ceil(sample_pct*data_size))
|
||||
num_eval = 0.
|
||||
bs = dl.batch_size
|
||||
accs, losses, bss = [], [], []
|
||||
|
||||
with torch.no_grad():
|
||||
for xb, yb in dl:
|
||||
xb, yb = xb.to(device, non_blocking=False), yb.to(device, non_blocking=False)
|
||||
|
||||
if transform_fn:
|
||||
xb, yb = transform_fn(xb, yb)
|
||||
|
||||
sc = model(xb)
|
||||
|
||||
if loss_fn is F.cross_entropy:
|
||||
loss = loss_fn(sc, yb, reduction='mean')
|
||||
pred = torch.argmax(sc, 1)
|
||||
elif loss_fn is F.binary_cross_entropy_with_logits:
|
||||
loss = loss_fn(sc, yb.float().unsqueeze(1))
|
||||
pred = (sc > 0.).long().squeeze()
|
||||
elif loss_fn is hinge_loss:
|
||||
loss = loss_fn(sc, yb)
|
||||
pred = (sc > 0).long().squeeze()
|
||||
else:
|
||||
try:
|
||||
loss = loss_fn(sc, yb)
|
||||
pred = torch.argmax(sc, 1)
|
||||
except:
|
||||
assert False, "unknown loss function"
|
||||
|
||||
correct = (pred==yb).sum().float()
|
||||
n = float(len(xb))
|
||||
losses.append(loss)
|
||||
accs.append(correct/n)
|
||||
bss.append(n)
|
||||
|
||||
num_eval += n
|
||||
if num_eval >= samp_size: break
|
||||
|
||||
accs, losses, bss = map(np.array, [accs, losses, bss])
|
||||
if in_tr_mode: model = model.train()
|
||||
return np.sum(bss*accs)/num_eval, np.sum(bs*losses)/num_eval
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
def get_logits(model, loader, device):
|
||||
S, Y = [], []
|
||||
with torch.no_grad():
|
||||
for xb, yb in loader:
|
||||
xb = xb.to(device)
|
||||
out = model(xb).cpu().numpy()
|
||||
S.append(out)
|
||||
Y.append(list(yb))
|
||||
S, Y = map(np.concatenate, [S, Y])
|
||||
return S, Y
|
||||
|
||||
def get_scores(model, loader, device):
|
||||
"""binary tasks only"""
|
||||
S, Y = get_logits(model, loader, device)
|
||||
return S[:,1]-S[:,0], Y
|
||||
|
||||
def get_multiclass_logit_score(L, Y):
|
||||
scores = []
|
||||
for idx, (l, y) in enumerate(zip(L, Y)):
|
||||
sc_y = l[y]
|
||||
|
||||
indices = np.argsort(l)
|
||||
best2_idx, best1_idx = indices[-2:]
|
||||
sc_max = l[best2_idx] if y == best1_idx else l[best1_idx]
|
||||
|
||||
score = sc_y - sc_max
|
||||
scores.append(score)
|
||||
|
||||
return np.array(scores)
|
||||
|
||||
def get_binary_auc(model, loader, device):
|
||||
S, Y = get_scores(model, loader, device)
|
||||
return roc_auc_score(Y, S)
|
||||
|
||||
def get_multiclass_auc(model, loader, device, one_vs_rest=True):
|
||||
X, Y = extract_tensors_from_loader(loader)
|
||||
S = get_logits_given_tensor(X, model, device=device, softmax=True).cpu()
|
||||
mc = 'ovr' if one_vs_rest is True else 'ovo'
|
||||
S, Y = S.numpy(), Y.numpy()
|
||||
return roc_auc_score(Y, S, multi_class=mc)
|
||||
|
||||
def clip_gradient(model, clip_value):
|
||||
params = list(filter(lambda p: p.grad is not None, model.parameters()))
|
||||
for p in params: p.grad.data.clamp_(-clip_value, clip_value)
|
||||
|
||||
def print_model_gradients(model, print_bias=True):
|
||||
for name, params in model.named_parameters():
|
||||
if not print_bias and 'bias' in name: continue
|
||||
if not params.requires_grad: continue
|
||||
avg_grad = np.mean(params.grad.cpu().numpy())
|
||||
print (name, params.shape, avg_grad)
|
||||
|
||||
def hinge_loss(out, y):
|
||||
y_ = (2*y.float()-1).unsqueeze(1)
|
||||
return torch.mean(F.relu(1-out*y_))
|
||||
|
||||
def pgd_adv_fit_model(model, opt, tr_dl, te_dl, attack, eval_attack=None, device=None, sch=None, max_epochs=100, epoch_gap=2,
|
||||
min_loss=0.001, print_info=True, save_init_model=True):
|
||||
|
||||
# setup tracking
|
||||
PR = lambda x: print (x) if print_info else None
|
||||
stop_training = False
|
||||
stats = defaultdict(list)
|
||||
best_val, best_model = np.inf, None
|
||||
adv_epoch_timer = []
|
||||
epoch_gap_timer = [time.time()]
|
||||
init_model = copy.deepcopy(model).cpu() if save_init_model else None
|
||||
|
||||
# eval attack
|
||||
eval_attack = eval_attack or attack
|
||||
|
||||
print ("Min loss: {}".format(min_loss))
|
||||
|
||||
def standard_epoch(loader, model, optimizer=None, sch=None):
|
||||
"""compute accuracy and loss. Backprop if optimizer provided"""
|
||||
total_loss, total_err = 0.,0.
|
||||
model = model.eval() if optimizer is None else model.train()
|
||||
model = model.to(device)
|
||||
update_params = optimizer is not None
|
||||
|
||||
with torch.set_grad_enabled(update_params):
|
||||
for xb, yb in loader:
|
||||
xb, yb = xb.to(device), yb.to(device)
|
||||
yp = model(xb)
|
||||
loss = F.cross_entropy(yp, yb)
|
||||
if update_params:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_err += (yp.max(dim=1)[1] != yb).sum().item()
|
||||
total_loss += loss.item() * xb.shape[0]
|
||||
return total_err / len(loader.dataset), total_loss / len(loader.dataset)
|
||||
|
||||
def adv_epoch(loader, model, attack, optimizer=None, sch=None):
|
||||
"""compute adv accuracy and loss. Backprop if optimizer provided"""
|
||||
start_time = time.time()
|
||||
total_loss, total_err = 0.,0.
|
||||
model = model.eval() if optimizer is None else model.train()
|
||||
model = model.to(device)
|
||||
update_params = optimizer is not None
|
||||
|
||||
for xb, yb in loader:
|
||||
torch.set_grad_enabled(True)
|
||||
xb, yb = xb.to(device), yb.to(device)
|
||||
delta = attack.perturb(xb, yb, model).to(device)
|
||||
xb = xb + delta
|
||||
with torch.set_grad_enabled(update_params):
|
||||
yp = model(xb)
|
||||
loss = F.cross_entropy(yp, yb)
|
||||
if update_params:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_err += (yp.max(dim=1)[1] != yb).sum().item()
|
||||
total_loss += loss.item() * xb.shape[0]
|
||||
|
||||
if optimizer is not None and sch is not None:
|
||||
cur_lr = next(iter(opt.param_groups))['lr']
|
||||
sch.step()
|
||||
new_lr = next(iter(opt.param_groups))['lr']
|
||||
if new_lr != cur_lr:
|
||||
PR('Epoch {}, LR : {} -> {}'.format(epoch, cur_lr, new_lr))
|
||||
|
||||
total_time = time.time()-start_time
|
||||
adv_epoch_timer.append(total_time)
|
||||
return total_err / len(loader.dataset), total_loss / len(loader.dataset)
|
||||
|
||||
epoch = 0
|
||||
while epoch < max_epochs:
|
||||
if stop_training:
|
||||
break
|
||||
try:
|
||||
stat = {}
|
||||
model = model.train()
|
||||
train_err, train_loss = adv_epoch(tr_dl, model, attack, optimizer=opt, sch=sch)
|
||||
|
||||
if epoch % epoch_gap == 0:
|
||||
model = model.eval()
|
||||
test_err, test_loss = standard_epoch(te_dl, model, optimizer=None, sch=None)
|
||||
adv_err, adv_loss = adv_epoch(te_dl, model, eval_attack, optimizer=None, sch=None)
|
||||
stat['acc_te'], stat['acc_te_std'] = adv_err, test_err
|
||||
stat['loss_te'], stat['loss_te_std'] = adv_loss, test_loss
|
||||
|
||||
if adv_err < best_val:
|
||||
best_val = adv_err
|
||||
best_model = copy.deepcopy(model).eval()
|
||||
|
||||
if print_info:
|
||||
if epoch==0: print ("Epoch", "l-tr", "a-tr", "a-te", "s-te", "time", sep='\t')
|
||||
#print (epoch, *("{:.4f}".format(i) for i in (train_loss, train_err)), sep=' ')
|
||||
diff_time = time.time()-epoch_gap_timer[-1]
|
||||
epoch_gap_timer.append(time.time())
|
||||
print (epoch, *("{:.4f}".format(i) for i in (train_loss, 1.-train_err, 1.-adv_err, 1.-test_err, diff_time)), sep=' ')
|
||||
|
||||
if train_loss < min_loss:
|
||||
stop_training = True
|
||||
|
||||
print ("Epoch {}: accuracy {:.3f} and loss {:.3f}".format(epoch, 1-train_err, train_loss))
|
||||
|
||||
stat['epoch'] = epoch
|
||||
stat['acc_tr'] = train_err
|
||||
stat['loss_tr'] = train_loss
|
||||
|
||||
for k, v in stat.items():
|
||||
stats[k].append(v)
|
||||
|
||||
epoch += 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
inp = input("LR num or Q or SAVE or GAP or MAXEPOCHS: ")
|
||||
if inp.startswith('LR'):
|
||||
lr = float(inp.split(' ')[-1])
|
||||
cur_lr = next(iter(opt.param_groups))['lr']
|
||||
PR("New LR: {}".format(lr))
|
||||
for g in opt.param_groups: g['lr'] = lr
|
||||
if inp.startswith('Q'):
|
||||
stop_training = True
|
||||
if inp.startswith('SAVE'):
|
||||
fpath = inp.split(' ')[-1]
|
||||
stats['best_model'] = (best_val, best_model.cpu())
|
||||
torch.save({
|
||||
'model': copy.deepcopy(model).cpu(),
|
||||
'stats': stats,
|
||||
'opt': copy.deepcopy(opt).cpu()
|
||||
}, fpath)
|
||||
PR(f'Saved to {fpath}')
|
||||
if inp.startswith('GAP'):
|
||||
_, gap = inp.split(' ')
|
||||
gap = int(gap)
|
||||
print ("epoch gap: {} -> {}".format(epoch_gap, gap))
|
||||
epoch_gap = gap
|
||||
if inp.startswith('MAXEPOCHS'):
|
||||
_, me = inp.split(' ')
|
||||
me = int(me)
|
||||
print ("max_epochs: {} -> {}".format(max_epochs, me))
|
||||
max_epochs = me
|
||||
|
||||
stats['best_model'] = (best_val, best_model.cpu())
|
||||
stats['init_model'] = init_model
|
||||
return stats
|
||||
|
||||
|
||||
def fit_model(model, loss, opt, train_dl, valid_dl, sch=None, epsilon=1e-2, is_loss_epsilon=False, update_gap=50, update_print_gap=50, gap=None,
|
||||
print_info=True, save_grads=False, test_dl=None, skip_epoch_eval=True, sample_pct=0.5, sample_loss_threshold=0.75, save_models=False,
|
||||
print_grads=False, print_model_layers=False, tr_batch_fn=None, te_batch_fn=None, device=None, max_updates=800_000, patience_updates=1,
|
||||
enable_redo=False, save_best_model=True, save_init_model=True, max_epochs=100000, **misc):
|
||||
|
||||
# setup update metadata
|
||||
MAX_LOSS_VAL = 1000000.
|
||||
PR = lambda x: print (x) if print_info else None
|
||||
use_epoch = False
|
||||
if gap is not None: update_gap = update_print_gap = gap
|
||||
bs_ratio = int(len(train_dl.dataset)/float(train_dl.batch_size))
|
||||
act_update_gap = update_gap if not use_epoch else update_gap*bs_ratio
|
||||
act_pr_update_gap = update_print_gap if not use_epoch else update_print_gap*bs_ratio
|
||||
PR("accuracy/loss measured every {} updates".format(act_update_gap))
|
||||
|
||||
if save_models:
|
||||
PR("saving models every {} updates".format(act_update_gap))
|
||||
|
||||
PR("update_print_gap: {}, epss: {}, bs: {}, device: {}".format(act_pr_update_gap, epsilon, train_dl.batch_size, device or 'cpu'))
|
||||
|
||||
# init_save setup
|
||||
init_model = copy.deepcopy(model).cpu() if save_init_model else None
|
||||
|
||||
# redo setup
|
||||
if enable_redo:
|
||||
init_model_sd = copy.deepcopy(model.state_dict())
|
||||
init_opt_sd = copy.deepcopy(opt.state_dict())
|
||||
else:
|
||||
init_model_sd = None
|
||||
init_opt_sd = None
|
||||
|
||||
# best model setup
|
||||
best_val, best_model = 0, None
|
||||
|
||||
# tracking setup
|
||||
start_time = time.time()
|
||||
num_evals, num_epochs, num_updates, num_patience = 0, 0, 0, 0
|
||||
stats = dict(loss_tr=[], loss_te=[], acc_tr=[], acc_te=[], acc_test=[], loss_test=[], models=[], gradients=[])
|
||||
if save_models: stats['models'].append(copy.deepcopy(model).cpu())
|
||||
first_run, converged = True, False
|
||||
print_stats_flag = update_print_gap is not None
|
||||
exceeded_max = False
|
||||
diverged = False
|
||||
|
||||
def _evaluate(device=device):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
prev_loss = stats['loss_tr'][-1] if stats['loss_tr'] else 1.
|
||||
tr_sample_pct = sample_pct if prev_loss > sample_loss_threshold else 1.
|
||||
acc_tr, loss_tr = compute_loss_and_accuracy_from_dl(train_dl,model,loss,sample_pct=tr_sample_pct,device=device,transform_fn=tr_batch_fn)
|
||||
acc_te, loss_te = compute_loss_and_accuracy_from_dl(valid_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn)
|
||||
acc_tr, loss_tr, acc_te, loss_te = map(lambda x: x.item(), [acc_tr, loss_tr, acc_te, loss_te])
|
||||
stats['loss_tr'].append(loss_tr)
|
||||
stats['loss_te'].append(loss_te)
|
||||
stats['acc_tr'].append(acc_tr)
|
||||
stats['acc_te'].append(acc_te)
|
||||
|
||||
if test_dl is not None:
|
||||
acc_test, loss_test = compute_loss_and_accuracy_from_dl(test_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn)
|
||||
acc_test, loss_test = acc_test.item(), loss_test.item()
|
||||
stats['acc_test'].append(acc_test)
|
||||
stats['loss_test'].append(loss_test)
|
||||
|
||||
if save_models:
|
||||
stats['models'].append(copy.deepcopy(model).cpu())
|
||||
|
||||
def _update(x,y,diff_device, device=device, save_grads=False, print_grads=False):
|
||||
model.train()
|
||||
|
||||
# if diff_device:
|
||||
# x = x.to(device, non_blocking=False)
|
||||
# y = y.to(device, non_blocking=False)
|
||||
|
||||
opt.zero_grad()
|
||||
out = model(x)
|
||||
if loss is F.cross_entropy or loss is hinge_loss:
|
||||
bloss = loss(out, y)
|
||||
elif loss is F.binary_cross_entropy_with_logits:
|
||||
bloss = loss(out, y.float().unsqueeze(1))
|
||||
else:
|
||||
try:
|
||||
bloss = loss(out, y)
|
||||
except:
|
||||
assert False, "unknown loss function"
|
||||
|
||||
bloss.backward()
|
||||
if print_grads and print_info: print_model_gradients(model)
|
||||
#clip_gradient(model, clip_value)
|
||||
opt.step()
|
||||
|
||||
if save_grads:
|
||||
g = {k: v.grad.data.cpu().numpy() for k, v in model.named_parameters() if v.requires_grad}
|
||||
stats['gradients'].append(g)
|
||||
|
||||
opt.zero_grad()
|
||||
model.eval()
|
||||
|
||||
def print_time():
|
||||
end_time = time.time()
|
||||
minutes, seconds = divmod(end_time-start_time, 60)
|
||||
gap_valid = len(stats['acc_tr']) > 0
|
||||
gap = round(stats['acc_tr'][-1]-stats['acc_te'][-1],4) if gap_valid else 'na'
|
||||
PR("converged after {} epochs in {}m {:1f}s, gap: {}".format(num_epochs, minutes, seconds, gap))
|
||||
|
||||
def print_stats(force_print=False):
|
||||
|
||||
if test_dl is None:
|
||||
atr, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'loss_tr']]
|
||||
PR("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr))
|
||||
if not print_info and force_print:
|
||||
print ("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr))
|
||||
else:
|
||||
atr, aval, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'acc_test', 'loss_tr']]
|
||||
PR("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr))
|
||||
if not print_info and force_print:
|
||||
print ("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr))
|
||||
|
||||
#xb_, yb_ = next(iter(train_dl))
|
||||
diff_device = True #xb_.device != device
|
||||
|
||||
if test_dl is None: PR("#updates, train acc, test acc, train loss")
|
||||
else: PR("#updates, train acc, val acc, test acc, train loss")
|
||||
|
||||
while not converged or num_patience < patience_updates:
|
||||
try:
|
||||
model.train()
|
||||
for xb, yb in train_dl:
|
||||
|
||||
if tr_batch_fn:
|
||||
xb, yb = tr_batch_fn(xb, yb)
|
||||
|
||||
if diff_device:
|
||||
xb = xb.to(device, non_blocking=False)
|
||||
yb = yb.to(device, non_blocking=False)
|
||||
|
||||
if converged:
|
||||
num_patience += 1
|
||||
|
||||
if converged and num_patience == patience_updates:
|
||||
_evaluate()
|
||||
print_stats()
|
||||
break
|
||||
|
||||
# update flag for printing gradients
|
||||
update_flag = print_model_layers and (num_updates == 0 or (num_updates % act_update_gap == 0 and print_grads))
|
||||
_update(xb, yb, diff_device, device=device, save_grads=save_grads, print_grads=update_flag)
|
||||
|
||||
if (num_evals == 0 or num_updates % act_update_gap == 0):
|
||||
num_evals += 1
|
||||
_evaluate()
|
||||
print_stats()
|
||||
|
||||
val_acc = stats['acc_te'][-1]
|
||||
if num_updates > 0 and val_acc >= best_val:
|
||||
best_val = val_acc
|
||||
best_model = copy.deepcopy(model).eval()
|
||||
|
||||
# check if loss has diverged
|
||||
loss_val = max(stats['loss_tr'][-1], stats['loss_te'][-1])
|
||||
if loss_val > MAX_LOSS_VAL: diverged = True
|
||||
if not np.isfinite(loss_val): diverged = True
|
||||
|
||||
|
||||
if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon
|
||||
else: stop = stats['acc_tr'][-1] >= 1-epsilon
|
||||
|
||||
if not converged and diverged:
|
||||
converged = True
|
||||
print_time()
|
||||
PR("loss diverging...exiting".format(patience_updates))
|
||||
|
||||
if not converged and stop:
|
||||
converged = True
|
||||
print_time()
|
||||
PR("init-ing patience ({} updates)".format(patience_updates))
|
||||
|
||||
num_updates += 1
|
||||
first_run = False
|
||||
|
||||
if num_updates > max_updates:
|
||||
converged = True
|
||||
exceeded_max = True
|
||||
num_patience = patience_updates
|
||||
PR("Exceeded max updates")
|
||||
print_stats()
|
||||
print_time()
|
||||
break
|
||||
|
||||
# re-eval at the end of epoch
|
||||
if not converged:
|
||||
num_epochs += 1
|
||||
|
||||
if not converged and num_epochs >= max_epochs:
|
||||
converged = True
|
||||
exceeded_max = True
|
||||
num_patience = patience_updates
|
||||
PR("Exceeded max epochs")
|
||||
print_stats()
|
||||
print_time()
|
||||
break
|
||||
|
||||
if not skip_epoch_eval:
|
||||
_evaluate()
|
||||
print_stats()
|
||||
|
||||
if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon
|
||||
else: stop = stats['acc_tr'][-1] >= 1-epsilon
|
||||
|
||||
if not converged and stop:
|
||||
converged = True
|
||||
print_time()
|
||||
PR("init-ing patience ({} updates)".format(patience_updates))
|
||||
|
||||
if num_patience >= patience_updates:
|
||||
_evaluate()
|
||||
print_stats()
|
||||
break
|
||||
|
||||
# update LR via scheduler
|
||||
if sch is not None:
|
||||
cur_lr = next(iter(opt.param_groups))['lr']
|
||||
sch.step()
|
||||
new_lr = next(iter(opt.param_groups))['lr']
|
||||
if new_lr != cur_lr:
|
||||
PR('Epoch {}, LR : {} -> {}'.format(num_epochs, cur_lr, new_lr))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
inp = input("LR num or Q or GAP num or SAVE fpath or EVAL or REDO: ")
|
||||
if inp.startswith('LR'):
|
||||
lr = float(inp.split(' ')[-1])
|
||||
cur_lr = next(iter(opt.param_groups))['lr']
|
||||
PR("LR: {} - > {}".format(cur_lr, lr))
|
||||
for g in opt.param_groups: g['lr'] = lr
|
||||
elif inp.startswith('GAP'):
|
||||
gap = int(inp.split(' ')[-1])
|
||||
act_update_gap = act_pr_update_gap = gap
|
||||
elif inp == "Q":
|
||||
converged = True
|
||||
num_patience = patience_updates
|
||||
print_time()
|
||||
elif inp.startswith('SAVE'):
|
||||
fpath = inp.split(' ')[-1]
|
||||
torch.save({
|
||||
'model': model,
|
||||
'opt': opt,
|
||||
'update_gap': update_gap
|
||||
}, fpath)
|
||||
elif inp == 'EVAL':
|
||||
_evaluate()
|
||||
print_stats(True)
|
||||
elif inp == 'REDO':
|
||||
if enable_redo:
|
||||
model.load_state_dict(init_model_sd)
|
||||
opt.load_state_dict(init_opt_sd)
|
||||
else:
|
||||
print ("REDO disabled")
|
||||
|
||||
best_test = None
|
||||
if test_dl is not None:
|
||||
best_test = compute_loss_and_accuracy_from_dl(test_dl, best_model, loss, sample_pct=1.0, device=device)[0].item()
|
||||
|
||||
stats['num_updates'] = num_updates
|
||||
stats['num_epochs'] = num_epochs
|
||||
stats['update_gap'] = update_gap
|
||||
|
||||
stats['best_model'] = (best_val, best_test, best_model.cpu() if best_model else model.cpu())
|
||||
stats['init_model'] = init_model
|
||||
if save_models: stats['models'].append(copy.deepcopy(model).cpu())
|
||||
|
||||
stats['x_updates']= list(range(0, num_evals*(update_gap+1), update_gap))
|
||||
stats['x'] = stats['x_updates'][:]
|
||||
stats['x_epochs'] = list(range(num_epochs))
|
||||
stats['gap'] = stats['acc_tr'][-1]-stats['acc_te'][-1]
|
||||
return stats
|
||||
|
||||
def save_pickle(fname, d, mode='w'):
|
||||
with open(fname, mode) as f:
|
||||
pickle.dump(d, f)
|
||||
|
||||
def load_pickle(fname, mode='r'):
|
||||
with open(fname, mode) as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def update_ax(ax, title=None, xlabel=None, ylabel=None, legend_loc='best', ticks=True, ticks_fs=10, label_fs=12, legend_fs=12, title_fs=14, hide_xlabels=False, hide_ylabels=False, despine=True):
|
||||
if title: ax.set_title(title, fontsize=title_fs)
|
||||
if xlabel: ax.set_xlabel(xlabel, fontsize=label_fs)
|
||||
if ylabel: ax.set_ylabel(ylabel, fontsize=label_fs)
|
||||
if legend_loc: ax.legend(loc=legend_loc, fontsize=legend_fs)
|
||||
if despine: sns.despine(ax=ax)
|
||||
|
||||
if ticks:
|
||||
# ax.minorticks_on()
|
||||
ax.tick_params(direction='in', length=6, width=2, colors='k', which='major', top=False, right=False)
|
||||
ax.tick_params(direction='in', length=4, width=1, colors='k', which='minor', top=False, right=False)
|
||||
ax.tick_params(labelsize=ticks_fs)
|
||||
|
||||
if hide_xlabels: ax.set_xticks([])
|
||||
if hide_ylabels: ax.set_yticks([])
|
||||
return ax
|
|
@ -0,0 +1,92 @@
|
|||
import sys
|
||||
import random
|
||||
import os, copy, pickle, time
|
||||
import argparse
|
||||
import itertools
|
||||
from collections import defaultdict, Counter, OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import optim, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from torch.autograd import Variable
|
||||
import pandas as pd
|
||||
|
||||
import utils.scripts.gpu_utils as gu
|
||||
import utils.scripts.gendata as gendata
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
DEVICE_ID = 0 # GPU_ID or None (CPU)
|
||||
DEVICE = gu.get_device(DEVICE_ID)
|
||||
|
||||
def get_data(num_samples, spur_corr, total_slabs):
|
||||
|
||||
if total_slabs== 5:
|
||||
slab_5_flag= 1
|
||||
slab_7_flag= 0
|
||||
elif total_slabs== 7:
|
||||
slab_5_flag= 0
|
||||
slab_7_flag= 1
|
||||
|
||||
c = config = {
|
||||
'num_train': num_samples, # training dataset size
|
||||
'dim': 2, # input dimension
|
||||
'lin_margin': 0.1, # linear margin
|
||||
'slab_margin': 0.1, # slab margin,
|
||||
'same_margin': True, # keep same margin
|
||||
'random_transform': True, # apply random (orthogonal) transformation
|
||||
'width': 1, # data width in standard basis
|
||||
'num_lin': 1, # number of linear components
|
||||
'num_slabs': slab_5_flag, #. number of 5 slabs
|
||||
'num_slabs7': slab_7_flag, # number of 7 slabs
|
||||
'num_slabs3': 0, # number of 3 slabs
|
||||
'bs': 256, # batch size
|
||||
'corrupt_lin': spur_corr, # p_noise
|
||||
'corrupt_lin_margin': True, # noise model
|
||||
'corrupt_slab': 0.0, # slab corruption
|
||||
'num_test': 0, # test dataset size
|
||||
'hdim': 100, # model width
|
||||
'hl': 2, # model depth
|
||||
'mtype': 'fcn', # model architecture
|
||||
'device': gu.get_device(DEVICE_ID), # GPU device
|
||||
'lr': 0.1, # step size
|
||||
'weight_decay': 5e-5 # weight decay
|
||||
}
|
||||
|
||||
smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin']
|
||||
data_func = gendata.generate_ub_linslab_data_v2
|
||||
spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7']
|
||||
data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'],
|
||||
corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab'])
|
||||
|
||||
|
||||
#Project data
|
||||
W = data['W']
|
||||
X, Y = data['X'], data['Y']
|
||||
X = X.numpy().dot(W.T)
|
||||
Y = Y.numpy()
|
||||
|
||||
#Generate objects/slab_id array
|
||||
# Equation: k*slab_len + (k-1)*0.2 = 2: k is total number of slabs
|
||||
# Slabs: [-1, -1 + slab_len], [-1+slab_len+0.2, -1 + 0.2 + 2*slab_len ]
|
||||
# Slabs: For i in range(0, k-1): [-1 + i*(0.2 + slab_len), -1 + slab_len + i*(0.2 + slab_len) ]
|
||||
|
||||
k=total_slabs
|
||||
slab_len= (2 - 0.2*(k-1))/k
|
||||
O=np.zeros(X.shape[0])
|
||||
for idx in range(X.shape[0]):
|
||||
for slab_id in range(k):
|
||||
start= -1 + slab_id*(0.2 + slab_len)
|
||||
end= start+ slab_len
|
||||
if X[idx, 1] >= start and X[idx, 1] <= end:
|
||||
O[idx]= slab_id
|
||||
break
|
||||
|
||||
return data, X, Y, O
|
||||
|
Загрузка…
Ссылка в новой задаче