Slab Synthetic Data integrated with RobustDG

This commit is contained in:
divyat09 2021-03-17 00:41:15 +00:00
Родитель 514a3d92c8
Коммит 7ee0caa5a9
18 изменённых файлов: 2665 добавлений и 5 удалений

Просмотреть файл

@ -53,6 +53,10 @@ class BaseAlgo():
from models.lenet import LeNet5
phi= LeNet5()
if self.args.model_name == 'slab':
from models.slab import SlabClf
phi= SlabClf(self.args.slab_data_dim, self.args.out_classes)
if self.args.model_name == 'fc':
from models.fc import FC
if self.args.method_name in ['csd', 'matchdg_ctr']:
@ -148,7 +152,6 @@ class BaseAlgo():
with torch.no_grad():
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e = torch.argmax(d_e, dim=1).numpy()
#Forward Pass
out= self.phi(x_e)

121
algorithms/perf_match.py Normal file
Просмотреть файл

@ -0,0 +1,121 @@
import sys
import numpy as np
import argparse
import copy
import random
import json
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, slab_batch_process
class PerfMatch(BaseAlgo):
def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda)
def train(self):
self.max_epoch=-1
self.max_val_acc=0.0
for epoch in range(self.args.epochs):
penalty_erm=0
penalty_ws=0
train_acc= 0.0
train_size=0
#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)
#Process current batch as per slab dataset
# x_e, y_e ,d_e, idx_e= slab_batch_process(x_e, y_e ,d_e, idx_e)
self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
#Forward Pass
out= self.phi(x_e)
erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
loss_e+= erm_loss
penalty_erm += float(loss_e)
#Perfect Match Penalty
ws_loss= torch.tensor(0.0).to(self.cuda)
counter=0
match_objs= np.unique(idx_e)
feat= self.phi.feat_net(x_e)
for obj in match_objs:
indices= idx_e == obj
feat_obj= feat[indices]
d_obj= d_e[indices]
match_domains= torch.unique(d_obj)
if len(match_domains) != len(torch.unique(d_e)):
# print('Error: Positivty Violation, objects not present in all the domains')
continue
for d_i in range(len(match_domains)):
for d_j in range(len(match_domains)):
if d_j <= d_i:
continue
x1= feat_obj[ d_obj == d_i ]
x2= feat_obj[ d_obj == d_j ]
#Typecasting
# print(x1.shape, x2.shape)
x1= x1.view(x1.shape[0], 1, x1.shape[1])
ws_loss= torch.sum( torch.sum( torch.sum( (x1 -x2)**2, dim=2), dim=1 ) )
# ws_loss= torch.sum( torch.sum( torch.sum( torch.abs(x1 -x2), dim=2), dim=1 ) )
counter+= x1.shape[0]*x2.shape[0]
ws_loss= ws_loss/counter
penalty_ws += float(ws_loss)
#Backprop
loss_e+= self.args.penalty_ws*ws_loss*((epoch+1)/self.args.epochs)
loss_e.backward(retain_graph=False)
self.opt.step()
del erm_loss
del loss_e
torch.cuda.empty_cache()
train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
train_size+= y_e.shape[0]
print('Train Loss Basic : ', penalty_erm, penalty_ws )
print('Train Acc Env : ', 100*train_acc/train_size )
print('Done Training for epoch: ', epoch)
#Train Dataset Accuracy
self.train_acc.append( 100*train_acc/train_size )
#Val Dataset Accuracy
self.val_acc.append( self.get_test_accuracy('val') )
#Test Dataset Accuracy
self.final_acc.append( self.get_test_accuracy('test') )
#Save the model if current best epoch as per validation loss
if self.val_acc[-1] > self.max_val_acc:
self.max_val_acc=self.val_acc[-1]
self.max_epoch= epoch
self.save_model()
# Save the model's weights post training
self.save_model()

102
data/slab_loader.py Normal file
Просмотреть файл

@ -0,0 +1,102 @@
#Common imports
import os
import random
import copy
import numpy as np
import h5py
from PIL import Image
#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torchvision import datasets, transforms
#Base Class
from .data_loader import BaseDataLoader
#Specific Modules
from utils.slab_data import *
class SlabData(BaseDataLoader):
def __init__(self, args, list_train_domains, root, transform=None, data_case='train', match_func=False, base_size=10000, freq_ratio=50, data_dim=2, total_slabs=5):
super().__init__(args, list_train_domains, root, transform, data_case, match_func)
self.base_size = base_size
self.freq_ratio= freq_ratio
self.data_dim= data_dim
self.total_slabs= total_slabs
print(list_train_domains)
if self.data_case == 'train':
self.domain_size = [self.base_size, self.base_size]
# Default Train Domains: [0.0, 0.10]
self.spur_probs= [ float(domain) for domain in list_train_domains ]
elif self.data_case == 'val':
self.domain_size = [int(self.base_size/4), int(self.base_size/4)]
self.spur_probs= [ float(domain) for domain in list_train_domains ]
elif self.data_case == 'test':
self.domain_size = [self.base_size]
self.spur_probs= [1.0]
print('\n')
print('Data Case: ', self.data_case)
self.train_data, self.train_labels, self.train_domain, self.train_indices = self._get_data(self.domain_size, self.data_dim, self.total_slabs, self.spur_probs)
def _get_data(self, domain_size, data_dim, total_slabs, spur_probs):
list_data = []
list_labels = []
list_objs= []
total_domains= len(domain_size)
for idx in range(total_domains):
num_samples= domain_size[idx]
spur_prob= spur_probs[idx]
_, data, labels, match_obj= get_data(num_samples, spur_prob, total_slabs)
print('Source Domain: ', idx, ' Size: ', data.shape, labels.shape, match_obj.shape)
list_data.append(torch.tensor(data))
list_labels.append(torch.tensor(labels))
list_objs.append(match_obj)
# Stack data from the different domains
data_feat = torch.cat(list_data)
data_labels = torch.cat(list_labels)
data_objs= np.hstack(list_objs)
# Create domain labels
data_domains = torch.zeros(data_labels.size())
domain_start=0
for idx in range(total_domains):
curr_domain_size= domain_size[idx]
data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
domain_start+= curr_domain_size
# Shuffle everything one more time
# inds = np.arange(data_labels.size()[0])
# np.random.shuffle(inds)
# data_feat = data_feat[inds]
# data_labels = data_labels[inds].long()
# data_domains = data_domains[inds].long()
# Convert to onehot
y = torch.eye(2)
data_labels = y[data_labels]
# # Convert to onehot
# d = torch.eye(len(self.list_train_domains))
# data_domains = d[data_domains]
#Type Casting
data_feat= data_feat.type(torch.FloatTensor)
data_labels = data_labels.long()
data_domains = data_domains.long()
print('Final Dataset: ', data_feat.shape, data_labels.shape, data_domains.shape, data_objs.shape)
return data_feat, data_labels, data_domains, data_objs

Просмотреть файл

@ -63,6 +63,10 @@ class BaseEval():
from models.lenet import LeNet5
phi= LeNet5()
if self.args.model_name == 'slab':
from models.slab import SlabClf
phi= SlabClf(self.args.slab_data_dim, self.args.out_classes)
if self.args.model_name == 'fc':
from models.fc import FC
if self.args.method_name in ['csd', 'matchdg_ctr']:
@ -111,7 +115,7 @@ class BaseEval():
def load_model(self, run_matchdg_erm):
if self.args.method_name in ['erm_match', 'csd', 'irm']:
if self.args.method_name in ['erm_match', 'csd', 'irm', 'perf_match']:
self.save_path= self.base_res_dir + '/Model_' + self.post_string
elif self.args.method_name == 'matchdg_ctr':

35
models/slab.py Normal file
Просмотреть файл

@ -0,0 +1,35 @@
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
class SlabClf(nn.Module):
def __init__(self, inp_shape, out_shape):
super(SlabClf, self).__init__()
self.inp_shape = inp_shape
self.out_shape = out_shape
self.hidden_dim = 100
self.feat_net= nn.Sequential(
nn.Linear( self.inp_shape, self.hidden_dim),
nn.ReLU(),
)
self.fc= nn.Sequential(
nn.Linear( self.hidden_dim, self.hidden_dim),
nn.Linear( self.hidden_dim, self.out_shape),
)
self.disc= nn.Sequential(
nn.Linear( self.hidden_dim, self.hidden_dim),
nn.Linear( self.hidden_dim, 2),
)
self.embedding = nn.Embedding(2, self.hidden_dim)
def forward(self, x):
return self.fc(self.feat_net(x))

Просмотреть файл

@ -45,6 +45,10 @@ parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',

Просмотреть файл

@ -42,6 +42,10 @@ parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
@ -172,6 +176,13 @@ for run in range(args.n_runs):
test_dataset, base_res_dir,
run, cuda
)
if args.method_name == 'perf_match':
from algorithms.perf_match import PerfMatch
train_method= PerfMatch(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'matchdg_ctr':
from algorithms.match_dg import MatchDG
ctr_phase=1

Просмотреть файл

@ -14,6 +14,23 @@ from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
# Slab Dataset: Flatten the tensor along batch and domain axis
# Input of the shape (Batch, Domain, Feat)
def slab_batch_process(x, y, d, o):
if len(x.shape) > 2:
x= x.flatten(start_dim=0, end_dim=1)
if len(y.shape) > 1:
y= y.flatten(start_dim=0, end_dim=1)
if len(d.shape) > 1:
d= d.flatten(start_dim=0, end_dim=1)
if len(o.shape) > 1:
o= o.flatten(start_dim=0, end_dim=1)
return x, y, d, o
def t_sne_plot(X):
X= X.detach().cpu().numpy()
X= TSNE(n_components=2).fit_transform(X)
@ -214,6 +231,8 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
else:
from data.pacs_loader import PACS
elif args.dataset_name == 'slab':
from data.slab_loader import SlabData
if data_case == 'train':
match_func=True
@ -237,7 +256,10 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs):
except AttributeError:
batch_size= batch_size
if args.dataset_name in ['pacs', 'vlcs']:
if args.dataset_name == 'slab':
data_obj= SlabData(args, domains, '/slab/', data_case=data_case, match_func=match_func, base_size=args.slab_num_samples, freq_ratio=50, data_dim=args.slab_data_dim, total_slabs=args.slab_total_slabs)
elif args.dataset_name in ['pacs', 'vlcs']:
data_obj= PACS(args, domains, '/pacs/train_val_splits/', data_case=data_case, match_func=match_func)
elif args.dataset_name in ['chestxray']:

202
utils/scripts/data_utils.py Normal file
Просмотреть файл

