Adding re-ranking for image retrieval (#515)

This commit is contained in:
PatrickBue 2020-03-18 19:00:32 -04:00 коммит произвёл GitHub
Родитель 782288a101
Коммит f9595d3fba
8 изменённых файлов: 462 добавлений и 58 удалений

Просмотреть файл

@ -500,3 +500,29 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--
https://github.com/layumi/Person_reID_baseline_pytorch
MIT License
Copyright (c) 2018 Zhedong Zheng
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -27,6 +27,11 @@ Below are a subset of popular papers in the field with reported accuracies on st
| [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/abs/1811.12649) <br> (Implemented in this repository) | BMVC 2019 | No | **84%** (512-dim) <br> **89%** (2048-dim) | 61% (512-dim) <br> **65%** (2048-dim) | **78%** (512-dim) <br> **80%** (2048-dim) |
## Re-ranking
In addition to the SOTA method introduced above, we provide an implementation of a popular re-ranking approach published in the CVPR 2017 paper [Re-ranking Person Re-identification with k-reciprocal Encoding](http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf). Re-ranking is a post-processing step to improve retrieval accuracy. The proposed approach is fast, fully automatic, unsupervised, and shown to outperform other state-of-the-art methods with regards to accuracy.
## Frequently asked questions
Answers to Frequently Asked Questions such as "How many images do I need to train a model?" or "How to annotate images?" can be found in the [FAQ.md](FAQ.md) file. For image classification specified questions, see the [FAQ.md](../classification/FAQ.md) in the classification folder.

Просмотреть файл

@ -16,7 +16,7 @@ import random
from PIL import Image
from torch import tensor
from pathlib import Path
from fastai.vision import cnn_learner, models
from fastai.vision import cnn_learner, DatasetType, models
from fastai.vision.data import ImageList, imagenet_stats
from typing import List, Tuple
from tempfile import TemporaryDirectory
@ -35,6 +35,7 @@ from utils_cv.detection.model import (
_apply_threshold,
)
from utils_cv.similarity.data import Urls as is_urls
from utils_cv.similarity.model import compute_features_learner
def path_classification_notebooks():
@ -279,7 +280,7 @@ def tiny_ic_databunch(tmp_session):
.split_by_rand_pct(valid_pct=0.1, seed=20)
.label_from_folder()
.transform(size=50)
.databunch(bs=16, num_workers = db_num_workers())
.databunch(bs=16, num_workers=db_num_workers())
.normalize(imagenet_stats)
)
@ -351,7 +352,7 @@ def testing_databunch(tmp_session):
.split_by_rand_pct(valid_pct=0.2, seed=20)
.label_from_folder()
.transform(size=300)
.databunch(bs=16, num_workers = db_num_workers())
.databunch(bs=16, num_workers=db_num_workers())
.normalize(imagenet_stats)
)
@ -735,6 +736,7 @@ def workspace_region(request):
# ------|-- Similarity ---------------------------------------------
@pytest.fixture(scope="session")
def tiny_is_data_path(tmp_session) -> str:
""" Returns the path to the tiny fridge objects dataset. """
@ -743,4 +745,14 @@ def tiny_is_data_path(tmp_session) -> str:
fpath=tmp_session,
dest=tmp_session,
exist_ok=True,
)
)
@pytest.fixture(scope="session")
def tiny_ic_databunch_valid_features(tiny_ic_databunch):
learn = cnn_learner(tiny_ic_databunch, models.resnet18)
embedding_layer = learn.model[1][6]
features = compute_features_learner(
tiny_ic_databunch, DatasetType.Valid, learn, embedding_layer
)
return features

Просмотреть файл

