This commit is contained in:
Коммит
01d8f905df
|
@ -0,0 +1,10 @@
|
|||
data/*
|
||||
checkpoint/*
|
||||
logs/*
|
||||
others/*
|
||||
|
||||
*.pyc
|
||||
*.bak
|
||||
*.log
|
||||
*.tar
|
||||
*.pth
|
|
@ -0,0 +1,66 @@
|
|||
## Improving Generalization via Scalable Neighborhood Component Analysis
|
||||
|
||||
This repo constains the pytorch implementation for the ECCV2018 paper [(arxiv)](https://arxiv.org/pdf/.pdf).
|
||||
The project is about deep learning feature representations optimized for
|
||||
nearest neighbor classifiers, which may generalize to new object categories.
|
||||
|
||||
Much of code is borrowed from the previous [unsupervised learning project](https://arxiv.org/pdf/1805.01978.pdf).
|
||||
Please refer to [this repo](https://github.com/zhirongw/lemniscate.pytorch) for more details.
|
||||
|
||||
|
||||
## Pretrained Model
|
||||
|
||||
Currently, we provide 3 pretrained ResNet models.
|
||||
Each release contains the feature representation of all ImageNet training images (600 mb) and model weights (100-200mb).
|
||||
You can also get these representations by forwarding the network for the entire ImageNet images.
|
||||
|
||||
- [ResNet 18](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet18.pth.tar) (top 1 accuracy 70.59%)
|
||||
- [ResNet 34](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet34.pth.tar) (top 1 accuracy 74.41%)
|
||||
- [ResNet 50](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet50.pth.tar) (top 1 accuracy 76.57%)
|
||||
|
||||
## Nearest Neighbor
|
||||
|
||||
Please follow [this link](http://zhirongw.westus2.cloudapp.azure.com/nn.html) for a list of nearest neighbors on ImageNet.
|
||||
Results are visualized from our ResNet50 feature, compared with baseline ResNet50 feature, raw image features and supervised features.
|
||||
First column is the query image, followed by 20 retrievals ranked by the similarity.
|
||||
|
||||
## Usage
|
||||
|
||||
Our code extends the pytorch implementation of imagenet classification in [official pytorch release](https://github.com/pytorch/examples/tree/master/imagenet).
|
||||
Please refer to the official repo for details of data preparation and hardware configurations.
|
||||
|
||||
- install python2 and [pytorch=0.3](http://pytorch.org)
|
||||
|
||||
- clone this repo: `git clone https://github.com/zhirongw/snca.pytorch`
|
||||
|
||||
- Training on ImageNet:
|
||||
|
||||
`python main.py DATAPATH --arch resnet18 -j 32 --temperature 0.05 --low-dim 128 -b 256 `
|
||||
|
||||
- During training, we monitor the supervised validation accuracy by K nearest neighbor with k=1, as it's faster, and gives a good estimation of the feature quality.
|
||||
|
||||
- Testing on ImageNet:
|
||||
|
||||
`python main.py DATAPATH --arch resnet18 --resume input_model.pth.tar -e` runs testing with default K=30 neighbors.
|
||||
|
||||
- Training on CIFAR10:
|
||||
|
||||
`python cifar.py --nce-t 0.05 --lr 0.1`
|
||||
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{wu2018improving,
|
||||
title={Improving Generalization via Scalable Neighborhood Component Analysis},
|
||||
author={Wu, Zhirong and Efros, Alexei A and Yu, Stella},
|
||||
booktitle={European Conference on Computer Vision (ECCV) 2018},
|
||||
year={2018}
|
||||
}
|
||||
```
|
||||
|
||||
## Contact
|
||||
|
||||
For any questions, please feel free to reach
|
||||
```
|
||||
Zhirong Wu: xavibrowu@gmail.com
|
||||
```
|
|
@ -0,0 +1,165 @@
|
|||
'''Train CIFAR10 with PyTorch.'''
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import models
|
||||
import datasets
|
||||
import math
|
||||
|
||||
from lib.LinearAverage import LinearAverage
|
||||
from lib.NCA import NCACrossEntropy
|
||||
from lib.utils import AverageMeter
|
||||
from test import NN, kNN
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
|
||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||
parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint')
|
||||
parser.add_argument('--test-only', action='store_true', help='test only')
|
||||
parser.add_argument('--low-dim', default=128, type=int,
|
||||
metavar='D', help='feature dimension')
|
||||
parser.add_argument('--temperature', default=0.05, type=float,
|
||||
metavar='T', help='temperature parameter for softmax')
|
||||
parser.add_argument('--memory-momentum', default=0.5, type=float,
|
||||
metavar='M', help='momentum for non-parametric updates')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
best_acc = 0 # best test accuracy
|
||||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||
|
||||
# Data
|
||||
print('==> Preparing data..')
|
||||
transform_train = transforms.Compose([
|
||||
#transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
|
||||
|
||||
testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
|
||||
|
||||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||
ndata = trainset.__len__()
|
||||
|
||||
# Model
|
||||
if args.test_only or len(args.resume)>0:
|
||||
# Load checkpoint.
|
||||
print('==> Resuming from checkpoint..')
|
||||
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
||||
checkpoint = torch.load('./checkpoint/'+args.resume)
|
||||
net = checkpoint['net']
|
||||
lemniscate = checkpoint['lemniscate']
|
||||
best_acc = checkpoint['acc']
|
||||
start_epoch = checkpoint['epoch']
|
||||
else:
|
||||
print('==> Building model..')
|
||||
net = models.__dict__['ResNet18'](low_dim=args.low_dim)
|
||||
# define leminiscate
|
||||
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)
|
||||
|
||||
# define loss function
|
||||
criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.train_labels))
|
||||
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
||||
lemniscate.cuda()
|
||||
criterion.cuda()
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.test_only:
|
||||
acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
|
||||
sys.exit(0)
|
||||
|
||||
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
lr = args.lr * (0.1 ** (epoch // 50))
|
||||
print(lr)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# Training
|
||||
def train(epoch):
|
||||
print('\nEpoch: %d' % epoch)
|
||||
adjust_learning_rate(optimizer, epoch)
|
||||
train_loss = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
# switch to train mode
|
||||
net.train()
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
|
||||
data_time.update(time.time() - end)
|
||||
if use_cuda:
|
||||
inputs, targets, indexes = inputs.cuda(), targets.cuda(), indexes.cuda()
|
||||
optimizer.zero_grad()
|
||||
|
||||
features = net(inputs)
|
||||
outputs = lemniscate(features, indexes)
|
||||
loss = criterion(outputs, indexes)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss.update(loss.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
print('Epoch: [{}][{}/{}]'
|
||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
|
||||
'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
|
||||
epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss))
|
||||
|
||||
for epoch in range(start_epoch, start_epoch+200):
|
||||
train(epoch)
|
||||
acc = kNN(epoch, net, lemniscate, trainloader, testloader, 30, args.temperature)
|
||||
|
||||
if acc > best_acc:
|
||||
print('Saving..')
|
||||
state = {
|
||||
'net': net.module if use_cuda else net,
|
||||
'lemniscate': lemniscate,
|
||||
'acc': acc,
|
||||
'epoch': epoch,
|
||||
}
|
||||
if not os.path.isdir('checkpoint'):
|
||||
os.mkdir('checkpoint')
|
||||
torch.save(state, './checkpoint/ckpt.t7')
|
||||
best_acc = acc
|
||||
|
||||
print('best accuracy: {:.2f}'.format(best_acc*100))
|
|
@ -0,0 +1,5 @@
|
|||
from .folder import ImageFolderInstance
|
||||
from .cifar import CIFAR10Instance, CIFAR100Instance
|
||||
|
||||
__all__ = ('ImageFolderInstance', 'CIFAR10Instance', 'CIFAR100Instance')
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import print_function
|
||||
from PIL import Image
|
||||
import torchvision.datasets as datasets
|
||||
import torch.utils.data as data
|
||||
|
||||
class CIFAR10Instance(datasets.CIFAR10):
|
||||
"""CIFAR10Instance Dataset.
|
||||
"""
|
||||
def __getitem__(self, index):
|
||||
if self.train:
|
||||
img, target = self.train_data[index], self.train_labels[index]
|
||||
else:
|
||||
img, target = self.test_data[index], self.test_labels[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, index
|
||||
|
||||
class CIFAR100Instance(CIFAR10Instance):
|
||||
"""CIFAR100Instance Dataset.
|
||||
|
||||
This is a subclass of the `CIFAR10Instance` Dataset.
|
||||
"""
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
||||
]
|
|
@ -0,0 +1,21 @@
|
|||
import torchvision.datasets as datasets
|
||||
|
||||
class ImageFolderInstance(datasets.ImageFolder):
|
||||
""": Folder datasets which returns the index of the image as well::
|
||||
"""
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (image, target) where target is class_index of the target class.
|
||||
"""
|
||||
path, target = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, index
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
class LinearAverageOp(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y, memory, params):
|
||||
T = params[0].item()
|
||||
batchSize = x.size(0)
|
||||
|
||||
# inner product
|
||||
out = torch.mm(x.data, memory.t())
|
||||
out.div_(T) # batchSize * N
|
||||
|
||||
self.save_for_backward(x, memory, y, params)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(self, gradOutput):
|
||||
x, memory, y, params = self.saved_tensors
|
||||
batchSize = gradOutput.size(0)
|
||||
T = params[0].item()
|
||||
momentum = params[1].item()
|
||||
|
||||
# add temperature
|
||||
gradOutput.data.div_(T)
|
||||
|
||||
# gradient of linear
|
||||
gradInput = torch.mm(gradOutput.data, memory)
|
||||
gradInput.resize_as_(x)
|
||||
|
||||
# update the non-parametric data
|
||||
weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
|
||||
weight_pos.mul_(momentum)
|
||||
weight_pos.add_(torch.mul(x.data, 1-momentum))
|
||||
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
|
||||
updated_weight = weight_pos.div(w_norm)
|
||||
memory.index_copy_(0, y, updated_weight)
|
||||
|
||||
return gradInput, None, None, None
|
||||
|
||||
class LinearAverage(nn.Module):
|
||||
|
||||
def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
|
||||
super(LinearAverage, self).__init__()
|
||||
stdv = 1 / math.sqrt(inputSize)
|
||||
self.nLem = outputSize
|
||||
|
||||
self.register_buffer('params',torch.tensor([T, momentum]));
|
||||
stdv = 1. / math.sqrt(inputSize/3)
|
||||
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
|
||||
|
||||
def forward(self, x, y):
|
||||
out = LinearAverageOp.apply(x, y, self.memory, self.params)
|
||||
return out
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import math
|
||||
|
||||
eps = 1e-8
|
||||
|
||||
class NCACrossEntropy(nn.Module):
|
||||
''' \sum_{j=C} log(p_{ij})
|
||||
Store all the labels of the dataset.
|
||||
Only pass the indexes of the training instances during forward.
|
||||
'''
|
||||
def __init__(self, labels, margin=0):
|
||||
super(NCACrossEntropy, self).__init__()
|
||||
self.register_buffer('labels', torch.LongTensor(labels.size(0)))
|
||||
self.labels = labels
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, x, indexes):
|
||||
batchSize = x.size(0)
|
||||
n = x.size(1)
|
||||
exp = torch.exp(x)
|
||||
|
||||
# labels for currect batch
|
||||
y = torch.index_select(self.labels, 0, indexes.data).view(batchSize, 1)
|
||||
same = y.repeat(1, n).eq_(self.labels)
|
||||
|
||||
# self prob exclusion, hack with memory for effeciency
|
||||
exp.data.scatter_(1, indexes.data.view(-1,1), 0)
|
||||
|
||||
p = torch.mul(exp, same.float()).sum(dim=1)
|
||||
Z = exp.sum(dim=1)
|
||||
|
||||
Z_exclude = Z - p
|
||||
p = p.div(math.exp(self.margin))
|
||||
Z = Z_exclude + p
|
||||
|
||||
prob = torch.div(p, Z)
|
||||
prob_masked = torch.masked_select(prob, prob.ne(0))
|
||||
|
||||
loss = prob_masked.log().sum(0)
|
||||
|
||||
return - loss / batchSize
|
||||
|
|
@ -0,0 +1 @@
|
|||
# nothing
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
|
||||
class Normalize(nn.Module):
|
||||
|
||||
def __init__(self, power=2):
|
||||
super(Normalize, self).__init__()
|
||||
self.power = power
|
||||
|
||||
def forward(self, x):
|
||||
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
|
||||
out = x.div(norm)
|
||||
return out
|
|
@ -0,0 +1,16 @@
|
|||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
|
@ -0,0 +1,281 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import datasets
|
||||
import models
|
||||
import math
|
||||
|
||||
from lib.LinearAverage import LinearAverage
|
||||
from lib.NCA import NCACrossEntropy
|
||||
from lib.utils import AverageMeter
|
||||
from test import NN, kNN
|
||||
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet18)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=130, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--test-only', action='store_true', help='test only')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--world-size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='gloo', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--low-dim', default=128, type=int,
|
||||
metavar='D', help='feature dimension')
|
||||
parser.add_argument('--temperature', default=0.05, type=float,
|
||||
metavar='T', help='temperature parameter')
|
||||
parser.add_argument('--memory-momentum', '--m-mementum', default=0.5, type=float,
|
||||
metavar='M', help='momentum for non-parametric updates')
|
||||
parser.add_argument('--iter-size', default=1, type=int,
|
||||
help='caffe style iter size')
|
||||
parser.add_argument('--margin', default=0.0, type=float,
|
||||
help='classification margin')
|
||||
|
||||
best_prec1 = 0
|
||||
|
||||
def main():
|
||||
global args, best_prec1
|
||||
args = parser.parse_args()
|
||||
|
||||
args.distributed = args.world_size > 1
|
||||
|
||||
if args.distributed:
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size)
|
||||
|
||||
# Data loading code
|
||||
traindir = os.path.join(args.data, 'train')
|
||||
valdir = os.path.join(args.data, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
train_dataset = datasets.ImageFolderInstance(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolderInstance(valdir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
|
||||
# create model
|
||||
if args.pretrained:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch](pretrained=True)
|
||||
else:
|
||||
print("=> creating model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch](low_dim=args.low_dim)
|
||||
|
||||
if not args.distributed:
|
||||
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
|
||||
model.features = torch.nn.DataParallel(model.features)
|
||||
model.cuda()
|
||||
else:
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
else:
|
||||
model.cuda()
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay, nesterov=True)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
best_prec1 = checkpoint['best_prec1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
lemniscate = checkpoint['lemniscate']
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
else:
|
||||
# define lemniscate and loss function (criterion)
|
||||
ndata = train_dataset.__len__()
|
||||
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda()
|
||||
|
||||
|
||||
criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]),
|
||||
args.margin / args.temperature).cuda()
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.evaluate:
|
||||
prec1 = kNN(0, model, lemniscate, train_loader, val_loader, 30, args.temperature, 0)
|
||||
return
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, epoch)
|
||||
adjust_memory_update_rate(lemniscate, epoch)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, lemniscate, criterion, optimizer, epoch)
|
||||
|
||||
# evaluate on validation set
|
||||
prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)
|
||||
|
||||
# remember best prec@1 and save checkpoint
|
||||
is_best = prec1 > best_prec1
|
||||
best_prec1 = max(prec1, best_prec1)
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'lemniscate': lemniscate,
|
||||
'best_prec1': best_prec1,
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
}, is_best)
|
||||
|
||||
|
||||
def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
optimizer.zero_grad()
|
||||
for i, (input, target, index) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
target = target.cuda(async=True)
|
||||
index = index.cuda(async=True)
|
||||
input_var = torch.autograd.Variable(input)
|
||||
target_var = torch.autograd.Variable(target)
|
||||
index_var = torch.autograd.Variable(index)
|
||||
|
||||
# compute output
|
||||
feature = model(input_var)
|
||||
output = lemniscate(feature, index_var)
|
||||
loss = criterion(output, index_var) / args.iter_size
|
||||
|
||||
loss.backward()
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.data[0] * args.iter_size, input.size(0))
|
||||
|
||||
if (i+1) % args.iter_size == 0:
|
||||
# compute gradient and do SGD step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Epoch: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t').format(
|
||||
epoch, i, len(train_loader), batch_time=batch_time,
|
||||
data_time=data_time, loss=losses)
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
||||
|
||||
def adjust_memory_update_rate(lemniscate, epoch):
|
||||
if epoch >= 80:
|
||||
lemniscate.params[1] = 0.8
|
||||
if epoch >= 120:
|
||||
lemniscate.params[1] = 0.9
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 40 epochs"""
|
||||
lr = args.lr * (0.1 ** (epoch // 40))
|
||||
print(lr)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,2 @@
|
|||
from .resnet import *
|
||||
from .resnet_cifar import *
|
|
@ -0,0 +1,208 @@
|
|||
import torch.nn as nn
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from lib.normalize import Normalize
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152']
|
||||
|
||||
model_urls = { }
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, low_dim=128):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc = nn.Linear(512 * block.expansion, low_dim)
|
||||
self.l2norm = Normalize(2)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
x = self.l2norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
return model
|
|
@ -0,0 +1,126 @@
|
|||
'''ResNet in PyTorch.
|
||||
|
||||
For Pre-activation ResNet, see 'preact_resnet.py'.
|
||||
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from lib.normalize import Normalize
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, low_dim=128):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512*block.expansion, low_dim)
|
||||
self.l2norm = Normalize(2)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
out = self.l2norm(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNet18(low_dim=128):
|
||||
return ResNet(BasicBlock, [2,2,2,2], low_dim)
|
||||
|
||||
def ResNet34(low_dim=128):
|
||||
return ResNet(BasicBlock, [3,4,6,3], low_dim)
|
||||
|
||||
def ResNet50(low_dim=128):
|
||||
return ResNet(Bottleneck, [3,4,6,3], low_dim)
|
||||
|
||||
def ResNet101(low_dim=128):
|
||||
return ResNet(Bottleneck, [3,4,23,3], low_dim)
|
||||
|
||||
def ResNet152(low_dim=128):
|
||||
return ResNet(Bottleneck, [3,8,36,3], low_dim)
|
||||
|
||||
|
||||
def test():
|
||||
net = ResNet18()
|
||||
y = net(Variable(torch.randn(1,3,32,32)))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
|
@ -0,0 +1,136 @@
|
|||
import torch
|
||||
import time
|
||||
import datasets
|
||||
from lib.utils import AverageMeter
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
|
||||
def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0):
|
||||
net.eval()
|
||||
net_time = AverageMeter()
|
||||
cls_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
correct = 0.
|
||||
total = 0
|
||||
testsize = testloader.dataset.__len__()
|
||||
|
||||
trainFeatures = lemniscate.memory.t()
|
||||
if hasattr(trainloader.dataset, 'imgs'):
|
||||
trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
|
||||
else:
|
||||
trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
|
||||
|
||||
if recompute_memory:
|
||||
transform_bak = trainloader.dataset.transform
|
||||
trainloader.dataset.transform = testloader.dataset.transform
|
||||
temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
|
||||
for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
|
||||
targets = targets.cuda(async=True)
|
||||
batchSize = inputs.size(0)
|
||||
features = net(inputs)
|
||||
trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
|
||||
trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
|
||||
trainloader.dataset.transform = transform_bak
|
||||
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
|
||||
targets = targets.cuda(async=True)
|
||||
batchSize = inputs.size(0)
|
||||
features = net(inputs)
|
||||
net_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
dist = torch.mm(features, trainFeatures)
|
||||
|
||||
yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
|
||||
candidates = trainLabels.view(1,-1).expand(batchSize, -1)
|
||||
retrieval = torch.gather(candidates, 1, yi)
|
||||
|
||||
retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
|
||||
yd = yd.narrow(1, 0, 1)
|
||||
|
||||
total += targets.size(0)
|
||||
correct += retrieval.eq(targets.data).sum().item()
|
||||
|
||||
cls_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
print('Test [{}/{}]\t'
|
||||
'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
|
||||
'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
|
||||
'Top1: {:.2f}'.format(
|
||||
total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time))
|
||||
|
||||
return correct/total
|
||||
|
||||
def kNN(epoch, net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0):
|
||||
net.eval()
|
||||
net_time = AverageMeter()
|
||||
cls_time = AverageMeter()
|
||||
total = 0
|
||||
testsize = testloader.dataset.__len__()
|
||||
|
||||
trainFeatures = lemniscate.memory.t()
|
||||
if hasattr(trainloader.dataset, 'imgs'):
|
||||
trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
|
||||
else:
|
||||
trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
|
||||
C = trainLabels.max() + 1
|
||||
|
||||
if recompute_memory:
|
||||
transform_bak = trainloader.dataset.transform
|
||||
trainloader.dataset.transform = testloader.dataset.transform
|
||||
temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
|
||||
for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
|
||||
targets = targets.cuda(async=True)
|
||||
batchSize = inputs.size(0)
|
||||
features = net(inputs)
|
||||
trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
|
||||
trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
|
||||
trainloader.dataset.transform = transform_bak
|
||||
|
||||
top1 = 0.
|
||||
top5 = 0.
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
retrieval_one_hot = torch.zeros(K, C).cuda()
|
||||
for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
|
||||
end = time.time()
|
||||
targets = targets.cuda(async=True)
|
||||
batchSize = inputs.size(0)
|
||||
features = net(inputs)
|
||||
net_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
dist = torch.mm(features, trainFeatures)
|
||||
|
||||
yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
|
||||
candidates = trainLabels.view(1,-1).expand(batchSize, -1)
|
||||
retrieval = torch.gather(candidates, 1, yi)
|
||||
|
||||
retrieval_one_hot.resize_(batchSize * K, C).zero_()
|
||||
retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
|
||||
yd_transform = yd.clone().div_(sigma).exp_()
|
||||
probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1)
|
||||
_, predictions = probs.sort(1, True)
|
||||
|
||||
# Find which predictions match the target
|
||||
correct = predictions.eq(targets.data.view(-1,1))
|
||||
cls_time.update(time.time() - end)
|
||||
|
||||
top1 = top1 + correct.narrow(1,0,1).sum().item()
|
||||
top5 = top5 + correct.narrow(1,0,5).sum().item()
|
||||
|
||||
total += targets.size(0)
|
||||
|
||||
print('Test [{}/{}]\t'
|
||||
'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
|
||||
'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
|
||||
'Top1: {:.2f} Top5: {:.2f}'.format(
|
||||
total, testsize, top1*100./total, top5*100./total, net_time=net_time, cls_time=cls_time))
|
||||
|
||||
print(top1*100./total)
|
||||
|
||||
return top1/total
|
||||
|
Загрузка…
Ссылка в новой задаче