@ -0,0 +1,202 @@
import sys
import random, os, copy, pickle, time, random, argparse, itertools
from collections import defaultdict, Counter, OrderedDict
import numpy as np
import torch
import torchvision
from torch import optim, nn
import torch.nn.functional as F
from sklearn import metrics
from torch.utils.data import TensorDataset, DataLoader
import utils.scripts.gpu_utils as gu
import utils.scripts.lms_utils as au
import utils.scripts.synth_models as synth_models
import utils.scripts.utils as utils
import matplotlib.pyplot as plt
import pathlib
try:
sys.path.append('../../cifar10_models/')
import cifar10_models as c10
c10_not_found = False
except:
c10_not_found = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
REPO_DIR = pathlib.Path(__file__).parent.parent.absolute()
DOWNLOAD_DIR = os.path.join(REPO_DIR, 'datasets')
def msd(x, r=3):
return np.round(np.mean(x), r), np.round(np.std(x), r)
def _get_dataloaders(trd, ted, bs, pm=True, shuffle=True):
train_dl = DataLoader(trd, batch_size=bs, shuffle=shuffle, pin_memory=pm)
test_dl = DataLoader(ted, batch_size=bs, pin_memory=pm)
return train_dl, test_dl
def get_cifar10_models(device=None, pretrained=True):
if c10_not_found: return {}
device = gu.get_device(None) if device is None else device
get_lmbda = lambda cls: (lambda: cls(pretrained=pretrained).eval().to(device))
return {
'vgg11_bn': get_lmbda(c10.vgg11_bn),
'vgg13_bn': get_lmbda(c10.vgg13_bn),
'vgg16_bn': get_lmbda(c10.vgg16_bn),
'vgg19_bn': get_lmbda(c10.vgg19_bn),
'resnet18': get_lmbda(c10.resnet18),
'resnet34': get_lmbda(c10.resnet34),
'resnet50': get_lmbda(c10.resnet50),
'densenet121': get_lmbda(c10.densenet121),
'densenet161': get_lmbda(c10.densenet161),
'densenet169': get_lmbda(c10.densenet169),
'mobilenet_v2': get_lmbda(c10.mobilenet_v2),
'googlenet': get_lmbda(c10.googlenet),
'inception_v3': get_lmbda(c10.inception_v3)
}
def plot_decision_boundary(dl, model, c1, c2, ax=None, print_info=True):
if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4))
model = model.cpu()
deps = sorted(au.get_feature_deps(dl, model).items(), key=lambda t: t[-1])
if print_info:
for k, v in deps: print ('{}:{:.3f}'.format(k,v), end=', ')
print ("")
X, Y = utils.extract_numpy_from_loader(dl)
K = 100_000
U = np.random.uniform(low=X.min(), high=X.max(), size=(K, X.shape[1])) # copy.deepcopy(X)
U[:, c1] = np.random.uniform(low=X[:, c1].min(), high=X[:, c1].max(), size=K)
U[:, c2] = np.random.uniform(low=X[:, c2].min(), high=X[:, c2].max(), size=K)
U = torch.Tensor(U)
with torch.no_grad():
out = model(U)
Yu = torch.argmax(out, 1)
ax.scatter(U[:,c1], U[:,c2], c=Yu, alpha=0.3, s=24)
ax.scatter(X[:,c1], X[:,c2], c=Y, cmap='coolwarm', s=12)
def get_binary_datasets(X, Y, y1, y2, image_width=28, use_cnn=False):
assert type(X) is np.ndarray and type(Y) is np.ndarray
idx0 = (Y==y1).nonzero()[0]
idx1 = (Y==y2).nonzero()[0]
idx = np.concatenate((idx0, idx1))
X_, Y_ = X[idx,:], (Y[idx]==y2).astype(int)
P = np.random.permutation(len(X_))
X_, Y_ = X_[P,:], Y_[P]
if use_cnn: X_ = X_.reshape(X.shape[0], -1, image_width)[:, None, :, :]
return X_[P,:], Y_[P]
def get_binary_loader(dl, y1, y2):
X, Y = utils.extract_numpy_from_loader(dl)
X, Y = get_binary_datasets(X, Y, y1, y2)
return utils._to_dl(X, Y, bs=dl.batch_size)
def get_mnist(fpath=DOWNLOAD_DIR, flatten=False, binarize=False, normalize=True, y0={0,1,2,3,4}):
"""get preprocessed mnist torch.TensorDataset class"""
def _to_torch(d):
X, Y = [], []
for xb, yb in d:
X.append(xb)
Y.append(yb)
return torch.Tensor(np.stack(X)), torch.LongTensor(np.stack(Y))
to_tensor = torchvision.transforms.ToTensor()
to_flat = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze())
to_norm = torchvision.transforms.Normalize((0.5, ), (0.5, ))
to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1)
transforms = [to_tensor]
if normalize: transforms.append(to_norm)
if flatten: transforms.append(to_flat)
tf = torchvision.transforms.Compose(transforms)
ttf = to_binary if binarize else None
X_tr = torchvision.datasets.MNIST(fpath, download=True, transform=tf, target_transform=ttf)
X_te = torchvision.datasets.MNIST(fpath, download=True, train=False, transform=tf, target_transform=ttf)
return _to_torch(X_tr), _to_torch(X_te)
def get_mnist_dl(fpath=DOWNLOAD_DIR, to_np=False, bs=128, pm=False, shuffle=False,
normalize=True, flatten=False, binarize=False, y0={0,1,2,3,4}):
(X_tr, Y_tr), (X_te, Y_te) = get_mnist(fpath, normalize=normalize, flatten=flatten, binarize=binarize, y0=y0)
tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, shuffle=shuffle, pin_memory=pm)
te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, pin_memory=pm)
return tr_dl, te_dl
def get_cifar(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none',
means=None, std=None, use_grayscale=False, binarize=False, normalize=True, y0={0,1,2,3,4}):
"""get preprocessed cifar torch.Dataset class"""
if transform_type == 'none':
normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
tensorize = torchvision.transforms.ToTensor()
to_grayscale = torchvision.transforms.Grayscale()
flatten = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze())
transforms = [tensorize]
if use_grayscale: transforms = [to_grayscale] + transforms
if normalize: transforms.append(normalize_cifar())
if flatten_data: transforms.append(flatten)
tr_transforms = te_transforms = torchvision.transforms.Compose(transforms)
if transform_type == 'basic':
normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
tr_transforms= [
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()
]
te_transforms = [
torchvision.transforms.Resize(32),
torchvision.transforms.CenterCrop(32),
torchvision.transforms.ToTensor(),
]
if normalize:
tr_transforms.append(normalize_cifar())
te_transforms.append(normalize_cifar())
tr_transforms = torchvision.transforms.Compose(tr_transforms)
te_transforms = torchvision.transforms.Compose(te_transforms)
to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1)
target_transforms = to_binary if binarize else None
dset = 'cifar10' if use_cifar10 else 'cifar100'
func = torchvision.datasets.CIFAR10 if use_cifar10 else torchvision.datasets.CIFAR100
X_tr = func(fpath, download=True, transform=tr_transforms, target_transform=target_transforms)
X_te = func(fpath, download=True, train=False, transform=te_transforms, target_transform=target_transforms)
return X_tr, X_te
def get_cifar_dl(fpath=DOWNLOAD_DIR, use_cifar10=False, bs=128, shuffle=True, transform_type='none',
means=None, std=None, normalize=True, flatten_data=False, use_grayscale=False, nw=4, pm=False, binarize=False, y0={0,1,2,3,4}):
"""data in dataloaders have has shape (B, C, W, H)"""
d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, means=means, std=std, flatten_data=flatten_data, binarize=binarize, y0=y0)
tr_dl = DataLoader(d_tr, batch_size=bs, shuffle=shuffle, num_workers=nw, pin_memory=pm)
te_dl = DataLoader(d_te, batch_size=bs, num_workers=nw, pin_memory=pm)
return tr_dl, te_dl
def get_cifar_np(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', normalize=True, binarize=False, y0={0,1,2,3,4}, use_grayscale=False):
"""get numpy matrices of preprocessed cifar data"""
def _to_np(d):
X, Y = [], []
for xb, yb in d:
X.append(xb)
Y.append(yb)
return map(np.stack, [X,Y])
d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, flatten_data=flatten_data, binarize=binarize, y0=y0)
return _to_np(d_tr), _to_np(d_te)
if __name__ == '__main__':
pass

130
utils/scripts/ensemble.py Normal file
Просмотреть файл

@ -0,0 +1,130 @@
import os, copy, pickle, time
import random, itertools
from collections import defaultdict, Counter, OrderedDict
import numpy as np
import torch
import pandas as pd
import torchvision
from torch.utils.data import TensorDataset, DataLoader
from torch import optim, nn
import torch.nn.functional as F
import dill
import gpu_utils as gu
import data_utils as du
import synth_models as sm
import utils
class Ensemble(nn.Module):
def _get_dummy_classifier(self):
def dummy(x):
return x
return dummy
def __init__(self, models, num_classes, use_softmax=False):
super(Ensemble, self).__init__()
self.num_classes = num_classes
self.use_softmax = use_softmax
# register models as pytorch modules
self.models = []
for idx, m in enumerate(models,1):
setattr(self, 'm{}'.format(idx), m.eval())
self.models.append(getattr(self, 'm{}'.format(idx)))
self.classifier = self._get_dummy_classifier()
def _forward(self, x):
return x
def forward(self, x):
outs = self._forward(x)
return self.classifier(outs)
def get_output_loader(self, dl, device=gu.get_device(None), bs=None):
"""return dataloader of model output (logit or softmax prob)"""
X, Y = [], []
with torch.no_grad():
for xb, yb in dl:
xb = xb.to(device)
out = self._forward(xb).cpu()
X.append(out)
Y.append(yb)
X, Y = torch.cat(X), torch.cat(Y)
return DataLoader(TensorDataset(X, Y), batch_size=bs or dl.batch_size)
def fit_classifier(self, tr_dl, te_dl, lr=0.05, adam=False, wd=5e-5, device=None, **fit_kw):
device = gu.get_device(None) if device is None else device
self = self.to(device)
c = dict(gap=1000, epsilon=1e-2, wd=5e-5, is_loss_epsilon=True)
c.update(**fit_kw)
tro_dl = self.get_output_loader(tr_dl, device)
teo_dl = self.get_output_loader(te_dl, device)
if adam: opt = optim.Adam(self.classifier.parameters())
else: opt = optim.SGD(self.classifier.parameters(), lr=lr, weight_decay=wd)
stats = utils.fit_model(self.classifier, F.cross_entropy, opt, tro_dl, teo_dl, device=device, **c)
self.classifier = stats['best_model'][-1].to(device)
self = self.cpu()
return stats
class EnsembleLinear(Ensemble):
def _get_classifier(self):
# linear with equal weights and zero bias
nl = self.num_classes*len(self.models)
linear = nn.Linear(nl, self.num_classes, bias=self.use_bias)
nn.init.ones_(linear.weight.data)
linear.weight.data /= float(nl)
if self.use_bias: linear.bias.data.zero_()
return linear
def __init__(self, models, num_classes=2, use_softmax=False, use_bias=True):
super(EnsembleLinear, self).__init__(models, num_classes, use_softmax)
self.use_bias = use_bias
self.classifier = self._get_classifier()
def _forward(self, x):
outs = [m(x) for m in self.models]
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
outs = torch.stack(outs, dim=2)
outs = outs.reshape(outs.shape[0], -1)
return outs
class EnsembleMLP(Ensemble):
def _get_classifier(self):
nl = self.num_classes*len(self.models)
fcn = sm.get_fcn(nl, self.hdim or nl, self.num_classes, hl=self.hl)
return fcn
def __init__(self, models, num_classes=2, use_softmax=False, hdim=None, hl=1):
super(EnsembleMLP, self).__init__(models, num_classes, use_softmax)
self.hdim = hdim
self.hl = hl
self.classifier = self._get_classifier()
def _forward(self, x):
outs = [m(x) for m in self.models]
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
outs = torch.stack(outs, dim=2)
outs = outs.reshape(outs.shape[0], -1)
return outs
class EnsembleAverage(Ensemble):
def __init__(self, models, num_classes=2, use_softmax=False):
super(EnsembleAverage, self).__init__(models, num_classes, use_softmax)
self.classifier = self._get_dummy_classifier()
def _forward(self, x):
outs = [m(x) for m in self.models]
if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs]
outs = torch.stack(outs)
return outs.mean(dim=0)
def fit_classifier(self, *args, **kwargs):
return None

316
utils/scripts/gendata.py Normal file
Просмотреть файл

