This commit is contained in:
Zhirong Wu 2018-08-14 15:57:15 +08:00
Коммит 01d8f905df
16 изменённых файлов: 1195 добавлений и 0 удалений

10
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,10 @@
data/*
checkpoint/*
logs/*
others/*
*.pyc
*.bak
*.log
*.tar
*.pth

66
README.md Normal file
Просмотреть файл

@ -0,0 +1,66 @@
## Improving Generalization via Scalable Neighborhood Component Analysis
This repo constains the pytorch implementation for the ECCV2018 paper [(arxiv)](https://arxiv.org/pdf/.pdf).
The project is about deep learning feature representations optimized for
nearest neighbor classifiers, which may generalize to new object categories.
Much of code is borrowed from the previous [unsupervised learning project](https://arxiv.org/pdf/1805.01978.pdf).
Please refer to [this repo](https://github.com/zhirongw/lemniscate.pytorch) for more details.
## Pretrained Model
Currently, we provide 3 pretrained ResNet models.
Each release contains the feature representation of all ImageNet training images (600 mb) and model weights (100-200mb).
You can also get these representations by forwarding the network for the entire ImageNet images.
- [ResNet 18](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet18.pth.tar) (top 1 accuracy 70.59%)
- [ResNet 34](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet34.pth.tar) (top 1 accuracy 74.41%)
- [ResNet 50](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet50.pth.tar) (top 1 accuracy 76.57%)
## Nearest Neighbor
Please follow [this link](http://zhirongw.westus2.cloudapp.azure.com/nn.html) for a list of nearest neighbors on ImageNet.
Results are visualized from our ResNet50 feature, compared with baseline ResNet50 feature, raw image features and supervised features.
First column is the query image, followed by 20 retrievals ranked by the similarity.
## Usage
Our code extends the pytorch implementation of imagenet classification in [official pytorch release](https://github.com/pytorch/examples/tree/master/imagenet).
Please refer to the official repo for details of data preparation and hardware configurations.
- install python2 and [pytorch=0.3](http://pytorch.org)
- clone this repo: `git clone https://github.com/zhirongw/snca.pytorch`
- Training on ImageNet:
`python main.py DATAPATH --arch resnet18 -j 32 --temperature 0.05 --low-dim 128 -b 256 `
- During training, we monitor the supervised validation accuracy by K nearest neighbor with k=1, as it's faster, and gives a good estimation of the feature quality.
- Testing on ImageNet:
`python main.py DATAPATH --arch resnet18 --resume input_model.pth.tar -e` runs testing with default K=30 neighbors.
- Training on CIFAR10:
`python cifar.py --nce-t 0.05 --lr 0.1`
## Citation
```
@inproceedings{wu2018improving,
title={Improving Generalization via Scalable Neighborhood Component Analysis},
author={Wu, Zhirong and Efros, Alexei A and Yu, Stella},
booktitle={European Conference on Computer Vision (ECCV) 2018},
year={2018}
}
```
## Contact
For any questions, please feel free to reach
```
Zhirong Wu: xavibrowu@gmail.com
```

165
cifar.py Normal file
Просмотреть файл

@ -0,0 +1,165 @@
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import time
import models
import datasets
import math
from lib.LinearAverage import LinearAverage
from lib.NCA import NCACrossEntropy
from lib.utils import AverageMeter
from test import NN, kNN
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint')
parser.add_argument('--test-only', action='store_true', help='test only')
parser.add_argument('--low-dim', default=128, type=int,
metavar='D', help='feature dimension')
parser.add_argument('--temperature', default=0.05, type=float,
metavar='T', help='temperature parameter for softmax')
parser.add_argument('--memory-momentum', default=0.5, type=float,
metavar='M', help='momentum for non-parametric updates')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
#transforms.RandomCrop(32, padding=4),
transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
transforms.RandomGrayscale(p=0.2),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
ndata = trainset.__len__()
# Model
if args.test_only or len(args.resume)>0:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/'+args.resume)
net = checkpoint['net']
lemniscate = checkpoint['lemniscate']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
else:
print('==> Building model..')
net = models.__dict__['ResNet18'](low_dim=args.low_dim)
# define leminiscate
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)
# define loss function
criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.train_labels))
if use_cuda:
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
lemniscate.cuda()
criterion.cuda()
cudnn.benchmark = True
if args.test_only:
acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
sys.exit(0)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 50))
print(lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
adjust_learning_rate(optimizer, epoch)
train_loss = AverageMeter()
data_time = AverageMeter()
batch_time = AverageMeter()
correct = 0
total = 0
# switch to train mode
net.train()
end = time.time()
for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
data_time.update(time.time() - end)
if use_cuda:
inputs, targets, indexes = inputs.cuda(), targets.cuda(), indexes.cuda()
optimizer.zero_grad()
features = net(inputs)
outputs = lemniscate(features, indexes)
loss = criterion(outputs, indexes)
loss.backward()
optimizer.step()
train_loss.update(loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
print('Epoch: [{}][{}/{}]'
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss))
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
acc = kNN(epoch, net, lemniscate, trainloader, testloader, 30, args.temperature)
if acc > best_acc:
print('Saving..')
state = {
'net': net.module if use_cuda else net,
'lemniscate': lemniscate,
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc
print('best accuracy: {:.2f}'.format(best_acc*100))

5
datasets/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,5 @@
from .folder import ImageFolderInstance
from .cifar import CIFAR10Instance, CIFAR100Instance
__all__ = ('ImageFolderInstance', 'CIFAR10Instance', 'CIFAR100Instance')

42
datasets/cifar.py Normal file
Просмотреть файл

@ -0,0 +1,42 @@
from __future__ import print_function
from PIL import Image
import torchvision.datasets as datasets
import torch.utils.data as data
class CIFAR10Instance(datasets.CIFAR10):
"""CIFAR10Instance Dataset.
"""
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
class CIFAR100Instance(CIFAR10Instance):
"""CIFAR100Instance Dataset.
This is a subclass of the `CIFAR10Instance` Dataset.
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]

21
datasets/folder.py Normal file
Просмотреть файл

@ -0,0 +1,21 @@
import torchvision.datasets as datasets
class ImageFolderInstance(datasets.ImageFolder):
""": Folder datasets which returns the index of the image as well::
"""
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index

58
lib/LinearAverage.py Normal file
Просмотреть файл

@ -0,0 +1,58 @@
import torch
from torch.autograd import Function
from torch import nn
import math
class LinearAverageOp(Function):
@staticmethod
def forward(self, x, y, memory, params):
T = params[0].item()
batchSize = x.size(0)
# inner product
out = torch.mm(x.data, memory.t())
out.div_(T) # batchSize * N
self.save_for_backward(x, memory, y, params)
return out
@staticmethod
def backward(self, gradOutput):
x, memory, y, params = self.saved_tensors
batchSize = gradOutput.size(0)
T = params[0].item()
momentum = params[1].item()
# add temperature
gradOutput.data.div_(T)
# gradient of linear
gradInput = torch.mm(gradOutput.data, memory)
gradInput.resize_as_(x)
# update the non-parametric data
weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
weight_pos.mul_(momentum)
weight_pos.add_(torch.mul(x.data, 1-momentum))
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
updated_weight = weight_pos.div(w_norm)
memory.index_copy_(0, y, updated_weight)
return gradInput, None, None, None
class LinearAverage(nn.Module):
def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
super(LinearAverage, self).__init__()
stdv = 1 / math.sqrt(inputSize)
self.nLem = outputSize
self.register_buffer('params',torch.tensor([T, momentum]));
stdv = 1. / math.sqrt(inputSize/3)
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
def forward(self, x, y):
out = LinearAverageOp.apply(x, y, self.memory, self.params)
return out

44
lib/NCA.py Normal file
Просмотреть файл

@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.autograd import Function
import math
eps = 1e-8
class NCACrossEntropy(nn.Module):
''' \sum_{j=C} log(p_{ij})
Store all the labels of the dataset.
Only pass the indexes of the training instances during forward.
'''
def __init__(self, labels, margin=0):
super(NCACrossEntropy, self).__init__()
self.register_buffer('labels', torch.LongTensor(labels.size(0)))
self.labels = labels
self.margin = margin
def forward(self, x, indexes):
batchSize = x.size(0)
n = x.size(1)
exp = torch.exp(x)
# labels for currect batch
y = torch.index_select(self.labels, 0, indexes.data).view(batchSize, 1)
same = y.repeat(1, n).eq_(self.labels)
# self prob exclusion, hack with memory for effeciency
exp.data.scatter_(1, indexes.data.view(-1,1), 0)
p = torch.mul(exp, same.float()).sum(dim=1)
Z = exp.sum(dim=1)
Z_exclude = Z - p
p = p.div(math.exp(self.margin))
Z = Z_exclude + p
prob = torch.div(p, Z)
prob_masked = torch.masked_select(prob, prob.ne(0))
loss = prob_masked.log().sum(0)
return - loss / batchSize

1
lib/__init__.py Normal file
Просмотреть файл

@ -0,0 +1 @@
# nothing

14
lib/normalize.py Normal file
Просмотреть файл

@ -0,0 +1,14 @@
import torch
from torch.autograd import Variable
from torch import nn
class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
out = x.div(norm)
return out

16
lib/utils.py Normal file
Просмотреть файл

@ -0,0 +1,16 @@
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

281
main.py Normal file
Просмотреть файл

@ -0,0 +1,281 @@
import argparse
import os
import sys
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import datasets
import models
import math
from lib.LinearAverage import LinearAverage
from lib.NCA import NCACrossEntropy
from lib.utils import AverageMeter
from test import NN, kNN
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=130, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--test-only', action='store_true', help='test only')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='gloo', type=str,
help='distributed backend')
parser.add_argument('--low-dim', default=128, type=int,
metavar='D', help='feature dimension')
parser.add_argument('--temperature', default=0.05, type=float,
metavar='T', help='temperature parameter')
parser.add_argument('--memory-momentum', '--m-mementum', default=0.5, type=float,
metavar='M', help='momentum for non-parametric updates')
parser.add_argument('--iter-size', default=1, type=int,
help='caffe style iter size')
parser.add_argument('--margin', default=0.0, type=float,
help='classification margin')
best_prec1 = 0
def main():
global args, best_prec1
args = parser.parse_args()
args.distributed = args.world_size > 1
if args.distributed:
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolderInstance(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolderInstance(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](low_dim=args.low_dim)
if not args.distributed:
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
else:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
lemniscate = checkpoint['lemniscate']
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
# define lemniscate and loss function (criterion)
ndata = train_dataset.__len__()
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda()
criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]),
args.margin / args.temperature).cuda()
cudnn.benchmark = True
if args.evaluate:
prec1 = kNN(0, model, lemniscate, train_loader, val_loader, 30, args.temperature, 0)
return
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
adjust_memory_update_rate(lemniscate, epoch)
# train for one epoch
train(train_loader, model, lemniscate, criterion, optimizer, epoch)
# evaluate on validation set
prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'lemniscate': lemniscate,
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
# switch to train mode
model.train()
end = time.time()
optimizer.zero_grad()
for i, (input, target, index) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(async=True)
index = index.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
index_var = torch.autograd.Variable(index)
# compute output
feature = model(input_var)
output = lemniscate(feature, index_var)
loss = criterion(output, index_var) / args.iter_size
loss.backward()
# measure accuracy and record loss
losses.update(loss.data[0] * args.iter_size, input.size(0))
if (i+1) % args.iter_size == 0:
# compute gradient and do SGD step
optimizer.step()
optimizer.zero_grad()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t').format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def adjust_memory_update_rate(lemniscate, epoch):
if epoch >= 80:
lemniscate.params[1] = 0.8
if epoch >= 120:
lemniscate.params[1] = 0.9
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 40 epochs"""
lr = args.lr * (0.1 ** (epoch // 40))
print(lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()

2
models/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,2 @@
from .resnet import *
from .resnet_cifar import *

208
models/resnet.py Normal file
Просмотреть файл

@ -0,0 +1,208 @@
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
from lib.normalize import Normalize
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = { }
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, low_dim=128):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, low_dim)
self.l2norm = Normalize(2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.l2norm(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model

126
models/resnet_cifar.py Normal file
Просмотреть файл

@ -0,0 +1,126 @@
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.normalize import Normalize
from torch.autograd import Variable
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, low_dim=128):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, low_dim)
self.l2norm = Normalize(2)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
out = self.l2norm(out)
return out
def ResNet18(low_dim=128):
return ResNet(BasicBlock, [2,2,2,2], low_dim)
def ResNet34(low_dim=128):
return ResNet(BasicBlock, [3,4,6,3], low_dim)
def ResNet50(low_dim=128):
return ResNet(Bottleneck, [3,4,6,3], low_dim)
def ResNet101(low_dim=128):
return ResNet(Bottleneck, [3,4,23,3], low_dim)
def ResNet152(low_dim=128):
return ResNet(Bottleneck, [3,8,36,3], low_dim)
def test():
net = ResNet18()
y = net(Variable(torch.randn(1,3,32,32)))
print(y.size())
# test()

136
test.py Normal file
Просмотреть файл

@ -0,0 +1,136 @@
import torch
import time
import datasets
from lib.utils import AverageMeter
import torchvision.transforms as transforms
import numpy as np
def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0):
net.eval()
net_time = AverageMeter()
cls_time = AverageMeter()
losses = AverageMeter()
correct = 0.
total = 0
testsize = testloader.dataset.__len__()
trainFeatures = lemniscate.memory.t()
if hasattr(trainloader.dataset, 'imgs'):
trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
else:
trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
if recompute_memory:
transform_bak = trainloader.dataset.transform
trainloader.dataset.transform = testloader.dataset.transform
temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
targets = targets.cuda(async=True)
batchSize = inputs.size(0)
features = net(inputs)
trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
trainloader.dataset.transform = transform_bak
end = time.time()
with torch.no_grad():
for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
targets = targets.cuda(async=True)
batchSize = inputs.size(0)
features = net(inputs)
net_time.update(time.time() - end)
end = time.time()
dist = torch.mm(features, trainFeatures)
yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
candidates = trainLabels.view(1,-1).expand(batchSize, -1)
retrieval = torch.gather(candidates, 1, yi)
retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
yd = yd.narrow(1, 0, 1)
total += targets.size(0)
correct += retrieval.eq(targets.data).sum().item()
cls_time.update(time.time() - end)
end = time.time()
print('Test [{}/{}]\t'
'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
'Top1: {:.2f}'.format(
total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time))
return correct/total
def kNN(epoch, net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0):
net.eval()
net_time = AverageMeter()
cls_time = AverageMeter()
total = 0
testsize = testloader.dataset.__len__()
trainFeatures = lemniscate.memory.t()
if hasattr(trainloader.dataset, 'imgs'):
trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
else:
trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
C = trainLabels.max() + 1
if recompute_memory:
transform_bak = trainloader.dataset.transform
trainloader.dataset.transform = testloader.dataset.transform
temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
targets = targets.cuda(async=True)
batchSize = inputs.size(0)
features = net(inputs)
trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
trainloader.dataset.transform = transform_bak
top1 = 0.
top5 = 0.
end = time.time()
with torch.no_grad():
retrieval_one_hot = torch.zeros(K, C).cuda()
for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
end = time.time()
targets = targets.cuda(async=True)
batchSize = inputs.size(0)
features = net(inputs)
net_time.update(time.time() - end)
end = time.time()
dist = torch.mm(features, trainFeatures)
yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
candidates = trainLabels.view(1,-1).expand(batchSize, -1)
retrieval = torch.gather(candidates, 1, yi)
retrieval_one_hot.resize_(batchSize * K, C).zero_()
retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
yd_transform = yd.clone().div_(sigma).exp_()
probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1)
_, predictions = probs.sort(1, True)
# Find which predictions match the target
correct = predictions.eq(targets.data.view(-1,1))
cls_time.update(time.time() - end)
top1 = top1 + correct.narrow(1,0,1).sum().item()
top5 = top5 + correct.narrow(1,0,5).sum().item()
total += targets.size(0)
print('Test [{}/{}]\t'
'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
'Top1: {:.2f} Top5: {:.2f}'.format(
total, testsize, top1*100./total, top5*100./total, net_time=net_time, cls_time=cls_time))
print(top1*100./total)
return top1/total