зеркало из https://github.com/microsoft/EdgeML.git
DROCC Code (#196)
* drocc * data processing files * Update main_tabular.py * Update README.md * CIFAR, data processing scripts fodler * Update README.md * Update README.md * arg parser, other touchups * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update process_epilepsy.py * Update README.md * DROCC: update docs * Update README.md * abalone preprocessing improved Co-authored-by: Moksh Jain <mokshjn00@gmail.com>
This commit is contained in:
Родитель
f835ab1abc
Коммит
a1021dfa12
|
@ -17,6 +17,7 @@ Algorithms that shine in this setting in terms of both model size and compute, n
|
|||
- **EMI-RNN**: Training routine to recover the critical signature from time series data for faster and accurate RNN predictions.
|
||||
- **Shallow RNN**: A meta-architecture for training RNNs that can be applied to streaming data.
|
||||
- **FastRNN & FastGRNN - FastCells**: **F**ast, **A**ccurate, **S**table and **T**iny (**G**ated) RNN cells.
|
||||
- **DROCC**: **D**eep **R**obust **O**ne-**C**lass **C**lassfiication for training robust anomaly detectors.
|
||||
|
||||
These algorithms can train models for classical supervised learning problems
|
||||
with memory requirements that are orders of magnitude lower than other modern
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Deep Robust One-Class Classification
|
||||
In this directory we present examples of how to use the `DROCCTrainer` to replicate results in [paper](https://proceedings.icml.cc/book/4293.pdf).
|
||||
|
||||
`DROCCTrainer` is part of the `edgeml_pytorch` package. Please install the `edgeml_pytorch` package as follows:
|
||||
```
|
||||
git clone https://github.com/microsoft/EdgeML
|
||||
cd EdgeML/pytorch
|
||||
pip install -r requirements-gpu.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Tabular Experiments
|
||||
Data is expected in the following format:
|
||||
```
|
||||
train_data.npy: features of train data
|
||||
test_data.npy: features of test data
|
||||
train_labels.npy: labels for train data (Normal Class Labelled as 1)
|
||||
test_labels.npy: labels for test data
|
||||
```
|
||||
|
||||
### Arrhythmia and Thyroid
|
||||
* Download the datasets from the ODDS Repository, [Arrhythmia](http://odds.cs.stonybrook.edu/arrhythmia-dataset/) and [Thyroid](http://odds.cs.stonybrook.edu/annthyroid-dataset/). This will consist of `arrhythmia.mat` or `annthyroid.mat`.
|
||||
* The data is divided for training as presented in previous works: [DAGMM](https://openreview.net/forum?id=BJJLHbb0-) and [GOAD](https://openreview.net/forum?id=H1lK_lBtvS).
|
||||
* To generate the training and test data, use the `data_process_scripts/process_odds.py` script as follows
|
||||
```
|
||||
python data_process_scripts/process_odds.py -d <path/to/downloaded_data/file_name.mat> -o <output path>
|
||||
```
|
||||
The output path is referred to as "root_data" in the following section.
|
||||
|
||||
### Abalone
|
||||
* Download the `abalone.data` file from the UCI Repository [here](http://archive.ics.uci.edu/ml/datasets/Abalone).
|
||||
* To generate the training and test data, use the `data_process_scripts/process_abalone.py` script as follows
|
||||
```
|
||||
python data_process_scripts/process_abalone.py -d <path/to/data/abalone.data> -o <output path>
|
||||
```
|
||||
The output path is referred to as "root_data" in the following section.
|
||||
|
||||
### Command to run experiments to reproduce results
|
||||
#### Arrhythmia
|
||||
```
|
||||
python3 main_tabular.py --hd 128 --lr 0.0001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 16 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data"
|
||||
```
|
||||
|
||||
#### Thyroid
|
||||
```
|
||||
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 2.5 --batch_size 256 --epochs 100 --optim 0 --restore 0 --metric F1 -d "root_data"
|
||||
```
|
||||
|
||||
#### Abalone
|
||||
```
|
||||
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 3 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data"
|
||||
```
|
||||
|
||||
|
||||
## Time-Series Experiments
|
||||
|
||||
### Data Processing
|
||||
### Epilepsy
|
||||
* Download the dataset from the UCI Repository [here](https://archive.ics.uci.edu/ml/datasets/Epileptic+Seizure+Recognition). This will consists of a `data.csv` file.
|
||||
* To generate the training and test data, use the `data_process_scripts/process_epilepsy.py` script as follows
|
||||
|
||||
```
|
||||
python data_process_scripts/process_epilepsy.py -d <path/to/data/data.csv> -o <output path>
|
||||
```
|
||||
The output path is referred to as "root_data" in the following section.
|
||||
|
||||
|
||||
### Example Usage for Epilepsy Dataset
|
||||
```
|
||||
python3 main_timeseries.py --hd 128 --lr 0.00001 --lamda 0.5 --gamma 2 --ascent_step_size 0.1 --radius 10 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric AUC -d "root_data"
|
||||
```
|
||||
|
||||
## CIFAR Experiments
|
||||
```
|
||||
python3 main_cifar.py --lamda 1 --radius 8 --lr 0.001 --gamma 1 --ascent_step_size 0.001 --batch_size 256 --epochs 40 --optim 0 --normal_class 0
|
||||
```
|
||||
|
||||
|
||||
### Arguments Detail
|
||||
normal_class => CIFAR10 class to be considered as normal
|
||||
lamda => Weightage to the loss from adversarially sampled negative points (\mu in the paper)
|
||||
radius => radius corresponding to the definition of set N_i(r)
|
||||
hd => LSTM Hidden Dimension
|
||||
optim => 0: Adam 1: SGD(M)
|
||||
ascent_step_size => step size for gradient ascent to generate adversarial anomalies
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Preprocess Abalone Data')
|
||||
parser.add_argument('-d', '--data_path', type=str, default='./abalone.data')
|
||||
parser.add_argument('-o', '--output_path', type=str, default='.')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = pd.read_csv(args.data_path, header=None, sep=',')
|
||||
|
||||
data = data.rename(columns={8: 'y'})
|
||||
|
||||
data['y'].replace([8, 9, 10], -1, inplace=True)
|
||||
data['y'].replace([3, 21], 0, inplace=True)
|
||||
data.iloc[:, 0].replace('M', 0, inplace=True)
|
||||
data.iloc[:, 0].replace('F', 1, inplace=True)
|
||||
data.iloc[:, 0].replace('I', 2, inplace=True)
|
||||
|
||||
test = data[data['y'] == 0]
|
||||
num_normal_samples_test = test.shape[0]
|
||||
|
||||
normal = data[data['y'] == -1].sample(frac=1)
|
||||
|
||||
test_data = np.concatenate((test.drop('y', axis=1), normal[:num_normal_samples_test].drop('y', axis=1)), axis=0)
|
||||
train = normal[num_normal_samples_test:]
|
||||
train_data = train.drop('y', axis=1).values
|
||||
train_labels = train['y'].replace(-1, 1)
|
||||
test_labels = np.concatenate((test['y'], normal[:num_normal_samples_test]['y'].replace(-1, 1)), axis=0)
|
||||
|
||||
np.save(os.path.join(args.output_path,'train_data.npy'), train_data)
|
||||
np.save(os.path.join(args.output_path,'train_labels.npy'), train_labels)
|
||||
np.save(os.path.join(args.output_path,'test_data.npy'), test_data)
|
||||
np.save(os.path.join(args.output_path,'test_labels.npy'), test_labels)
|
|
@ -0,0 +1,155 @@
|
|||
'''
|
||||
Code borrowed from https://github.com/lukasruff/Deep-SVDD-PyTorch
|
||||
'''
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from random import sample
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
class BaseADDataset(ABC):
|
||||
"""Anomaly detection dataset base class."""
|
||||
|
||||
def __init__(self, root: str):
|
||||
super().__init__()
|
||||
self.root = root # root path to data
|
||||
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = None # tuple with original class labels that define the normal class
|
||||
self.outlier_classes = None # tuple with original class labels that define the outlier class
|
||||
|
||||
self.train_set = None # must be of type torch.utils.data.Dataset
|
||||
self.test_set = None # must be of type torch.utils.data.Dataset
|
||||
|
||||
@abstractmethod
|
||||
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
|
||||
DataLoader, DataLoader):
|
||||
"""Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
class TorchvisionDataset(BaseADDataset):
|
||||
"""TorchvisionDataset class for datasets already implemented in torchvision.datasets."""
|
||||
|
||||
def __init__(self, root: str):
|
||||
super().__init__(root)
|
||||
|
||||
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
|
||||
DataLoader, DataLoader):
|
||||
train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train,
|
||||
num_workers=num_workers)
|
||||
test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test,
|
||||
num_workers=num_workers)
|
||||
return train_loader, test_loader
|
||||
|
||||
class CIFAR10_Dataset(TorchvisionDataset):
|
||||
|
||||
def __init__(self, root: str, normal_class=5):
|
||||
super().__init__(root)
|
||||
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = tuple([normal_class])
|
||||
self.outlier_classes = list(range(0, 10))
|
||||
self.outlier_classes.remove(normal_class)
|
||||
|
||||
# Pre-computed min and max values (after applying GCN) from train data per class
|
||||
# min_max = [(-28.94083453598571, 13.802961825439636),
|
||||
# (-6.681770233365245, 9.158067708230273),
|
||||
# (-34.924463588638204, 14.419298165027628),
|
||||
# (-10.599172931391799, 11.093187820377565),
|
||||
# (-11.945022995801637, 10.628045447867583),
|
||||
# (-9.691969487694928, 8.948326776180823),
|
||||
# (-9.174940012342555, 13.847014686472365),
|
||||
# (-6.876682005899029, 12.282371383343161),
|
||||
# (-15.603507135507172, 15.2464923804279),
|
||||
# (-6.132882973622672, 8.046098172351265)]
|
||||
# CIFAR-10 preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1]
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
|
||||
std=[0.247, 0.243, 0.261])])
|
||||
|
||||
target_transform = transforms.Lambda(lambda x: int(x not in self.outlier_classes))
|
||||
|
||||
train_set = MyCIFAR10(root=self.root, train=True, download=True,
|
||||
transform=transform, target_transform=target_transform)
|
||||
|
||||
# Subset train set to normal class
|
||||
train_idx_normal = get_target_label_idx(train_set.targets, self.normal_classes)
|
||||
# train_idx_normal_train = sample(train_idx_normal, 4000)
|
||||
# val_idx_normal = [x for x in train_idx_normal if x not in train_idx_normal_train]
|
||||
|
||||
# rest_train_classes = get_target_label_idx(train_set.train_labels, self.outlier_classes)
|
||||
# rest_train_classes_subset = sample(rest_train_classes, 9000)
|
||||
# val_idx = val_idx_normal + rest_train_classes_subset
|
||||
self.train_set = Subset(train_set, train_idx_normal)
|
||||
# self.test_set = Subset(train_set, val_idx)
|
||||
self.test_set = MyCIFAR10(root=self.root, train=False, download=True,
|
||||
transform=transform, target_transform=target_transform)
|
||||
|
||||
|
||||
class MyCIFAR10(CIFAR10):
|
||||
"""Torchvision CIFAR10 class with patch of __getitem__ method to also return the index of a data sample."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyCIFAR10, self).__init__(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the CIFAR10 class.
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
triple: (image, target, index) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, index # only line changed
|
||||
|
||||
def get_target_label_idx(labels, targets):
|
||||
"""
|
||||
Get the indices of labels that are included in targets.
|
||||
:param labels: array of labels
|
||||
:param targets: list/tuple of target labels
|
||||
:return: list with indices of target labels
|
||||
"""
|
||||
return np.argwhere(np.isin(labels, targets)).flatten().tolist()
|
||||
|
||||
|
||||
def global_contrast_normalization(x: torch.tensor, scale='l2'):
|
||||
"""
|
||||
Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
|
||||
which is either the standard deviation, L1- or L2-norm across features (pixels).
|
||||
Note this is a *per sample* normalization globally across features (and not across the dataset).
|
||||
"""
|
||||
|
||||
assert scale in ('l1', 'l2')
|
||||
|
||||
n_features = int(np.prod(x.shape))
|
||||
|
||||
mean = torch.mean(x) # mean over all features (pixels) per sample
|
||||
x -= mean
|
||||
|
||||
if scale == 'l1':
|
||||
x_scale = torch.mean(torch.abs(x))
|
||||
|
||||
if scale == 'l2':
|
||||
x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features
|
||||
|
||||
x /= x_scale
|
||||
|
||||
return x
|
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--data_path', type=str, default='./data.csv')
|
||||
parser.add_argument('-o', '--output_path', type=str, default='.')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = pd.read_csv(args.data_path)
|
||||
|
||||
data['y'] = data['y'].replace(1, 0)
|
||||
|
||||
data['y'] = data['y'].replace([2, 3, 4, 5], 1)
|
||||
|
||||
|
||||
test = data[data['y'] == 0]
|
||||
normal = data[data['y'] == 1].sample(frac=1).reset_index(drop=True)
|
||||
|
||||
test = pd.concat([test, normal.iloc[:2300]])
|
||||
|
||||
normal = normal.iloc[2300:]
|
||||
|
||||
normal = normal.drop(['y', 'Unnamed: 0'], axis=1)
|
||||
np.save(os.path.join(args.output_path, 'train.npy'), normal.values)
|
||||
|
||||
test = test.drop('Unnamed: 0', axis=1)
|
||||
test = test.sample(frac=1).reset_index(drop=True)
|
||||
|
||||
labels = test['y'].values
|
||||
|
||||
test = test.drop('y', axis=1).values
|
||||
np.save(os.path.join(args.output_path, 'test_data.npy'), test)
|
||||
np.save(os.path.join(args.output_path, 'test_labels.npy'), labels)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
import numpy as np
|
||||
|
||||
train_f = np.load('train_seven.npz')['features'] # containing only the class marvin
|
||||
others_f = np.load('other_seven.npz')['features'] # containing classes other than marvin
|
||||
|
||||
np.random.shuffle(train_f)
|
||||
np.random.shuffle(others_f)
|
||||
|
||||
len_train = 0.8 * len(train_f)
|
||||
len_test = len(train_f) - len_train
|
||||
|
||||
data = train_f[:len_train]
|
||||
np.save('train.npy', data)
|
||||
|
||||
test_data = np.concatenate((train_f[len_train:], others_f[len_t:len_train+len_test]), axis=0)
|
||||
labels = [1] * len_test + [0] * len_test
|
||||
np.save('test_data.npy', test_data)
|
||||
np.save('test_labels.npy', labels)
|
|
@ -0,0 +1,37 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from scipy.io import loadmat
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Preprocess Dataset from ODDS Repository')
|
||||
parser.add_argument('-d', '--data_path', type=str, default='./arrhythmia.mat')
|
||||
parser.add_argument('-o', '--output_path', type=str, default='.')
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = loadmat(args.data_path)
|
||||
|
||||
data = np.concatenate((dataset['X'], dataset['y']), axis=1)
|
||||
|
||||
test = data[data[:,-1] == 1]
|
||||
num_normal_samples_test = test.shape[0]
|
||||
|
||||
normal = data[data[:,-1] == 0]
|
||||
np.random.shuffle(normal)
|
||||
|
||||
test = np.concatenate((test, normal[:num_normal_samples_test]), axis=0)
|
||||
|
||||
train = normal[num_normal_samples_test:]
|
||||
train_data = train[:,:-1]
|
||||
# DROCC requires normal data to be labelled 1
|
||||
train_labels = np.ones(train_data.shape[0])
|
||||
|
||||
test_data = test[:,:-1]
|
||||
# DROCC requires normal data to be labelled 1 and anomalies 0
|
||||
test_labels = np.concatenate((
|
||||
np.zeros(num_normal_samples_test), np.ones(num_normal_samples_test)),
|
||||
axis=0)
|
||||
|
||||
np.save(os.path.join(args.output_path,'train_data.npy'), train_data)
|
||||
np.save(os.path.join(args.output_path,'train_labels.npy'), train_labels)
|
||||
np.save(os.path.join(args.output_path,'test_data.npy'), test_data)
|
||||
np.save(os.path.join(args.output_path,'test_labels.npy'), test_labels)
|
|
@ -0,0 +1,152 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from collections import OrderedDict
|
||||
from data_process_scripts.process_cifar import CIFAR10_Dataset
|
||||
from edgeml_pytorch.trainer.drocc_trainer import DROCCTrainer
|
||||
|
||||
class CIFAR10_LeNet(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(CIFAR10_LeNet, self).__init__()
|
||||
|
||||
self.rep_dim = 128
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, 5, bias=False, padding=2)
|
||||
self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
|
||||
self.conv2 = nn.Conv2d(32, 64, 5, bias=False, padding=2)
|
||||
self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
|
||||
self.conv3 = nn.Conv2d(64, 128, 5, bias=False, padding=2)
|
||||
self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
|
||||
self.fc1 = nn.Linear(128 * 4 * 4, self.rep_dim, bias=False)
|
||||
self.fc2 = nn.Linear(self.rep_dim, int(self.rep_dim/2), bias=False)
|
||||
self.fc3 = nn.Linear(int(self.rep_dim/2), 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2d1(x)))
|
||||
x = self.conv2(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2d2(x)))
|
||||
x = self.conv3(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2d3(x)))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.leaky_relu(self.fc1(x))
|
||||
x = F.leaky_relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
def adjust_learning_rate(epoch, total_epochs, only_ce_epochs, learning_rate, optimizer):
|
||||
"""Adjust learning rate during training.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epoch: Current training epoch.
|
||||
total_epochs: Total number of epochs for training.
|
||||
only_ce_epochs: Number of epochs for initial pretraining.
|
||||
learning_rate: Initial learning rate for training.
|
||||
"""
|
||||
#We dont want to consider the only ce
|
||||
#based epochs for the lr scheduler
|
||||
epoch = epoch - only_ce_epochs
|
||||
drocc_epochs = total_epochs - only_ce_epochs
|
||||
# lr = learning_rate
|
||||
if epoch <= drocc_epochs:
|
||||
lr = learning_rate * 0.01
|
||||
if epoch <= 0.80 * drocc_epochs:
|
||||
lr = learning_rate * 0.1
|
||||
if epoch <= 0.40 * drocc_epochs:
|
||||
lr = learning_rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
return optimizer
|
||||
|
||||
def main():
|
||||
|
||||
dataset = CIFAR10_Dataset("data", args.normal_class)
|
||||
train_loader, test_loader = dataset.loaders(batch_size=args.batch_size)
|
||||
model = CIFAR10_LeNet().to(device)
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if args.optim == 1:
|
||||
optimizer = optim.SGD(model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.mom)
|
||||
print("using SGD")
|
||||
else:
|
||||
optimizer = optim.Adam(model.parameters(),
|
||||
lr=args.lr)
|
||||
print("using Adam")
|
||||
|
||||
# Training the model
|
||||
trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device)
|
||||
|
||||
# Restore from checkpoint
|
||||
if args.restore == 1:
|
||||
if os.path.exists(os.path.join(args.model_dir, 'model.pt')):
|
||||
trainer.load(args.model_dir)
|
||||
print("Saved Model Loaded")
|
||||
|
||||
trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
|
||||
metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = 0)
|
||||
|
||||
trainer.save(args.model_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.set_printoptions(precision=5)
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Simple Training')
|
||||
parser.add_argument('--normal_class', type=int, default=0, metavar='N',
|
||||
help='CIFAR10 normal class index')
|
||||
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
|
||||
help='batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100, metavar='N',
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N',
|
||||
help='number of epochs to train with only CE loss')
|
||||
parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N',
|
||||
help='Number of gradient ascent steps')
|
||||
parser.add_argument('--hd', type=int, default=128, metavar='N',
|
||||
help='Num hidden nodes for LSTM model')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate')
|
||||
parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR',
|
||||
help='step size of gradient ascent')
|
||||
parser.add_argument('--mom', type=float, default=0.99, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--model_dir', default='log',
|
||||
help='path where to save checkpoint')
|
||||
parser.add_argument('--one_class_adv', type=int, default=1, metavar='N',
|
||||
help='adv loss to be used or not, 1:use 0:not use(only CE)')
|
||||
parser.add_argument('--radius', type=float, default=0.2, metavar='N',
|
||||
help='radius corresponding to the definition of set N_i(r)')
|
||||
parser.add_argument('--lamda', type=float, default=1, metavar='N',
|
||||
help='Weight to the adversarial loss')
|
||||
parser.add_argument('--reg', type=float, default=0, metavar='N',
|
||||
help='weight reg')
|
||||
parser.add_argument('--restore', type=int, default=0, metavar='N',
|
||||
help='whether to load a pretrained model, 1: load 0: train from scratch ')
|
||||
parser.add_argument('--optim', type=int, default=0, metavar='N',
|
||||
help='0 : Adam 1: SGD')
|
||||
parser.add_argument('--gamma', type=float, default=2.0, metavar='N',
|
||||
help='r to gamma * r projection for the set N_i(r)')
|
||||
parser.add_argument('-d', '--data_path', type=str, default='.')
|
||||
parser.add_argument('--metric', type=str, default='AUC')
|
||||
args = parser. parse_args()
|
||||
|
||||
# settings
|
||||
#Checkpoint store path
|
||||
model_dir = args.model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
main()
|
|
@ -0,0 +1,179 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
from edgeml_pytorch.trainer.drocc_trainer import DROCCTrainer
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Multi-layer perceptron with single hidden layer.
|
||||
"""
|
||||
def __init__(self,
|
||||
input_dim=2,
|
||||
num_classes=1,
|
||||
num_hidden_nodes=20):
|
||||
super(MLP, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.num_classes = num_classes
|
||||
self.num_hidden_nodes = num_hidden_nodes
|
||||
activ = nn.ReLU(True)
|
||||
self.feature_extractor = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(self.input_dim, self.num_hidden_nodes)),
|
||||
('relu1', activ)]))
|
||||
self.size_final = self.num_hidden_nodes
|
||||
|
||||
self.classifier = nn.Sequential(OrderedDict([
|
||||
('fc1', nn.Linear(self.size_final, self.num_classes))]))
|
||||
|
||||
def forward(self, input):
|
||||
features = self.feature_extractor(input)
|
||||
logits = self.classifier(features.view(-1, self.size_final))
|
||||
return logits
|
||||
|
||||
def adjust_learning_rate(epoch, total_epochs, only_ce_epochs, learning_rate, optimizer):
|
||||
"""Adjust learning rate during training.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epoch: Current training epoch.
|
||||
total_epochs: Total number of epochs for training.
|
||||
only_ce_epochs: Number of epochs for initial pretraining.
|
||||
learning_rate: Initial learning rate for training.
|
||||
"""
|
||||
#We dont want to consider the only ce
|
||||
#based epochs for the lr scheduler
|
||||
epoch = epoch - only_ce_epochs
|
||||
drocc_epochs = total_epochs - only_ce_epochs
|
||||
# lr = learning_rate
|
||||
if epoch <= drocc_epochs:
|
||||
lr = learning_rate * 0.001
|
||||
if epoch <= 0.90 * drocc_epochs:
|
||||
lr = learning_rate * 0.01
|
||||
if epoch <= 0.60 * drocc_epochs:
|
||||
lr = learning_rate * 0.1
|
||||
if epoch <= 0.30 * drocc_epochs:
|
||||
lr = learning_rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
return optimizer
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, data, labels):
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if torch.is_tensor(idx):
|
||||
idx = idx.tolist()
|
||||
return torch.from_numpy(self.data[idx]), (self.labels[idx]), torch.tensor([0])
|
||||
|
||||
def load_data(path):
|
||||
train_data = np.load(os.path.join(path, 'train_data.npy'), allow_pickle = True)
|
||||
train_lab = np.ones((train_data.shape[0])) #All positive labelled data points collected
|
||||
test_data = np.load(os.path.join(path, 'test_data.npy'), allow_pickle = True)
|
||||
test_lab = np.load(os.path.join(path, 'test_labels.npy'), allow_pickle = True)
|
||||
|
||||
## preprocessing
|
||||
mean=np.mean(train_data,0)
|
||||
std=np.std(train_data,0)
|
||||
train_data=(train_data-mean)/ (std + 1e-4)
|
||||
num_features = train_data.shape[1]
|
||||
test_data = (test_data - mean)/(std + 1e-4)
|
||||
|
||||
train_samples = train_data.shape[0]
|
||||
test_samples = test_data.shape[0]
|
||||
print("Train Samples: ", train_samples)
|
||||
print("Test Samples: ", test_samples)
|
||||
|
||||
return CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab), num_features
|
||||
|
||||
def main():
|
||||
train_dataset, test_dataset, num_features = load_data(args.data_path)
|
||||
train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, args.batch_size, shuffle=True)
|
||||
|
||||
model = MLP(input_dim=num_features, num_hidden_nodes=args.hd, num_classes=1).to(device)
|
||||
if args.optim == 1:
|
||||
optimizer = optim.SGD(model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.mom)
|
||||
print("using SGD")
|
||||
else:
|
||||
optimizer = optim.Adam(model.parameters(),
|
||||
lr=args.lr)
|
||||
print("using Adam")
|
||||
|
||||
# Training the model
|
||||
trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device)
|
||||
|
||||
# Restore from checkpoint
|
||||
if args.restore == 1:
|
||||
if os.path.exists(os.path.join(args.model_dir, 'model.pt')):
|
||||
trainer.load(args.model_dir)
|
||||
print("Saved Model Loaded")
|
||||
|
||||
trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
|
||||
metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs)
|
||||
|
||||
trainer.save(args.model_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.set_printoptions(precision=5)
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Simple Training')
|
||||
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
|
||||
help='batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100, metavar='N',
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N',
|
||||
help='number of epochs to train with only CE loss')
|
||||
parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N',
|
||||
help='Number of gradient ascent steps')
|
||||
parser.add_argument('--hd', type=int, default=128, metavar='N',
|
||||
help='Number of hidden nodes for LSTM model')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate')
|
||||
parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR',
|
||||
help='step size of gradient ascent')
|
||||
parser.add_argument('--mom', type=float, default=0.99, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--model_dir', default='log',
|
||||
help='path where to save checkpoint')
|
||||
parser.add_argument('--one_class_adv', type=int, default=1, metavar='N',
|
||||
help='adv loss to be used or not, 1:use 0:not use(only CE)')
|
||||
parser.add_argument('--radius', type=float, default=0.2, metavar='N',
|
||||
help='radius corresponding to the definition of set N_i(r)')
|
||||
parser.add_argument('--lamda', type=float, default=1, metavar='N',
|
||||
help='Weight to the adversarial loss')
|
||||
parser.add_argument('--reg', type=float, default=0, metavar='N',
|
||||
help='weight reg')
|
||||
parser.add_argument('--restore', type=int, default=0, metavar='N',
|
||||
help='whether to load a pretrained model, 1: load 0: train from scratch')
|
||||
parser.add_argument('--optim', type=int, default=0, metavar='N',
|
||||
help='0 : Adam 1: SGD')
|
||||
parser.add_argument('--gamma', type=float, default=2.0, metavar='N',
|
||||
help='r to gamma * r projection for the set N_i(r)')
|
||||
parser.add_argument('-d', '--data_path', type=str, default='.')
|
||||
parser.add_argument('--metric', type=str, default='F1')
|
||||
args = parser.parse_args()
|
||||
|
||||
# settings
|
||||
#Checkpoint store path
|
||||
model_dir = args.model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
main()
|
|
@ -0,0 +1,179 @@
|
|||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from collections import OrderedDict
|
||||
from edgeml_pytorch.trainer.drocc_trainer import DROCCTrainer
|
||||
|
||||
class LSTM_FC(nn.Module):
|
||||
"""
|
||||
Single layer LSTM with a fully connected layer
|
||||
on the last hidden state
|
||||
"""
|
||||
def __init__(self,
|
||||
input_dim=32,
|
||||
num_classes=1,
|
||||
num_hidden_nodes=8
|
||||
):
|
||||
|
||||
super(LSTM_FC, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.num_classes = num_classes
|
||||
self.num_hidden_nodes = num_hidden_nodes
|
||||
self.encoder = nn.LSTM(input_size=self.input_dim,
|
||||
hidden_size=self.num_hidden_nodes,
|
||||
num_layers=1, batch_first=True)
|
||||
self.fc = nn.Linear(self.num_hidden_nodes,
|
||||
self.num_classes)
|
||||
activ = nn.ReLU(True)
|
||||
|
||||
def forward(self, input):
|
||||
features = self.encoder(input)[0][:,-1,:]
|
||||
# pdb.set_trace()
|
||||
logits = self.fc(features)
|
||||
return logits
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, data, labels):
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if torch.is_tensor(idx):
|
||||
idx = idx.tolist()
|
||||
return torch.from_numpy(self.data[idx]), (self.labels[idx]), torch.tensor([0])
|
||||
|
||||
def adjust_learning_rate(epoch, total_epochs, only_ce_epochs, learning_rate, optimizer):
|
||||
"""Adjust learning rate during training.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epoch: Current training epoch.
|
||||
total_epochs: Total number of epochs for training.
|
||||
only_ce_epochs: Number of epochs for initial pretraining.
|
||||
learning_rate: Initial learning rate for training.
|
||||
"""
|
||||
#We dont want to consider the only ce
|
||||
#based epochs for the lr scheduler
|
||||
epoch = epoch - only_ce_epochs
|
||||
drocc_epochs = total_epochs - only_ce_epochs
|
||||
# lr = learning_rate
|
||||
if epoch <= drocc_epochs:
|
||||
lr = learning_rate * 0.1
|
||||
if epoch <= 0.50 * drocc_epochs:
|
||||
lr = learning_rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
return optimizer
|
||||
|
||||
def load_data(path):
|
||||
train_data = np.load(os.path.join(path, 'train.npy'), allow_pickle = True)
|
||||
print("Train data Shape : ", train_data.shape)
|
||||
train_lab = np.ones((train_data.shape[0])) #All positive labelled data points collected
|
||||
test_data = np.load(os.path.join(path, 'test_data.npy'), allow_pickle = True)
|
||||
test_lab = np.load(os.path.join(path, 'test_labels.npy'), allow_pickle = True)
|
||||
|
||||
##preprocessing
|
||||
mean=np.mean(train_data,0)
|
||||
std=np.std(train_data,0)
|
||||
train_data=(train_data-mean)/std
|
||||
|
||||
test_data = (test_data - mean)/std
|
||||
train_samples = train_data.shape[0]
|
||||
test_samples = test_data.shape[0]
|
||||
print("Train Samples: ", train_samples)
|
||||
print("Test Samples: ", test_samples)
|
||||
train_data = np.expand_dims(train_data, axis =2)
|
||||
test_data = np.expand_dims(test_data, axis =2)
|
||||
num_features = train_data.shape[2]
|
||||
return CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab), num_features
|
||||
|
||||
|
||||
def main():
|
||||
train_dataset, test_dataset, num_features = load_data(args.data_path)
|
||||
train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, args.batch_size, shuffle=True)
|
||||
|
||||
model = LSTM_FC(input_dim=1, num_classes=1, num_hidden_nodes=args.hd).to(device)
|
||||
if args.optim == 1:
|
||||
optimizer = optim.SGD(model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.mom)
|
||||
print("using SGD")
|
||||
else:
|
||||
optimizer = optim.Adam(model.parameters(),
|
||||
lr=args.lr)
|
||||
print("using Adam")
|
||||
|
||||
# Training the model
|
||||
trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device)
|
||||
|
||||
# Restore from checkpoint
|
||||
if args.restore == 1:
|
||||
if os.path.exists(os.path.join(args.model_dir, 'model.pt')):
|
||||
trainer.load(args.model_dir)
|
||||
print("Saved Model Loaded")
|
||||
|
||||
trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
|
||||
metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs)
|
||||
|
||||
trainer.save(args.model_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.set_printoptions(precision=5)
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Simple Training')
|
||||
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
|
||||
help='batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100, metavar='N',
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=0, metavar='N',
|
||||
help='number of epochs to train with only CE loss')
|
||||
parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N',
|
||||
help='Number of gradient ascent steps')
|
||||
parser.add_argument('--hd', type=int, default=128, metavar='N',
|
||||
help='Number of hidden nodes for LSTM model')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate')
|
||||
parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR',
|
||||
help='step size of gradient ascent')
|
||||
parser.add_argument('--mom', type=float, default=0.99, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--model_dir', default='log',
|
||||
help='path where to save checkpoint')
|
||||
parser.add_argument('--one_class_adv', type=int, default=1, metavar='N',
|
||||
help='adv loss to be used or not, 1:use 0:not use(only CE)')
|
||||
parser.add_argument('--radius', type=float, default=0.2, metavar='N',
|
||||
help='radius corresponding to the definition of set N_i(r)')
|
||||
parser.add_argument('--lamda', type=float, default=1, metavar='N',
|
||||
help='Weight to the adversarial loss')
|
||||
parser.add_argument('--reg', type=float, default=0, metavar='N',
|
||||
help='weight reg')
|
||||
parser.add_argument('--restore', type=int, default=1, metavar='N',
|
||||
help='whether to load a pretrained model, 1: load 0: train from scratch ')
|
||||
parser.add_argument('--optim', type=int, default=0, metavar='N',
|
||||
help='0 : Adam 1: SGD')
|
||||
parser.add_argument('--gamma', type=float, default=2.0, metavar='N',
|
||||
help='r to gamma * r projection for the set N_i(r)')
|
||||
parser.add_argument('-d', '--data_path', type=str, default='.')
|
||||
parser.add_argument('--metric', type=str, default='AUC')
|
||||
args = parser. parse_args()
|
||||
|
||||
# settings
|
||||
#Checkpoint store path
|
||||
model_dir = args.model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
main()
|
|
@ -28,6 +28,8 @@ for these algorithms are in `edgeml_pytorch.trainer`.
|
|||
4. [S-RNN](https://github.com/microsoft/EdgeML/blob/master/docs/publications/SRNN.pdf): `edgeml_pytorch.graph.rnn.SRNN2` implements a
|
||||
2 layer SRNN network which can be instantied with a choice of RNN cell. The training
|
||||
routine for SRNN is in `edgeml_pytorch.trainer.srnnTrainer`.
|
||||
5. DROCC: `edgeml_pytorch.trainer.drocc_trainer` implements a meta-trainer for training any given model architecture
|
||||
for one-class classification on the supplied dataset.
|
||||
|
||||
Usage directions and examples notebooks for this package are provided [here](https://github.com/microsoft/EdgeML/blobl/master/examples/pytorch).
|
||||
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
|
||||
|
||||
#trainer class for DROCC
|
||||
class DROCCTrainer:
|
||||
"""
|
||||
Trainer class that implements the DROCC algorithm proposed in
|
||||
https://arxiv.org/abs/2002.12718
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer, lamda, radius, gamma, device):
|
||||
"""Initialize the DROCC Trainer class
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Torch neural network object
|
||||
optimizer: Total number of epochs for training.
|
||||
lamda: Adversarial loss weight for input layer
|
||||
radius: Radius of hypersphere to sample points from.
|
||||
gamma: Parameter to vary projection.
|
||||
device: torch.device object for device to use.
|
||||
"""
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lamda = lamda
|
||||
self.radius = radius
|
||||
self.gamma = gamma
|
||||
self.device = device
|
||||
|
||||
def train(self, train_loader, val_loader, learning_rate, lr_scheduler, total_epochs,
|
||||
only_ce_epochs=50, ascent_step_size=0.001, ascent_num_steps=50,
|
||||
metric='AUC'):
|
||||
"""Trains the model on the given training dataset with periodic
|
||||
evaluation on the validation dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train_loader: Dataloader object for the training dataset.
|
||||
val_loader: Dataloader object for the validation dataset.
|
||||
learning_rate: Initial learning rate for training.
|
||||
total_epochs: Total number of epochs for training.
|
||||
only_ce_epochs: Number of epochs for initial pretraining.
|
||||
ascent_step_size: Step size for gradient ascent for adversarial
|
||||
generation of negative points.
|
||||
ascent_num_steps: Number of gradient ascent steps for adversarial
|
||||
generation of negative points.
|
||||
metric: Metric used for evaluation (AUC / F1).
|
||||
"""
|
||||
self.ascent_num_steps = ascent_num_steps
|
||||
self.ascent_step_size = ascent_step_size
|
||||
for epoch in range(total_epochs):
|
||||
#Make the weights trainable
|
||||
self.model.train()
|
||||
lr_scheduler(epoch, total_epochs, only_ce_epochs, learning_rate, self.optimizer)
|
||||
|
||||
#Placeholder for the respective 2 loss values
|
||||
epoch_adv_loss = torch.tensor([0]).type(torch.float32).detach() #AdvLoss @ Input Layer
|
||||
epoch_ce_loss = 0 #Cross entropy Loss
|
||||
|
||||
batch_idx = -1
|
||||
for data, target, _ in train_loader:
|
||||
batch_idx += 1
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
# Data Processing
|
||||
data = data.to(torch.float)
|
||||
target = target.to(torch.float)
|
||||
target = torch.squeeze(target)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Extract the logits for cross entropy loss
|
||||
logits = self.model(data)
|
||||
logits = torch.squeeze(logits, dim = 1)
|
||||
ce_loss = F.binary_cross_entropy_with_logits(logits, target)
|
||||
# Add to the epoch variable for printing average CE Loss
|
||||
epoch_ce_loss += ce_loss
|
||||
|
||||
'''
|
||||
Adversarial Loss is calculated only for the positive data points (label==1).
|
||||
'''
|
||||
if epoch >= only_ce_epochs:
|
||||
data = data[target == 1]
|
||||
# AdvLoss
|
||||
adv_loss_inp = self.one_class_adv_loss(data)
|
||||
epoch_adv_loss += adv_loss_inp
|
||||
|
||||
loss = ce_loss + adv_loss_inp * self.lamda
|
||||
else:
|
||||
# If only CE based training has to be done
|
||||
loss = ce_loss
|
||||
|
||||
# Backprop
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
epoch_ce_loss = epoch_ce_loss/(batch_idx + 1) #Average CE Loss
|
||||
epoch_adv_loss = epoch_adv_loss/(batch_idx + 1) #Average AdvLoss @Input Layer
|
||||
|
||||
test_score = self.test(val_loader, metric)
|
||||
|
||||
print('Epoch: {}, CE Loss: {}, AdvLoss: {}, {}: {}'.format(
|
||||
epoch, epoch_ce_loss.item(), epoch_adv_loss.item(),
|
||||
metric, test_score))
|
||||
|
||||
def test(self, test_loader, metric):
|
||||
"""Evaluate the model on the given test dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
test_loader: Dataloader object for the test dataset.
|
||||
metric: Metric used for evaluation (AUC / F1).
|
||||
"""
|
||||
self.model.eval()
|
||||
label_score = []
|
||||
batch_idx = -1
|
||||
for data, target, _ in test_loader:
|
||||
batch_idx += 1
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
data = data.to(torch.float)
|
||||
target = target.to(torch.float)
|
||||
target = torch.squeeze(target)
|
||||
|
||||
logits = self.model(data)
|
||||
logits = torch.squeeze(logits, dim = 1)
|
||||
sigmoid_logits = torch.sigmoid(logits)
|
||||
scores = sigmoid_logits
|
||||
label_score += list(zip(target.cpu().data.numpy().tolist(),
|
||||
scores.cpu().data.numpy().tolist()))
|
||||
# Compute test score
|
||||
labels, scores = zip(*label_score)
|
||||
labels = np.array(labels)
|
||||
scores = np.array(scores)
|
||||
if metric == 'F1':
|
||||
# Evaluation based on https://openreview.net/forum?id=BJJLHbb0-
|
||||
thresh = np.percentile(scores, 20)
|
||||
y_pred = np.where(scores >= thresh, 1, 0)
|
||||
prec, recall, test_metric, _ = precision_recall_fscore_support(
|
||||
labels, y_pred, average="binary")
|
||||
if metric == 'AUC':
|
||||
test_metric = roc_auc_score(labels, scores)
|
||||
return test_metric
|
||||
|
||||
|
||||
def one_class_adv_loss(self, x_train_data):
|
||||
"""Computes the adversarial loss:
|
||||
1) Sample points initially at random around the positive training
|
||||
data points
|
||||
2) Gradient ascent to find the most optimal point in set N_i(r)
|
||||
classified as +ve (label=0). This is done by maximizing
|
||||
the CE loss wrt label 0
|
||||
3) Project the points between spheres of radius R and gamma * R
|
||||
(set N_i(r))
|
||||
4) Pass the calculated adversarial points through the model,
|
||||
and calculate the CE loss wrt target class 0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_train_data: Batch of data to compute loss on.
|
||||
"""
|
||||
batch_size = len(x_train_data)
|
||||
# Randomly sample points around the training data
|
||||
# We will perform SGD on these to find the adversarial points
|
||||
x_adv = torch.randn(x_train_data.shape).to(self.device).detach().requires_grad_()
|
||||
x_adv_sampled = x_adv + x_train_data
|
||||
|
||||
for step in range(self.ascent_num_steps):
|
||||
with torch.enable_grad():
|
||||
|
||||
new_targets = torch.zeros(batch_size, 1).to(self.device)
|
||||
new_targets = torch.squeeze(new_targets)
|
||||
new_targets = new_targets.to(torch.float)
|
||||
|
||||
logits = self.model(x_adv_sampled)
|
||||
logits = torch.squeeze(logits, dim = 1)
|
||||
new_loss = F.binary_cross_entropy_with_logits(logits, new_targets)
|
||||
|
||||
grad = torch.autograd.grad(new_loss, [x_adv_sampled])[0]
|
||||
grad_norm = torch.norm(grad, p=2, dim = tuple(range(1, grad.dim())))
|
||||
grad_norm = grad_norm.view(-1, *[1]*(grad.dim()-1))
|
||||
grad_normalized = grad/grad_norm
|
||||
with torch.no_grad():
|
||||
x_adv_sampled.add_(self.ascent_step_size * grad_normalized)
|
||||
|
||||
if (step + 1) % 10==0:
|
||||
# Project the normal points to the set N_i(r)
|
||||
h = x_adv_sampled - x_train_data
|
||||
norm_h = torch.sqrt(torch.sum(h**2,
|
||||
dim=tuple(range(1, h.dim()))))
|
||||
alpha = torch.clamp(norm_h, self.radius,
|
||||
self.gamma * self.radius).to(self.device)
|
||||
# Make use of broadcast to project h
|
||||
proj = (alpha/norm_h).view(-1, *[1] * (h.dim()-1))
|
||||
h = proj * h
|
||||
x_adv_sampled = x_train_data + h #These adv_points are now on the surface of hyper-sphere
|
||||
|
||||
adv_pred = self.model(x_adv_sampled)
|
||||
adv_pred = torch.squeeze(adv_pred, dim=1)
|
||||
adv_loss = F.binary_cross_entropy_with_logits(adv_pred, (new_targets * 0))
|
||||
|
||||
return adv_loss
|
||||
|
||||
def save(self, path):
|
||||
torch.save(self.model.state_dict(),os.path.join(path, 'model.pt'))
|
||||
|
||||
def load(self, path):
|
||||
self.model.load_state_dict(torch.load(os.path.join(path, 'model.pt')))
|
Загрузка…
Ссылка в новой задаче