@ -0,0 +1,316 @@
import numpy as np
import scipy.stats as scs
import random
from collections import Counter
import torch
from torch.utils.data import TensorDataset, DataLoader
import utils.scripts.utils as utils
import utils.scripts.gpu_utils as gu
def _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=None):
X_te, Y_te = torch.Tensor(X[:N_te,:]), torch.Tensor(Y[:N_te])
X_tr, Y_tr = torch.Tensor(X[N_te:,:]), torch.Tensor(Y[N_te:])
Y_te, Y_tr = map(lambda Z: Z.long(), [Y_te, Y_tr])
tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=True)
te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=False)
return {
'X': torch.tensor(X).float(),
'Y': torch.tensor(Y).long(),
'w': w,
'tr_dl': tr_dl,
'te_dl': te_dl,
'N': (N_tr, N_te),
'W': orth_matrix
}
def _get_random_data(N, dim, scale):
X = np.random.uniform(size=(N, dim))
X *= scale
Y = np.random.choice([0,1], size=N)
return X, Y
def generate_linsep_data_v2(N_tr, dim, eff_margin, width=10., bs=256, scale_noise=True, pm=True, nw=0, no_width=False, N_te=5000): # no unif_max.
assert eff_margin < 1, "equal range constraint"
margin = eff_margin if no_width else eff_margin*width
N = N_tr + N_te
w = np.zeros(shape=dim)
w[0] = 1
X, Y = _get_random_data(N, dim, width if scale_noise else 1.)
U = np.random.uniform(size=N)
if no_width: X[:,0] = (2*Y-1)*margin
else: X[:, 0] = (2*Y-1)*(margin + (width-margin)*U)
P = np.random.permutation(X.shape[0])
X, Y = X[P,:], Y[P]
return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w)
def sample_from_unif_union_of_unifs(unifs, size):
x = []
choices = Counter(np.random.choice(list(range(len(unifs))), size=size))
for choice, sz in choices.items():
s = np.random.uniform(low=unifs[choice][0], high=unifs[choice][1], size=sz)
x.append(s)
x = np.concatenate(x)
return x
def generate_ub_linslab_data_diffmargin_v2(N_tr, dim, eff_lin_margins, eff_slab_margins,
slabs_per_coord, slab_p_vals, corrupt_lin=0., corrupt_slab=0.,
corrupt_slab7=0., scale_noise=True, width=10., lin_coord=0, lin_shift=0.,
slab_shift=0., indep_slabs=True, bs=256, pm=True, nw=0, N_te=10000,
random_transform=False, corrupt_lin_margin=False, corrupt_5slab_margin=False):
get_unif = lambda a: np.random.uniform(size=a)
get_bool = lambda a: np.random.choice([0,1], size=a)
get_sign = lambda a: 2*get_bool(a)-1.
def get_slab_width(NS, B, SM):
if NS==3: return (2.*B-4.*SM)/3.
if NS==5: return (2.*B-8.*SM)/5.
if NS==7: return (2.*B-12.*SM)/7.
return None
num_lin, num_slabs = map(len, [eff_lin_margins, eff_slab_margins])
assert 0 <= corrupt_lin <= 1, "input is probability"
assert num_lin + num_slabs <= dim, "dim constraint, num_lin: {}, num_slabs: {}, dim: {}".format(num_lin, num_slabs, dim)
for elm in eff_lin_margins: assert 0 < elm < 1, "equal range constraint (0 < eff_lin_margin={} < 1)".format(elm)
for esm in eff_slab_margins: assert 0 < esm < 1, "equal range constraint (0 < eff_slab_margin={} < 0.25)".format(esm)
lin_margins = list(map(lambda x: x*width, eff_lin_margins))
slab_margins = list(map(lambda x: x*width, eff_slab_margins))
# hyperplane
N = N_tr + N_te
half_N = N//2
w = np.zeros(shape=dim); w[0] = 1
X, Y = _get_random_data(N, dim, width if scale_noise else 1.)
nrange = list(range(N))
# linear
total_corrupt = int(round(N*corrupt_lin))
no_linear = num_lin == 0
if not no_linear:
for coord, lin_margin in enumerate(lin_margins):
if indep_slabs:
P = np.random.permutation(N)
X, Y = X[P, :], Y[P]
X[:, coord] = (2*Y-1)*(lin_margin+(width-lin_margin)*get_unif(N)) + lin_shift*width
# corrupt linear coordinate
if total_corrupt > 0:
corrupt_sample = np.random.choice(nrange, size=total_corrupt, replace=False)
if corrupt_lin_margin:
X[corrupt_sample, 0] = np.random.uniform(low=-lin_margin, high=lin_margin, size=total_corrupt)
else:
X[corrupt_sample, 0] *= -1
# slabs
i = (num_lin)*int(not no_linear)
for idx, coord in enumerate(range(i, i+num_slabs)):
slab_per = slabs_per_coord[idx]
assert slab_per in [3, 5, 7], "Invalid slabs_per_coord"
slab_pval = slab_p_vals[idx]
slab_margin = slab_margins[idx]
slab_width = get_slab_width(slab_per, width, slab_margin)
if indep_slabs:
P = np.random.permutation(N)
X, Y = X[P, :], Y[P]
if slab_per == 3:
# positive slabs
idx_p = (Y==1).nonzero()[0]
offset = 0.5*slab_width + 2*slab_margin
X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p)))
# negative center
idx_n = (Y==0).nonzero()[0]
X[idx_n, coord] = 0.5*get_sign(len(idx_n))*slab_width*get_unif(len(idx_n))
if slab_per == 5:
# positive slabs
idx_p = (Y==1).nonzero()[0]
offset = (width+6*slab_margin)/5.
X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p)))
# negative slabs partitioned using p val
idx_n = (Y==0).nonzero()[0]
in_ctr = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_n))
idx_nc, idx_ns = idx_n[(in_ctr==1)], idx_n[(in_ctr==0)]
# negative center
X[idx_nc, coord] = 0.5*get_sign(len(idx_nc))*slab_width*get_unif(len(idx_nc))
# negative sides
offset = (8*slab_margin+3*width)/5.
X[idx_ns, coord] = get_sign(len(idx_ns))*(offset+slab_width*get_unif(len(idx_ns)))
# corrupt slab 5
total_corrupt = int(round(N*corrupt_slab))
if total_corrupt > 0:
if corrupt_5slab_margin:
offset1 = (width+6*slab_margin)/5.
offset2 = (8*slab_margin+3*width)/5.
unifs = [
(0.5*slab_width, offset1),
(offset1+slab_width, offset2),
(-offset1, -0.5*slab_width),
(-offset2, -(offset1+slab_width))
]
idx = np.random.choice(range(N), size=total_corrupt, replace=False)
X[idx, coord] = sample_from_unif_union_of_unifs(unifs, total_corrupt)
else:
# get corrupt sample
idx = np.random.choice(range(N), size=total_corrupt, replace=False)
idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1)
idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1)
# move negative points to random positive slabs
offset = (0.5*slab_width+2*slab_margin)
X[idx_n, coord] = torch.Tensor(get_sign(len(idx_n))*(offset+slab_width*get_unif(len(idx_n))))
# pick negative slab for each positve point
mv_to_ctr = np.random.choice([0, 1], size=len(idx_p))
idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1)
idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1)
# move positive points to negative slabs
X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*slab_width*get_unif(len(idx_p_ctr)))
# move negative points to positve slabs
offset = 1.5*slab_width + 4*slab_margin
X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid))))
if slab_per == 7:
# positive slabs
idx_p = (Y==1).nonzero()[0]
in_s0 = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_p))
idx_p0, idx_p1 = idx_p[(in_s0==1)], idx_p[(in_s0==0)]
# positive slab 0 (inner)
offset = 0.5*slab_width+2*slab_margin
X[idx_p0, coord] = get_sign(len(idx_p0))*(offset+slab_width*get_unif(len(idx_p0)))
# positive slab 1 (outer)
offset = 2.5*slab_width+6*slab_margin
X[idx_p1, coord] = get_sign(len(idx_p1))*(offset+slab_width*get_unif(len(idx_p1)))
# negative slabs
idx_n = (Y==0).nonzero()[0]
in_s0 = get_bool(len(idx_n))
idx_n0, idx_n1 = idx_n[(in_s0==1)], idx_n[(in_s0==0)]
# negative slab 0 (center)
X[idx_n0, coord] = 0.5*get_sign(len(idx_n0))*slab_width*get_unif(len(idx_n0))
# negative slab 1 (outer)
offset = 1.5*slab_width+4*slab_margin
X[idx_n1, coord] = get_sign(len(idx_n1))*(offset+slab_width*get_unif(len(idx_n1)))
# corrupt slab7
total_corrupt = int(round(N*corrupt_slab7))
if total_corrupt > 0:
# corrupt data
idx = np.random.choice(range(len(X)), size=total_corrupt, replace=False)
idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1)
idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1)
# pick positive slab for each negative slab
mv_to_inner = get_bool(len(idx_n))
idx_n_inner = idx_n[np.argwhere(mv_to_inner==1)].reshape(-1)
idx_n_outer = idx_n[np.argwhere(mv_to_inner==0)].reshape(-1)
# move to idx_n_inner and outer
offset = 0.5*slab_width+2*slab_margin
X[idx_n_inner, coord] = torch.Tensor(get_sign(len(idx_n_inner))*(offset+slab_width*get_unif(len(idx_n_inner))))
offset = 2.5*slab_width+6*slab_margin
X[idx_n_outer, coord] = torch.Tensor(get_sign(len(idx_n_outer))*(offset+slab_width*get_unif(len(idx_n_outer))))
# pick negative slab for each positive point
mv_to_ctr = get_bool(len(idx_p))
idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1)
idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1)
# move to idx_n_inner and outer
X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*(slab_width*get_unif(len(idx_p_ctr))))
offset = 1.5*slab_width+4*slab_margin
X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid))))
# shift
X[:, coord] += slab_shift*width
# reshuffle
P = np.random.permutation(N)
X, Y = X[P,:], Y[P]
# lin coord position
if not random_transform and lin_coord != 0:
X[:, [0, lin_coord]] = X[:, [lin_coord, 0]]
# transform
W = np.eye(dim)
if random_transform: W = utils.get_orthonormal_matrix(dim)
X = X.dot(W)
return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=W)
def generate_ub_linslab_data_v2(N_tr, dim, eff_lin_margin, eff_slab_margin=None, lin_coord=0,
corrupt_lin=0., corrupt_slab=0., corrupt_slab3=0., corrupt_slab7=0.,
scale_noise=True, num_lin=1, lin_shift=0., slab_shift=0., random_transform=False,
num_slabs=1, slabs_per_coord=5, width=10., indep_slabs=True, no_linear=False,
bs=256, pm=True, nw=0, N_te=10000, corrupt_lin_margin=False, slab5_pval=3/4.,
slab3_pval=1/2., slab7_pval=7/8., corrupt_5slab_margin=False):
slab_p_map = {5: slab5_pval, 7: slab7_pval, 3: slab3_pval}
slabs_per_coord = [slabs_per_coord]*num_slabs if type(slabs_per_coord) is int else slabs_per_coord[:]
for x in slabs_per_coord: assert x in slab_p_map
slab_p_vals = [slab_p_map[x] for x in slabs_per_coord]
lms = [eff_lin_margin]*num_lin
sms = eff_slab_margin if type(eff_slab_margin) is list else [eff_slab_margin]*num_slabs
return generate_ub_linslab_data_diffmargin_v2(N_tr, dim, lms, sms, slabs_per_coord, slab_p_vals, lin_coord=lin_coord, corrupt_slab=corrupt_slab,
corrupt_slab7=corrupt_slab7, corrupt_lin=corrupt_lin, scale_noise=scale_noise, width=width,
lin_shift=lin_shift, slab_shift=slab_shift, random_transform=random_transform, indep_slabs=indep_slabs,
pm=pm, bs=bs, corrupt_lin_margin=corrupt_lin_margin, nw=nw, N_te=N_te, corrupt_5slab_margin=corrupt_5slab_margin)
def get_lms_data(**kw):
c = config = {
'num_train': 100_000,
'dim': 20,
'lin_margin': 0.1,
'slab_margin': 0.1,
'same_margin': False,
'random_transform': False,
'width': 1, # data width
'bs': 256,
'corrupt_lin': 0.0,
'corrupt_lin_margin': False,
'corrupt_slab': 0.0,
'num_test': 2_000,
'hdim': 200, # model width
'hl': 2, # model depth
'device': gu.get_device(0),
'input_dropout': 0,
'num_lin': 1,
'num_slabs': 19,
'num_slabs7': 0,
'num_slabs3': 0,
}
c.update(kw)
smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin']
data_func = generate_ub_linslab_data_v2
spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7']
data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'],
corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'],
corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab'])
return data, c

100
utils/scripts/gpu_utils.py Normal file
Просмотреть файл

