Adding re-ranking for image retrieval (#515)
This commit is contained in:
Родитель
0b503c40e9
Коммит
735f7ff474
26
NOTICE.txt
26
NOTICE.txt
|
@ -500,3 +500,29 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
--
|
||||
|
||||
https://github.com/layumi/Person_reID_baseline_pytorch
|
||||
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Zhedong Zheng
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -27,6 +27,11 @@ Below are a subset of popular papers in the field with reported accuracies on st
|
|||
| [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/abs/1811.12649) <br> (Implemented in this repository) | BMVC 2019 | No | **84%** (512-dim) <br> **89%** (2048-dim) | 61% (512-dim) <br> **65%** (2048-dim) | **78%** (512-dim) <br> **80%** (2048-dim) |
|
||||
|
||||
|
||||
## Re-ranking
|
||||
|
||||
In addition to the SOTA method introduced above, we provide an implementation of a popular re-ranking approach published in the CVPR 2017 paper [Re-ranking Person Re-identification with k-reciprocal Encoding](http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf). Re-ranking is a post-processing step to improve retrieval accuracy. The proposed approach is fast, fully automatic, unsupervised, and shown to outperform other state-of-the-art methods with regards to accuracy.
|
||||
|
||||
|
||||
## Frequently asked questions
|
||||
|
||||
Answers to Frequently Asked Questions such as "How many images do I need to train a model?" or "How to annotate images?" can be found in the [FAQ.md](FAQ.md) file. For image classification specified questions, see the [FAQ.md](../classification/FAQ.md) in the classification folder.
|
||||
|
|
|
@ -16,7 +16,7 @@ import random
|
|||
from PIL import Image
|
||||
from torch import tensor
|
||||
from pathlib import Path
|
||||
from fastai.vision import cnn_learner, models
|
||||
from fastai.vision import cnn_learner, DatasetType, models
|
||||
from fastai.vision.data import ImageList, imagenet_stats
|
||||
from typing import List, Tuple
|
||||
from tempfile import TemporaryDirectory
|
||||
|
@ -35,6 +35,7 @@ from utils_cv.detection.model import (
|
|||
_apply_threshold,
|
||||
)
|
||||
from utils_cv.similarity.data import Urls as is_urls
|
||||
from utils_cv.similarity.model import compute_features_learner
|
||||
|
||||
|
||||
def path_classification_notebooks():
|
||||
|
@ -279,7 +280,7 @@ def tiny_ic_databunch(tmp_session):
|
|||
.split_by_rand_pct(valid_pct=0.1, seed=20)
|
||||
.label_from_folder()
|
||||
.transform(size=50)
|
||||
.databunch(bs=16, num_workers = db_num_workers())
|
||||
.databunch(bs=16, num_workers=db_num_workers())
|
||||
.normalize(imagenet_stats)
|
||||
)
|
||||
|
||||
|
@ -351,7 +352,7 @@ def testing_databunch(tmp_session):
|
|||
.split_by_rand_pct(valid_pct=0.2, seed=20)
|
||||
.label_from_folder()
|
||||
.transform(size=300)
|
||||
.databunch(bs=16, num_workers = db_num_workers())
|
||||
.databunch(bs=16, num_workers=db_num_workers())
|
||||
.normalize(imagenet_stats)
|
||||
)
|
||||
|
||||
|
@ -735,6 +736,7 @@ def workspace_region(request):
|
|||
|
||||
# ------|-- Similarity ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_is_data_path(tmp_session) -> str:
|
||||
""" Returns the path to the tiny fridge objects dataset. """
|
||||
|
@ -743,4 +745,14 @@ def tiny_is_data_path(tmp_session) -> str:
|
|||
fpath=tmp_session,
|
||||
dest=tmp_session,
|
||||
exist_ok=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_ic_databunch_valid_features(tiny_ic_databunch):
|
||||
learn = cnn_learner(tiny_ic_databunch, models.resnet18)
|
||||
embedding_layer = learn.model[1][6]
|
||||
features = compute_features_learner(
|
||||
tiny_ic_databunch, DatasetType.Valid, learn, embedding_layer
|
||||
)
|
||||
return features
|
||||
|
|
|
@ -8,6 +8,7 @@ from pytest import approx
|
|||
from utils_cv.similarity.data import comparative_set_builder
|
||||
from utils_cv.similarity.metrics import (
|
||||
compute_distances,
|
||||
evaluate,
|
||||
positive_image_ranks,
|
||||
recall_at_k,
|
||||
vector_distance,
|
||||
|
@ -64,3 +65,31 @@ def test_recall_at_k():
|
|||
assert recall_at_k(rank_list, 3) == 60
|
||||
assert recall_at_k(rank_list, 6) == 100
|
||||
assert recall_at_k(rank_list, 10) == 100
|
||||
|
||||
|
||||
def test_evaluate(tiny_ic_databunch, tiny_ic_databunch_valid_features):
|
||||
(rank_accs, mAP) = evaluate(
|
||||
tiny_ic_databunch.valid_ds,
|
||||
tiny_ic_databunch_valid_features,
|
||||
use_rerank=False,
|
||||
)
|
||||
assert 0 <= mAP <= 1.0
|
||||
assert len(rank_accs) == 6
|
||||
assert max(rank_accs) <= 1.001
|
||||
assert min(rank_accs) >= -0.001
|
||||
for i in range(len(rank_accs) - 1):
|
||||
rank_accs[i] <= rank_accs[i + 1]
|
||||
|
||||
(rank_accs, ap) = evaluate(
|
||||
tiny_ic_databunch.valid_ds,
|
||||
tiny_ic_databunch_valid_features,
|
||||
use_rerank=True,
|
||||
rerank_k1=2,
|
||||
rerank_k2=3,
|
||||
)
|
||||
assert 0 <= mAP <= 1.0
|
||||
assert len(rank_accs) == 6
|
||||
assert max(rank_accs) <= 1.001
|
||||
assert min(rank_accs) >= -0.001
|
||||
for i in range(len(rank_accs) - 1):
|
||||
rank_accs[i] <= rank_accs[i + 1]
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import List
|
||||
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
from fastai.vision import LabelList
|
||||
from .references.evaluate import evaluate_with_query_set
|
||||
|
||||
|
||||
def vector_distance(
|
||||
vec1: np.ndarray,
|
||||
|
@ -105,3 +106,52 @@ def recall_at_k(ranks: List[int], k: int) -> float:
|
|||
below_threshold = [x for x in ranks if x <= k]
|
||||
percent_in_top_k = round(100.0 * len(below_threshold) / len(ranks), 1)
|
||||
return percent_in_top_k
|
||||
|
||||
|
||||
def evaluate(
|
||||
data: LabelList,
|
||||
features: Dict[str, np.array],
|
||||
use_rerank=False,
|
||||
rerank_k1=20,
|
||||
rerank_k2=6,
|
||||
rerank_lambda=0.3,
|
||||
):
|
||||
"""
|
||||
Computes rank@1 through rank@10 accuracy as well as mAP, optionally with re-ranking
|
||||
post-processor to improve accuracy (see the re-ranking implementation for more info).
|
||||
|
||||
Args:
|
||||
data: Fastai's image labellist
|
||||
features: Dictionary of DNN features for each image
|
||||
use_rerank: use re-ranking
|
||||
rerank_k1, rerank_k2, rerank_lambda: re-ranking parameters
|
||||
Returns:
|
||||
rank_accs: accuracy at rank1 through rank10
|
||||
mAP: average precision
|
||||
|
||||
"""
|
||||
|
||||
labels = np.array([data.y[i].obj for i in range(len(data.y))])
|
||||
features = np.array([features[str(s)] for s in data.items])
|
||||
|
||||
# Assign each image into its own group. This serves as id during evaluation to
|
||||
# ensure a query image is not compared to itself during rank computation.
|
||||
# For the market-1501 dataset, the group ids can be used to ensure that a query
|
||||
# can not match to an image taken from the same camera.
|
||||
groups = np.array(range(len(labels)))
|
||||
assert len(labels) == len(groups) == features.shape[0]
|
||||
|
||||
# Run evaluation
|
||||
rank_accs, mAP = evaluate_with_query_set(
|
||||
labels,
|
||||
groups,
|
||||
features,
|
||||
labels,
|
||||
groups,
|
||||
features,
|
||||
use_rerank,
|
||||
rerank_k1,
|
||||
rerank_k2,
|
||||
rerank_lambda,
|
||||
)
|
||||
return rank_accs, mAP
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
# Most of the code in this file is copied and slightly modified from:
|
||||
# https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/evaluate.py
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
||||
from .re_ranking import re_ranking
|
||||
|
||||
|
||||
# Note: the Market1501 dataset has a slightly different evaluation procedure which can be used
|
||||
# by setting is_market1501=True.
|
||||
def evaluate_with_query_set(
|
||||
gallery_labels,
|
||||
gallery_groups,
|
||||
gallery_features,
|
||||
query_labels,
|
||||
query_groups,
|
||||
query_features,
|
||||
use_rerank=False,
|
||||
rerank_k1=20,
|
||||
rerank_k2=6,
|
||||
rerank_lambda=0.3,
|
||||
is_market1501=False,
|
||||
):
|
||||
|
||||
# Init
|
||||
ap = 0.0
|
||||
CMC = torch.IntTensor(len(gallery_labels)).zero_()
|
||||
|
||||
# Compute pairwise distance
|
||||
q_g_dist = np.dot(query_features, np.transpose(gallery_features))
|
||||
|
||||
# Improve pairwise distances using re-ranking
|
||||
if use_rerank:
|
||||
print("Calculate re-ranked distances..")
|
||||
q_q_dist = np.dot(query_features, np.transpose(query_features))
|
||||
g_g_dist = np.dot(gallery_features, np.transpose(gallery_features))
|
||||
since = time.time()
|
||||
distances = re_ranking(
|
||||
q_g_dist, q_q_dist, g_g_dist, k1=rerank_k1, k2=rerank_k2, lambda_value=rerank_lambda,
|
||||
)
|
||||
time_elapsed = time.time() - since
|
||||
print(
|
||||
"Reranking complete in {:.0f}m {:.0f}s".format(
|
||||
time_elapsed // 60, time_elapsed % 60
|
||||
)
|
||||
)
|
||||
else:
|
||||
distances = -q_g_dist
|
||||
|
||||
# Compute accuracies
|
||||
norm = 0
|
||||
skip = 1 # set to >1 to only consider a subset of the query images
|
||||
for i in range(len(query_labels))[::skip]:
|
||||
ap_tmp, CMC_tmp = evaluate_helper(
|
||||
distances[i, :],
|
||||
query_labels[i],
|
||||
query_groups[i],
|
||||
gallery_labels,
|
||||
gallery_groups,
|
||||
is_market1501,
|
||||
)
|
||||
if CMC_tmp[0] == -1:
|
||||
continue
|
||||
norm += 1
|
||||
ap += ap_tmp
|
||||
CMC = CMC + CMC_tmp
|
||||
|
||||
# Print accuracy. Note that Market1501 normalizes by dividing over number of query images.
|
||||
if is_market1501:
|
||||
norm = len(query_labels) / float(skip)
|
||||
ap = ap / norm
|
||||
CMC = CMC.float()
|
||||
CMC = CMC / norm
|
||||
print(
|
||||
"Rank@1:{:.1f}, rank@5:{:.1f}, mAP:{:.2f}".format(100 * CMC[0], 100 * CMC[4], ap)
|
||||
)
|
||||
|
||||
return (CMC, ap)
|
||||
|
||||
|
||||
# Explanation:
|
||||
# - query_index: all images in the reference set with the same label as the query image ("true match")
|
||||
# - camera_index: all images which share the same group (called "camera" since the code was originally written for the Market-1501 dataset).
|
||||
# - junk_index2: all reference images with the same group ("camera") as the query are considered "false matches".
|
||||
# - junk_index1: for the market1501 dataset, images with label -1 should be ignored.
|
||||
def evaluate_helper(score, ql, qc, gl, gc, is_market1501=False):
|
||||
assert type(gl) == np.ndarray, "Input gl has to be a numpy ndarray"
|
||||
assert type(gc) == np.ndarray, "Input gc has to be a numpy ndarray"
|
||||
|
||||
# Sort scores
|
||||
index = np.argsort(score) # from small to large
|
||||
|
||||
# Compare reference images to the query image.
|
||||
query_index = np.argwhere(gl == ql)
|
||||
camera_index = np.argwhere(gc == qc)
|
||||
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
|
||||
junk_index2 = np.intersect1d(query_index, camera_index)
|
||||
|
||||
# For market 1501 dataset, ignore images with label -1
|
||||
if is_market1501:
|
||||
junk_index1a = np.argwhere(gl == -1)
|
||||
junk_index1b = np.argwhere(gl == "-1")
|
||||
junk_index1 = np.append(junk_index1a, junk_index1b)
|
||||
junk_index = np.append(junk_index2, junk_index1)
|
||||
else:
|
||||
junk_index = junk_index2
|
||||
|
||||
CMC_tmp = compute_mAP(index, good_index, junk_index)
|
||||
return CMC_tmp
|
||||
|
||||
|
||||
def compute_mAP(index, good_index, junk_index):
|
||||
ap = 0
|
||||
cmc = torch.IntTensor(len(index)).zero_()
|
||||
if good_index.size == 0: # if empty
|
||||
cmc[0] = -1
|
||||
return ap, cmc
|
||||
|
||||
# remove junk_index
|
||||
mask = np.in1d(index, junk_index, invert=True)
|
||||
index = index[mask]
|
||||
|
||||
# find good_index index
|
||||
ngood = len(good_index)
|
||||
mask = np.in1d(index, good_index)
|
||||
rows_good = np.argwhere(mask) # == True)
|
||||
rows_good = rows_good.flatten()
|
||||
|
||||
cmc[rows_good[0] :] = 1
|
||||
for i in range(ngood):
|
||||
d_recall = 1.0 / ngood
|
||||
precision = (i + 1) * 1.0 / (rows_good[i] + 1)
|
||||
if rows_good[i] != 0:
|
||||
old_precision = i * 1.0 / rows_good[i]
|
||||
else:
|
||||
old_precision = 1.0
|
||||
ap = ap + d_recall * (old_precision + precision) / 2
|
||||
|
||||
return ap, cmc
|
|
@ -0,0 +1,89 @@
|
|||
# This code is copied (without modification) from:
|
||||
# https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/re_ranking.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
"""
|
||||
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
|
||||
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
|
||||
Matlab version: https://github.com/zhunzhong07/person-re-ranking
|
||||
"""
|
||||
|
||||
"""
|
||||
API
|
||||
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
|
||||
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
|
||||
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
|
||||
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
|
||||
Returns:
|
||||
final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
|
||||
"""
|
||||
def k_reciprocal_neigh( initial_rank, i, k1):
|
||||
forward_k_neigh_index = initial_rank[i,:k1+1]
|
||||
backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
|
||||
fi = np.where(backward_k_neigh_index==i)[0]
|
||||
return forward_k_neigh_index[fi]
|
||||
|
||||
|
||||
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
|
||||
# The following naming, e.g. gallery_num, is different from outer scope.
|
||||
# Don't care about it.
|
||||
original_dist = np.concatenate(
|
||||
[np.concatenate([q_q_dist, q_g_dist], axis=1),
|
||||
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
|
||||
axis=0)
|
||||
original_dist = 2. - 2 * original_dist # change the cosine similarity metric to euclidean similarity metric
|
||||
original_dist = np.power(original_dist, 2).astype(np.float32)
|
||||
original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
|
||||
V = np.zeros_like(original_dist).astype(np.float32)
|
||||
#initial_rank = np.argsort(original_dist).astype(np.int32)
|
||||
# top K1+1
|
||||
initial_rank = np.argpartition( original_dist, range(1,k1+1) )
|
||||
|
||||
query_num = q_g_dist.shape[0]
|
||||
all_num = original_dist.shape[0]
|
||||
|
||||
for i in range(all_num):
|
||||
# k-reciprocal neighbors
|
||||
k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1)
|
||||
k_reciprocal_expansion_index = k_reciprocal_index
|
||||
for j in range(len(k_reciprocal_index)):
|
||||
candidate = k_reciprocal_index[j]
|
||||
candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2)))
|
||||
if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
|
||||
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
|
||||
|
||||
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
|
||||
weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
|
||||
V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
|
||||
|
||||
original_dist = original_dist[:query_num,]
|
||||
if k2 != 1:
|
||||
V_qe = np.zeros_like(V,dtype=np.float32)
|
||||
for i in range(all_num):
|
||||
V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
|
||||
V = V_qe
|
||||
del V_qe
|
||||
del initial_rank
|
||||
invIndex = []
|
||||
for i in range(all_num):
|
||||
invIndex.append(np.where(V[:,i] != 0)[0])
|
||||
|
||||
jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
|
||||
|
||||
for i in range(query_num):
|
||||
temp_min = np.zeros(shape=[1,all_num],dtype=np.float32)
|
||||
indNonZero = np.where(V[i,:] != 0)[0]
|
||||
indImages = []
|
||||
indImages = [invIndex[ind] for ind in indNonZero]
|
||||
for j in range(len(indNonZero)):
|
||||
temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
|
||||
jaccard_dist[i] = 1-temp_min/(2.-temp_min)
|
||||
|
||||
final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num,query_num:]
|
||||
return final_dist
|
Загрузка…
Ссылка в новой задаче