From 7ee0caa5a93c9868abc82644d73df489a08c2079 Mon Sep 17 00:00:00 2001 From: divyat09 Date: Wed, 17 Mar 2021 00:41:15 +0000 Subject: [PATCH] Slab Synthetic Data integrated with RobustDG --- algorithms/algo.py | 5 +- algorithms/perf_match.py | 121 +++++ data/slab_loader.py | 102 +++++ evaluation/base_eval.py | 8 +- models/slab.py | 35 ++ test.py | 4 + train.py | 11 + utils/helper.py | 26 +- utils/scripts/data_utils.py | 202 +++++++++ utils/scripts/ensemble.py | 130 ++++++ utils/scripts/gendata.py | 316 +++++++++++++ utils/scripts/gpu_utils.py | 100 +++++ utils/scripts/lms_utils.py | 307 +++++++++++++ utils/scripts/mnistcifar_utils.py | 98 ++++ utils/scripts/ptb_utils.py | 236 ++++++++++ utils/scripts/synth_models.py | 160 +++++++ utils/scripts/utils.py | 717 ++++++++++++++++++++++++++++++ utils/slab_data.py | 92 ++++ 18 files changed, 2665 insertions(+), 5 deletions(-) create mode 100644 algorithms/perf_match.py create mode 100644 data/slab_loader.py create mode 100644 models/slab.py create mode 100644 utils/scripts/data_utils.py create mode 100644 utils/scripts/ensemble.py create mode 100644 utils/scripts/gendata.py create mode 100644 utils/scripts/gpu_utils.py create mode 100644 utils/scripts/lms_utils.py create mode 100644 utils/scripts/mnistcifar_utils.py create mode 100644 utils/scripts/ptb_utils.py create mode 100644 utils/scripts/synth_models.py create mode 100644 utils/scripts/utils.py create mode 100644 utils/slab_data.py diff --git a/algorithms/algo.py b/algorithms/algo.py index c966b81..9918e97 100644 --- a/algorithms/algo.py +++ b/algorithms/algo.py @@ -52,6 +52,10 @@ class BaseAlgo(): if self.args.model_name == 'lenet': from models.lenet import LeNet5 phi= LeNet5() + + if self.args.model_name == 'slab': + from models.slab import SlabClf + phi= SlabClf(self.args.slab_data_dim, self.args.out_classes) if self.args.model_name == 'fc': from models.fc import FC @@ -148,7 +152,6 @@ class BaseAlgo(): with torch.no_grad(): x_e= x_e.to(self.cuda) y_e= torch.argmax(y_e, dim=1).to(self.cuda) - d_e = torch.argmax(d_e, dim=1).numpy() #Forward Pass out= self.phi(x_e) diff --git a/algorithms/perf_match.py b/algorithms/perf_match.py new file mode 100644 index 0000000..dda6e3f --- /dev/null +++ b/algorithms/perf_match.py @@ -0,0 +1,121 @@ +import sys +import numpy as np +import argparse +import copy +import random +import json + +import torch +from torch.autograd import grad +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch.autograd import Variable +import torch.utils.data as data_utils + +from .algo import BaseAlgo +from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, slab_batch_process + +class PerfMatch(BaseAlgo): + def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda): + + super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) + + def train(self): + + self.max_epoch=-1 + self.max_val_acc=0.0 + for epoch in range(self.args.epochs): + + penalty_erm=0 + penalty_ws=0 + train_acc= 0.0 + train_size=0 + + #Batch iteration over single epoch + for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): + # print('Batch Idx: ', batch_idx) + + #Process current batch as per slab dataset +# x_e, y_e ,d_e, idx_e= slab_batch_process(x_e, y_e ,d_e, idx_e) + + self.opt.zero_grad() + loss_e= torch.tensor(0.0).to(self.cuda) + + x_e= x_e.to(self.cuda) + y_e= torch.argmax(y_e, dim=1).to(self.cuda) + + #Forward Pass + out= self.phi(x_e) + erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda) + loss_e+= erm_loss + penalty_erm += float(loss_e) + + #Perfect Match Penalty + ws_loss= torch.tensor(0.0).to(self.cuda) + counter=0 + match_objs= np.unique(idx_e) + feat= self.phi.feat_net(x_e) + for obj in match_objs: + indices= idx_e == obj + feat_obj= feat[indices] + d_obj= d_e[indices] + + match_domains= torch.unique(d_obj) + + if len(match_domains) != len(torch.unique(d_e)): + # print('Error: Positivty Violation, objects not present in all the domains') + continue + + for d_i in range(len(match_domains)): + for d_j in range(len(match_domains)): + if d_j <= d_i: + continue + x1= feat_obj[ d_obj == d_i ] + x2= feat_obj[ d_obj == d_j ] + + #Typecasting + # print(x1.shape, x2.shape) + x1= x1.view(x1.shape[0], 1, x1.shape[1]) + ws_loss= torch.sum( torch.sum( torch.sum( (x1 -x2)**2, dim=2), dim=1 ) ) + # ws_loss= torch.sum( torch.sum( torch.sum( torch.abs(x1 -x2), dim=2), dim=1 ) ) + counter+= x1.shape[0]*x2.shape[0] + + ws_loss= ws_loss/counter + penalty_ws += float(ws_loss) + + #Backprop + loss_e+= self.args.penalty_ws*ws_loss*((epoch+1)/self.args.epochs) + loss_e.backward(retain_graph=False) + self.opt.step() + + del erm_loss + del loss_e + torch.cuda.empty_cache() + + train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() + train_size+= y_e.shape[0] + + + print('Train Loss Basic : ', penalty_erm, penalty_ws ) + print('Train Acc Env : ', 100*train_acc/train_size ) + print('Done Training for epoch: ', epoch) + + #Train Dataset Accuracy + self.train_acc.append( 100*train_acc/train_size ) + + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + + #Test Dataset Accuracy + self.final_acc.append( self.get_test_accuracy('test') ) + + #Save the model if current best epoch as per validation loss + if self.val_acc[-1] > self.max_val_acc: + self.max_val_acc=self.val_acc[-1] + self.max_epoch= epoch + self.save_model() + + # Save the model's weights post training + self.save_model() \ No newline at end of file diff --git a/data/slab_loader.py b/data/slab_loader.py new file mode 100644 index 0000000..a347adf --- /dev/null +++ b/data/slab_loader.py @@ -0,0 +1,102 @@ +#Common imports +import os +import random +import copy +import numpy as np +import h5py +from PIL import Image + +#Pytorch +import torch +import torch.utils.data as data_utils +from torchvision import datasets, transforms +from torchvision import datasets, transforms + +#Base Class +from .data_loader import BaseDataLoader + +#Specific Modules +from utils.slab_data import * + + +class SlabData(BaseDataLoader): + def __init__(self, args, list_train_domains, root, transform=None, data_case='train', match_func=False, base_size=10000, freq_ratio=50, data_dim=2, total_slabs=5): + + super().__init__(args, list_train_domains, root, transform, data_case, match_func) + + self.base_size = base_size + self.freq_ratio= freq_ratio + self.data_dim= data_dim + self.total_slabs= total_slabs + + print(list_train_domains) + if self.data_case == 'train': + self.domain_size = [self.base_size, self.base_size] + # Default Train Domains: [0.0, 0.10] + self.spur_probs= [ float(domain) for domain in list_train_domains ] + + elif self.data_case == 'val': + self.domain_size = [int(self.base_size/4), int(self.base_size/4)] + self.spur_probs= [ float(domain) for domain in list_train_domains ] + + elif self.data_case == 'test': + self.domain_size = [self.base_size] + self.spur_probs= [1.0] + + print('\n') + print('Data Case: ', self.data_case) + + self.train_data, self.train_labels, self.train_domain, self.train_indices = self._get_data(self.domain_size, self.data_dim, self.total_slabs, self.spur_probs) + + def _get_data(self, domain_size, data_dim, total_slabs, spur_probs): + + list_data = [] + list_labels = [] + list_objs= [] + total_domains= len(domain_size) + + for idx in range(total_domains): + + num_samples= domain_size[idx] + spur_prob= spur_probs[idx] + + _, data, labels, match_obj= get_data(num_samples, spur_prob, total_slabs) + print('Source Domain: ', idx, ' Size: ', data.shape, labels.shape, match_obj.shape) + list_data.append(torch.tensor(data)) + list_labels.append(torch.tensor(labels)) + list_objs.append(match_obj) + + # Stack data from the different domains + data_feat = torch.cat(list_data) + data_labels = torch.cat(list_labels) + data_objs= np.hstack(list_objs) + + # Create domain labels + data_domains = torch.zeros(data_labels.size()) + domain_start=0 + for idx in range(total_domains): + curr_domain_size= domain_size[idx] + data_domains[ domain_start: domain_start+ curr_domain_size ] += idx + domain_start+= curr_domain_size + + # Shuffle everything one more time + # inds = np.arange(data_labels.size()[0]) + # np.random.shuffle(inds) + # data_feat = data_feat[inds] + # data_labels = data_labels[inds].long() + # data_domains = data_domains[inds].long() + + # Convert to onehot + y = torch.eye(2) + data_labels = y[data_labels] +# # Convert to onehot +# d = torch.eye(len(self.list_train_domains)) +# data_domains = d[data_domains] + + #Type Casting + data_feat= data_feat.type(torch.FloatTensor) + data_labels = data_labels.long() + data_domains = data_domains.long() + + print('Final Dataset: ', data_feat.shape, data_labels.shape, data_domains.shape, data_objs.shape) + return data_feat, data_labels, data_domains, data_objs \ No newline at end of file diff --git a/evaluation/base_eval.py b/evaluation/base_eval.py index baa9afd..e9bd2a3 100644 --- a/evaluation/base_eval.py +++ b/evaluation/base_eval.py @@ -62,7 +62,11 @@ class BaseEval(): if self.args.model_name == 'lenet': from models.lenet import LeNet5 phi= LeNet5() - + + if self.args.model_name == 'slab': + from models.slab import SlabClf + phi= SlabClf(self.args.slab_data_dim, self.args.out_classes) + if self.args.model_name == 'fc': from models.fc import FC if self.args.method_name in ['csd', 'matchdg_ctr']: @@ -111,7 +115,7 @@ class BaseEval(): def load_model(self, run_matchdg_erm): - if self.args.method_name in ['erm_match', 'csd', 'irm']: + if self.args.method_name in ['erm_match', 'csd', 'irm', 'perf_match']: self.save_path= self.base_res_dir + '/Model_' + self.post_string elif self.args.method_name == 'matchdg_ctr': diff --git a/models/slab.py b/models/slab.py new file mode 100644 index 0000000..9f1d10c --- /dev/null +++ b/models/slab.py @@ -0,0 +1,35 @@ +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch.autograd import Variable + + +class SlabClf(nn.Module): + def __init__(self, inp_shape, out_shape): + + super(SlabClf, self).__init__() + self.inp_shape = inp_shape + self.out_shape = out_shape + self.hidden_dim = 100 + self.feat_net= nn.Sequential( + nn.Linear( self.inp_shape, self.hidden_dim), + nn.ReLU(), + ) + + self.fc= nn.Sequential( + nn.Linear( self.hidden_dim, self.hidden_dim), + nn.Linear( self.hidden_dim, self.out_shape), + ) + + self.disc= nn.Sequential( + nn.Linear( self.hidden_dim, self.hidden_dim), + nn.Linear( self.hidden_dim, 2), + ) + + self.embedding = nn.Embedding(2, self.hidden_dim) + + def forward(self, x): + return self.fc(self.feat_net(x)) \ No newline at end of file diff --git a/test.py b/test.py index 6495818..bef8881 100644 --- a/test.py +++ b/test.py @@ -45,6 +45,10 @@ parser.add_argument('--img_h', type=int, default= 224, help='Height of the image in dataset') parser.add_argument('--img_w', type=int, default= 224, help='Width of the image in dataset') +parser.add_argument('--slab_data_dim', type=int, default= 2, + help='Number of features in the slab dataset') +parser.add_argument('--slab_total_slabs', type=int, default=7) +parser.add_argument('--slab_num_samples', type=int, default=1000) parser.add_argument('--fc_layer', type=int, default= 1, help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet') parser.add_argument('--match_layer', type=str, default='logit_match', diff --git a/train.py b/train.py index 89849cc..a7189a0 100644 --- a/train.py +++ b/train.py @@ -42,6 +42,10 @@ parser.add_argument('--img_h', type=int, default= 224, help='Height of the image in dataset') parser.add_argument('--img_w', type=int, default= 224, help='Width of the image in dataset') +parser.add_argument('--slab_data_dim', type=int, default= 2, + help='Number of features in the slab dataset') +parser.add_argument('--slab_total_slabs', type=int, default=7) +parser.add_argument('--slab_num_samples', type=int, default=1000) parser.add_argument('--fc_layer', type=int, default= 1, help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet') parser.add_argument('--match_layer', type=str, default='logit_match', @@ -172,6 +176,13 @@ for run in range(args.n_runs): test_dataset, base_res_dir, run, cuda ) + if args.method_name == 'perf_match': + from algorithms.perf_match import PerfMatch + train_method= PerfMatch( + args, train_dataset, val_dataset, + test_dataset, base_res_dir, + run, cuda + ) elif args.method_name == 'matchdg_ctr': from algorithms.match_dg import MatchDG ctr_phase=1 diff --git a/utils/helper.py b/utils/helper.py index a6075b6..333fe08 100644 --- a/utils/helper.py +++ b/utils/helper.py @@ -14,6 +14,23 @@ from torchvision.utils import save_image from torch.autograd import Variable import torch.utils.data as data_utils +# Slab Dataset: Flatten the tensor along batch and domain axis +# Input of the shape (Batch, Domain, Feat) +def slab_batch_process(x, y, d, o): + if len(x.shape) > 2: + x= x.flatten(start_dim=0, end_dim=1) + + if len(y.shape) > 1: + y= y.flatten(start_dim=0, end_dim=1) + + if len(d.shape) > 1: + d= d.flatten(start_dim=0, end_dim=1) + + if len(o.shape) > 1: + o= o.flatten(start_dim=0, end_dim=1) + + return x, y, d, o + def t_sne_plot(X): X= X.detach().cpu().numpy() X= TSNE(n_components=2).fit_transform(X) @@ -213,7 +230,9 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs): from data.pacs_loader_aug import PACSAug as PACS else: from data.pacs_loader import PACS - + + elif args.dataset_name == 'slab': + from data.slab_loader import SlabData if data_case == 'train': match_func=True @@ -237,7 +256,10 @@ def get_dataloader(args, run, domains, data_case, eval_case, kwargs): except AttributeError: batch_size= batch_size - if args.dataset_name in ['pacs', 'vlcs']: + if args.dataset_name == 'slab': + data_obj= SlabData(args, domains, '/slab/', data_case=data_case, match_func=match_func, base_size=args.slab_num_samples, freq_ratio=50, data_dim=args.slab_data_dim, total_slabs=args.slab_total_slabs) + + elif args.dataset_name in ['pacs', 'vlcs']: data_obj= PACS(args, domains, '/pacs/train_val_splits/', data_case=data_case, match_func=match_func) elif args.dataset_name in ['chestxray']: diff --git a/utils/scripts/data_utils.py b/utils/scripts/data_utils.py new file mode 100644 index 0000000..d801e3a --- /dev/null +++ b/utils/scripts/data_utils.py @@ -0,0 +1,202 @@ +import sys + +import random, os, copy, pickle, time, random, argparse, itertools +from collections import defaultdict, Counter, OrderedDict +import numpy as np +import torch +import torchvision +from torch import optim, nn +import torch.nn.functional as F +from sklearn import metrics +from torch.utils.data import TensorDataset, DataLoader + +import utils.scripts.gpu_utils as gu +import utils.scripts.lms_utils as au +import utils.scripts.synth_models as synth_models +import utils.scripts.utils as utils +import matplotlib.pyplot as plt +import pathlib + +try: + sys.path.append('../../cifar10_models/') + import cifar10_models as c10 + c10_not_found = False +except: + c10_not_found = True + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.enabled = True + +REPO_DIR = pathlib.Path(__file__).parent.parent.absolute() +DOWNLOAD_DIR = os.path.join(REPO_DIR, 'datasets') + +def msd(x, r=3): + return np.round(np.mean(x), r), np.round(np.std(x), r) + +def _get_dataloaders(trd, ted, bs, pm=True, shuffle=True): + train_dl = DataLoader(trd, batch_size=bs, shuffle=shuffle, pin_memory=pm) + test_dl = DataLoader(ted, batch_size=bs, pin_memory=pm) + return train_dl, test_dl + +def get_cifar10_models(device=None, pretrained=True): + if c10_not_found: return {} + device = gu.get_device(None) if device is None else device + get_lmbda = lambda cls: (lambda: cls(pretrained=pretrained).eval().to(device)) + return { + 'vgg11_bn': get_lmbda(c10.vgg11_bn), + 'vgg13_bn': get_lmbda(c10.vgg13_bn), + 'vgg16_bn': get_lmbda(c10.vgg16_bn), + 'vgg19_bn': get_lmbda(c10.vgg19_bn), + 'resnet18': get_lmbda(c10.resnet18), + 'resnet34': get_lmbda(c10.resnet34), + 'resnet50': get_lmbda(c10.resnet50), + 'densenet121': get_lmbda(c10.densenet121), + 'densenet161': get_lmbda(c10.densenet161), + 'densenet169': get_lmbda(c10.densenet169), + 'mobilenet_v2': get_lmbda(c10.mobilenet_v2), + 'googlenet': get_lmbda(c10.googlenet), + 'inception_v3': get_lmbda(c10.inception_v3) + } + +def plot_decision_boundary(dl, model, c1, c2, ax=None, print_info=True): + if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4)) + model = model.cpu() + deps = sorted(au.get_feature_deps(dl, model).items(), key=lambda t: t[-1]) + + if print_info: + for k, v in deps: print ('{}:{:.3f}'.format(k,v), end=', ') + print ("") + + X, Y = utils.extract_numpy_from_loader(dl) + K = 100_000 + U = np.random.uniform(low=X.min(), high=X.max(), size=(K, X.shape[1])) # copy.deepcopy(X) + U[:, c1] = np.random.uniform(low=X[:, c1].min(), high=X[:, c1].max(), size=K) + U[:, c2] = np.random.uniform(low=X[:, c2].min(), high=X[:, c2].max(), size=K) + U = torch.Tensor(U) + + with torch.no_grad(): + out = model(U) + Yu = torch.argmax(out, 1) + + ax.scatter(U[:,c1], U[:,c2], c=Yu, alpha=0.3, s=24) + ax.scatter(X[:,c1], X[:,c2], c=Y, cmap='coolwarm', s=12) + +def get_binary_datasets(X, Y, y1, y2, image_width=28, use_cnn=False): + assert type(X) is np.ndarray and type(Y) is np.ndarray + idx0 = (Y==y1).nonzero()[0] + idx1 = (Y==y2).nonzero()[0] + idx = np.concatenate((idx0, idx1)) + X_, Y_ = X[idx,:], (Y[idx]==y2).astype(int) + P = np.random.permutation(len(X_)) + X_, Y_ = X_[P,:], Y_[P] + if use_cnn: X_ = X_.reshape(X.shape[0], -1, image_width)[:, None, :, :] + return X_[P,:], Y_[P] + +def get_binary_loader(dl, y1, y2): + X, Y = utils.extract_numpy_from_loader(dl) + X, Y = get_binary_datasets(X, Y, y1, y2) + return utils._to_dl(X, Y, bs=dl.batch_size) + +def get_mnist(fpath=DOWNLOAD_DIR, flatten=False, binarize=False, normalize=True, y0={0,1,2,3,4}): + """get preprocessed mnist torch.TensorDataset class""" + def _to_torch(d): + X, Y = [], [] + for xb, yb in d: + X.append(xb) + Y.append(yb) + return torch.Tensor(np.stack(X)), torch.LongTensor(np.stack(Y)) + + to_tensor = torchvision.transforms.ToTensor() + to_flat = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze()) + to_norm = torchvision.transforms.Normalize((0.5, ), (0.5, )) + to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1) + + transforms = [to_tensor] + if normalize: transforms.append(to_norm) + if flatten: transforms.append(to_flat) + tf = torchvision.transforms.Compose(transforms) + ttf = to_binary if binarize else None + + X_tr = torchvision.datasets.MNIST(fpath, download=True, transform=tf, target_transform=ttf) + X_te = torchvision.datasets.MNIST(fpath, download=True, train=False, transform=tf, target_transform=ttf) + + return _to_torch(X_tr), _to_torch(X_te) + +def get_mnist_dl(fpath=DOWNLOAD_DIR, to_np=False, bs=128, pm=False, shuffle=False, + normalize=True, flatten=False, binarize=False, y0={0,1,2,3,4}): + (X_tr, Y_tr), (X_te, Y_te) = get_mnist(fpath, normalize=normalize, flatten=flatten, binarize=binarize, y0=y0) + tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, shuffle=shuffle, pin_memory=pm) + te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, pin_memory=pm) + return tr_dl, te_dl + +def get_cifar(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', + means=None, std=None, use_grayscale=False, binarize=False, normalize=True, y0={0,1,2,3,4}): + """get preprocessed cifar torch.Dataset class""" + + if transform_type == 'none': + normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + tensorize = torchvision.transforms.ToTensor() + to_grayscale = torchvision.transforms.Grayscale() + flatten = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze()) + + transforms = [tensorize] + if use_grayscale: transforms = [to_grayscale] + transforms + if normalize: transforms.append(normalize_cifar()) + if flatten_data: transforms.append(flatten) + tr_transforms = te_transforms = torchvision.transforms.Compose(transforms) + + if transform_type == 'basic': + normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + + tr_transforms= [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor() + ] + + te_transforms = [ + torchvision.transforms.Resize(32), + torchvision.transforms.CenterCrop(32), + torchvision.transforms.ToTensor(), + ] + + if normalize: + tr_transforms.append(normalize_cifar()) + te_transforms.append(normalize_cifar()) + + tr_transforms = torchvision.transforms.Compose(tr_transforms) + te_transforms = torchvision.transforms.Compose(te_transforms) + + to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1) + target_transforms = to_binary if binarize else None + dset = 'cifar10' if use_cifar10 else 'cifar100' + func = torchvision.datasets.CIFAR10 if use_cifar10 else torchvision.datasets.CIFAR100 + + X_tr = func(fpath, download=True, transform=tr_transforms, target_transform=target_transforms) + X_te = func(fpath, download=True, train=False, transform=te_transforms, target_transform=target_transforms) + + return X_tr, X_te + +def get_cifar_dl(fpath=DOWNLOAD_DIR, use_cifar10=False, bs=128, shuffle=True, transform_type='none', + means=None, std=None, normalize=True, flatten_data=False, use_grayscale=False, nw=4, pm=False, binarize=False, y0={0,1,2,3,4}): + """data in dataloaders have has shape (B, C, W, H)""" + d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, means=means, std=std, flatten_data=flatten_data, binarize=binarize, y0=y0) + tr_dl = DataLoader(d_tr, batch_size=bs, shuffle=shuffle, num_workers=nw, pin_memory=pm) + te_dl = DataLoader(d_te, batch_size=bs, num_workers=nw, pin_memory=pm) + return tr_dl, te_dl + +def get_cifar_np(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', normalize=True, binarize=False, y0={0,1,2,3,4}, use_grayscale=False): + """get numpy matrices of preprocessed cifar data""" + + def _to_np(d): + X, Y = [], [] + for xb, yb in d: + X.append(xb) + Y.append(yb) + return map(np.stack, [X,Y]) + + d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, flatten_data=flatten_data, binarize=binarize, y0=y0) + return _to_np(d_tr), _to_np(d_te) + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/utils/scripts/ensemble.py b/utils/scripts/ensemble.py new file mode 100644 index 0000000..fb5a04b --- /dev/null +++ b/utils/scripts/ensemble.py @@ -0,0 +1,130 @@ +import os, copy, pickle, time +import random, itertools +from collections import defaultdict, Counter, OrderedDict +import numpy as np +import torch +import pandas as pd +import torchvision +from torch.utils.data import TensorDataset, DataLoader +from torch import optim, nn +import torch.nn.functional as F +import dill +import gpu_utils as gu +import data_utils as du +import synth_models as sm +import utils + +class Ensemble(nn.Module): + + def _get_dummy_classifier(self): + def dummy(x): + return x + return dummy + + def __init__(self, models, num_classes, use_softmax=False): + super(Ensemble, self).__init__() + self.num_classes = num_classes + self.use_softmax = use_softmax + + # register models as pytorch modules + self.models = [] + for idx, m in enumerate(models,1): + setattr(self, 'm{}'.format(idx), m.eval()) + self.models.append(getattr(self, 'm{}'.format(idx))) + + self.classifier = self._get_dummy_classifier() + + def _forward(self, x): + return x + + def forward(self, x): + outs = self._forward(x) + return self.classifier(outs) + + def get_output_loader(self, dl, device=gu.get_device(None), bs=None): + """return dataloader of model output (logit or softmax prob)""" + X, Y = [], [] + with torch.no_grad(): + for xb, yb in dl: + xb = xb.to(device) + out = self._forward(xb).cpu() + X.append(out) + Y.append(yb) + X, Y = torch.cat(X), torch.cat(Y) + return DataLoader(TensorDataset(X, Y), batch_size=bs or dl.batch_size) + + def fit_classifier(self, tr_dl, te_dl, lr=0.05, adam=False, wd=5e-5, device=None, **fit_kw): + device = gu.get_device(None) if device is None else device + self = self.to(device) + + c = dict(gap=1000, epsilon=1e-2, wd=5e-5, is_loss_epsilon=True) + c.update(**fit_kw) + + tro_dl = self.get_output_loader(tr_dl, device) + teo_dl = self.get_output_loader(te_dl, device) + + if adam: opt = optim.Adam(self.classifier.parameters()) + else: opt = optim.SGD(self.classifier.parameters(), lr=lr, weight_decay=wd) + stats = utils.fit_model(self.classifier, F.cross_entropy, opt, tro_dl, teo_dl, device=device, **c) + + self.classifier = stats['best_model'][-1].to(device) + self = self.cpu() + return stats + +class EnsembleLinear(Ensemble): + + def _get_classifier(self): + # linear with equal weights and zero bias + nl = self.num_classes*len(self.models) + linear = nn.Linear(nl, self.num_classes, bias=self.use_bias) + nn.init.ones_(linear.weight.data) + linear.weight.data /= float(nl) + if self.use_bias: linear.bias.data.zero_() + return linear + + def __init__(self, models, num_classes=2, use_softmax=False, use_bias=True): + super(EnsembleLinear, self).__init__(models, num_classes, use_softmax) + self.use_bias = use_bias + self.classifier = self._get_classifier() + + def _forward(self, x): + outs = [m(x) for m in self.models] + if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] + outs = torch.stack(outs, dim=2) + outs = outs.reshape(outs.shape[0], -1) + return outs + +class EnsembleMLP(Ensemble): + + def _get_classifier(self): + nl = self.num_classes*len(self.models) + fcn = sm.get_fcn(nl, self.hdim or nl, self.num_classes, hl=self.hl) + return fcn + + def __init__(self, models, num_classes=2, use_softmax=False, hdim=None, hl=1): + super(EnsembleMLP, self).__init__(models, num_classes, use_softmax) + self.hdim = hdim + self.hl = hl + self.classifier = self._get_classifier() + + def _forward(self, x): + outs = [m(x) for m in self.models] + if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] + outs = torch.stack(outs, dim=2) + outs = outs.reshape(outs.shape[0], -1) + return outs + +class EnsembleAverage(Ensemble): + + def __init__(self, models, num_classes=2, use_softmax=False): + super(EnsembleAverage, self).__init__(models, num_classes, use_softmax) + self.classifier = self._get_dummy_classifier() + + def _forward(self, x): + outs = [m(x) for m in self.models] + if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] + outs = torch.stack(outs) + return outs.mean(dim=0) + + def fit_classifier(self, *args, **kwargs): + return None diff --git a/utils/scripts/gendata.py b/utils/scripts/gendata.py new file mode 100644 index 0000000..2749eb8 --- /dev/null +++ b/utils/scripts/gendata.py @@ -0,0 +1,316 @@ +import numpy as np +import scipy.stats as scs +import random +from collections import Counter +import torch +from torch.utils.data import TensorDataset, DataLoader +import utils.scripts.utils as utils +import utils.scripts.gpu_utils as gu + +def _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=None): + X_te, Y_te = torch.Tensor(X[:N_te,:]), torch.Tensor(Y[:N_te]) + X_tr, Y_tr = torch.Tensor(X[N_te:,:]), torch.Tensor(Y[N_te:]) + Y_te, Y_tr = map(lambda Z: Z.long(), [Y_te, Y_tr]) + + tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=True) + te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=False) + + return { + 'X': torch.tensor(X).float(), + 'Y': torch.tensor(Y).long(), + 'w': w, + 'tr_dl': tr_dl, + 'te_dl': te_dl, + 'N': (N_tr, N_te), + 'W': orth_matrix + } + +def _get_random_data(N, dim, scale): + X = np.random.uniform(size=(N, dim)) + X *= scale + Y = np.random.choice([0,1], size=N) + return X, Y + +def generate_linsep_data_v2(N_tr, dim, eff_margin, width=10., bs=256, scale_noise=True, pm=True, nw=0, no_width=False, N_te=5000): # no unif_max. + assert eff_margin < 1, "equal range constraint" + margin = eff_margin if no_width else eff_margin*width + + N = N_tr + N_te + w = np.zeros(shape=dim) + w[0] = 1 + + X, Y = _get_random_data(N, dim, width if scale_noise else 1.) + + U = np.random.uniform(size=N) + if no_width: X[:,0] = (2*Y-1)*margin + else: X[:, 0] = (2*Y-1)*(margin + (width-margin)*U) + + P = np.random.permutation(X.shape[0]) + X, Y = X[P,:], Y[P] + + return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w) + +def sample_from_unif_union_of_unifs(unifs, size): + x = [] + choices = Counter(np.random.choice(list(range(len(unifs))), size=size)) + for choice, sz in choices.items(): + s = np.random.uniform(low=unifs[choice][0], high=unifs[choice][1], size=sz) + x.append(s) + x = np.concatenate(x) + return x + +def generate_ub_linslab_data_diffmargin_v2(N_tr, dim, eff_lin_margins, eff_slab_margins, + slabs_per_coord, slab_p_vals, corrupt_lin=0., corrupt_slab=0., + corrupt_slab7=0., scale_noise=True, width=10., lin_coord=0, lin_shift=0., + slab_shift=0., indep_slabs=True, bs=256, pm=True, nw=0, N_te=10000, + random_transform=False, corrupt_lin_margin=False, corrupt_5slab_margin=False): + get_unif = lambda a: np.random.uniform(size=a) + get_bool = lambda a: np.random.choice([0,1], size=a) + get_sign = lambda a: 2*get_bool(a)-1. + + def get_slab_width(NS, B, SM): + if NS==3: return (2.*B-4.*SM)/3. + if NS==5: return (2.*B-8.*SM)/5. + if NS==7: return (2.*B-12.*SM)/7. + return None + + num_lin, num_slabs = map(len, [eff_lin_margins, eff_slab_margins]) + assert 0 <= corrupt_lin <= 1, "input is probability" + assert num_lin + num_slabs <= dim, "dim constraint, num_lin: {}, num_slabs: {}, dim: {}".format(num_lin, num_slabs, dim) + for elm in eff_lin_margins: assert 0 < elm < 1, "equal range constraint (0 < eff_lin_margin={} < 1)".format(elm) + for esm in eff_slab_margins: assert 0 < esm < 1, "equal range constraint (0 < eff_slab_margin={} < 0.25)".format(esm) + + lin_margins = list(map(lambda x: x*width, eff_lin_margins)) + slab_margins = list(map(lambda x: x*width, eff_slab_margins)) + + # hyperplane + N = N_tr + N_te + half_N = N//2 + w = np.zeros(shape=dim); w[0] = 1 + + X, Y = _get_random_data(N, dim, width if scale_noise else 1.) + nrange = list(range(N)) + # linear + total_corrupt = int(round(N*corrupt_lin)) + no_linear = num_lin == 0 + if not no_linear: + for coord, lin_margin in enumerate(lin_margins): + if indep_slabs: + P = np.random.permutation(N) + X, Y = X[P, :], Y[P] + X[:, coord] = (2*Y-1)*(lin_margin+(width-lin_margin)*get_unif(N)) + lin_shift*width + + # corrupt linear coordinate + if total_corrupt > 0: + corrupt_sample = np.random.choice(nrange, size=total_corrupt, replace=False) + if corrupt_lin_margin: + X[corrupt_sample, 0] = np.random.uniform(low=-lin_margin, high=lin_margin, size=total_corrupt) + else: + X[corrupt_sample, 0] *= -1 + + # slabs + i = (num_lin)*int(not no_linear) + for idx, coord in enumerate(range(i, i+num_slabs)): + slab_per = slabs_per_coord[idx] + assert slab_per in [3, 5, 7], "Invalid slabs_per_coord" + + slab_pval = slab_p_vals[idx] + slab_margin = slab_margins[idx] + slab_width = get_slab_width(slab_per, width, slab_margin) + + if indep_slabs: + P = np.random.permutation(N) + X, Y = X[P, :], Y[P] + + if slab_per == 3: + # positive slabs + idx_p = (Y==1).nonzero()[0] + offset = 0.5*slab_width + 2*slab_margin + X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p))) + + # negative center + idx_n = (Y==0).nonzero()[0] + X[idx_n, coord] = 0.5*get_sign(len(idx_n))*slab_width*get_unif(len(idx_n)) + + if slab_per == 5: + # positive slabs + idx_p = (Y==1).nonzero()[0] + offset = (width+6*slab_margin)/5. + X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p))) + + # negative slabs partitioned using p val + idx_n = (Y==0).nonzero()[0] + in_ctr = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_n)) + idx_nc, idx_ns = idx_n[(in_ctr==1)], idx_n[(in_ctr==0)] + + # negative center + X[idx_nc, coord] = 0.5*get_sign(len(idx_nc))*slab_width*get_unif(len(idx_nc)) + + # negative sides + offset = (8*slab_margin+3*width)/5. + X[idx_ns, coord] = get_sign(len(idx_ns))*(offset+slab_width*get_unif(len(idx_ns))) + + # corrupt slab 5 + total_corrupt = int(round(N*corrupt_slab)) + if total_corrupt > 0: + if corrupt_5slab_margin: + offset1 = (width+6*slab_margin)/5. + offset2 = (8*slab_margin+3*width)/5. + unifs = [ + (0.5*slab_width, offset1), + (offset1+slab_width, offset2), + (-offset1, -0.5*slab_width), + (-offset2, -(offset1+slab_width)) + ] + + idx = np.random.choice(range(N), size=total_corrupt, replace=False) + X[idx, coord] = sample_from_unif_union_of_unifs(unifs, total_corrupt) + else: + # get corrupt sample + idx = np.random.choice(range(N), size=total_corrupt, replace=False) + idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1) + idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1) + + # move negative points to random positive slabs + offset = (0.5*slab_width+2*slab_margin) + X[idx_n, coord] = torch.Tensor(get_sign(len(idx_n))*(offset+slab_width*get_unif(len(idx_n)))) + + # pick negative slab for each positve point + mv_to_ctr = np.random.choice([0, 1], size=len(idx_p)) + idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1) + idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1) + + # move positive points to negative slabs + X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*slab_width*get_unif(len(idx_p_ctr))) + + # move negative points to positve slabs + offset = 1.5*slab_width + 4*slab_margin + X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid)))) + + if slab_per == 7: + # positive slabs + idx_p = (Y==1).nonzero()[0] + in_s0 = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_p)) + idx_p0, idx_p1 = idx_p[(in_s0==1)], idx_p[(in_s0==0)] + + # positive slab 0 (inner) + offset = 0.5*slab_width+2*slab_margin + X[idx_p0, coord] = get_sign(len(idx_p0))*(offset+slab_width*get_unif(len(idx_p0))) + + # positive slab 1 (outer) + offset = 2.5*slab_width+6*slab_margin + X[idx_p1, coord] = get_sign(len(idx_p1))*(offset+slab_width*get_unif(len(idx_p1))) + + # negative slabs + idx_n = (Y==0).nonzero()[0] + in_s0 = get_bool(len(idx_n)) + idx_n0, idx_n1 = idx_n[(in_s0==1)], idx_n[(in_s0==0)] + + # negative slab 0 (center) + X[idx_n0, coord] = 0.5*get_sign(len(idx_n0))*slab_width*get_unif(len(idx_n0)) + + # negative slab 1 (outer) + offset = 1.5*slab_width+4*slab_margin + X[idx_n1, coord] = get_sign(len(idx_n1))*(offset+slab_width*get_unif(len(idx_n1))) + + # corrupt slab7 + total_corrupt = int(round(N*corrupt_slab7)) + if total_corrupt > 0: + # corrupt data + idx = np.random.choice(range(len(X)), size=total_corrupt, replace=False) + idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1) + idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1) + + # pick positive slab for each negative slab + mv_to_inner = get_bool(len(idx_n)) + idx_n_inner = idx_n[np.argwhere(mv_to_inner==1)].reshape(-1) + idx_n_outer = idx_n[np.argwhere(mv_to_inner==0)].reshape(-1) + + # move to idx_n_inner and outer + offset = 0.5*slab_width+2*slab_margin + X[idx_n_inner, coord] = torch.Tensor(get_sign(len(idx_n_inner))*(offset+slab_width*get_unif(len(idx_n_inner)))) + offset = 2.5*slab_width+6*slab_margin + X[idx_n_outer, coord] = torch.Tensor(get_sign(len(idx_n_outer))*(offset+slab_width*get_unif(len(idx_n_outer)))) + + # pick negative slab for each positive point + mv_to_ctr = get_bool(len(idx_p)) + idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1) + idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1) + + # move to idx_n_inner and outer + X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*(slab_width*get_unif(len(idx_p_ctr)))) + offset = 1.5*slab_width+4*slab_margin + X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid)))) + + # shift + X[:, coord] += slab_shift*width + + # reshuffle + P = np.random.permutation(N) + X, Y = X[P,:], Y[P] + + # lin coord position + if not random_transform and lin_coord != 0: + X[:, [0, lin_coord]] = X[:, [lin_coord, 0]] + + # transform + W = np.eye(dim) + if random_transform: W = utils.get_orthonormal_matrix(dim) + X = X.dot(W) + + return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=W) + + +def generate_ub_linslab_data_v2(N_tr, dim, eff_lin_margin, eff_slab_margin=None, lin_coord=0, + corrupt_lin=0., corrupt_slab=0., corrupt_slab3=0., corrupt_slab7=0., + scale_noise=True, num_lin=1, lin_shift=0., slab_shift=0., random_transform=False, + num_slabs=1, slabs_per_coord=5, width=10., indep_slabs=True, no_linear=False, + bs=256, pm=True, nw=0, N_te=10000, corrupt_lin_margin=False, slab5_pval=3/4., + slab3_pval=1/2., slab7_pval=7/8., corrupt_5slab_margin=False): + slab_p_map = {5: slab5_pval, 7: slab7_pval, 3: slab3_pval} + slabs_per_coord = [slabs_per_coord]*num_slabs if type(slabs_per_coord) is int else slabs_per_coord[:] + for x in slabs_per_coord: assert x in slab_p_map + slab_p_vals = [slab_p_map[x] for x in slabs_per_coord] + lms = [eff_lin_margin]*num_lin + sms = eff_slab_margin if type(eff_slab_margin) is list else [eff_slab_margin]*num_slabs + return generate_ub_linslab_data_diffmargin_v2(N_tr, dim, lms, sms, slabs_per_coord, slab_p_vals, lin_coord=lin_coord, corrupt_slab=corrupt_slab, + corrupt_slab7=corrupt_slab7, corrupt_lin=corrupt_lin, scale_noise=scale_noise, width=width, + lin_shift=lin_shift, slab_shift=slab_shift, random_transform=random_transform, indep_slabs=indep_slabs, + pm=pm, bs=bs, corrupt_lin_margin=corrupt_lin_margin, nw=nw, N_te=N_te, corrupt_5slab_margin=corrupt_5slab_margin) + + +def get_lms_data(**kw): + + c = config = { + 'num_train': 100_000, + 'dim': 20, + 'lin_margin': 0.1, + 'slab_margin': 0.1, + 'same_margin': False, + 'random_transform': False, + 'width': 1, # data width + 'bs': 256, + 'corrupt_lin': 0.0, + 'corrupt_lin_margin': False, + 'corrupt_slab': 0.0, + 'num_test': 2_000, + 'hdim': 200, # model width + 'hl': 2, # model depth + 'device': gu.get_device(0), + 'input_dropout': 0, + 'num_lin': 1, + 'num_slabs': 19, + 'num_slabs7': 0, + 'num_slabs3': 0, + } + + c.update(kw) + + smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin'] + data_func = generate_ub_linslab_data_v2 + spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7'] + data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'], + corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], + corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab']) + return data, c + diff --git a/utils/scripts/gpu_utils.py b/utils/scripts/gpu_utils.py new file mode 100644 index 0000000..13cdd6f --- /dev/null +++ b/utils/scripts/gpu_utils.py @@ -0,0 +1,100 @@ +try: import pycuda.driver as cuda +except: print ("pycuda not available") + +import torch +import sys, os, glob, subprocess + +def get_gpu_info(print_info=True, get_specs=False): + cuda.init() + if get_specs: gpu_specs = cuda.Device(0).get_attributes() # assume same for all (dnnx) + else: gpu_specs = None + + gpu_info = { + 'available': torch.cuda.is_available(), + 'num_devices': torch.cuda.device_count(), + 'devices': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]), + 'current device id': torch.cuda.current_device(), + 'allocated memory': torch.cuda.memory_allocated(), + 'cached memory': torch.cuda.memory_cached() + } + + if print_info: + for k,v in gpu_info.items(): print ("{}: {}".format(k, v)) + + return gpu_info, gpu_specs + + +def get_device(device_id=None): # None -> cpu + device = 'cuda:{}'.format(device_id) if device_id is not None else 'cpu' + device = torch.device(device if torch.cuda.is_available() and device_id is not None else 'cpu') + return device + +def get_gpu_name(): + try: + out_str = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE).stdout + out_list = out_str.decode("utf-8").split('\n') + out_list = out_list[1:-1] + return out_list + except Exception as e: + print(e) + +def get_cuda_version(): + """Get CUDA version""" + if sys.platform == 'win32': + raise NotImplementedError("Implement this!") + # This breaks on linux: + #cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA" + #path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\version.txt" + elif sys.platform == 'linux' or sys.platform == 'darwin': + path = '/usr/local/cuda/version.txt' + else: + raise ValueError("Not in Windows, Linux or Mac") + if os.path.isfile(path): + with open(path, 'r') as f: + data = f.read().replace('\n','') + return data + else: + return "No CUDA in this machine" + +def get_cudnn_version(): + """Get CUDNN version""" + if sys.platform == 'win32': + raise NotImplementedError("Implement this!") + # This breaks on linux: + #cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA" + #candidates = ["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\include\\cudnn.h"] + elif sys.platform == 'linux': + candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h', + '/usr/local/cuda/include/cudnn.h', + '/usr/include/cudnn.h'] + elif sys.platform == 'darwin': + candidates = ['/usr/local/cuda/include/cudnn.h', + '/usr/include/cudnn.h'] + else: + raise ValueError("Not in Windows, Linux or Mac") + for c in candidates: + file = glob.glob(c) + if file: break + if file: + with open(file[0], 'r') as f: + version = '' + for line in f: + if "#define CUDNN_MAJOR" in line: + version = line.split()[-1] + if "#define CUDNN_MINOR" in line: + version += '.' + line.split()[-1] + if "#define CUDNN_PATCHLEVEL" in line: + version += '.' + line.split()[-1] + if version: + return version + else: + return "Cannot find CUDNN version" + else: + return "No CUDNN in this machine" + +if __name__=='__main__': + print ('gpu name', get_gpu_name()) + print ('cuda', get_cuda_version()) + print ('cudnn', get_cudnn_version()) + print ('device0', get_device(0)) + print ('available', torch.cuda.is_available()) \ No newline at end of file diff --git a/utils/scripts/lms_utils.py b/utils/scripts/lms_utils.py new file mode 100644 index 0000000..cb4c6d5 --- /dev/null +++ b/utils/scripts/lms_utils.py @@ -0,0 +1,307 @@ +import seaborn as sns +import utils.scripts.gpu_utils as gu +import utils.scripts.data_utils as du +import utils +import random +import os, copy, pickle, time +import itertools +from collections import defaultdict, Counter, OrderedDict +import matplotlib.pyplot as plt +import numpy as np +import torch +import pandas as pd +#import foolbox +from torch.utils.data import TensorDataset, DataLoader +from torch import optim, nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score + +def parse_data(exps=None, root='/', **funcs_kw): + """ + main function (parse data files and run added functions on it) + """ + exps = exps if exps is not None else os.listdir(root) + total = len(exps) + print ("total: {}".format(total)) + parsed = defaultdict(dict) + if total == 0: return parsed + for idx, exp in enumerate(exps): + if (idx+1) % 1 == 0: print (idx+1, end=' ', flush=True) + # load data + fpath = os.path.join(root, exp) + try: + data = torch.load(fpath, map_location=lambda storage, loc: storage) + except: + print ("File {} corrupted, skip.".format(fpath)) + continue + config = data['config'] + + # make exp config key + config['run'] = int(exp.rsplit('.', 1)[0][-1]) + config['fname'] = exp + ckeys = ['exp_name', 'dim', 'num_train', 'lin_margin', 'slab_margin', 'num_slabs', + 'num_slabs', 'width', 'hdim', 'hl', 'linear', 'use_bn', 'run', 'fname', + 'weight_decay', 'dropout'] + ckeys = [c for c in ckeys if c in config] + cvals = [config[k] for k in ckeys] + ckv = tuple(zip(ckeys, cvals)) + + # save config + parsed[ckv]['config'] = config + + # save functions + for func_name, func in funcs_kw.items(): + parsed[ckv][func_name] = func(data) + + return parsed + +def parse_exp_stats(data): + """training summary statistics""" + stats = data['stats'] + s = {} + + # loss + accuracy + for t1, t2 in itertools.product(['acc', 'loss'], ['tr', 'te']): + s['{}_{}'.format(t1,t2)] = stats['{}_{}'.format(t1, t2)][-1] + s['orig_stats'] = stats + s['acc_gap'] = s['acc_tr']-s['acc_te'] + s['loss_gap'] = s['loss_te']-s['loss_tr'] + s['fin_acc_te'] = stats['acc_te'][-1] + s['fin_acc_tr'] = stats['acc_tr'][-1] + + # updates + s['update_gap'] = stats['update_gap'] + s['num_updates'] = stats['num_updates'] + + # effective number of updates + for acc_threshold in [0.96, 0.97, 0.98, 0.99, 1]: + eff = np.argmin(np.abs(np.array(stats['acc_tr'])-acc_threshold))*s['update_gap'] + s['effective_num_updates{}'.format(int(acc_threshold*100))] = eff + + return s + +def parse_exp_model(data): + """model parameter stats""" + depth = data['config']['hl'] + linear = data['config']['linear'] + mtype = data['config'].get('mtype', 'fcn') + if mtype == 'fcn' and depth == 1 and not linear: d = parse_exp_depth1_model(data) + if mtype == 'fcn' and depth == 1 and linear: d = parse_exp_linear_model(data) + return {} + +def parse_exp_depth1_model(data): + """cosine + w2""" + device = gu.get_device() + model = data['model'].to(device) + p = W1, b1, w2, b2 = list(map(lambda x: x.detach().numpy(), model.parameters())) + s = {} + s['params'] = p + s['cosine'] = W1[:, 0]/np.linalg.norm(W1, axis=1) + s['l2'] = np.linalg.norm(W1, axis=1) + s['w2'] = w2 + s['corr0'] = np.corrcoef(s['cosine'], w2[0, :])[0,1] + s['corr1'] = np.corrcoef(s['cosine'], w2[1, :])[0,1] + s['max_weight_cosine'] = s['cosine'][np.argmax(s['w2'][1,:])] + return s + +def parse_exp_linear_model(data): + """cosine""" + device = gu.get_device() + model = data['model'].to(device) + p = W,b = list(map(lambda x: x.detach().numpy(), model.parameters())) + s = {} + s['cosine0'], s['cosine1'] = W[:, 0]/np.linalg.norm(W, axis=1) + return s + +def parse_exp_data(data, load_X=False): + s = {} + model = data['model'].to(gu.get_device()) + data = data['data'] + X, Y = data['X'], data['Y'] + + if type(X) != np.ndarray: + X = data['X'].detach().cpu() + + if type(X) != np.ndarray: + Y = data['Y'].detach().cpu() + + s['Y'] = Y + if load_X: s['X'] = X + s['Y_'] = get_yhat(model, X) + s['model'] = model + return s + +def get_yhat(model, data): + if type(data)==np.ndarray: data = torch.Tensor(data) + return torch.argmax(model(data), 1) + +def get_acc(y,yhat): + n = float(len(y)) + return (y==yhat).sum().item()/n + +def parse_and_get_df(root, prefix, files=None, device_id=None, only_load=False, only_linear=False, sample_pct=0.5, load_X=False, use_model_pred=False): + exps = files if files is not None else [f for f in os.listdir(root) if f.startswith(prefix)] + + funcs = { + 'config': lambda d: d['config'], + 'stats': parse_exp_stats, + 'model': parse_exp_model, + 'data': lambda x: parse_exp_data(x, load_X=load_X), + 'random_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='random', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id), + 'swap_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='swap', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id), + } + + P = parse_data(root=root, exps=exps, **funcs) + if only_load: return P + + D = [] + for idx, (k,v) in enumerate(P.items(),1): + d = OrderedDict() + for a,b in k: d[a] = b + for vk in ['model', 'data', 'stats', 'config']: + for a,b in v[vk].items(): d[a] = b + for vk in ['random_dep', 'swap_dep']: + for coord, dep in v[vk].items(): + d[f'{vk[0]}dep_{coord}'] = dep + D.append(d) + + df = pd.DataFrame(D) + if len(df): df['nd'] = df['num_train']/df['dim'] + return df + +def viz(d, c1, c2, k=80_000, info=True, plot_dm=True, plot_data=True, use_yhat=False, unif_k=False, width=10, title=None, is_binary=False, dep_type='swap', ax=None): + if 'W' not in d['data']: W = np.eye(d['config']['dim']) + else: W = d['data']['W'] + if W is None: W = np.eye(d['config']['dim']) + + z = parse_exp_data(d) + X = d['data']['X'] + + # visualize un-transformed data... + X_ = np.array(X).dot(W.T) + Y, Y_ = z['Y'], z['Y_'] + model = d['model'].cpu() + D = X.shape[1] + kn = k if unif_k else len(X) + K = torch.Tensor(np.random.uniform(size=(k, D)))*width if unif_k else np.array(X_) + K[:, c1] = torch.Tensor(np.random.uniform(low=min(X_[:,c1]), high=max(X_[:,c1]), size=kn)) + K[:, c2] = torch.Tensor(np.random.uniform(low=min(X_[:,c2]), high=max(X_[:,c2]), size=kn)) + KO = model(torch.Tensor(np.array(K).dot(W))) + if is_binary: KY = (KO > 0).squeeze().numpy() + else: KY = torch.argmax(KO, 1).numpy() + + if info: + deps = get_feature_deps(d['data']['te_dl'], d['model'], W=d['data'].get('W', None), dep_type=dep_type) + for k,v in sorted(deps.items(), reverse=False, key=lambda t: t[-1]): print ('{}:{:.3f}'.format(k,v), end=' ') + print ("\n") + + if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4)) + + if plot_dm: ax.scatter(K[:, c1], K[:, c2], c=KY, cmap='binary', s=8, alpha=.2) + if plot_data: ax.scatter(X_[:, c1], X_[:, c2], c=Y_ if use_yhat else Y, cmap='coolwarm', s=8, alpha=.4) + + ax.set_xlabel('e_{}'.format(c1)) + ax.set_ylabel('e_{}'.format(c2)) + ax.set_title(title if title else '') + plt.tight_layout() + return ax + +def visualize_boundary(model, data, c1, c2, dim, ax=None, is_binary=False, use_yhat=False, width=1, unif_k=True, k=100_000, print_info=True, dep_type='random'): + agg = {'model': model, 'data': data, 'config': dict(dim=dim)} + return viz(agg, c1, c2, unif_k=unif_k, width=width, dep_type=dep_type, is_binary=is_binary, use_yhat=use_yhat, ax=ax, info=print_info) + +def get_randomized_loader(dl, W, coordinates): + """ + dl: dataloader + W: rotation matrix + coordinates: list of coordinates to randomize + output: randomized dataloader + """ + + def _randomize(X, coords): + p = torch.randperm(len(X)) + for c in coords: X[:, c] = X[p, c] + return X + + # rotate data + X, Y = map(copy.deepcopy, dl.dataset.tensors) + dim = X.shape[1] + if W is None: W = np.eye(dim) + + rt_X = torch.Tensor(X.numpy().dot(W.T)) + rand_rt_X = _randomize(rt_X, coordinates) + rand_X = torch.Tensor(rand_rt_X.numpy().dot(W)) + + return utils._to_dl(rand_X, Y, dl.batch_size) + + +def get_feature_deps(dl, model, W=None, dep_type='random', only_linear=False, coords=None, metric='accuracy', + use_model_pred=False, print_info=False, sample_pct=1.0, device_id=None): + """Compute feature dependencies using randomization or swapping""" + def _randomize(X, Y, coords): + p = torch.randperm(len(X)) + for c in coords: X[:, c] = X[p, c] + return X + + def _swap(X, Y, coords): + idx0, idx1 = map(lambda c: (Y.numpy()==c).nonzero()[0], [0, 1]) + idx0_new = np.random.choice(idx1, size=len(idx0), replace=True) + idx1_new = np.random.choice(idx0, size=len(idx1), replace=True) + for c in coords: X[idx0, c], X[idx1, c] = X[idx0_new, c], X[idx1_new, c] + return X + + def _get_dep_data(X, Y, coords): + return dict(random=_randomize, swap=_swap)[dep_type](X, Y, coords) + + + assert metric in {'accuracy', 'loss', 'auc'} + + # setup data + device = gu.get_device(device_id) + model = model.to(device) + X, Y = map(lambda Z: Z.to(device), dl.dataset.tensors) + Yh = get_yhat(model, X) + dim = X.shape[1] + if W is None: W = np.eye(dim) + W = torch.Tensor(W).to(device) + rt_X = torch.mm(X, torch.transpose(W,0,1)) + + # subsample data + n_samp = int(round(sample_pct*len(rt_X))) + perm = torch.randperm(len(rt_X))[:n_samp] + rt_X, Y, Yh = rt_X[perm, :], Y[perm], Yh[perm] + + # compute deps + deps = {} + + dims = list(range(dim)) + if coords is None and not only_linear: coords = dims + if coords is None and only_linear: coords = [0,1] + + for idx, coord in enumerate(coords): + if print_info: print ('{}/{}'.format(idx, len(coords)), end=' ') + rt_X_ = copy.deepcopy(rt_X).to(device) + rt_X_ = _get_dep_data(rt_X_, Y, coord if type(coord) in (list, tuple) else [coord]) + X_ = torch.mm(rt_X_, W) + Ys = get_yhat(model, X_) + + key = tuple(coord) if type(coord) in (list, tuple) else coord + + if metric == 'auc': + L = utils.get_logits_given_tensor(X_, model, device=device, bs=250) + S = L[:,1]-L[:,0] + auc = roc_auc_score(Y.cpu().numpy(), S.cpu().numpy()) + deps[key] = auc + elif metric == 'accuracy': + deps[key] = get_acc(Yh if use_model_pred else Y, Ys) + elif metric == 'loss': + L = utils.get_logits_given_tensor(X_, model, device=device, bs=250) + with torch.no_grad(): + loss_val = F.cross_entropy(L, Y).item() + deps[key] = loss_val + + return deps + +def get_subset_feature_deps(dl, model, coords_set, comb_size, W=None, dep_type='random', sample_pct=0.5, device_id=None, print_info=False): + coords = list(itertools.combinations(coords_set, comb_size)) + return get_feature_deps(dl, model, W=W, dep_type=dep_type, coords=coords, print_info=print_info, sample_pct=sample_pct, device_id=device_id) diff --git a/utils/scripts/mnistcifar_utils.py b/utils/scripts/mnistcifar_utils.py new file mode 100644 index 0000000..7d3bdbc --- /dev/null +++ b/utils/scripts/mnistcifar_utils.py @@ -0,0 +1,98 @@ +import random +import os, copy, pickle, time +import itertools +from collections import defaultdict, Counter, OrderedDict +import numpy as np +import torch +from torch.utils.data import TensorDataset, DataLoader +import utils +import gpu_utils as gu +import data_utils as du + +def get_binary_mnist(y1=0, y2=1, apply_padding=True, repeat_channels=True): + + def _make_cifar_compatible(X): + if apply_padding: X = np.stack([np.pad(X[i][0], 2)[None,:] for i in range(len(X))]) # pad + if repeat_channels: X = np.repeat(X, 3, axis=1) # add channels + return X + + binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2) + + tr_dl, te_dl = du.get_mnist_dl(normalize=False) + Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl)) + Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl)) + Xtr, Xte = map(_make_cifar_compatible, [Xtr, Xte]) + return (Xtr, Ytr), (Xte, Yte) + +def get_binary_cifar(y1=3, y2=5, c={0,1,2,3,4}, use_cifar10=True): + binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2) + binary = False if y1 is not None and y2 is not None else True + if binary: print ("grouping cifar classes") + tr_dl, te_dl = du.get_cifar_dl(use_cifar10=use_cifar10, shuffle=False, normalize=False, binarize=binary, y0=c) + + Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl)) + Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl)) + return (Xtr, Ytr), (Xte, Yte) + +def combine_datasets(Xm, Ym, Xc, Yc, randomize_order=False, randomize_first_block=False, randomize_second_block=False): + """combine two datasets""" + + def partition(X, Y, randomize=False): + """partition randomly or using labels""" + if randomize: + n = len(Y) + p = np.random.permutation(n) + ni, pi = p[:n//2], p[n//2:] + else: + ni, pi = (Y==0).nonzero()[0], (Y==1).nonzero()[0] + return X[pi], X[ni] + + def _combine(X1, X2): + """concatenate images from two sources""" + X = [] + for i in range(min(len(X1), len(X2))): + x1, x2 = X1[i], X2[i] + # randomize order + if randomize_order and random.random() < 0.5: + x1, x2 = x2, x1 + x = np.concatenate((x1,x2), axis=1) + X.append(x) + return np.stack(X) + + Xmp, Xmn = partition(Xm, Ym, randomize=randomize_first_block) + Xcp, Xcn = partition(Xc, Yc, randomize=randomize_second_block) + n = min(map(len, [Xmp, Xmn, Xcp, Xcn])) + Xmp, Xmn, Xcp, Xcn = map(lambda Z: Z[:n], [Xmp, Xmn, Xcp, Xcn]) + + Xp = _combine(Xmp, Xcp) + Yp = np.ones(len(Xp)) + + Xn = _combine(Xmn, Xcn) + Yn = np.zeros(len(Xn)) + + X = np.concatenate([Xp, Xn], axis=0) + Y = np.concatenate([Yp, Yn], axis=0) + P = np.random.permutation(len(X)) + X, Y = X[P], Y[P] + return X, Y + +def get_mnist_cifar(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, + randomize_mnist=False, randomize_cifar=False): + + y1, y2 = mnist_classes + (Xtrm, Ytrm), (Xtem, Ytem) = get_binary_mnist(y1=y1, y2=y2) + + y1, y2 = (None, None) if cifar_classes is None else cifar_classes + (Xtrc, Ytrc), (Xtec, Ytec) = get_binary_cifar(c=c, y1=y1, y2=y2) + + Xtr, Ytr = combine_datasets(Xtrm, Ytrm, Xtrc, Ytrc, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar) + Xte, Yte = combine_datasets(Xtem, Ytem, Xtec, Ytec, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar) + return (Xtr, Ytr), (Xte, Yte) + +def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, bs=256, + randomize_mnist=False, randomize_cifar=False): + (Xtr, Ytr), (Xte, Yte) = get_mnist_cifar(mnist_classes=mnist_classes, cifar_classes=cifar_classes, + c=c, randomize_mnist=randomize_mnist, randomize_cifar=randomize_cifar) + tr_dl = utils._to_dl(Xtr, Ytr, bs=bs, shuffle=True) + te_dl = utils._to_dl(Xte, Yte, bs=100, shuffle=False) + return tr_dl, te_dl \ No newline at end of file diff --git a/utils/scripts/ptb_utils.py b/utils/scripts/ptb_utils.py new file mode 100644 index 0000000..d0ab661 --- /dev/null +++ b/utils/scripts/ptb_utils.py @@ -0,0 +1,236 @@ +import seaborn as sns +import utils +import random +import os, copy, pickle, time +import itertools +from collections import defaultdict, Counter, OrderedDict +import matplotlib.pyplot as plt +import numpy as np +import torch +import pandas as pd +from torch.utils.data import TensorDataset, DataLoader +from torch import optim, nn +import torch.nn.functional as F + +import utils.scripts.gpu_utils as gu +import utils.scripts.data_utils as du +import utils.scripts.synth_models as synth_models + +#import foolbox as fb +#from autoattack import AutoAttack + +# Misc +def get_yhat(model, data): return torch.argmax(model(data), 1) +def get_acc(y,yhat): return (y==yhat).sum().item()/float(len(y)) + +class PGD_Attack(object): + + def __init__(self, eps, lr, num_iter, loss_type, rand_eps=1e-3, + num_classes=2, bounds=(0.,1.), minimal=False, restarts=1, device=None): + self.eps = eps + self.lr = lr + self.num_iter = num_iter + self.B = bounds + self.restarts = restarts + self.rand_eps = rand_eps + self.device = device or gu.get_device(None) + self.loss_type = loss_type + self.num_classes = num_classes + self.classes = list(range(self.num_classes)) + self.delta = None + self.minimal = minimal # early stop + no eps + self.project = not self.minimal + self.loss = -np.inf + + def evaluate_attack(self, dl, model): + model = model.to(self.device) + Xa, Ya, Yh, P = [], [], [], [] + + for xb, yb in dl: + xb, yb = xb.to(self.device), yb.to(self.device) + delta = self.perturb(xb, yb, model) + xba = xb+delta + + with torch.no_grad(): + out = model(xba).detach() + yh = torch.argmax(out, dim=1) + xb, yb, yh, xba, delta = xb.cpu(), yb.cpu(), yh.cpu(), xba.cpu(), delta.cpu() + + Ya.append(yb) + Yh.append(yh) + Xa.append(xba) + P.append(delta) + + Xa, Ya, Yh, P = map(torch.cat, [Xa, Ya, Yh, P]) + ta_dl = utils._to_dl(Xa, Ya, dl.batch_size) + acc, loss = utils.compute_loss_and_accuracy_from_dl(ta_dl, model, + F.cross_entropy, + device=self.device) + return { + 'acc': acc.item(), + 'loss': loss.item(), + 'ta_dl': ta_dl, + 'Xa': Xa.numpy(), + 'Ya': Ya.numpy(), + 'Yh': Yh.numpy(), + 'P': P.numpy() + } + + def perturb(self, xb, yb, model, cpu=False): + model, xb, yb = model.to(self.device), xb.to(self.device), yb.to(self.device) + if self.eps == 0: return torch.zeros_like(xb) + + # compute perturbations and track best perturbations + self.loss = -np.inf + max_delta = self._perturb_once(xb, yb, model) + + with torch.no_grad(): + out = model(xb+max_delta) + max_loss = nn.CrossEntropyLoss(reduction='none')(out, yb) + + for _ in range(self.restarts-1): + delta = self._perturb_once(xb, yb, model) + + with torch.no_grad(): + out = model(xb+delta) + all_loss = nn.CrossEntropyLoss(reduction='none')(out, yb) + + loss_flag = all_loss >= max_loss + max_delta[loss_flag] = delta[loss_flag] + max_loss = torch.max(max_loss, all_loss) + + if cpu: max_delta = max_delta.cpu() + return max_delta + + def _perturb_once(self, xb, yb, model, track_scores=False, stop_const=1e-5): + self.delta = self._init_delta(xb, yb) + scores = [] + + # (minimal) mask perturbations if model already misclassifies + for t in range(self.num_iter): + loss, out = self._get_loss(xb, yb, model, get_scores=True) + + if self.minimal: + yh = torch.argmax(out, dim=1).detach() + not_flipped = yh == yb + not_flipped_ratio = not_flipped.sum().item()/float(len(yb)) + else: + not_flipped = None + not_flipped_ratio = 1.0 + + # stop if almost all examples in the batch misclassified + if not_flipped_ratio < stop_const: + break + + if track_scores: + scores.append(out.detach().cpu().numpy()) + + # compute loss, update + clamp delta + loss.backward() + self.loss = max(self.loss, loss.item()) + + self.delta = self._update_delta(xb, yb, update_mask=not_flipped) + self.delta = self._clamp_input(xb, yb) + + d = self.delta.detach() + + if track_scores: + scores = np.stack(scores).swapaxes(0, 1) + return d, scores + + return d + + def _init_delta(self, xb, yb): + delta = torch.empty_like(xb) + delta = delta.uniform_(-self.rand_eps, self.rand_eps) + delta = delta.to(self.device) + delta.requires_grad = True + return delta + + def _clamp_input(self, xb, yb): + # clamp delta s.t. X+delta in valid input range + self.delta.data = torch.max(self.B[0]-xb, + torch.min(self.B[1]-xb, + self.delta.data)) + return self.delta + + def _get_loss(self, xb, yb, model, get_scores=False): + out = model(xb+self.delta) + + if self.loss_type == 'untargeted': + L = -1*F.cross_entropy(out, yb) + + elif self.loss_type == 'targeted': + L = nn.CrossEntropyLoss()(out, yb) + + elif self.loss_type == 'random_targeted': + rand_yb = torch.randint(low=0, high=self.num_classes, size=(len(yb),), device=self.device) + #rand_yb[rand_yb==yb] = (yb[rand_yb==yb]+1) % self.num_classes + L = nn.CrossEntropyLoss()(out, rand_yb) + + elif self.loss_type == 'plusone_targeted': + next_yb = (yb+1) % self.num_classes + L = nn.CrossEntropyLoss()(out, next_yb) + + elif self.loss_type == 'binary_targeted': + yb_opp = 1-yb + L = nn.CrossEntropyLoss()(out, yb_opp) + + elif self.loss_type == 'binary_hybrid': + yb_opp = 1-yb + L = nn.CrossEntropyLoss()(out, yb_opp) - nn.CrossEntropyLoss()(out, yb) + + else: + assert False, "unknown loss type" + + if get_scores: return L, out + return L + +class L2_PGD_Attack(PGD_Attack): + + OVERFLOW_CONST = 1e-10 + + def get_norms(self, X): + nch = len(X.shape) + return X.view(X.shape[0], -1).norm(dim=1)[(...,) + (None,)*(nch-1)] + + def _update_delta(self, xb, yb, update_mask=None): + # normalize gradients + grad = self.delta.grad.detach() + norms = self.get_norms(grad) + grad = grad/(norms+self.OVERFLOW_CONST) # add const to avoid overflow + + # steepest descent + if self.minimal and update_mask is not None: + um = update_mask + self.delta.data[um] = self.delta.data[um] - self.lr*grad[um] + else: + self.delta.data = self.delta.data - self.lr*grad + + # l2 ball projection + if self.project: + delta_norms = self.get_norms(self.delta.data) + self.delta.data = self.eps*self.delta.data / (delta_norms.clamp(min=self.eps)) + + self.delta.grad.zero_() + return self.delta + + def _init_delta(self, xb, yb): + # random vector with L2 norm rand_eps + delta = torch.zeros_like(xb) + delta = delta.uniform_(-self.rand_eps, self.rand_eps) + delta_norm = self.get_norms(delta) + delta = self.rand_eps*delta/(delta_norm+self.OVERFLOW_CONST) + delta = delta.to(self.device) + delta.requires_grad = True + return delta + +class Linf_PGD_Attack(PGD_Attack): + + def _update_delta(self, xb, yb, **kw): + # steepest descent + linf projection (GD) + self.delta.data = self.delta.data - self.lr*(self.delta.grad.detach().sign()) + self.delta.data = self.delta.data.clamp(-self.eps, self.eps) + self.delta.grad.zero_() + return self.delta + diff --git a/utils/scripts/synth_models.py b/utils/scripts/synth_models.py new file mode 100644 index 0000000..8a08817 --- /dev/null +++ b/utils/scripts/synth_models.py @@ -0,0 +1,160 @@ +import sys, copy +import torch, torchvision +from torch import optim, nn +import torch.nn.functional as F +from torch.utils.data import TensorDataset, DataLoader +import utils.scripts.gendata as gendata +import utils.scripts.utils as utils +import numpy as np +import utils.scripts.gpu_utils as gu +import utils.scripts.ptb_utils as pu + +def kaiming_init(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + nn.init.kaiming_uniform_(m.bias.data) + +class SequenceClassifier(nn.Module): + + def __init__(self, seq_model, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): + super(SequenceClassifier, self).__init__() + self.seq_model = seq_model + self.hdim = hdim + self.hl = hl + self.input_size = input_size + self.idim = idim + self.num_classes = num_classes + self.unsqueeze_input = unsqueeze_input + self.many_to_many = many_to_many + + self.seq_length = self.idim//self.input_size + self.seq = self.seq_model(input_size=input_size, hidden_size=hdim, num_layers=hl, batch_first=True) + self.lin_idim = hdim*self.seq_length if many_to_many else hdim + self.lin = nn.Linear(self.lin_idim, num_classes) + + def forward(self, x): + if self.unsqueeze_input: x = x.unsqueeze(2) + bsize, idim, _ = x.shape + seq_length = idim//self.input_size + x = x.view((bsize, seq_length, self.input_size)) + out, hidden = self.seq(x) + lin_in = out[:,-1,:] + if self.many_to_many: lin_in = out.contiguous().view((bsize, -1)) + lin_out = self.lin(lin_in) + return lin_out + +class GRUClassifier(SequenceClassifier): + + def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): + super(GRUClassifier, self).__init__(nn.GRU, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input) + +class LSTMClassifier(SequenceClassifier): + + def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): + super(LSTMClassifier, self).__init__(nn.LSTM, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input) + +class CNNClassifier(nn.Module): + + def __init__(self, out_channels, hl, kernel_size, idim, num_classes=2, padding=None, stride=1, maxpool_kernel_size=None, use_maxpool=False): + """ + Fixed architecture: + - default max pool kernel size half of convolution kernel size + - default padding = kernel size - 1 // 2 to maintain same dimension + - stride = 1 + - 1 FC layer + """ + if padding == None: assert kernel_size % 2 == 1, "use odd kernel size, equal padding constraint" + super(CNNClassifier, self).__init__() + self.out_channels = out_channels + self.num_conv = hl + self.kernel_size = kernel_size + self.padding = padding or (self.kernel_size-1)//2 + self.stride = 1 + self.num_classes = 2 + self.idim = idim + self.use_maxpool = use_maxpool + self.maxpool_kernel_size = maxpool_kernel_size or self.kernel_size//2 + + self.maxpool = nn.MaxPool1d(self.maxpool_kernel_size) + self.ih_conv = nn.Conv1d(1, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride) + + self.hh_convs = [] + for _ in range(self.num_conv-1): + self.hh_convs.append(nn.Conv1d(self.out_channels, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride)) + self.hh_convs.append(nn.ReLU()) + self.hh_convs = nn.Sequential(*self.hh_convs) + + fc_idim = int(self.idim/self.maxpool_kernel_size) if self.use_maxpool else self.idim + self.fc_layer = nn.Linear(self.out_channels*fc_idim, self.idim) + self.out_layer = nn.Linear(self.idim, self.num_classes) + self.relu = nn.ReLU() + + def forward(self, x): + bs = x.shape[0] + x_ = x.unsqueeze(1) + + x_ = self.relu(self.ih_conv(x_)) + x_ = self.hh_convs(x_) + + if self.use_maxpool: x_ = self.maxpool(x_) + x_ = self.relu(self.fc_layer(x_.view(bs, -1))) + + return self.out_layer(x_) + +class CNN2DClassifier(nn.Module): + + def __init__(self, num_filters, filter_size, num_layers, input_shape, input_channels=1, stride=2, padding=None, num_stride2_layers=2, fc_idim=None, fc_odim=None, num_classes=2, use_avgpool=True, avgpool_ksize=5): + super(CNN2DClassifier, self).__init__() + self.outch = num_filters + self.fsize = filter_size + self.input_channels = input_channels + self.hl = num_layers + self.padding = (self.fsize-1)//2 if padding is None else padding + self.num_classes = num_classes + num_stride2_layers = num_stride2_layers + self.strides = iter([stride]*num_stride2_layers+[1]*(num_layers-num_stride2_layers)) + self.use_avgpool = use_avgpool + self.avgpool_ksize = avgpool_ksize + + self.convs = [nn.Conv2d(self.input_channels, self.outch, self.fsize, padding=self.padding, stride=next(self.strides)), nn.ReLU()] + if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize)) + + for _ in range(self.hl-1): + self.convs.append(nn.Conv2d(self.outch, self.outch, self.fsize, stride=next(self.strides), padding=self.padding)) + self.convs.append(nn.ReLU()) + if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize)) + + self.convs = nn.Sequential(*self.convs) # need to wrap for gpu + sl = min(self.hl, num_stride2_layers) + self.fc_idim = int(num_filters*input_shape[0]*input_shape[1]/float(4**sl)) if fc_idim is None else fc_idim + self.fc_odim = fc_odim if fc_odim is not None else self.fc_idim + self.fc = nn.Linear(self.fc_idim, self.fc_odim) + self.out = nn.Linear(self.fc_odim, self.num_classes) + + def forward(self, x): + x = self.convs(x) + x = x.reshape(x.shape[0], -1) + return self.out(F.relu(self.fc(x))) + +def get_linear(input_dim, num_classes): + return nn.Sequential(nn.Linear(input_dim, num_classes)) + +def get_fcn(idim, hdim, odim, hl=1, init=False, activation=nn.ReLU, use_activation=True, use_bn=False, input_dropout=0, dropout=0): + use_dropout = dropout > 0 + layers = [] + if input_dropout > 0: layers.append(nn.Dropout(input_dropout)) + layers.append(nn.Linear(idim, hdim)) + if use_activation: layers.append(activation()) + if use_dropout: layers.append(nn.Dropout(dropout)) + if use_bn: layers.append(nn.BatchNorm1d(hdim)) + for _ in range(hl-1): + l = [nn.Linear(hdim, hdim)] + if use_activation: l.append(activation()) + if use_dropout: l.append(nn.Dropout(dropout)) + if use_bn: l.append(nn.BatchNorm1d(hdim)) + layers.extend(l) + layers.append(nn.Linear(hdim, odim)) + model = nn.Sequential(*layers) + + if init: model.apply(kaiming_init) + return model \ No newline at end of file diff --git a/utils/scripts/utils.py b/utils/scripts/utils.py new file mode 100644 index 0000000..e7cb2d9 --- /dev/null +++ b/utils/scripts/utils.py @@ -0,0 +1,717 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pickle +import copy +from collections import defaultdict, Counter, OrderedDict +import time +from torch.utils.data import TensorDataset, DataLoader +import torchvision +from torch import optim, nn +import torch.nn.functional as F +from scipy.linalg import qr +import utils.scripts.lms_utils as au +import utils.scripts.ptb_utils as pu +import utils.scripts.gpu_utils as gu +from sklearn import metrics +import collections +from sklearn.metrics import roc_auc_score + +plt.style.use('seaborn-ticks') +import matplotlib.ticker as ticker + +def get_orthonormal_matrix(n): + H = np.random.randn(n, n) + s = np.linalg.svd(H)[1] + s = s[s>1e-7] + if len(s) != n: return get_orthonormal_matrix(n) + Q, R = qr(H) + return Q + +def get_dataloader(X, Y, bs, **kw): + return DataLoader(TensorDataset(X, Y), batch_size=bs, **kw) + +def split_dataloader(dl, frac=0.5): + bs = dl.batch_size + X, Y = dl.dataset.tensors + p = torch.randperm(len(X)) + X, Y = X[p, :], Y[p] + n = int(round(len(X)*frac)) + X0, Y0 = X[:n, :], Y[:n] + X1, Y1 = X[n:, :], Y[n:] + dl0 = DataLoader(TensorDataset(torch.Tensor(X0), torch.LongTensor(Y0)), batch_size=bs, shuffle=True) + dl1 = DataLoader(TensorDataset(torch.Tensor(X1), torch.LongTensor(Y1)), batch_size=bs, shuffle=True) + return dl0, dl1 + +def _to_dl(X, Y, bs, shuffle=True): + return DataLoader(TensorDataset(torch.Tensor(X), torch.LongTensor(Y)), batch_size=bs, shuffle=shuffle) + +def extract_tensors_from_loader(dl, repeat=1, transform_fn=None): + X, Y = [], [] + for _ in range(repeat): + for xb, yb in dl: + if transform_fn: + xb, yb = transform_fn(xb, yb) + X.append(xb) + Y.append(yb) + X = torch.FloatTensor(torch.cat(X)) + Y = torch.LongTensor(torch.cat(Y)) + return X, Y + +def extract_numpy_from_loader(dl, repeat=1, transform_fn=None): + X, Y = extract_tensors_from_loader(dl, repeat=repeat, transform_fn=transform_fn) + return X.numpy(), Y.numpy() + +def _to_tensor_dl(dl, repeat=1, bs=None): + X, Y = extract_numpy_from_loader(dl, repeat=repeat) + dl = _to_dl(X, Y, bs if bs else dl.batch_size) + return dl + +def flatten_loader(dl, bs=None): + X, Y = extract_numpy_from_loader(dl) + X = X.reshape(X.shape[0], -1) + return _to_dl(X, Y, bs=bs if bs else dl.batch_size) + +def merge_loaders(dla, dlb): + bs = dla.batch_size + Xa, Ya = extract_numpy_from_loader(dla) + Xb, Yb = extract_numpy_from_loader(dlb) + return _to_dl(np.concatenate([Xa, Xb]), np.concatenate([Ya, Yb]), bs) + +def transform_loader(dl, func, shuffle=True): + #assert type(dl.sampler) is torch.utils.data.sampler.SequentialSampler + X, Y = extract_numpy_from_loader(dl, transform_fn=func) + return _to_dl(X, Y, dl.batch_size, shuffle=shuffle) + +def visualize_tensors(P, size=8, normalize=True, scale_each=False, permute=True, ax=None, pad_value=0.): + if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4)) + if permute: + s = np.random.choice(len(P), size=size, replace=False) + p = P[s] + else: + p = P[:size] + g = torchvision.utils.make_grid(torch.FloatTensor(p), nrow=size, normalize=normalize, scale_each=scale_each, pad_value=pad_value) + g = g.permute(1,2,0).numpy() + ax.imshow(g) + ax.set_xticks([]) + ax.set_yticks([]) + return ax + +def visualize_loader(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None): + if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4)) + for xb, yb in dl: break + if reshape: xb = xb.reshape(len(xb), *reshape) + return visualize_tensors(xb, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax) + +def visualize_loader_by_class(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None): + for xb, yb in dl: break + if reshape: xb = xb.reshape(len(xb), *reshape) + + classes = list(set(list(yb.numpy()))) + fig, axs = plt.subplots(len(classes), 1, figsize=(15, 3*len(classes))) + + for y, ax in zip(classes, axs): + xb_ = xb[yb==y] + ax = visualize_tensors(xb_, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax) + ax.set_title('Class: {}'.format(y)) + + return fig + +def visualize_perturbations(P, transform_fn=None): + if transform_fn is not None: + P = transform_fn(P) + plt.figure(figsize=(20,4)) + s = np.random.choice(len(P), size=8, replace=False) + p = P[s] + g = torchvision.utils.make_grid(torch.FloatTensor(p)) + g = g.permute(1,2,0).numpy() + g = (g-g.min())/g.max() + plt.imshow(g) + +def get_logits_given_tensor(X, model, device=None, bs=250, softmax=False): + if device is None: device = gu.get_device(None) + sampler = torch.utils.data.SequentialSampler(X) + sampler = torch.utils.data.BatchSampler(sampler, bs, False) + + logits = [] + + with torch.no_grad(): + model = model.to(device) + for idx in sampler: + xb = X[idx].to(device) + out = model(xb) + logits.append(out) + + L = torch.cat(logits) + if softmax: return F.softmax(L, 1) + return L + +def get_predictions_given_tensor(X, model, device=None, bs=250): + out = get_logits_given_tensor(X, model, device=device, bs=bs) + return torch.argmax(out, 1) + +def get_accuracy_given_tensor(X, Y, model, device=None, bs=250): + if device is None: device = gu.get_device(None) + Y = torch.LongTensor(Y).to(device) + yhat = get_predictions_given_tensor(X, model, device=device, bs=bs) + return (Y==yhat).float().mean().item() + +def compute_accuracy(X, Y, model): + with torch.no_grad(): + pred = torch.argmax(model(X),1) + correct = (pred == Y).sum().item() + accuracy = correct/float(len(Y)) + return accuracy + +def compute_loss_and_accuracy_from_dl(dl, model, loss_fn, sample_pct=1.0, device=None, transform_fn=None): + in_tr_mode = model.training + model = model.eval() + data_size = float(len(dl.dataset)) + samp_size = int(np.ceil(sample_pct*data_size)) + num_eval = 0. + bs = dl.batch_size + accs, losses, bss = [], [], [] + + with torch.no_grad(): + for xb, yb in dl: + xb, yb = xb.to(device, non_blocking=False), yb.to(device, non_blocking=False) + + if transform_fn: + xb, yb = transform_fn(xb, yb) + + sc = model(xb) + + if loss_fn is F.cross_entropy: + loss = loss_fn(sc, yb, reduction='mean') + pred = torch.argmax(sc, 1) + elif loss_fn is F.binary_cross_entropy_with_logits: + loss = loss_fn(sc, yb.float().unsqueeze(1)) + pred = (sc > 0.).long().squeeze() + elif loss_fn is hinge_loss: + loss = loss_fn(sc, yb) + pred = (sc > 0).long().squeeze() + else: + try: + loss = loss_fn(sc, yb) + pred = torch.argmax(sc, 1) + except: + assert False, "unknown loss function" + + correct = (pred==yb).sum().float() + n = float(len(xb)) + losses.append(loss) + accs.append(correct/n) + bss.append(n) + + num_eval += n + if num_eval >= samp_size: break + + accs, losses, bss = map(np.array, [accs, losses, bss]) + if in_tr_mode: model = model.train() + return np.sum(bss*accs)/num_eval, np.sum(bs*losses)/num_eval + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def get_logits(model, loader, device): + S, Y = [], [] + with torch.no_grad(): + for xb, yb in loader: + xb = xb.to(device) + out = model(xb).cpu().numpy() + S.append(out) + Y.append(list(yb)) + S, Y = map(np.concatenate, [S, Y]) + return S, Y + +def get_scores(model, loader, device): + """binary tasks only""" + S, Y = get_logits(model, loader, device) + return S[:,1]-S[:,0], Y + +def get_multiclass_logit_score(L, Y): + scores = [] + for idx, (l, y) in enumerate(zip(L, Y)): + sc_y = l[y] + + indices = np.argsort(l) + best2_idx, best1_idx = indices[-2:] + sc_max = l[best2_idx] if y == best1_idx else l[best1_idx] + + score = sc_y - sc_max + scores.append(score) + + return np.array(scores) + +def get_binary_auc(model, loader, device): + S, Y = get_scores(model, loader, device) + return roc_auc_score(Y, S) + +def get_multiclass_auc(model, loader, device, one_vs_rest=True): + X, Y = extract_tensors_from_loader(loader) + S = get_logits_given_tensor(X, model, device=device, softmax=True).cpu() + mc = 'ovr' if one_vs_rest is True else 'ovo' + S, Y = S.numpy(), Y.numpy() + return roc_auc_score(Y, S, multi_class=mc) + +def clip_gradient(model, clip_value): + params = list(filter(lambda p: p.grad is not None, model.parameters())) + for p in params: p.grad.data.clamp_(-clip_value, clip_value) + +def print_model_gradients(model, print_bias=True): + for name, params in model.named_parameters(): + if not print_bias and 'bias' in name: continue + if not params.requires_grad: continue + avg_grad = np.mean(params.grad.cpu().numpy()) + print (name, params.shape, avg_grad) + +def hinge_loss(out, y): + y_ = (2*y.float()-1).unsqueeze(1) + return torch.mean(F.relu(1-out*y_)) + +def pgd_adv_fit_model(model, opt, tr_dl, te_dl, attack, eval_attack=None, device=None, sch=None, max_epochs=100, epoch_gap=2, + min_loss=0.001, print_info=True, save_init_model=True): + + # setup tracking + PR = lambda x: print (x) if print_info else None + stop_training = False + stats = defaultdict(list) + best_val, best_model = np.inf, None + adv_epoch_timer = [] + epoch_gap_timer = [time.time()] + init_model = copy.deepcopy(model).cpu() if save_init_model else None + + # eval attack + eval_attack = eval_attack or attack + + print ("Min loss: {}".format(min_loss)) + + def standard_epoch(loader, model, optimizer=None, sch=None): + """compute accuracy and loss. Backprop if optimizer provided""" + total_loss, total_err = 0.,0. + model = model.eval() if optimizer is None else model.train() + model = model.to(device) + update_params = optimizer is not None + + with torch.set_grad_enabled(update_params): + for xb, yb in loader: + xb, yb = xb.to(device), yb.to(device) + yp = model(xb) + loss = F.cross_entropy(yp, yb) + if update_params: + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_err += (yp.max(dim=1)[1] != yb).sum().item() + total_loss += loss.item() * xb.shape[0] + return total_err / len(loader.dataset), total_loss / len(loader.dataset) + + def adv_epoch(loader, model, attack, optimizer=None, sch=None): + """compute adv accuracy and loss. Backprop if optimizer provided""" + start_time = time.time() + total_loss, total_err = 0.,0. + model = model.eval() if optimizer is None else model.train() + model = model.to(device) + update_params = optimizer is not None + + for xb, yb in loader: + torch.set_grad_enabled(True) + xb, yb = xb.to(device), yb.to(device) + delta = attack.perturb(xb, yb, model).to(device) + xb = xb + delta + with torch.set_grad_enabled(update_params): + yp = model(xb) + loss = F.cross_entropy(yp, yb) + if update_params: + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_err += (yp.max(dim=1)[1] != yb).sum().item() + total_loss += loss.item() * xb.shape[0] + + if optimizer is not None and sch is not None: + cur_lr = next(iter(opt.param_groups))['lr'] + sch.step() + new_lr = next(iter(opt.param_groups))['lr'] + if new_lr != cur_lr: + PR('Epoch {}, LR : {} -> {}'.format(epoch, cur_lr, new_lr)) + + total_time = time.time()-start_time + adv_epoch_timer.append(total_time) + return total_err / len(loader.dataset), total_loss / len(loader.dataset) + + epoch = 0 + while epoch < max_epochs: + if stop_training: + break + try: + stat = {} + model = model.train() + train_err, train_loss = adv_epoch(tr_dl, model, attack, optimizer=opt, sch=sch) + + if epoch % epoch_gap == 0: + model = model.eval() + test_err, test_loss = standard_epoch(te_dl, model, optimizer=None, sch=None) + adv_err, adv_loss = adv_epoch(te_dl, model, eval_attack, optimizer=None, sch=None) + stat['acc_te'], stat['acc_te_std'] = adv_err, test_err + stat['loss_te'], stat['loss_te_std'] = adv_loss, test_loss + + if adv_err < best_val: + best_val = adv_err + best_model = copy.deepcopy(model).eval() + + if print_info: + if epoch==0: print ("Epoch", "l-tr", "a-tr", "a-te", "s-te", "time", sep='\t') + #print (epoch, *("{:.4f}".format(i) for i in (train_loss, train_err)), sep=' ') + diff_time = time.time()-epoch_gap_timer[-1] + epoch_gap_timer.append(time.time()) + print (epoch, *("{:.4f}".format(i) for i in (train_loss, 1.-train_err, 1.-adv_err, 1.-test_err, diff_time)), sep=' ') + + if train_loss < min_loss: + stop_training = True + + print ("Epoch {}: accuracy {:.3f} and loss {:.3f}".format(epoch, 1-train_err, train_loss)) + + stat['epoch'] = epoch + stat['acc_tr'] = train_err + stat['loss_tr'] = train_loss + + for k, v in stat.items(): + stats[k].append(v) + + epoch += 1 + + except KeyboardInterrupt: + inp = input("LR num or Q or SAVE or GAP or MAXEPOCHS: ") + if inp.startswith('LR'): + lr = float(inp.split(' ')[-1]) + cur_lr = next(iter(opt.param_groups))['lr'] + PR("New LR: {}".format(lr)) + for g in opt.param_groups: g['lr'] = lr + if inp.startswith('Q'): + stop_training = True + if inp.startswith('SAVE'): + fpath = inp.split(' ')[-1] + stats['best_model'] = (best_val, best_model.cpu()) + torch.save({ + 'model': copy.deepcopy(model).cpu(), + 'stats': stats, + 'opt': copy.deepcopy(opt).cpu() + }, fpath) + PR(f'Saved to {fpath}') + if inp.startswith('GAP'): + _, gap = inp.split(' ') + gap = int(gap) + print ("epoch gap: {} -> {}".format(epoch_gap, gap)) + epoch_gap = gap + if inp.startswith('MAXEPOCHS'): + _, me = inp.split(' ') + me = int(me) + print ("max_epochs: {} -> {}".format(max_epochs, me)) + max_epochs = me + + stats['best_model'] = (best_val, best_model.cpu()) + stats['init_model'] = init_model + return stats + + +def fit_model(model, loss, opt, train_dl, valid_dl, sch=None, epsilon=1e-2, is_loss_epsilon=False, update_gap=50, update_print_gap=50, gap=None, + print_info=True, save_grads=False, test_dl=None, skip_epoch_eval=True, sample_pct=0.5, sample_loss_threshold=0.75, save_models=False, + print_grads=False, print_model_layers=False, tr_batch_fn=None, te_batch_fn=None, device=None, max_updates=800_000, patience_updates=1, + enable_redo=False, save_best_model=True, save_init_model=True, max_epochs=100000, **misc): + + # setup update metadata + MAX_LOSS_VAL = 1000000. + PR = lambda x: print (x) if print_info else None + use_epoch = False + if gap is not None: update_gap = update_print_gap = gap + bs_ratio = int(len(train_dl.dataset)/float(train_dl.batch_size)) + act_update_gap = update_gap if not use_epoch else update_gap*bs_ratio + act_pr_update_gap = update_print_gap if not use_epoch else update_print_gap*bs_ratio + PR("accuracy/loss measured every {} updates".format(act_update_gap)) + + if save_models: + PR("saving models every {} updates".format(act_update_gap)) + + PR("update_print_gap: {}, epss: {}, bs: {}, device: {}".format(act_pr_update_gap, epsilon, train_dl.batch_size, device or 'cpu')) + + # init_save setup + init_model = copy.deepcopy(model).cpu() if save_init_model else None + + # redo setup + if enable_redo: + init_model_sd = copy.deepcopy(model.state_dict()) + init_opt_sd = copy.deepcopy(opt.state_dict()) + else: + init_model_sd = None + init_opt_sd = None + + # best model setup + best_val, best_model = 0, None + + # tracking setup + start_time = time.time() + num_evals, num_epochs, num_updates, num_patience = 0, 0, 0, 0 + stats = dict(loss_tr=[], loss_te=[], acc_tr=[], acc_te=[], acc_test=[], loss_test=[], models=[], gradients=[]) + if save_models: stats['models'].append(copy.deepcopy(model).cpu()) + first_run, converged = True, False + print_stats_flag = update_print_gap is not None + exceeded_max = False + diverged = False + + def _evaluate(device=device): + model.eval() + with torch.no_grad(): + prev_loss = stats['loss_tr'][-1] if stats['loss_tr'] else 1. + tr_sample_pct = sample_pct if prev_loss > sample_loss_threshold else 1. + acc_tr, loss_tr = compute_loss_and_accuracy_from_dl(train_dl,model,loss,sample_pct=tr_sample_pct,device=device,transform_fn=tr_batch_fn) + acc_te, loss_te = compute_loss_and_accuracy_from_dl(valid_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn) + acc_tr, loss_tr, acc_te, loss_te = map(lambda x: x.item(), [acc_tr, loss_tr, acc_te, loss_te]) + stats['loss_tr'].append(loss_tr) + stats['loss_te'].append(loss_te) + stats['acc_tr'].append(acc_tr) + stats['acc_te'].append(acc_te) + + if test_dl is not None: + acc_test, loss_test = compute_loss_and_accuracy_from_dl(test_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn) + acc_test, loss_test = acc_test.item(), loss_test.item() + stats['acc_test'].append(acc_test) + stats['loss_test'].append(loss_test) + + if save_models: + stats['models'].append(copy.deepcopy(model).cpu()) + + def _update(x,y,diff_device, device=device, save_grads=False, print_grads=False): + model.train() + + # if diff_device: + # x = x.to(device, non_blocking=False) + # y = y.to(device, non_blocking=False) + + opt.zero_grad() + out = model(x) + if loss is F.cross_entropy or loss is hinge_loss: + bloss = loss(out, y) + elif loss is F.binary_cross_entropy_with_logits: + bloss = loss(out, y.float().unsqueeze(1)) + else: + try: + bloss = loss(out, y) + except: + assert False, "unknown loss function" + + bloss.backward() + if print_grads and print_info: print_model_gradients(model) + #clip_gradient(model, clip_value) + opt.step() + + if save_grads: + g = {k: v.grad.data.cpu().numpy() for k, v in model.named_parameters() if v.requires_grad} + stats['gradients'].append(g) + + opt.zero_grad() + model.eval() + + def print_time(): + end_time = time.time() + minutes, seconds = divmod(end_time-start_time, 60) + gap_valid = len(stats['acc_tr']) > 0 + gap = round(stats['acc_tr'][-1]-stats['acc_te'][-1],4) if gap_valid else 'na' + PR("converged after {} epochs in {}m {:1f}s, gap: {}".format(num_epochs, minutes, seconds, gap)) + + def print_stats(force_print=False): + + if test_dl is None: + atr, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'loss_tr']] + PR("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr)) + if not print_info and force_print: + print ("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr)) + else: + atr, aval, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'acc_test', 'loss_tr']] + PR("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr)) + if not print_info and force_print: + print ("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr)) + + #xb_, yb_ = next(iter(train_dl)) + diff_device = True #xb_.device != device + + if test_dl is None: PR("#updates, train acc, test acc, train loss") + else: PR("#updates, train acc, val acc, test acc, train loss") + + while not converged or num_patience < patience_updates: + try: + model.train() + for xb, yb in train_dl: + + if tr_batch_fn: + xb, yb = tr_batch_fn(xb, yb) + + if diff_device: + xb = xb.to(device, non_blocking=False) + yb = yb.to(device, non_blocking=False) + + if converged: + num_patience += 1 + + if converged and num_patience == patience_updates: + _evaluate() + print_stats() + break + + # update flag for printing gradients + update_flag = print_model_layers and (num_updates == 0 or (num_updates % act_update_gap == 0 and print_grads)) + _update(xb, yb, diff_device, device=device, save_grads=save_grads, print_grads=update_flag) + + if (num_evals == 0 or num_updates % act_update_gap == 0): + num_evals += 1 + _evaluate() + print_stats() + + val_acc = stats['acc_te'][-1] + if num_updates > 0 and val_acc >= best_val: + best_val = val_acc + best_model = copy.deepcopy(model).eval() + + # check if loss has diverged + loss_val = max(stats['loss_tr'][-1], stats['loss_te'][-1]) + if loss_val > MAX_LOSS_VAL: diverged = True + if not np.isfinite(loss_val): diverged = True + + + if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon + else: stop = stats['acc_tr'][-1] >= 1-epsilon + + if not converged and diverged: + converged = True + print_time() + PR("loss diverging...exiting".format(patience_updates)) + + if not converged and stop: + converged = True + print_time() + PR("init-ing patience ({} updates)".format(patience_updates)) + + num_updates += 1 + first_run = False + + if num_updates > max_updates: + converged = True + exceeded_max = True + num_patience = patience_updates + PR("Exceeded max updates") + print_stats() + print_time() + break + + # re-eval at the end of epoch + if not converged: + num_epochs += 1 + + if not converged and num_epochs >= max_epochs: + converged = True + exceeded_max = True + num_patience = patience_updates + PR("Exceeded max epochs") + print_stats() + print_time() + break + + if not skip_epoch_eval: + _evaluate() + print_stats() + + if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon + else: stop = stats['acc_tr'][-1] >= 1-epsilon + + if not converged and stop: + converged = True + print_time() + PR("init-ing patience ({} updates)".format(patience_updates)) + + if num_patience >= patience_updates: + _evaluate() + print_stats() + break + + # update LR via scheduler + if sch is not None: + cur_lr = next(iter(opt.param_groups))['lr'] + sch.step() + new_lr = next(iter(opt.param_groups))['lr'] + if new_lr != cur_lr: + PR('Epoch {}, LR : {} -> {}'.format(num_epochs, cur_lr, new_lr)) + + except KeyboardInterrupt: + inp = input("LR num or Q or GAP num or SAVE fpath or EVAL or REDO: ") + if inp.startswith('LR'): + lr = float(inp.split(' ')[-1]) + cur_lr = next(iter(opt.param_groups))['lr'] + PR("LR: {} - > {}".format(cur_lr, lr)) + for g in opt.param_groups: g['lr'] = lr + elif inp.startswith('GAP'): + gap = int(inp.split(' ')[-1]) + act_update_gap = act_pr_update_gap = gap + elif inp == "Q": + converged = True + num_patience = patience_updates + print_time() + elif inp.startswith('SAVE'): + fpath = inp.split(' ')[-1] + torch.save({ + 'model': model, + 'opt': opt, + 'update_gap': update_gap + }, fpath) + elif inp == 'EVAL': + _evaluate() + print_stats(True) + elif inp == 'REDO': + if enable_redo: + model.load_state_dict(init_model_sd) + opt.load_state_dict(init_opt_sd) + else: + print ("REDO disabled") + + best_test = None + if test_dl is not None: + best_test = compute_loss_and_accuracy_from_dl(test_dl, best_model, loss, sample_pct=1.0, device=device)[0].item() + + stats['num_updates'] = num_updates + stats['num_epochs'] = num_epochs + stats['update_gap'] = update_gap + + stats['best_model'] = (best_val, best_test, best_model.cpu() if best_model else model.cpu()) + stats['init_model'] = init_model + if save_models: stats['models'].append(copy.deepcopy(model).cpu()) + + stats['x_updates']= list(range(0, num_evals*(update_gap+1), update_gap)) + stats['x'] = stats['x_updates'][:] + stats['x_epochs'] = list(range(num_epochs)) + stats['gap'] = stats['acc_tr'][-1]-stats['acc_te'][-1] + return stats + +def save_pickle(fname, d, mode='w'): + with open(fname, mode) as f: + pickle.dump(d, f) + +def load_pickle(fname, mode='r'): + with open(fname, mode) as f: + return pickle.load(f) + +def update_ax(ax, title=None, xlabel=None, ylabel=None, legend_loc='best', ticks=True, ticks_fs=10, label_fs=12, legend_fs=12, title_fs=14, hide_xlabels=False, hide_ylabels=False, despine=True): + if title: ax.set_title(title, fontsize=title_fs) + if xlabel: ax.set_xlabel(xlabel, fontsize=label_fs) + if ylabel: ax.set_ylabel(ylabel, fontsize=label_fs) + if legend_loc: ax.legend(loc=legend_loc, fontsize=legend_fs) + if despine: sns.despine(ax=ax) + + if ticks: + # ax.minorticks_on() + ax.tick_params(direction='in', length=6, width=2, colors='k', which='major', top=False, right=False) + ax.tick_params(direction='in', length=4, width=1, colors='k', which='minor', top=False, right=False) + ax.tick_params(labelsize=ticks_fs) + + if hide_xlabels: ax.set_xticks([]) + if hide_ylabels: ax.set_yticks([]) + return ax \ No newline at end of file diff --git a/utils/slab_data.py b/utils/slab_data.py new file mode 100644 index 0000000..36384e3 --- /dev/null +++ b/utils/slab_data.py @@ -0,0 +1,92 @@ +import sys +import random +import os, copy, pickle, time +import argparse +import itertools +from collections import defaultdict, Counter, OrderedDict +import matplotlib.pyplot as plt + +import numpy as np +import seaborn as sns +import torch +import torchvision +from torch import optim, nn +import torch.nn.functional as F +from torch.utils.data import TensorDataset, DataLoader +from torch.autograd import Variable +import pandas as pd + +import utils.scripts.gpu_utils as gu +import utils.scripts.gendata as gendata + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.enabled = True + +DEVICE_ID = 0 # GPU_ID or None (CPU) +DEVICE = gu.get_device(DEVICE_ID) + +def get_data(num_samples, spur_corr, total_slabs): + + if total_slabs== 5: + slab_5_flag= 1 + slab_7_flag= 0 + elif total_slabs== 7: + slab_5_flag= 0 + slab_7_flag= 1 + + c = config = { + 'num_train': num_samples, # training dataset size + 'dim': 2, # input dimension + 'lin_margin': 0.1, # linear margin + 'slab_margin': 0.1, # slab margin, + 'same_margin': True, # keep same margin + 'random_transform': True, # apply random (orthogonal) transformation + 'width': 1, # data width in standard basis + 'num_lin': 1, # number of linear components + 'num_slabs': slab_5_flag, #. number of 5 slabs + 'num_slabs7': slab_7_flag, # number of 7 slabs + 'num_slabs3': 0, # number of 3 slabs + 'bs': 256, # batch size + 'corrupt_lin': spur_corr, # p_noise + 'corrupt_lin_margin': True, # noise model + 'corrupt_slab': 0.0, # slab corruption + 'num_test': 0, # test dataset size + 'hdim': 100, # model width + 'hl': 2, # model depth + 'mtype': 'fcn', # model architecture + 'device': gu.get_device(DEVICE_ID), # GPU device + 'lr': 0.1, # step size + 'weight_decay': 5e-5 # weight decay + } + + smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin'] + data_func = gendata.generate_ub_linslab_data_v2 + spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7'] + data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'], + corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab']) + + + #Project data + W = data['W'] + X, Y = data['X'], data['Y'] + X = X.numpy().dot(W.T) + Y = Y.numpy() + + #Generate objects/slab_id array + # Equation: k*slab_len + (k-1)*0.2 = 2: k is total number of slabs + # Slabs: [-1, -1 + slab_len], [-1+slab_len+0.2, -1 + 0.2 + 2*slab_len ] + # Slabs: For i in range(0, k-1): [-1 + i*(0.2 + slab_len), -1 + slab_len + i*(0.2 + slab_len) ] + + k=total_slabs + slab_len= (2 - 0.2*(k-1))/k + O=np.zeros(X.shape[0]) + for idx in range(X.shape[0]): + for slab_id in range(k): + start= -1 + slab_id*(0.2 + slab_len) + end= start+ slab_len + if X[idx, 1] >= start and X[idx, 1] <= end: + O[idx]= slab_id + break + + return data, X, Y, O +