@ -0,0 +1,100 @@
try: import pycuda.driver as cuda
except: print ("pycuda not available")
import torch
import sys, os, glob, subprocess
def get_gpu_info(print_info=True, get_specs=False):
cuda.init()
if get_specs: gpu_specs = cuda.Device(0).get_attributes() # assume same for all (dnnx)
else: gpu_specs = None
gpu_info = {
'available': torch.cuda.is_available(),
'num_devices': torch.cuda.device_count(),
'devices': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]),
'current device id': torch.cuda.current_device(),
'allocated memory': torch.cuda.memory_allocated(),
'cached memory': torch.cuda.memory_cached()
}
if print_info:
for k,v in gpu_info.items(): print ("{}: {}".format(k, v))
return gpu_info, gpu_specs
def get_device(device_id=None): # None -> cpu
device = 'cuda:{}'.format(device_id) if device_id is not None else 'cpu'
device = torch.device(device if torch.cuda.is_available() and device_id is not None else 'cpu')
return device
def get_gpu_name():
try:
out_str = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE).stdout
out_list = out_str.decode("utf-8").split('\n')
out_list = out_list[1:-1]
return out_list
except Exception as e:
print(e)
def get_cuda_version():
"""Get CUDA version"""
if sys.platform == 'win32':
raise NotImplementedError("Implement this!")
# This breaks on linux:
#cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA"
#path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\version.txt"
elif sys.platform == 'linux' or sys.platform == 'darwin':
path = '/usr/local/cuda/version.txt'
else:
raise ValueError("Not in Windows, Linux or Mac")
if os.path.isfile(path):
with open(path, 'r') as f:
data = f.read().replace('\n','')
return data
else:
return "No CUDA in this machine"
def get_cudnn_version():
"""Get CUDNN version"""
if sys.platform == 'win32':
raise NotImplementedError("Implement this!")
# This breaks on linux:
#cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA"
#candidates = ["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\include\\cudnn.h"]
elif sys.platform == 'linux':
candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h',
'/usr/local/cuda/include/cudnn.h',
'/usr/include/cudnn.h']
elif sys.platform == 'darwin':
candidates = ['/usr/local/cuda/include/cudnn.h',
'/usr/include/cudnn.h']
else:
raise ValueError("Not in Windows, Linux or Mac")
for c in candidates:
file = glob.glob(c)
if file: break
if file:
with open(file[0], 'r') as f:
version = ''
for line in f:
if "#define CUDNN_MAJOR" in line:
version = line.split()[-1]
if "#define CUDNN_MINOR" in line:
version += '.' + line.split()[-1]
if "#define CUDNN_PATCHLEVEL" in line:
version += '.' + line.split()[-1]
if version:
return version
else:
return "Cannot find CUDNN version"
else:
return "No CUDNN in this machine"
if __name__=='__main__':
print ('gpu name', get_gpu_name())
print ('cuda', get_cuda_version())
print ('cudnn', get_cudnn_version())
print ('device0', get_device(0))
print ('available', torch.cuda.is_available())

307
utils/scripts/lms_utils.py Normal file
Просмотреть файл

@ -0,0 +1,307 @@
import seaborn as sns
import utils.scripts.gpu_utils as gu
import utils.scripts.data_utils as du
import utils
import random
import os, copy, pickle, time
import itertools
from collections import defaultdict, Counter, OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
#import foolbox
from torch.utils.data import TensorDataset, DataLoader
from torch import optim, nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
def parse_data(exps=None, root='/', **funcs_kw):
"""
main function (parse data files and run added functions on it)
"""
exps = exps if exps is not None else os.listdir(root)
total = len(exps)
print ("total: {}".format(total))
parsed = defaultdict(dict)
if total == 0: return parsed
for idx, exp in enumerate(exps):
if (idx+1) % 1 == 0: print (idx+1, end=' ', flush=True)
# load data
fpath = os.path.join(root, exp)
try:
data = torch.load(fpath, map_location=lambda storage, loc: storage)
except:
print ("File {} corrupted, skip.".format(fpath))
continue
config = data['config']
# make exp config key
config['run'] = int(exp.rsplit('.', 1)[0][-1])
config['fname'] = exp
ckeys = ['exp_name', 'dim', 'num_train', 'lin_margin', 'slab_margin', 'num_slabs',
'num_slabs', 'width', 'hdim', 'hl', 'linear', 'use_bn', 'run', 'fname',
'weight_decay', 'dropout']
ckeys = [c for c in ckeys if c in config]
cvals = [config[k] for k in ckeys]
ckv = tuple(zip(ckeys, cvals))
# save config
parsed[ckv]['config'] = config
# save functions
for func_name, func in funcs_kw.items():
parsed[ckv][func_name] = func(data)
return parsed
def parse_exp_stats(data):
"""training summary statistics"""
stats = data['stats']
s = {}
# loss + accuracy
for t1, t2 in itertools.product(['acc', 'loss'], ['tr', 'te']):
s['{}_{}'.format(t1,t2)] = stats['{}_{}'.format(t1, t2)][-1]
s['orig_stats'] = stats
s['acc_gap'] = s['acc_tr']-s['acc_te']
s['loss_gap'] = s['loss_te']-s['loss_tr']
s['fin_acc_te'] = stats['acc_te'][-1]
s['fin_acc_tr'] = stats['acc_tr'][-1]
# updates
s['update_gap'] = stats['update_gap']
s['num_updates'] = stats['num_updates']
# effective number of updates
for acc_threshold in [0.96, 0.97, 0.98, 0.99, 1]:
eff = np.argmin(np.abs(np.array(stats['acc_tr'])-acc_threshold))*s['update_gap']
s['effective_num_updates{}'.format(int(acc_threshold*100))] = eff
return s
def parse_exp_model(data):
"""model parameter stats"""
depth = data['config']['hl']
linear = data['config']['linear']
mtype = data['config'].get('mtype', 'fcn')
if mtype == 'fcn' and depth == 1 and not linear: d = parse_exp_depth1_model(data)
if mtype == 'fcn' and depth == 1 and linear: d = parse_exp_linear_model(data)
return {}
def parse_exp_depth1_model(data):
"""cosine + w2"""
device = gu.get_device()
model = data['model'].to(device)
p = W1, b1, w2, b2 = list(map(lambda x: x.detach().numpy(), model.parameters()))
s = {}
s['params'] = p
s['cosine'] = W1[:, 0]/np.linalg.norm(W1, axis=1)
s['l2'] = np.linalg.norm(W1, axis=1)
s['w2'] = w2
s['corr0'] = np.corrcoef(s['cosine'], w2[0, :])[0,1]
s['corr1'] = np.corrcoef(s['cosine'], w2[1, :])[0,1]
s['max_weight_cosine'] = s['cosine'][np.argmax(s['w2'][1,:])]
return s
def parse_exp_linear_model(data):
"""cosine"""
device = gu.get_device()
model = data['model'].to(device)
p = W,b = list(map(lambda x: x.detach().numpy(), model.parameters()))
s = {}
s['cosine0'], s['cosine1'] = W[:, 0]/np.linalg.norm(W, axis=1)
return s
def parse_exp_data(data, load_X=False):
s = {}
model = data['model'].to(gu.get_device())
data = data['data']
X, Y = data['X'], data['Y']
if type(X) != np.ndarray:
X = data['X'].detach().cpu()
if type(X) != np.ndarray:
Y = data['Y'].detach().cpu()
s['Y'] = Y
if load_X: s['X'] = X
s['Y_'] = get_yhat(model, X)
s['model'] = model
return s
def get_yhat(model, data):
if type(data)==np.ndarray: data = torch.Tensor(data)
return torch.argmax(model(data), 1)
def get_acc(y,yhat):
n = float(len(y))
return (y==yhat).sum().item()/n
def parse_and_get_df(root, prefix, files=None, device_id=None, only_load=False, only_linear=False, sample_pct=0.5, load_X=False, use_model_pred=False):
exps = files if files is not None else [f for f in os.listdir(root) if f.startswith(prefix)]
funcs = {
'config': lambda d: d['config'],
'stats': parse_exp_stats,
'model': parse_exp_model,
'data': lambda x: parse_exp_data(x, load_X=load_X),
'random_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='random', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id),
'swap_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='swap', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id),
}
P = parse_data(root=root, exps=exps, **funcs)
if only_load: return P
D = []
for idx, (k,v) in enumerate(P.items(),1):
d = OrderedDict()
for a,b in k: d[a] = b
for vk in ['model', 'data', 'stats', 'config']:
for a,b in v[vk].items(): d[a] = b
for vk in ['random_dep', 'swap_dep']:
for coord, dep in v[vk].items():
d[f'{vk[0]}dep_{coord}'] = dep
D.append(d)
df = pd.DataFrame(D)
if len(df): df['nd'] = df['num_train']/df['dim']
return df
def viz(d, c1, c2, k=80_000, info=True, plot_dm=True, plot_data=True, use_yhat=False, unif_k=False, width=10, title=None, is_binary=False, dep_type='swap', ax=None):
if 'W' not in d['data']: W = np.eye(d['config']['dim'])
else: W = d['data']['W']
if W is None: W = np.eye(d['config']['dim'])
z = parse_exp_data(d)
X = d['data']['X']
# visualize un-transformed data...
X_ = np.array(X).dot(W.T)
Y, Y_ = z['Y'], z['Y_']
model = d['model'].cpu()
D = X.shape[1]
kn = k if unif_k else len(X)
K = torch.Tensor(np.random.uniform(size=(k, D)))*width if unif_k else np.array(X_)
K[:, c1] = torch.Tensor(np.random.uniform(low=min(X_[:,c1]), high=max(X_[:,c1]), size=kn))
K[:, c2] = torch.Tensor(np.random.uniform(low=min(X_[:,c2]), high=max(X_[:,c2]), size=kn))
KO = model(torch.Tensor(np.array(K).dot(W)))
if is_binary: KY = (KO > 0).squeeze().numpy()
else: KY = torch.argmax(KO, 1).numpy()
if info:
deps = get_feature_deps(d['data']['te_dl'], d['model'], W=d['data'].get('W', None), dep_type=dep_type)
for k,v in sorted(deps.items(), reverse=False, key=lambda t: t[-1]): print ('{}:{:.3f}'.format(k,v), end=' ')
print ("\n")
if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4))
if plot_dm: ax.scatter(K[:, c1], K[:, c2], c=KY, cmap='binary', s=8, alpha=.2)
if plot_data: ax.scatter(X_[:, c1], X_[:, c2], c=Y_ if use_yhat else Y, cmap='coolwarm', s=8, alpha=.4)
ax.set_xlabel('e_{}'.format(c1))
ax.set_ylabel('e_{}'.format(c2))
ax.set_title(title if title else '')
plt.tight_layout()
return ax
def visualize_boundary(model, data, c1, c2, dim, ax=None, is_binary=False, use_yhat=False, width=1, unif_k=True, k=100_000, print_info=True, dep_type='random'):
agg = {'model': model, 'data': data, 'config': dict(dim=dim)}
return viz(agg, c1, c2, unif_k=unif_k, width=width, dep_type=dep_type, is_binary=is_binary, use_yhat=use_yhat, ax=ax, info=print_info)
def get_randomized_loader(dl, W, coordinates):
"""
dl: dataloader
W: rotation matrix
coordinates: list of coordinates to randomize
output: randomized dataloader
"""
def _randomize(X, coords):
p = torch.randperm(len(X))
for c in coords: X[:, c] = X[p, c]
return X
# rotate data
X, Y = map(copy.deepcopy, dl.dataset.tensors)
dim = X.shape[1]
if W is None: W = np.eye(dim)
rt_X = torch.Tensor(X.numpy().dot(W.T))
rand_rt_X = _randomize(rt_X, coordinates)
rand_X = torch.Tensor(rand_rt_X.numpy().dot(W))
return utils._to_dl(rand_X, Y, dl.batch_size)
def get_feature_deps(dl, model, W=None, dep_type='random', only_linear=False, coords=None, metric='accuracy',
use_model_pred=False, print_info=False, sample_pct=1.0, device_id=None):
"""Compute feature dependencies using randomization or swapping"""
def _randomize(X, Y, coords):
p = torch.randperm(len(X))
for c in coords: X[:, c] = X[p, c]
return X
def _swap(X, Y, coords):
idx0, idx1 = map(lambda c: (Y.numpy()==c).nonzero()[0], [0, 1])
idx0_new = np.random.choice(idx1, size=len(idx0), replace=True)
idx1_new = np.random.choice(idx0, size=len(idx1), replace=True)
for c in coords: X[idx0, c], X[idx1, c] = X[idx0_new, c], X[idx1_new, c]
return X
def _get_dep_data(X, Y, coords):
return dict(random=_randomize, swap=_swap)[dep_type](X, Y, coords)
assert metric in {'accuracy', 'loss', 'auc'}
# setup data
device = gu.get_device(device_id)
model = model.to(device)
X, Y = map(lambda Z: Z.to(device), dl.dataset.tensors)
Yh = get_yhat(model, X)
dim = X.shape[1]
if W is None: W = np.eye(dim)
W = torch.Tensor(W).to(device)
rt_X = torch.mm(X, torch.transpose(W,0,1))
# subsample data
n_samp = int(round(sample_pct*len(rt_X)))
perm = torch.randperm(len(rt_X))[:n_samp]
rt_X, Y, Yh = rt_X[perm, :], Y[perm], Yh[perm]
# compute deps
deps = {}
dims = list(range(dim))
if coords is None and not only_linear: coords = dims
if coords is None and only_linear: coords = [0,1]
for idx, coord in enumerate(coords):
if print_info: print ('{}/{}'.format(idx, len(coords)), end=' ')
rt_X_ = copy.deepcopy(rt_X).to(device)
rt_X_ = _get_dep_data(rt_X_, Y, coord if type(coord) in (list, tuple) else [coord])
X_ = torch.mm(rt_X_, W)
Ys = get_yhat(model, X_)
key = tuple(coord) if type(coord) in (list, tuple) else coord
if metric == 'auc':
L = utils.get_logits_given_tensor(X_, model, device=device, bs=250)
S = L[:,1]-L[:,0]
auc = roc_auc_score(Y.cpu().numpy(), S.cpu().numpy())
deps[key] = auc
elif metric == 'accuracy':
deps[key] = get_acc(Yh if use_model_pred else Y, Ys)
elif metric == 'loss':
L = utils.get_logits_given_tensor(X_, model, device=device, bs=250)
with torch.no_grad():
loss_val = F.cross_entropy(L, Y).item()
deps[key] = loss_val
return deps
def get_subset_feature_deps(dl, model, coords_set, comb_size, W=None, dep_type='random', sample_pct=0.5, device_id=None, print_info=False):
coords = list(itertools.combinations(coords_set, comb_size))
return get_feature_deps(dl, model, W=W, dep_type=dep_type, coords=coords, print_info=print_info, sample_pct=sample_pct, device_id=device_id)