@ -8,6 +8,7 @@ from pytest import approx
from utils_cv.similarity.data import comparative_set_builder
from utils_cv.similarity.metrics import (
compute_distances,
evaluate,
positive_image_ranks,
recall_at_k,
vector_distance,
@ -64,3 +65,31 @@ def test_recall_at_k():
assert recall_at_k(rank_list, 3) == 60
assert recall_at_k(rank_list, 6) == 100
assert recall_at_k(rank_list, 10) == 100
def test_evaluate(tiny_ic_databunch, tiny_ic_databunch_valid_features):
(rank_accs, mAP) = evaluate(
tiny_ic_databunch.valid_ds,
tiny_ic_databunch_valid_features,
use_rerank=False,
)
assert 0 <= mAP <= 1.0
assert len(rank_accs) == 6
assert max(rank_accs) <= 1.001
assert min(rank_accs) >= -0.001
for i in range(len(rank_accs) - 1):
rank_accs[i] <= rank_accs[i + 1]
(rank_accs, ap) = evaluate(
tiny_ic_databunch.valid_ds,
tiny_ic_databunch_valid_features,
use_rerank=True,
rerank_k1=2,
rerank_k2=3,
)
assert 0 <= mAP <= 1.0
assert len(rank_accs) == 6
assert max(rank_accs) <= 1.001
assert min(rank_accs) >= -0.001
for i in range(len(rank_accs) - 1):
rank_accs[i] <= rank_accs[i + 1]

Просмотреть файл

@ -1,11 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import List
from typing import Dict, List
import numpy as np
import scipy
from fastai.vision import LabelList
from .references.evaluate import evaluate_with_query_set
def vector_distance(
vec1: np.ndarray,
@ -105,3 +106,52 @@ def recall_at_k(ranks: List[int], k: int) -> float:
below_threshold = [x for x in ranks if x <= k]
percent_in_top_k = round(100.0 * len(below_threshold) / len(ranks), 1)
return percent_in_top_k
def evaluate(
data: LabelList,
features: Dict[str, np.array],
use_rerank=False,
rerank_k1=20,
rerank_k2=6,
rerank_lambda=0.3,
):
"""
Computes rank@1 through rank@10 accuracy as well as mAP, optionally with re-ranking
post-processor to improve accuracy (see the re-ranking implementation for more info).
Args:
data: Fastai's image labellist
features: Dictionary of DNN features for each image
use_rerank: use re-ranking
rerank_k1, rerank_k2, rerank_lambda: re-ranking parameters
Returns:
rank_accs: accuracy at rank1 through rank10
mAP: average precision
"""
labels = np.array([data.y[i].obj for i in range(len(data.y))])
features = np.array([features[str(s)] for s in data.items])
# Assign each image into its own group. This serves as id during evaluation to
# ensure a query image is not compared to itself during rank computation.
# For the market-1501 dataset, the group ids can be used to ensure that a query
# can not match to an image taken from the same camera.
groups = np.array(range(len(labels)))
assert len(labels) == len(groups) == features.shape[0]
# Run evaluation
rank_accs, mAP = evaluate_with_query_set(
labels,
groups,
features,
labels,
groups,
features,
use_rerank,
rerank_k1,
rerank_k2,
rerank_lambda,
)
return rank_accs, mAP

Просмотреть файл

@ -0,0 +1,141 @@
# Most of the code in this file is copied and slightly modified from:
# https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/evaluate.py
import numpy as np
import time
import torch
from .re_ranking import re_ranking
# Note: the Market1501 dataset has a slightly different evaluation procedure which can be used
# by setting is_market1501=True.
def evaluate_with_query_set(
gallery_labels,
gallery_groups,
gallery_features,
query_labels,
query_groups,
query_features,
use_rerank=False,
rerank_k1=20,
rerank_k2=6,
rerank_lambda=0.3,
is_market1501=False,
):
# Init
ap = 0.0
CMC = torch.IntTensor(len(gallery_labels)).zero_()
# Compute pairwise distance
q_g_dist = np.dot(query_features, np.transpose(gallery_features))
# Improve pairwise distances using re-ranking
if use_rerank:
print("Calculate re-ranked distances..")
q_q_dist = np.dot(query_features, np.transpose(query_features))
g_g_dist = np.dot(gallery_features, np.transpose(gallery_features))
since = time.time()
distances = re_ranking(
q_g_dist, q_q_dist, g_g_dist, k1=rerank_k1, k2=rerank_k2, lambda_value=rerank_lambda,
)
time_elapsed = time.time() - since
print(
"Reranking complete in {:.0f}m {:.0f}s".format(
time_elapsed // 60, time_elapsed % 60
)
)
else:
distances = -q_g_dist
# Compute accuracies
norm = 0
skip = 1 # set to >1 to only consider a subset of the query images
for i in range(len(query_labels))[::skip]:
ap_tmp, CMC_tmp = evaluate_helper(
distances[i, :],
query_labels[i],
query_groups[i],
gallery_labels,
gallery_groups,
is_market1501,
)
if CMC_tmp[0] == -1:
continue
norm += 1
ap += ap_tmp
CMC = CMC + CMC_tmp
# Print accuracy. Note that Market1501 normalizes by dividing over number of query images.
if is_market1501:
norm = len(query_labels) / float(skip)
ap = ap / norm
CMC = CMC.float()
CMC = CMC / norm
print(
"Rank@1:{:.1f}, rank@5:{:.1f}, mAP:{:.2f}".format(100 * CMC[0], 100 * CMC[4], ap)
)
return (CMC, ap)
# Explanation:
# - query_index: all images in the reference set with the same label as the query image ("true match")
# - camera_index: all images which share the same group (called "camera" since the code was originally written for the Market-1501 dataset).
# - junk_index2: all reference images with the same group ("camera") as the query are considered "false matches".
# - junk_index1: for the market1501 dataset, images with label -1 should be ignored.
def evaluate_helper(score, ql, qc, gl, gc, is_market1501=False):
assert type(gl) == np.ndarray, "Input gl has to be a numpy ndarray"
assert type(gc) == np.ndarray, "Input gc has to be a numpy ndarray"
# Sort scores
index = np.argsort(score) # from small to large
# Compare reference images to the query image.
query_index = np.argwhere(gl == ql)
camera_index = np.argwhere(gc == qc)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index2 = np.intersect1d(query_index, camera_index)
# For market 1501 dataset, ignore images with label -1
if is_market1501:
junk_index1a = np.argwhere(gl == -1)
junk_index1b = np.argwhere(gl == "-1")
junk_index1 = np.append(junk_index1a, junk_index1b)
junk_index = np.append(junk_index2, junk_index1)
else:
junk_index = junk_index2
CMC_tmp = compute_mAP(index, good_index, junk_index)
return CMC_tmp
def compute_mAP(index, good_index, junk_index):
ap = 0
cmc = torch.IntTensor(len(index)).zero_()
if good_index.size == 0: # if empty
cmc[0] = -1
return ap, cmc
# remove junk_index
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
# find good_index index
ngood = len(good_index)
mask = np.in1d(index, good_index)
rows_good = np.argwhere(mask) # == True)
rows_good = rows_good.flatten()
cmc[rows_good[0] :] = 1
for i in range(ngood):
d_recall = 1.0 / ngood
precision = (i + 1) * 1.0 / (rows_good[i] + 1)
if rows_good[i] != 0:
old_precision = i * 1.0 / rows_good[i]
else:
old_precision = 1.0
ap = ap + d_recall * (old_precision + precision) / 2
return ap, cmc

Просмотреть файл

@ -0,0 +1,89 @@
# This code is copied (without modification) from:
# https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/re_ranking.py
import numpy as np
"""
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
Matlab version: https://github.com/zhunzhong07/person-re-ranking
"""
"""
API
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
Returns:
final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
"""
def k_reciprocal_neigh( initial_rank, i, k1):
forward_k_neigh_index = initial_rank[i,:k1+1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
fi = np.where(backward_k_neigh_index==i)[0]
return forward_k_neigh_index[fi]
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
# The following naming, e.g. gallery_num, is different from outer scope.
# Don't care about it.
original_dist = np.concatenate(
[np.concatenate([q_q_dist, q_g_dist], axis=1),
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
axis=0)
original_dist = 2. - 2 * original_dist # change the cosine similarity metric to euclidean similarity metric
original_dist = np.power(original_dist, 2).astype(np.float32)
original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
V = np.zeros_like(original_dist).astype(np.float32)
#initial_rank = np.argsort(original_dist).astype(np.int32)
# top K1+1
initial_rank = np.argpartition( original_dist, range(1,k1+1) )
query_num = q_g_dist.shape[0]
all_num = original_dist.shape[0]
for i in range(all_num):
# k-reciprocal neighbors
k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1)
k_reciprocal_expansion_index = k_reciprocal_index
for j in range(len(k_reciprocal_index)):
candidate = k_reciprocal_index[j]
candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2)))
if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
original_dist = original_dist[:query_num,]
if k2 != 1:
V_qe = np.zeros_like(V,dtype=np.float32)
for i in range(all_num):
V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
V = V_qe
del V_qe
del initial_rank
invIndex = []
for i in range(all_num):
invIndex.append(np.where(V[:,i] != 0)[0])
jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
for i in range(query_num):
temp_min = np.zeros(shape=[1,all_num],dtype=np.float32)
indNonZero = np.where(V[i,:] != 0)[0]
indImages = []
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
jaccard_dist[i] = 1-temp_min/(2.-temp_min)
final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
del original_dist
del V
del jaccard_dist
final_dist = final_dist[:query_num,query_num:]
return final_dist