Просмотреть файл

@ -0,0 +1,98 @@
import random
import os, copy, pickle, time
import itertools
from collections import defaultdict, Counter, OrderedDict
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import utils
import gpu_utils as gu
import data_utils as du
def get_binary_mnist(y1=0, y2=1, apply_padding=True, repeat_channels=True):
def _make_cifar_compatible(X):
if apply_padding: X = np.stack([np.pad(X[i][0], 2)[None,:] for i in range(len(X))]) # pad
if repeat_channels: X = np.repeat(X, 3, axis=1) # add channels
return X
binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2)
tr_dl, te_dl = du.get_mnist_dl(normalize=False)
Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl))
Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl))
Xtr, Xte = map(_make_cifar_compatible, [Xtr, Xte])
return (Xtr, Ytr), (Xte, Yte)
def get_binary_cifar(y1=3, y2=5, c={0,1,2,3,4}, use_cifar10=True):
binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2)
binary = False if y1 is not None and y2 is not None else True
if binary: print ("grouping cifar classes")
tr_dl, te_dl = du.get_cifar_dl(use_cifar10=use_cifar10, shuffle=False, normalize=False, binarize=binary, y0=c)
Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl))
Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl))
return (Xtr, Ytr), (Xte, Yte)
def combine_datasets(Xm, Ym, Xc, Yc, randomize_order=False, randomize_first_block=False, randomize_second_block=False):
"""combine two datasets"""
def partition(X, Y, randomize=False):
"""partition randomly or using labels"""
if randomize:
n = len(Y)
p = np.random.permutation(n)
ni, pi = p[:n//2], p[n//2:]
else:
ni, pi = (Y==0).nonzero()[0], (Y==1).nonzero()[0]
return X[pi], X[ni]
def _combine(X1, X2):
"""concatenate images from two sources"""
X = []
for i in range(min(len(X1), len(X2))):
x1, x2 = X1[i], X2[i]
# randomize order
if randomize_order and random.random() < 0.5:
x1, x2 = x2, x1
x = np.concatenate((x1,x2), axis=1)
X.append(x)
return np.stack(X)
Xmp, Xmn = partition(Xm, Ym, randomize=randomize_first_block)
Xcp, Xcn = partition(Xc, Yc, randomize=randomize_second_block)
n = min(map(len, [Xmp, Xmn, Xcp, Xcn]))
Xmp, Xmn, Xcp, Xcn = map(lambda Z: Z[:n], [Xmp, Xmn, Xcp, Xcn])
Xp = _combine(Xmp, Xcp)
Yp = np.ones(len(Xp))
Xn = _combine(Xmn, Xcn)
Yn = np.zeros(len(Xn))
X = np.concatenate([Xp, Xn], axis=0)
Y = np.concatenate([Yp, Yn], axis=0)
P = np.random.permutation(len(X))
X, Y = X[P], Y[P]
return X, Y
def get_mnist_cifar(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4},
randomize_mnist=False, randomize_cifar=False):
y1, y2 = mnist_classes
(Xtrm, Ytrm), (Xtem, Ytem) = get_binary_mnist(y1=y1, y2=y2)
y1, y2 = (None, None) if cifar_classes is None else cifar_classes
(Xtrc, Ytrc), (Xtec, Ytec) = get_binary_cifar(c=c, y1=y1, y2=y2)
Xtr, Ytr = combine_datasets(Xtrm, Ytrm, Xtrc, Ytrc, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar)
Xte, Yte = combine_datasets(Xtem, Ytem, Xtec, Ytec, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar)
return (Xtr, Ytr), (Xte, Yte)
def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, bs=256,
randomize_mnist=False, randomize_cifar=False):
(Xtr, Ytr), (Xte, Yte) = get_mnist_cifar(mnist_classes=mnist_classes, cifar_classes=cifar_classes,
c=c, randomize_mnist=randomize_mnist, randomize_cifar=randomize_cifar)
tr_dl = utils._to_dl(Xtr, Ytr, bs=bs, shuffle=True)
te_dl = utils._to_dl(Xte, Yte, bs=100, shuffle=False)
return tr_dl, te_dl

236
utils/scripts/ptb_utils.py Normal file
Просмотреть файл

@ -0,0 +1,236 @@
import seaborn as sns
import utils
import random
import os, copy, pickle, time
import itertools
from collections import defaultdict, Counter, OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from torch import optim, nn
import torch.nn.functional as F
import utils.scripts.gpu_utils as gu
import utils.scripts.data_utils as du
import utils.scripts.synth_models as synth_models
#import foolbox as fb
#from autoattack import AutoAttack
# Misc
def get_yhat(model, data): return torch.argmax(model(data), 1)
def get_acc(y,yhat): return (y==yhat).sum().item()/float(len(y))
class PGD_Attack(object):
def __init__(self, eps, lr, num_iter, loss_type, rand_eps=1e-3,
num_classes=2, bounds=(0.,1.), minimal=False, restarts=1, device=None):
self.eps = eps
self.lr = lr
self.num_iter = num_iter
self.B = bounds
self.restarts = restarts
self.rand_eps = rand_eps
self.device = device or gu.get_device(None)
self.loss_type = loss_type
self.num_classes = num_classes
self.classes = list(range(self.num_classes))
self.delta = None
self.minimal = minimal # early stop + no eps
self.project = not self.minimal
self.loss = -np.inf
def evaluate_attack(self, dl, model):
model = model.to(self.device)
Xa, Ya, Yh, P = [], [], [], []
for xb, yb in dl:
xb, yb = xb.to(self.device), yb.to(self.device)
delta = self.perturb(xb, yb, model)
xba = xb+delta
with torch.no_grad():
out = model(xba).detach()
yh = torch.argmax(out, dim=1)
xb, yb, yh, xba, delta = xb.cpu(), yb.cpu(), yh.cpu(), xba.cpu(), delta.cpu()
Ya.append(yb)
Yh.append(yh)
Xa.append(xba)
P.append(delta)
Xa, Ya, Yh, P = map(torch.cat, [Xa, Ya, Yh, P])
ta_dl = utils._to_dl(Xa, Ya, dl.batch_size)
acc, loss = utils.compute_loss_and_accuracy_from_dl(ta_dl, model,
F.cross_entropy,
device=self.device)
return {
'acc': acc.item(),
'loss': loss.item(),
'ta_dl': ta_dl,
'Xa': Xa.numpy(),
'Ya': Ya.numpy(),
'Yh': Yh.numpy(),
'P': P.numpy()
}
def perturb(self, xb, yb, model, cpu=False):
model, xb, yb = model.to(self.device), xb.to(self.device), yb.to(self.device)
if self.eps == 0: return torch.zeros_like(xb)
# compute perturbations and track best perturbations
self.loss = -np.inf
max_delta = self._perturb_once(xb, yb, model)
with torch.no_grad():
out = model(xb+max_delta)
max_loss = nn.CrossEntropyLoss(reduction='none')(out, yb)
for _ in range(self.restarts-1):
delta = self._perturb_once(xb, yb, model)
with torch.no_grad():
out = model(xb+delta)
all_loss = nn.CrossEntropyLoss(reduction='none')(out, yb)
loss_flag = all_loss >= max_loss
max_delta[loss_flag] = delta[loss_flag]
max_loss = torch.max(max_loss, all_loss)
if cpu: max_delta = max_delta.cpu()
return max_delta
def _perturb_once(self, xb, yb, model, track_scores=False, stop_const=1e-5):
self.delta = self._init_delta(xb, yb)
scores = []
# (minimal) mask perturbations if model already misclassifies
for t in range(self.num_iter):
loss, out = self._get_loss(xb, yb, model, get_scores=True)
if self.minimal:
yh = torch.argmax(out, dim=1).detach()
not_flipped = yh == yb
not_flipped_ratio = not_flipped.sum().item()/float(len(yb))
else:
not_flipped = None
not_flipped_ratio = 1.0
# stop if almost all examples in the batch misclassified
if not_flipped_ratio < stop_const:
break
if track_scores:
scores.append(out.detach().cpu().numpy())
# compute loss, update + clamp delta
loss.backward()
self.loss = max(self.loss, loss.item())
self.delta = self._update_delta(xb, yb, update_mask=not_flipped)
self.delta = self._clamp_input(xb, yb)
d = self.delta.detach()
if track_scores:
scores = np.stack(scores).swapaxes(0, 1)
return d, scores
return d
def _init_delta(self, xb, yb):
delta = torch.empty_like(xb)
delta = delta.uniform_(-self.rand_eps, self.rand_eps)
delta = delta.to(self.device)
delta.requires_grad = True
return delta
def _clamp_input(self, xb, yb):
# clamp delta s.t. X+delta in valid input range
self.delta.data = torch.max(self.B[0]-xb,
torch.min(self.B[1]-xb,
self.delta.data))
return self.delta
def _get_loss(self, xb, yb, model, get_scores=False):
out = model(xb+self.delta)
if self.loss_type == 'untargeted':
L = -1*F.cross_entropy(out, yb)
elif self.loss_type == 'targeted':
L = nn.CrossEntropyLoss()(out, yb)
elif self.loss_type == 'random_targeted':
rand_yb = torch.randint(low=0, high=self.num_classes, size=(len(yb),), device=self.device)
#rand_yb[rand_yb==yb] = (yb[rand_yb==yb]+1) % self.num_classes
L = nn.CrossEntropyLoss()(out, rand_yb)
elif self.loss_type == 'plusone_targeted':
next_yb = (yb+1) % self.num_classes
L = nn.CrossEntropyLoss()(out, next_yb)
elif self.loss_type == 'binary_targeted':
yb_opp = 1-yb
L = nn.CrossEntropyLoss()(out, yb_opp)
elif self.loss_type == 'binary_hybrid':
yb_opp = 1-yb
L = nn.CrossEntropyLoss()(out, yb_opp) - nn.CrossEntropyLoss()(out, yb)
else:
assert False, "unknown loss type"
if get_scores: return L, out
return L
class L2_PGD_Attack(PGD_Attack):
OVERFLOW_CONST = 1e-10
def get_norms(self, X):
nch = len(X.shape)
return X.view(X.shape[0], -1).norm(dim=1)[(...,) + (None,)*(nch-1)]
def _update_delta(self, xb, yb, update_mask=None):
# normalize gradients
grad = self.delta.grad.detach()
norms = self.get_norms(grad)
grad = grad/(norms+self.OVERFLOW_CONST) # add const to avoid overflow
# steepest descent
if self.minimal and update_mask is not None:
um = update_mask
self.delta.data[um] = self.delta.data[um] - self.lr*grad[um]
else:
self.delta.data = self.delta.data - self.lr*grad
# l2 ball projection
if self.project:
delta_norms = self.get_norms(self.delta.data)
self.delta.data = self.eps*self.delta.data / (delta_norms.clamp(min=self.eps))
self.delta.grad.zero_()
return self.delta
def _init_delta(self, xb, yb):
# random vector with L2 norm rand_eps
delta = torch.zeros_like(xb)
delta = delta.uniform_(-self.rand_eps, self.rand_eps)
delta_norm = self.get_norms(delta)
delta = self.rand_eps*delta/(delta_norm+self.OVERFLOW_CONST)
delta = delta.to(self.device)
delta.requires_grad = True
return delta
class Linf_PGD_Attack(PGD_Attack):
def _update_delta(self, xb, yb, **kw):
# steepest descent + linf projection (GD)
self.delta.data = self.delta.data - self.lr*(self.delta.grad.detach().sign())
self.delta.data = self.delta.data.clamp(-self.eps, self.eps)
self.delta.grad.zero_()
return self.delta

Просмотреть файл

@ -0,0 +1,160 @@
import sys, copy
import torch, torchvision
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import utils.scripts.gendata as gendata
import utils.scripts.utils as utils
import numpy as np
import utils.scripts.gpu_utils as gu
import utils.scripts.ptb_utils as pu
def kaiming_init(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight.data)
nn.init.kaiming_uniform_(m.bias.data)
class SequenceClassifier(nn.Module):
def __init__(self, seq_model, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
super(SequenceClassifier, self).__init__()
self.seq_model = seq_model
self.hdim = hdim
self.hl = hl
self.input_size = input_size
self.idim = idim
self.num_classes = num_classes
self.unsqueeze_input = unsqueeze_input
self.many_to_many = many_to_many
self.seq_length = self.idim//self.input_size
self.seq = self.seq_model(input_size=input_size, hidden_size=hdim, num_layers=hl, batch_first=True)
self.lin_idim = hdim*self.seq_length if many_to_many else hdim
self.lin = nn.Linear(self.lin_idim, num_classes)
def forward(self, x):
if self.unsqueeze_input: x = x.unsqueeze(2)
bsize, idim, _ = x.shape
seq_length = idim//self.input_size
x = x.view((bsize, seq_length, self.input_size))
out, hidden = self.seq(x)
lin_in = out[:,-1,:]
if self.many_to_many: lin_in = out.contiguous().view((bsize, -1))
lin_out = self.lin(lin_in)
return lin_out
class GRUClassifier(SequenceClassifier):
def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
super(GRUClassifier, self).__init__(nn.GRU, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input)
class LSTMClassifier(SequenceClassifier):
def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True):
super(LSTMClassifier, self).__init__(nn.LSTM, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input)
class CNNClassifier(nn.Module):
def __init__(self, out_channels, hl, kernel_size, idim, num_classes=2, padding=None, stride=1, maxpool_kernel_size=None, use_maxpool=False):
"""
Fixed architecture:
- default max pool kernel size half of convolution kernel size
- default padding = kernel size - 1 // 2 to maintain same dimension
- stride = 1
- 1 FC layer
"""
if padding == None: assert kernel_size % 2 == 1, "use odd kernel size, equal padding constraint"
super(CNNClassifier, self).__init__()
self.out_channels = out_channels
self.num_conv = hl
self.kernel_size = kernel_size
self.padding = padding or (self.kernel_size-1)//2
self.stride = 1
self.num_classes = 2
self.idim = idim
self.use_maxpool = use_maxpool
self.maxpool_kernel_size = maxpool_kernel_size or self.kernel_size//2
self.maxpool = nn.MaxPool1d(self.maxpool_kernel_size)
self.ih_conv = nn.Conv1d(1, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride)
self.hh_convs = []
for _ in range(self.num_conv-1):
self.hh_convs.append(nn.Conv1d(self.out_channels, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride))
self.hh_convs.append(nn.ReLU())
self.hh_convs = nn.Sequential(*self.hh_convs)
fc_idim = int(self.idim/self.maxpool_kernel_size) if self.use_maxpool else self.idim
self.fc_layer = nn.Linear(self.out_channels*fc_idim, self.idim)
self.out_layer = nn.Linear(self.idim, self.num_classes)
self.relu = nn.ReLU()
def forward(self, x):
bs = x.shape[0]
x_ = x.unsqueeze(1)
x_ = self.relu(self.ih_conv(x_))
x_ = self.hh_convs(x_)
if self.use_maxpool: x_ = self.maxpool(x_)
x_ = self.relu(self.fc_layer(x_.view(bs, -1)))
return self.out_layer(x_)
class CNN2DClassifier(nn.Module):
def __init__(self, num_filters, filter_size, num_layers, input_shape, input_channels=1, stride=2, padding=None, num_stride2_layers=2, fc_idim=None, fc_odim=None, num_classes=2, use_avgpool=True, avgpool_ksize=5):
super(CNN2DClassifier, self).__init__()
self.outch = num_filters
self.fsize = filter_size
self.input_channels = input_channels
self.hl = num_layers
self.padding = (self.fsize-1)//2 if padding is None else padding
self.num_classes = num_classes
num_stride2_layers = num_stride2_layers
self.strides = iter([stride]*num_stride2_layers+[1]*(num_layers-num_stride2_layers))
self.use_avgpool = use_avgpool
self.avgpool_ksize = avgpool_ksize
self.convs = [nn.Conv2d(self.input_channels, self.outch, self.fsize, padding=self.padding, stride=next(self.strides)), nn.ReLU()]
if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize))
for _ in range(self.hl-1):
self.convs.append(nn.Conv2d(self.outch, self.outch, self.fsize, stride=next(self.strides), padding=self.padding))
self.convs.append(nn.ReLU())
if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize))
self.convs = nn.Sequential(*self.convs) # need to wrap for gpu
sl = min(self.hl, num_stride2_layers)
self.fc_idim = int(num_filters*input_shape[0]*input_shape[1]/float(4**sl)) if fc_idim is None else fc_idim
self.fc_odim = fc_odim if fc_odim is not None else self.fc_idim
self.fc = nn.Linear(self.fc_idim, self.fc_odim)
self.out = nn.Linear(self.fc_odim, self.num_classes)
def forward(self, x):
x = self.convs(x)
x = x.reshape(x.shape[0], -1)
return self.out(F.relu(self.fc(x)))
def get_linear(input_dim, num_classes):
return nn.Sequential(nn.Linear(input_dim, num_classes))
def get_fcn(idim, hdim, odim, hl=1, init=False, activation=nn.ReLU, use_activation=True, use_bn=False, input_dropout=0, dropout=0):
use_dropout = dropout > 0
layers = []
if input_dropout > 0: layers.append(nn.Dropout(input_dropout))
layers.append(nn.Linear(idim, hdim))
if use_activation: layers.append(activation())
if use_dropout: layers.append(nn.Dropout(dropout))
if use_bn: layers.append(nn.BatchNorm1d(hdim))
for _ in range(hl-1):
l = [nn.Linear(hdim, hdim)]
if use_activation: l.append(activation())
if use_dropout: l.append(nn.Dropout(dropout))
if use_bn: l.append(nn.BatchNorm1d(hdim))
layers.extend(l)
layers.append(nn.Linear(hdim, odim))
model = nn.Sequential(*layers)
if init: model.apply(kaiming_init)
return model

717
utils/scripts/utils.py Normal file
Просмотреть файл

@ -0,0 +1,717 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import copy
from collections import defaultdict, Counter, OrderedDict
import time
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from torch import optim, nn
import torch.nn.functional as F
from scipy.linalg import qr
import utils.scripts.lms_utils as au
import utils.scripts.ptb_utils as pu
import utils.scripts.gpu_utils as gu
from sklearn import metrics
import collections
from sklearn.metrics import roc_auc_score
plt.style.use('seaborn-ticks')
import matplotlib.ticker as ticker
def get_orthonormal_matrix(n):
H = np.random.randn(n, n)
s = np.linalg.svd(H)[1]
s = s[s>1e-7]
if len(s) != n: return get_orthonormal_matrix(n)
Q, R = qr(H)
return Q
def get_dataloader(X, Y, bs, **kw):
return DataLoader(TensorDataset(X, Y), batch_size=bs, **kw)
def split_dataloader(dl, frac=0.5):
bs = dl.batch_size
X, Y = dl.dataset.tensors
p = torch.randperm(len(X))
X, Y = X[p, :], Y[p]
n = int(round(len(X)*frac))
X0, Y0 = X[:n, :], Y[:n]
X1, Y1 = X[n:, :], Y[n:]
dl0 = DataLoader(TensorDataset(torch.Tensor(X0), torch.LongTensor(Y0)), batch_size=bs, shuffle=True)
dl1 = DataLoader(TensorDataset(torch.Tensor(X1), torch.LongTensor(Y1)), batch_size=bs, shuffle=True)
return dl0, dl1
def _to_dl(X, Y, bs, shuffle=True):
return DataLoader(TensorDataset(torch.Tensor(X), torch.LongTensor(Y)), batch_size=bs, shuffle=shuffle)
def extract_tensors_from_loader(dl, repeat=1, transform_fn=None):
X, Y = [], []
for _ in range(repeat):
for xb, yb in dl:
if transform_fn:
xb, yb = transform_fn(xb, yb)
X.append(xb)
Y.append(yb)
X = torch.FloatTensor(torch.cat(X))
Y = torch.LongTensor(torch.cat(Y))
return X, Y
def extract_numpy_from_loader(dl, repeat=1, transform_fn=None):
X, Y = extract_tensors_from_loader(dl, repeat=repeat, transform_fn=transform_fn)
return X.numpy(), Y.numpy()
def _to_tensor_dl(dl, repeat=1, bs=None):
X, Y = extract_numpy_from_loader(dl, repeat=repeat)
dl = _to_dl(X, Y, bs if bs else dl.batch_size)
return dl
def flatten_loader(dl, bs=None):
X, Y = extract_numpy_from_loader(dl)
X = X.reshape(X.shape[0], -1)
return _to_dl(X, Y, bs=bs if bs else dl.batch_size)
def merge_loaders(dla, dlb):
bs = dla.batch_size
Xa, Ya = extract_numpy_from_loader(dla)
Xb, Yb = extract_numpy_from_loader(dlb)
return _to_dl(np.concatenate([Xa, Xb]), np.concatenate([Ya, Yb]), bs)
def transform_loader(dl, func, shuffle=True):
#assert type(dl.sampler) is torch.utils.data.sampler.SequentialSampler
X, Y = extract_numpy_from_loader(dl, transform_fn=func)
return _to_dl(X, Y, dl.batch_size, shuffle=shuffle)
def visualize_tensors(P, size=8, normalize=True, scale_each=False, permute=True, ax=None, pad_value=0.):
if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4))
if permute:
s = np.random.choice(len(P), size=size, replace=False)
p = P[s]
else:
p = P[:size]
g = torchvision.utils.make_grid(torch.FloatTensor(p), nrow=size, normalize=normalize, scale_each=scale_each, pad_value=pad_value)
g = g.permute(1,2,0).numpy()
ax.imshow(g)
ax.set_xticks([])
ax.set_yticks([])
return ax
def visualize_loader(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None):
if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4))
for xb, yb in dl: break
if reshape: xb = xb.reshape(len(xb), *reshape)
return visualize_tensors(xb, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax)
def visualize_loader_by_class(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None):
for xb, yb in dl: break
if reshape: xb = xb.reshape(len(xb), *reshape)
classes = list(set(list(yb.numpy())))
fig, axs = plt.subplots(len(classes), 1, figsize=(15, 3*len(classes)))
for y, ax in zip(classes, axs):
xb_ = xb[yb==y]
ax = visualize_tensors(xb_, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax)
ax.set_title('Class: {}'.format(y))
return fig
def visualize_perturbations(P, transform_fn=None):
if transform_fn is not None:
P = transform_fn(P)
plt.figure(figsize=(20,4))
s = np.random.choice(len(P), size=8, replace=False)
p = P[s]
g = torchvision.utils.make_grid(torch.FloatTensor(p))
g = g.permute(1,2,0).numpy()
g = (g-g.min())/g.max()
plt.imshow(g)
def get_logits_given_tensor(X, model, device=None, bs=250, softmax=False):
if device is None: device = gu.get_device(None)
sampler = torch.utils.data.SequentialSampler(X)
sampler = torch.utils.data.BatchSampler(sampler, bs, False)
logits = []
with torch.no_grad():
model = model.to(device)
for idx in sampler:
xb = X[idx].to(device)
out = model(xb)
logits.append(out)
L = torch.cat(logits)
if softmax: return F.softmax(L, 1)
return L
def get_predictions_given_tensor(X, model, device=None, bs=250):
out = get_logits_given_tensor(X, model, device=device, bs=bs)
return torch.argmax(out, 1)
def get_accuracy_given_tensor(X, Y, model, device=None, bs=250):
if device is None: device = gu.get_device(None)
Y = torch.LongTensor(Y).to(device)
yhat = get_predictions_given_tensor(X, model, device=device, bs=bs)
return (Y==yhat).float().mean().item()
def compute_accuracy(X, Y, model):
with torch.no_grad():
pred = torch.argmax(model(X),1)
correct = (pred == Y).sum().item()
accuracy = correct/float(len(Y))
return accuracy
def compute_loss_and_accuracy_from_dl(dl, model, loss_fn, sample_pct=1.0, device=None, transform_fn=None):
in_tr_mode = model.training
model = model.eval()
data_size = float(len(dl.dataset))
samp_size = int(np.ceil(sample_pct*data_size))
num_eval = 0.
bs = dl.batch_size
accs, losses, bss = [], [], []
with torch.no_grad():
for xb, yb in dl:
xb, yb = xb.to(device, non_blocking=False), yb.to(device, non_blocking=False)
if transform_fn:
xb, yb = transform_fn(xb, yb)
sc = model(xb)
if loss_fn is F.cross_entropy:
loss = loss_fn(sc, yb, reduction='mean')
pred = torch.argmax(sc, 1)
elif loss_fn is F.binary_cross_entropy_with_logits:
loss = loss_fn(sc, yb.float().unsqueeze(1))
pred = (sc > 0.).long().squeeze()
elif loss_fn is hinge_loss:
loss = loss_fn(sc, yb)
pred = (sc > 0).long().squeeze()
else:
try:
loss = loss_fn(sc, yb)
pred = torch.argmax(sc, 1)
except:
assert False, "unknown loss function"
correct = (pred==yb).sum().float()
n = float(len(xb))
losses.append(loss)
accs.append(correct/n)
bss.append(n)
num_eval += n
if num_eval >= samp_size: break
accs, losses, bss = map(np.array, [accs, losses, bss])
if in_tr_mode: model = model.train()
return np.sum(bss*accs)/num_eval, np.sum(bs*losses)/num_eval
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_logits(model, loader, device):
S, Y = [], []
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(device)
out = model(xb).cpu().numpy()
S.append(out)
Y.append(list(yb))
S, Y = map(np.concatenate, [S, Y])
return S, Y
def get_scores(model, loader, device):
"""binary tasks only"""
S, Y = get_logits(model, loader, device)
return S[:,1]-S[:,0], Y
def get_multiclass_logit_score(L, Y):
scores = []
for idx, (l, y) in enumerate(zip(L, Y)):
sc_y = l[y]
indices = np.argsort(l)
best2_idx, best1_idx = indices[-2:]
sc_max = l[best2_idx] if y == best1_idx else l[best1_idx]
score = sc_y - sc_max
scores.append(score)
return np.array(scores)
def get_binary_auc(model, loader, device):
S, Y = get_scores(model, loader, device)
return roc_auc_score(Y, S)
def get_multiclass_auc(model, loader, device, one_vs_rest=True):
X, Y = extract_tensors_from_loader(loader)
S = get_logits_given_tensor(X, model, device=device, softmax=True).cpu()
mc = 'ovr' if one_vs_rest is True else 'ovo'
S, Y = S.numpy(), Y.numpy()
return roc_auc_score(Y, S, multi_class=mc)
def clip_gradient(model, clip_value):
params = list(filter(lambda p: p.grad is not None, model.parameters()))
for p in params: p.grad.data.clamp_(-clip_value, clip_value)
def print_model_gradients(model, print_bias=True):
for name, params in model.named_parameters():
if not print_bias and 'bias' in name: continue
if not params.requires_grad: continue
avg_grad = np.mean(params.grad.cpu().numpy())
print (name, params.shape, avg_grad)
def hinge_loss(out, y):
y_ = (2*y.float()-1).unsqueeze(1)
return torch.mean(F.relu(1-out*y_))
def pgd_adv_fit_model(model, opt, tr_dl, te_dl, attack, eval_attack=None, device=None, sch=None, max_epochs=100, epoch_gap=2,
min_loss=0.001, print_info=True, save_init_model=True):
# setup tracking
PR = lambda x: print (x) if print_info else None
stop_training = False
stats = defaultdict(list)
best_val, best_model = np.inf, None
adv_epoch_timer = []
epoch_gap_timer = [time.time()]
init_model = copy.deepcopy(model).cpu() if save_init_model else None
# eval attack
eval_attack = eval_attack or attack
print ("Min loss: {}".format(min_loss))
def standard_epoch(loader, model, optimizer=None, sch=None):
"""compute accuracy and loss. Backprop if optimizer provided"""
total_loss, total_err = 0.,0.
model = model.eval() if optimizer is None else model.train()
model = model.to(device)
update_params = optimizer is not None
with torch.set_grad_enabled(update_params):
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
yp = model(xb)
loss = F.cross_entropy(yp, yb)
if update_params:
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_err += (yp.max(dim=1)[1] != yb).sum().item()
total_loss += loss.item() * xb.shape[0]
return total_err / len(loader.dataset), total_loss / len(loader.dataset)
def adv_epoch(loader, model, attack, optimizer=None, sch=None):
"""compute adv accuracy and loss. Backprop if optimizer provided"""
start_time = time.time()
total_loss, total_err = 0.,0.
model = model.eval() if optimizer is None else model.train()
model = model.to(device)
update_params = optimizer is not None
for xb, yb in loader:
torch.set_grad_enabled(True)
xb, yb = xb.to(device), yb.to(device)
delta = attack.perturb(xb, yb, model).to(device)
xb = xb + delta
with torch.set_grad_enabled(update_params):
yp = model(xb)
loss = F.cross_entropy(yp, yb)
if update_params:
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_err += (yp.max(dim=1)[1] != yb).sum().item()
total_loss += loss.item() * xb.shape[0]
if optimizer is not None and sch is not None:
cur_lr = next(iter(opt.param_groups))['lr']
sch.step()
new_lr = next(iter(opt.param_groups))['lr']
if new_lr != cur_lr:
PR('Epoch {}, LR : {} -> {}'.format(epoch, cur_lr, new_lr))
total_time = time.time()-start_time
adv_epoch_timer.append(total_time)
return total_err / len(loader.dataset), total_loss / len(loader.dataset)
epoch = 0
while epoch < max_epochs:
if stop_training:
break
try:
stat = {}
model = model.train()
train_err, train_loss = adv_epoch(tr_dl, model, attack, optimizer=opt, sch=sch)
if epoch % epoch_gap == 0:
model = model.eval()
test_err, test_loss = standard_epoch(te_dl, model, optimizer=None, sch=None)
adv_err, adv_loss = adv_epoch(te_dl, model, eval_attack, optimizer=None, sch=None)
stat['acc_te'], stat['acc_te_std'] = adv_err, test_err
stat['loss_te'], stat['loss_te_std'] = adv_loss, test_loss
if adv_err < best_val:
best_val = adv_err
best_model = copy.deepcopy(model).eval()
if print_info:
if epoch==0: print ("Epoch", "l-tr", "a-tr", "a-te", "s-te", "time", sep='\t')
#print (epoch, *("{:.4f}".format(i) for i in (train_loss, train_err)), sep=' ')
diff_time = time.time()-epoch_gap_timer[-1]
epoch_gap_timer.append(time.time())
print (epoch, *("{:.4f}".format(i) for i in (train_loss, 1.-train_err, 1.-adv_err, 1.-test_err, diff_time)), sep=' ')
if train_loss < min_loss:
stop_training = True
print ("Epoch {}: accuracy {:.3f} and loss {:.3f}".format(epoch, 1-train_err, train_loss))
stat['epoch'] = epoch
stat['acc_tr'] = train_err
stat['loss_tr'] = train_loss
for k, v in stat.items():
stats[k].append(v)
epoch += 1
except KeyboardInterrupt:
inp = input("LR num or Q or SAVE or GAP or MAXEPOCHS: ")
if inp.startswith('LR'):
lr = float(inp.split(' ')[-1])
cur_lr = next(iter(opt.param_groups))['lr']
PR("New LR: {}".format(lr))
for g in opt.param_groups: g['lr'] = lr
if inp.startswith('Q'):
stop_training = True
if inp.startswith('SAVE'):
fpath = inp.split(' ')[-1]
stats['best_model'] = (best_val, best_model.cpu())
torch.save({
'model': copy.deepcopy(model).cpu(),
'stats': stats,
'opt': copy.deepcopy(opt).cpu()
}, fpath)
PR(f'Saved to {fpath}')
if inp.startswith('GAP'):
_, gap = inp.split(' ')
gap = int(gap)
print ("epoch gap: {} -> {}".format(epoch_gap, gap))
epoch_gap = gap
if inp.startswith('MAXEPOCHS'):
_, me = inp.split(' ')
me = int(me)
print ("max_epochs: {} -> {}".format(max_epochs, me))
max_epochs = me
stats['best_model'] = (best_val, best_model.cpu())
stats['init_model'] = init_model
return stats
def fit_model(model, loss, opt, train_dl, valid_dl, sch=None, epsilon=1e-2, is_loss_epsilon=False, update_gap=50, update_print_gap=50, gap=None,
print_info=True, save_grads=False, test_dl=None, skip_epoch_eval=True, sample_pct=0.5, sample_loss_threshold=0.75, save_models=False,
print_grads=False, print_model_layers=False, tr_batch_fn=None, te_batch_fn=None, device=None, max_updates=800_000, patience_updates=1,
enable_redo=False, save_best_model=True, save_init_model=True, max_epochs=100000, **misc):
# setup update metadata
MAX_LOSS_VAL = 1000000.
PR = lambda x: print (x) if print_info else None
use_epoch = False
if gap is not None: update_gap = update_print_gap = gap
bs_ratio = int(len(train_dl.dataset)/float(train_dl.batch_size))
act_update_gap = update_gap if not use_epoch else update_gap*bs_ratio
act_pr_update_gap = update_print_gap if not use_epoch else update_print_gap*bs_ratio
PR("accuracy/loss measured every {} updates".format(act_update_gap))
if save_models:
PR("saving models every {} updates".format(act_update_gap))
PR("update_print_gap: {}, epss: {}, bs: {}, device: {}".format(act_pr_update_gap, epsilon, train_dl.batch_size, device or 'cpu'))
# init_save setup
init_model = copy.deepcopy(model).cpu() if save_init_model else None
# redo setup
if enable_redo:
init_model_sd = copy.deepcopy(model.state_dict())
init_opt_sd = copy.deepcopy(opt.state_dict())
else:
init_model_sd = None
init_opt_sd = None
# best model setup
best_val, best_model = 0, None
# tracking setup
start_time = time.time()
num_evals, num_epochs, num_updates, num_patience = 0, 0, 0, 0
stats = dict(loss_tr=[], loss_te=[], acc_tr=[], acc_te=[], acc_test=[], loss_test=[], models=[], gradients=[])
if save_models: stats['models'].append(copy.deepcopy(model).cpu())
first_run, converged = True, False
print_stats_flag = update_print_gap is not None
exceeded_max = False
diverged = False
def _evaluate(device=device):
model.eval()
with torch.no_grad():
prev_loss = stats['loss_tr'][-1] if stats['loss_tr'] else 1.
tr_sample_pct = sample_pct if prev_loss > sample_loss_threshold else 1.
acc_tr, loss_tr = compute_loss_and_accuracy_from_dl(train_dl,model,loss,sample_pct=tr_sample_pct,device=device,transform_fn=tr_batch_fn)
acc_te, loss_te = compute_loss_and_accuracy_from_dl(valid_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn)
acc_tr, loss_tr, acc_te, loss_te = map(lambda x: x.item(), [acc_tr, loss_tr, acc_te, loss_te])
stats['loss_tr'].append(loss_tr)
stats['loss_te'].append(loss_te)
stats['acc_tr'].append(acc_tr)
stats['acc_te'].append(acc_te)
if test_dl is not None:
acc_test, loss_test = compute_loss_and_accuracy_from_dl(test_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn)
acc_test, loss_test = acc_test.item(), loss_test.item()
stats['acc_test'].append(acc_test)
stats['loss_test'].append(loss_test)
if save_models:
stats['models'].append(copy.deepcopy(model).cpu())
def _update(x,y,diff_device, device=device, save_grads=False, print_grads=False):
model.train()
# if diff_device:
# x = x.to(device, non_blocking=False)
# y = y.to(device, non_blocking=False)
opt.zero_grad()
out = model(x)
if loss is F.cross_entropy or loss is hinge_loss:
bloss = loss(out, y)
elif loss is F.binary_cross_entropy_with_logits:
bloss = loss(out, y.float().unsqueeze(1))
else:
try:
bloss = loss(out, y)
except:
assert False, "unknown loss function"
bloss.backward()
if print_grads and print_info: print_model_gradients(model)
#clip_gradient(model, clip_value)
opt.step()
if save_grads:
g = {k: v.grad.data.cpu().numpy() for k, v in model.named_parameters() if v.requires_grad}
stats['gradients'].append(g)
opt.zero_grad()
model.eval()
def print_time():
end_time = time.time()
minutes, seconds = divmod(end_time-start_time, 60)
gap_valid = len(stats['acc_tr']) > 0
gap = round(stats['acc_tr'][-1]-stats['acc_te'][-1],4) if gap_valid else 'na'
PR("converged after {} epochs in {}m {:1f}s, gap: {}".format(num_epochs, minutes, seconds, gap))
def print_stats(force_print=False):
if test_dl is None:
atr, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'loss_tr']]
PR("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr))
if not print_info and force_print:
print ("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr))
else:
atr, aval, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'acc_test', 'loss_tr']]
PR("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr))
if not print_info and force_print:
print ("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr))
#xb_, yb_ = next(iter(train_dl))
diff_device = True #xb_.device != device
if test_dl is None: PR("#updates, train acc, test acc, train loss")
else: PR("#updates, train acc, val acc, test acc, train loss")
while not converged or num_patience < patience_updates:
try:
model.train()
for xb, yb in train_dl:
if tr_batch_fn:
xb, yb = tr_batch_fn(xb, yb)
if diff_device:
xb = xb.to(device, non_blocking=False)
yb = yb.to(device, non_blocking=False)
if converged:
num_patience += 1
if converged and num_patience == patience_updates:
_evaluate()
print_stats()
break
# update flag for printing gradients
update_flag = print_model_layers and (num_updates == 0 or (num_updates % act_update_gap == 0 and print_grads))
_update(xb, yb, diff_device, device=device, save_grads=save_grads, print_grads=update_flag)
if (num_evals == 0 or num_updates % act_update_gap == 0):
num_evals += 1
_evaluate()
print_stats()
val_acc = stats['acc_te'][-1]
if num_updates > 0 and val_acc >= best_val:
best_val = val_acc
best_model = copy.deepcopy(model).eval()
# check if loss has diverged
loss_val = max(stats['loss_tr'][-1], stats['loss_te'][-1])
if loss_val > MAX_LOSS_VAL: diverged = True
if not np.isfinite(loss_val): diverged = True
if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon
else: stop = stats['acc_tr'][-1] >= 1-epsilon
if not converged and diverged:
converged = True
print_time()
PR("loss diverging...exiting".format(patience_updates))
if not converged and stop:
converged = True
print_time()
PR("init-ing patience ({} updates)".format(patience_updates))
num_updates += 1
first_run = False
if num_updates > max_updates:
converged = True
exceeded_max = True
num_patience = patience_updates
PR("Exceeded max updates")
print_stats()
print_time()
break
# re-eval at the end of epoch
if not converged:
num_epochs += 1
if not converged and num_epochs >= max_epochs:
converged = True
exceeded_max = True
num_patience = patience_updates
PR("Exceeded max epochs")
print_stats()
print_time()
break
if not skip_epoch_eval:
_evaluate()
print_stats()
if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon
else: stop = stats['acc_tr'][-1] >= 1-epsilon
if not converged and stop:
converged = True
print_time()
PR("init-ing patience ({} updates)".format(patience_updates))
if num_patience >= patience_updates:
_evaluate()
print_stats()
break
# update LR via scheduler
if sch is not None:
cur_lr = next(iter(opt.param_groups))['lr']
sch.step()
new_lr = next(iter(opt.param_groups))['lr']
if new_lr != cur_lr:
PR('Epoch {}, LR : {} -> {}'.format(num_epochs, cur_lr, new_lr))
except KeyboardInterrupt:
inp = input("LR num or Q or GAP num or SAVE fpath or EVAL or REDO: ")
if inp.startswith('LR'):
lr = float(inp.split(' ')[-1])
cur_lr = next(iter(opt.param_groups))['lr']
PR("LR: {} - > {}".format(cur_lr, lr))
for g in opt.param_groups: g['lr'] = lr
elif inp.startswith('GAP'):
gap = int(inp.split(' ')[-1])
act_update_gap = act_pr_update_gap = gap
elif inp == "Q":
converged = True
num_patience = patience_updates
print_time()
elif inp.startswith('SAVE'):
fpath = inp.split(' ')[-1]
torch.save({
'model': model,
'opt': opt,
'update_gap': update_gap
}, fpath)
elif inp == 'EVAL':
_evaluate()
print_stats(True)
elif inp == 'REDO':
if enable_redo:
model.load_state_dict(init_model_sd)
opt.load_state_dict(init_opt_sd)
else:
print ("REDO disabled")
best_test = None
if test_dl is not None:
best_test = compute_loss_and_accuracy_from_dl(test_dl, best_model, loss, sample_pct=1.0, device=device)[0].item()
stats['num_updates'] = num_updates
stats['num_epochs'] = num_epochs
stats['update_gap'] = update_gap
stats['best_model'] = (best_val, best_test, best_model.cpu() if best_model else model.cpu())
stats['init_model'] = init_model
if save_models: stats['models'].append(copy.deepcopy(model).cpu())
stats['x_updates']= list(range(0, num_evals*(update_gap+1), update_gap))
stats['x'] = stats['x_updates'][:]
stats['x_epochs'] = list(range(num_epochs))
stats['gap'] = stats['acc_tr'][-1]-stats['acc_te'][-1]
return stats
def save_pickle(fname, d, mode='w'):
with open(fname, mode) as f:
pickle.dump(d, f)
def load_pickle(fname, mode='r'):
with open(fname, mode) as f:
return pickle.load(f)
def update_ax(ax, title=None, xlabel=None, ylabel=None, legend_loc='best', ticks=True, ticks_fs=10, label_fs=12, legend_fs=12, title_fs=14, hide_xlabels=False, hide_ylabels=False, despine=True):
if title: ax.set_title(title, fontsize=title_fs)
if xlabel: ax.set_xlabel(xlabel, fontsize=label_fs)
if ylabel: ax.set_ylabel(ylabel, fontsize=label_fs)
if legend_loc: ax.legend(loc=legend_loc, fontsize=legend_fs)
if despine: sns.despine(ax=ax)
if ticks:
# ax.minorticks_on()
ax.tick_params(direction='in', length=6, width=2, colors='k', which='major', top=False, right=False)
ax.tick_params(direction='in', length=4, width=1, colors='k', which='minor', top=False, right=False)
ax.tick_params(labelsize=ticks_fs)
if hide_xlabels: ax.set_xticks([])
if hide_ylabels: ax.set_yticks([])
return ax

92
utils/slab_data.py Normal file
Просмотреть файл

@ -0,0 +1,92 @@
import sys
import random
import os, copy, pickle, time
import argparse
import itertools
from collections import defaultdict, Counter, OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable
import pandas as pd
import utils.scripts.gpu_utils as gu
import utils.scripts.gendata as gendata
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
DEVICE_ID = 0 # GPU_ID or None (CPU)
DEVICE = gu.get_device(DEVICE_ID)
def get_data(num_samples, spur_corr, total_slabs):
if total_slabs== 5:
slab_5_flag= 1
slab_7_flag= 0
elif total_slabs== 7:
slab_5_flag= 0
slab_7_flag= 1
c = config = {
'num_train': num_samples, # training dataset size
'dim': 2, # input dimension
'lin_margin': 0.1, # linear margin
'slab_margin': 0.1, # slab margin,
'same_margin': True, # keep same margin
'random_transform': True, # apply random (orthogonal) transformation
'width': 1, # data width in standard basis
'num_lin': 1, # number of linear components
'num_slabs': slab_5_flag, #. number of 5 slabs
'num_slabs7': slab_7_flag, # number of 7 slabs
'num_slabs3': 0, # number of 3 slabs
'bs': 256, # batch size
'corrupt_lin': spur_corr, # p_noise
'corrupt_lin_margin': True, # noise model
'corrupt_slab': 0.0, # slab corruption
'num_test': 0, # test dataset size
'hdim': 100, # model width
'hl': 2, # model depth
'mtype': 'fcn', # model architecture
'device': gu.get_device(DEVICE_ID), # GPU device
'lr': 0.1, # step size
'weight_decay': 5e-5 # weight decay
}
smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin']
data_func = gendata.generate_ub_linslab_data_v2
spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7']
data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'],
corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab'])
#Project data
W = data['W']
X, Y = data['X'], data['Y']
X = X.numpy().dot(W.T)
Y = Y.numpy()
#Generate objects/slab_id array
# Equation: k*slab_len + (k-1)*0.2 = 2: k is total number of slabs
# Slabs: [-1, -1 + slab_len], [-1+slab_len+0.2, -1 + 0.2 + 2*slab_len ]
# Slabs: For i in range(0, k-1): [-1 + i*(0.2 + slab_len), -1 + slab_len + i*(0.2 + slab_len) ]
k=total_slabs
slab_len= (2 - 0.2*(k-1))/k
O=np.zeros(X.shape[0])
for idx in range(X.shape[0]):
for slab_id in range(k):
start= -1 + slab_id*(0.2 + slab_len)
end= start+ slab_len
if X[idx, 1] >= start and X[idx, 1] <= end:
O[idx]= slab_id
break
return data, X, Y, O