New python interface, build setup, apps and unit tests (#308)
--------- Co-authored-by: Dax Pryce <daxpryce@microsoft.com> * Adding some diagnostics to a pr build in an attempt to see what is going on with our systems prior to running our streaming/incremental tests * fix cast error and add some status prints to in-mem-dynamic app * Adding unit tests for both memory and disk index builder methods * After the refactor and polish of the API was left half done, I also left half a jillion bugs in the library. At least I'm confident that build_memory_index and StaticMemoryIndex work in some cases, whereas before they barely were getting off the ground * Sanity checks of static index (not comprehensive coverage), and tombstone file for test_dynamic_memory_index * Argument range checks of some of the static memory index values. * fixes for dynamic index in python interface (#334) * create separate default number of frozen points for dynamic indices * consolidate works * remove superfluous param from dynamic index * remove superfluous param from dynamic index * batch insert and args modification to apps * batch insert and args modification to apps * typo * Committing the updated unit tests. At least the initial sanity checks of StaticMemory are done * Fixing an error in the static memory index ctor * Formatting python with black * Have to disable initial load with DynamicMemoryIndex, as there is no way to build a memory index with an associated tags file yet, making it impossible to load an index without tags * Working on unit tests and need to pull harsha's changes * I think I aligned this such that we can execute it via command line with the right behaviors * Providing rest of parameters build_memory_index requires * For some reason argparse is allowing a bunch of blank space to come in on arguments and they need stripped. It also needs to be using the right types. * Recall test now works * More unit tests for dynamic memory index * Adding different range check for alpha, as the values are only really that realistic between 1 and 2. Below 1 is an error, and above 2 we'll probably make a warning going forward * Storing this while I cut a new branch and walk back some work for a future branch * Undoing the auto load of the dynamic index until I can debug why my tag vector files cause an error in diskann * Updating the documentation for the python bindings. It's a lot closer than it was. * Fixing a unit test * add timers to dyanmic apps (#337) * add timers to dyanmic apps * clang format * np.uintc vs. int for dtype of tags * fixes to types in dynamic app * cast tags to np.uintc array * more timers * added example code in comments in app file * round elapsed * fix typo * fix typo --------- Co-authored-by: Harsha Vardhan Simhadri <harsha-simhadri@users.noreply.github.com> Co-authored-by: harsha vardhan simhadri <harsha.v.simhadri@gmail.com>
This commit is contained in:
Родитель
45a54090d7
Коммит
38d8c44cd5
|
@ -242,8 +242,10 @@ else()
|
|||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(tests)
|
||||
add_subdirectory(tests/utils)
|
||||
if (NOT PYBIND)
|
||||
add_subdirectory(tests)
|
||||
add_subdirectory(tests/utils)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
|
||||
|
|
|
@ -10,4 +10,3 @@ recursive-include python *
|
|||
recursive-include windows *
|
||||
prune python/tests
|
||||
recursive-include src *
|
||||
recursive-include tests *
|
||||
|
|
|
@ -10,10 +10,10 @@ namespace defaults
|
|||
{
|
||||
const float ALPHA = 1.2f;
|
||||
const uint32_t NUM_THREADS = 0;
|
||||
const uint32_t NUM_ROUNDS = 2;
|
||||
const uint32_t MAX_OCCLUSION_SIZE = 750;
|
||||
const uint32_t FILTER_LIST_SIZE = 0;
|
||||
const uint32_t NUM_FROZEN_POINTS = 0;
|
||||
const uint32_t NUM_FROZEN_POINTS_STATIC = 0;
|
||||
const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1;
|
||||
// following constants should always be specified, but are useful as a
|
||||
// sensible default at cli / python boundaries
|
||||
const uint32_t MAX_DEGREE = 64;
|
||||
|
|
|
@ -106,14 +106,15 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
|
|||
|
||||
// Batch build from a file. Optionally pass tags vector.
|
||||
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
|
||||
IndexWriteParameters ¶meters, const std::vector<TagT> &tags = std::vector<TagT>());
|
||||
const IndexWriteParameters ¶meters,
|
||||
const std::vector<TagT> &tags = std::vector<TagT>());
|
||||
|
||||
// Batch build from a file. Optionally pass tags file.
|
||||
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
|
||||
IndexWriteParameters ¶meters, const char *tag_filename);
|
||||
const IndexWriteParameters ¶meters, const char *tag_filename);
|
||||
|
||||
// Batch build from a data array, which must pad vectors to aligned_dim
|
||||
DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, IndexWriteParameters ¶meters,
|
||||
DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters,
|
||||
const std::vector<TagT> &tags);
|
||||
|
||||
// Filtered Support
|
||||
|
@ -215,7 +216,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
|
|||
|
||||
// Use after _data and _nd have been populated
|
||||
// Acquire exclusive _update_lock before calling
|
||||
void build_with_data_populated(IndexWriteParameters ¶meters, const std::vector<TagT> &tags);
|
||||
void build_with_data_populated(const IndexWriteParameters ¶meters, const std::vector<TagT> &tags);
|
||||
|
||||
// generates 1 frozen point that will never be deleted from the graph
|
||||
// This is not visible to the user
|
||||
|
@ -261,7 +262,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
|
|||
void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch);
|
||||
|
||||
// Acquire exclusive _update_lock before calling
|
||||
void link(IndexWriteParameters ¶meters);
|
||||
void link(const IndexWriteParameters ¶meters);
|
||||
|
||||
// Acquire exclusive _tag_lock and _delete_lock before calling
|
||||
int reserve_location();
|
||||
|
|
|
@ -20,17 +20,16 @@ class IndexWriteParameters
|
|||
const bool saturate_graph;
|
||||
const uint32_t max_occlusion_size; // C
|
||||
const float alpha;
|
||||
const uint32_t num_rounds;
|
||||
const uint32_t num_threads;
|
||||
const uint32_t filter_list_size; // Lf
|
||||
const uint32_t num_frozen_points;
|
||||
|
||||
private:
|
||||
IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph,
|
||||
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_rounds,
|
||||
const uint32_t num_threads, const uint32_t filter_list_size, const uint32_t num_frozen_points)
|
||||
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads,
|
||||
const uint32_t filter_list_size, const uint32_t num_frozen_points)
|
||||
: search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph),
|
||||
max_occlusion_size(max_occlusion_size), alpha(alpha), num_rounds(num_rounds), num_threads(num_threads),
|
||||
max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads),
|
||||
filter_list_size(filter_list_size), num_frozen_points(num_frozen_points)
|
||||
{
|
||||
}
|
||||
|
@ -70,21 +69,15 @@ class IndexWriteParametersBuilder
|
|||
return *this;
|
||||
}
|
||||
|
||||
IndexWriteParametersBuilder &with_num_rounds(const uint32_t num_rounds)
|
||||
{
|
||||
_num_rounds = num_rounds;
|
||||
return *this;
|
||||
}
|
||||
|
||||
IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads)
|
||||
{
|
||||
_num_threads = num_threads;
|
||||
_num_threads = num_threads == 0 ? omp_get_num_threads() : num_threads;
|
||||
return *this;
|
||||
}
|
||||
|
||||
IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size)
|
||||
{
|
||||
_filter_list_size = filter_list_size;
|
||||
_filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -97,13 +90,13 @@ class IndexWriteParametersBuilder
|
|||
IndexWriteParameters build() const
|
||||
{
|
||||
return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha,
|
||||
_num_rounds, _num_threads, _filter_list_size, _num_frozen_points);
|
||||
_num_threads, _filter_list_size, _num_frozen_points);
|
||||
}
|
||||
|
||||
IndexWriteParametersBuilder(const IndexWriteParameters &wp)
|
||||
: _search_list_size(wp.search_list_size), _max_degree(wp.max_degree),
|
||||
_max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha),
|
||||
_num_rounds(wp.num_rounds), _filter_list_size(wp.filter_list_size), _num_frozen_points(wp.num_frozen_points)
|
||||
_filter_list_size(wp.filter_list_size), _num_frozen_points(wp.num_frozen_points)
|
||||
{
|
||||
}
|
||||
IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete;
|
||||
|
@ -115,10 +108,9 @@ class IndexWriteParametersBuilder
|
|||
uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE};
|
||||
bool _saturate_graph{defaults::SATURATE_GRAPH};
|
||||
float _alpha{defaults::ALPHA};
|
||||
uint32_t _num_rounds{defaults::NUM_ROUNDS};
|
||||
uint32_t _num_threads{defaults::NUM_THREADS};
|
||||
uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
|
||||
uint32_t _num_frozen_points{defaults::NUM_FROZEN_POINTS};
|
||||
uint32_t _num_frozen_points{defaults::NUM_FROZEN_POINTS_STATIC};
|
||||
};
|
||||
|
||||
} // namespace diskann
|
||||
|
|
|
@ -5,6 +5,7 @@ requires = [
|
|||
"cmake>=3.22",
|
||||
"numpy>=1.21",
|
||||
"wheel",
|
||||
"ninja"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import utils
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="cluster", description="kmeans cluster points in a file"
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-k", "--num_clusters", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
npts, ndims = get_bin_metadata(indexdata_file)
|
||||
|
||||
data = utils.bin_to_numpy(args.data_type, args.indexdata_file)
|
||||
|
||||
offsets, permutation = utils.cluster_and_permute(
|
||||
args.data_type, npts, ndims, data, args.num_clusters
|
||||
)
|
||||
|
||||
permuted_data = data[permutation]
|
||||
|
||||
utils.numpy_to_bin(permuted_data, args.indexdata_file + ".cluster")
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
def insert_and_search(
|
||||
dtype_str,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
K,
|
||||
Ls,
|
||||
num_insert_threads,
|
||||
num_search_threads,
|
||||
gt_file,
|
||||
):
|
||||
npts, ndims = utils.get_bin_metadata(indexdata_file)
|
||||
|
||||
if dtype_str == "float":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.float32, ndims, npts, Lb, graph_degree
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.float32, querydata_file)
|
||||
data = utils.bin_to_numpy(np.float32, indexdata_file)
|
||||
elif dtype_str == "int8":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.int8, ndims, npts, Lb, graph_degree
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.int8, querydata_file)
|
||||
data = utils.bin_to_numpy(np.int8, indexdata_file)
|
||||
elif dtype_str == "uint8":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.uint8, ndims, npts, Lb, graph_degree
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.uint8, querydata_file)
|
||||
data = utils.bin_to_numpy(np.uint8, indexdata_file)
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
tags = np.zeros(npts, dtype=np.uintc)
|
||||
timer = utils.timer()
|
||||
for i in range(npts):
|
||||
tags[i] = i + 1
|
||||
index.batch_insert(data, tags, num_insert_threads)
|
||||
print('batch_insert complete in', timer.elapsed(), 's')
|
||||
|
||||
delete_tags = np.random.choice(
|
||||
np.array(range(1, npts + 1, 1), dtype=np.uintc),
|
||||
size=int(0.5 * npts),
|
||||
replace=False
|
||||
)
|
||||
for tag in delete_tags:
|
||||
index.mark_deleted(tag)
|
||||
print('mark deletion completed in', timer.elapsed(), 's')
|
||||
|
||||
index.consolidate_delete()
|
||||
print('consolidation completed in', timer.elapsed(), 's')
|
||||
|
||||
deleted_data = data[delete_tags - 1, :]
|
||||
|
||||
index.batch_insert(deleted_data, delete_tags, num_insert_threads)
|
||||
print('re-insertion completed in', timer.elapsed(), 's')
|
||||
|
||||
tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
|
||||
print('Batch searched', queries.shape[0], ' queries in ', timer.elapsed(), 's')
|
||||
|
||||
res_ids = tags - 1
|
||||
if gt_file != "":
|
||||
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-dynamic",
|
||||
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
|
||||
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
insert_and_search(
|
||||
args.data_type,
|
||||
args.indexdata_file,
|
||||
args.querydata_file,
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_insert_threads,
|
||||
args.num_search_threads, # search args
|
||||
args.gt_file,
|
||||
)
|
||||
|
||||
# An ingest optimized example with SIFT1M
|
||||
# python3 ~/DiskANN/python/apps/in-mem-dynamic.py -d float \
|
||||
# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \
|
||||
# -Lb 10 -R 30 -Ls 200
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
from xml.dom.pulldom import default_bufsize
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
|
||||
def build_and_search(
|
||||
dtype_str,
|
||||
index_directory,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
K,
|
||||
Ls,
|
||||
num_threads,
|
||||
gt_file,
|
||||
index_prefix
|
||||
):
|
||||
if dtype_str == "float":
|
||||
dtype = np.single
|
||||
elif dtype_str == "int8":
|
||||
dtype = np.byte
|
||||
elif dtype_str == "uint8":
|
||||
dtype = np.ubyte
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
# build index
|
||||
diskannpy.build_memory_index(
|
||||
data=indexdata_file,
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
index_directory=index_directory,
|
||||
complexity=Lb,
|
||||
graph_degree=graph_degree,
|
||||
num_threads=num_threads,
|
||||
index_prefix=index_prefix,
|
||||
alpha=1.2,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=8,
|
||||
use_opq=False,
|
||||
)
|
||||
|
||||
# ready search object
|
||||
index = diskannpy.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
data_path=indexdata_file,
|
||||
index_directory=index_directory,
|
||||
num_threads=num_threads, # this can be different at search time if you would like
|
||||
initial_search_complexity=Ls,
|
||||
index_prefix=index_prefix
|
||||
)
|
||||
|
||||
queries = utils.bin_to_numpy(dtype, querydata_file)
|
||||
|
||||
ids, dists = index.batch_search(queries, 10, Ls, num_threads)
|
||||
|
||||
if gt_file != "":
|
||||
recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-static",
|
||||
description="Static in-memory build and search from vectors in a file",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-id", "--index_directory", required=False, default=".")
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-T", "--num_threads", default=8, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
parser.add_argument("-ip", "--index_prefix", required=False, default="ann")
|
||||
args = parser.parse_args()
|
||||
|
||||
build_and_search(
|
||||
args.data_type,
|
||||
args.index_directory.strip(),
|
||||
args.indexdata_file.strip(),
|
||||
args.querydata_file.strip(),
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_threads, # search args
|
||||
args.gt_file,
|
||||
args.index_prefix
|
||||
)
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
import diskannpy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
|
||||
def insert_and_search(
|
||||
dtype_str,
|
||||
indexdata_file,
|
||||
querydata_file,
|
||||
Lb,
|
||||
graph_degree,
|
||||
num_clusters,
|
||||
num_insert_threads,
|
||||
K,
|
||||
Ls,
|
||||
num_search_threads,
|
||||
gt_file,
|
||||
):
|
||||
npts, ndims = utils.get_bin_metadata(indexdata_file)
|
||||
|
||||
if dtype_str == "float":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.float32, ndims, npts, Lb, graph_degree, False
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.float32, querydata_file)
|
||||
data = utils.bin_to_numpy(np.float32, indexdata_file)
|
||||
elif dtype_str == "int8":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.int8, ndims, npts, Lb, graph_degree
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.int8, querydata_file)
|
||||
data = utils.bin_to_numpy(np.int8, indexdata_file)
|
||||
elif dtype_str == "uint8":
|
||||
index = diskannpy.DynamicMemoryIndex(
|
||||
"l2", np.uint8, ndims, npts, Lb, graph_degree
|
||||
)
|
||||
queries = utils.bin_to_numpy(np.uint8, querydata_file)
|
||||
data = utils.bin_to_numpy(np.uint8, indexdata_file)
|
||||
else:
|
||||
raise ValueError("data_type must be float, int8 or uint8")
|
||||
|
||||
offsets, permutation = utils.cluster_and_permute(
|
||||
dtype_str, npts, ndims, data, num_clusters
|
||||
)
|
||||
|
||||
i = 0
|
||||
timer = utils.timer()
|
||||
for c in range(num_clusters):
|
||||
cluster_index_range = range(offsets[c], offsets[c + 1])
|
||||
cluster_indices = np.array(permutation[cluster_index_range], dtype=np.uintc)
|
||||
cluster_data = data[cluster_indices, :]
|
||||
index.batch_insert(cluster_data, cluster_indices + 1, num_insert_threads)
|
||||
print('Inserted cluster', c, 'in', timer.elapsed(), 's')
|
||||
tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
|
||||
print('Batch searched', queries.shape[0], 'queries in', timer.elapsed(), 's')
|
||||
res_ids = tags - 1
|
||||
|
||||
if gt_file != "":
|
||||
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
|
||||
print(f"recall@{K} is {recall}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="in-mem-dynamic",
|
||||
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
|
||||
)
|
||||
|
||||
parser.add_argument("-d", "--data_type", required=True)
|
||||
parser.add_argument("-i", "--indexdata_file", required=True)
|
||||
parser.add_argument("-q", "--querydata_file", required=True)
|
||||
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
|
||||
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
|
||||
parser.add_argument("-R", "--graph_degree", default=32, type=int)
|
||||
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
|
||||
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
|
||||
parser.add_argument("-C", "--num_clusters", default=32, type=int)
|
||||
parser.add_argument("-K", default=10, type=int)
|
||||
parser.add_argument("--gt_file", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
insert_and_search(
|
||||
args.data_type,
|
||||
args.indexdata_file,
|
||||
args.querydata_file,
|
||||
args.Lbuild,
|
||||
args.graph_degree, # Build args
|
||||
args.num_clusters,
|
||||
args.num_insert_threads,
|
||||
args.K,
|
||||
args.Lsearch,
|
||||
args.num_search_threads, # search args
|
||||
args.gt_file,
|
||||
)
|
||||
|
||||
# An ingest optimized example with SIFT1M
|
||||
# python3 ~/DiskANN/python/apps/insert-in-clustered-order.py -d float \
|
||||
# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \
|
||||
# -Lb 10 -R 30 -Ls 200 -C 32
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import vq, kmeans2
|
||||
from typing import Tuple
|
||||
from time import perf_counter
|
||||
|
||||
|
||||
def get_bin_metadata(bin_file) -> Tuple[int, int]:
|
||||
array = np.fromfile(file=bin_file, dtype=np.uint32, count=2)
|
||||
return array[0], array[1]
|
||||
|
||||
|
||||
def bin_to_numpy(dtype, bin_file) -> np.ndarray:
|
||||
npts, ndims = get_bin_metadata(bin_file)
|
||||
return np.fromfile(file=bin_file, dtype=dtype, offset=8).reshape(npts, ndims)
|
||||
|
||||
class timer:
|
||||
last = perf_counter()
|
||||
|
||||
def elapsed(self, round_digit:int = 3):
|
||||
new = perf_counter()
|
||||
elapsed_time = new - self.last
|
||||
self.last = new
|
||||
return round(elapsed_time, round_digit)
|
||||
|
||||
|
||||
def numpy_to_bin(array, out_file):
|
||||
shape = np.shape(array)
|
||||
npts = shape[0].astype(np.uint32)
|
||||
ndims = shape[1].astype(np.uint32)
|
||||
f = open(out_file, "wb")
|
||||
f.write(npts.tobytes())
|
||||
f.write(ndims.tobytes())
|
||||
f.write(array.tobytes())
|
||||
f.close()
|
||||
|
||||
|
||||
def read_gt_file(gt_file) -> Tuple[np.ndarray[int], np.ndarray[float]]:
|
||||
"""
|
||||
Return ids and distances to queries
|
||||
"""
|
||||
nq, K = get_bin_metadata(gt_file)
|
||||
ids = np.fromfile(file=gt_file, dtype=np.uint32, offset=8, count=nq * K).reshape(
|
||||
nq, K
|
||||
)
|
||||
dists = np.fromfile(
|
||||
file=gt_file, dtype=np.float32, offset=8 + nq * K * 4, count=nq * K
|
||||
).reshape(nq, K)
|
||||
return ids, dists
|
||||
|
||||
|
||||
def calculate_recall(
|
||||
result_set_indices: np.ndarray[int],
|
||||
truth_set_indices: np.ndarray[int],
|
||||
recall_at: int = 5,
|
||||
) -> float:
|
||||
"""
|
||||
result_set_indices and truth_set_indices correspond by row index. the columns in each row contain the indices of
|
||||
the nearest neighbors, with result_set_indices being the approximate nearest neighbor results and truth_set_indices
|
||||
being the brute force nearest neighbor calculation via sklearn's NearestNeighbor class.
|
||||
:param result_set_indices:
|
||||
:param truth_set_indices:
|
||||
:param recall_at:
|
||||
:return:
|
||||
"""
|
||||
found = 0
|
||||
for i in range(0, result_set_indices.shape[0]):
|
||||
result_set_set = set(result_set_indices[i][0:recall_at])
|
||||
truth_set_set = set(truth_set_indices[i][0:recall_at])
|
||||
found += len(result_set_set.intersection(truth_set_set))
|
||||
return found / (result_set_indices.shape[0] * recall_at)
|
||||
|
||||
|
||||
def calculate_recall_from_gt_file(K: int, ids: np.ndarray[int], gt_file: str) -> float:
|
||||
"""
|
||||
Calculate recall from ids returned from search and those read from file
|
||||
"""
|
||||
gt_ids, gt_dists = read_gt_file(gt_file)
|
||||
return calculate_recall(ids, gt_ids, K)
|
||||
|
||||
|
||||
def cluster_and_permute(
|
||||
dtype_str, npts, ndims, data, num_clusters
|
||||
) -> Tuple[np.ndarray[int], np.ndarray[int]]:
|
||||
"""
|
||||
Cluster the data and return permutation of row indices
|
||||
that would group indices of the same cluster together
|
||||
"""
|
||||
sample_size = min(100000, npts)
|
||||
sample_indices = np.random.choice(range(npts), size=sample_size, replace=False)
|
||||
sampled_data = data[sample_indices, :]
|
||||
centroids, sample_labels = kmeans2(sampled_data, num_clusters, minit="++", iter=10)
|
||||
labels, dist = vq(data, centroids)
|
||||
|
||||
count = np.zeros(num_clusters)
|
||||
for i in range(npts):
|
||||
count[labels[i]] += 1
|
||||
print("Cluster counts")
|
||||
print(count)
|
||||
|
||||
offsets = np.zeros(num_clusters + 1, dtype=int)
|
||||
for i in range(0, num_clusters, 1):
|
||||
offsets[i + 1] = offsets[i] + count[i]
|
||||
|
||||
permutation = np.zeros(npts, dtype=int)
|
||||
counters = np.zeros(num_clusters, dtype=int)
|
||||
for i in range(npts):
|
||||
label = labels[i]
|
||||
row = offsets[label] + counters[label]
|
||||
counters[label] += 1
|
||||
permutation[row] = i
|
||||
|
||||
return offsets, permutation
|
|
@ -1,6 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from ._wrapper import (DiskIndex, VectorDType,
|
||||
build_disk_index_from_vector_file,
|
||||
build_disk_index_from_vectors, numpy_to_diskann_file)
|
||||
from ._builder import (
|
||||
build_disk_index,
|
||||
build_memory_index,
|
||||
numpy_to_diskann_file,
|
||||
)
|
||||
from ._common import VectorDType
|
||||
from ._disk_index import DiskIndex
|
||||
from ._diskannpy import INNER_PRODUCT, L2, Metric, defaults
|
||||
from ._dynamic_memory_index import DynamicMemoryIndex
|
||||
from ._static_memory_index import StaticMemoryIndex
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
from ._common import (
|
||||
VectorDType,
|
||||
_assert,
|
||||
_assert_2d,
|
||||
_assert_dtype,
|
||||
_assert_existing_file,
|
||||
_assert_is_nonnegative_uint32,
|
||||
_assert_is_positive_uint32,
|
||||
_get_valid_metric,
|
||||
)
|
||||
from ._diskannpy import defaults
|
||||
|
||||
|
||||
def numpy_to_diskann_file(vectors: np.ndarray, file_handler: BinaryIO):
|
||||
"""
|
||||
Utility function that writes a DiskANN binary vector formatted file to the location of your choosing.
|
||||
|
||||
:param vectors: A 2d array of dtype ``numpy.single``, ``numpy.ubyte``, or ``numpy.byte``
|
||||
:type vectors: numpy.ndarray, dtype in set {numpy.single, numpy.ubyte, numpy.byte}
|
||||
:param file_handler: An open binary file handler (typing.BinaryIO).
|
||||
:type file_handler: io.BinaryIO
|
||||
:raises ValueError: If vectors are the wrong shape or an unsupported dtype
|
||||
:raises ValueError: If output_path is not a str or ``io.BinaryIO``
|
||||
"""
|
||||
_assert_2d(vectors, "vectors")
|
||||
_assert_dtype(vectors.dtype, "vectors.dtype")
|
||||
|
||||
_ = file_handler.write(np.array(vectors.shape, dtype=np.intc).tobytes())
|
||||
_ = file_handler.write(vectors.tobytes())
|
||||
|
||||
|
||||
def _valid_path_and_dtype(
|
||||
data: Union[str, np.ndarray], vector_dtype: Optional[VectorDType], index_path: str
|
||||
) -> Tuple[str, VectorDType]:
|
||||
if isinstance(data, np.ndarray):
|
||||
_assert_2d(data, "data")
|
||||
_assert_dtype(data.dtype, "data.dtype")
|
||||
|
||||
vector_bin_path = os.path.join(index_path, "vectors.bin")
|
||||
if Path(vector_bin_path).exists():
|
||||
raise ValueError(
|
||||
f"The path {vector_bin_path} already exists. Remove it and try again."
|
||||
)
|
||||
with open(vector_bin_path, "wb") as temp_vector_bin:
|
||||
numpy_to_diskann_file(data, temp_vector_bin)
|
||||
vector_dtype_actual = data.dtype
|
||||
else:
|
||||
vector_bin_path = data
|
||||
_assert(
|
||||
Path(data).exists() and Path(data).is_file(),
|
||||
"if data is of type `str`, it must both exist and be a file",
|
||||
)
|
||||
vector_dtype_actual = vector_dtype
|
||||
return vector_bin_path, vector_dtype_actual
|
||||
|
||||
|
||||
def build_disk_index(
|
||||
data: Union[str, np.ndarray],
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
search_memory_maximum: float,
|
||||
build_memory_maximum: float,
|
||||
num_threads: int,
|
||||
pq_disk_bytes: int = defaults.PQ_DISK_BYTES,
|
||||
vector_dtype: Optional[VectorDType] = None,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
This function will construct a DiskANN Disk Index and save it to disk.
|
||||
|
||||
If you provide a numpy array, it will save this array to disk in a temp location
|
||||
in the format DiskANN's PQ Flash Index builder requires. This temp folder is deleted upon index creation completion
|
||||
or error.
|
||||
|
||||
:param data: Either a ``str`` representing a path to a DiskANN vector bin file, or a numpy.ndarray,
|
||||
of a supported dtype, in 2 dimensions. Note that vector_dtype must be provided if vector_path_or_np_array is a
|
||||
``str``
|
||||
:type data: Union[str, numpy.ndarray]
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param index_directory: The path on disk that the index will be created in.
|
||||
:type index_directory: str
|
||||
:param complexity: The size of queue to use when building the index for search. Values between 75 and 200 are
|
||||
typical. Larger values will take more time to build but result in indices that provide higher recall for
|
||||
the same search complexity. Use a value that is at least as large as R unless you are prepared to
|
||||
somewhat compromise on quality
|
||||
:type complexity: int
|
||||
:param graph_degree: The degree of the graph index, typically between 60 and 150. A larger maximum degree will
|
||||
result in larger indices and longer indexing times, but better search quality.
|
||||
:type graph_degree int
|
||||
:param search_memory_maximum: Build index with the expectation that the search will use at most
|
||||
``search_memory_maximum``
|
||||
:type search_memory_maximum: float
|
||||
:param build_memory_maximum: Build index using at most ``build_memory_maximum``
|
||||
:type build_memory_maximum: float
|
||||
:param num_threads: Number of threads to use when creating this index.0 indicates we should use all available
|
||||
system threads.
|
||||
:type num_threads: int
|
||||
:param pq_disk_bytes: Use 0 to store uncompressed data on SSD. This allows the index to asymptote to 100%
|
||||
recall. If your vectors are too large to store in SSD, this parameter provides the option to compress the
|
||||
vectors using PQ for storing on SSD. This will trade off recall. You would also want this to be greater
|
||||
than the number of bytes used for the PQ compressed data stored in-memory. Default is ``0``.
|
||||
:type pq_disk_bytes: int (default = 0)
|
||||
:param vector_dtype: Required if the provided ``vector_path_or_np_array`` is of type ``str``, else we use the
|
||||
``vector_path_or_np_array.dtype`` if np array.
|
||||
:type vector_dtype: Optional[VectorDType], default is ``None``.
|
||||
:param index_prefix: The prefix to give your index files. Defaults to ``ann``.
|
||||
:type index_prefix: str, default="ann"
|
||||
:raises ValueError: If vectors are not 2d numpy array or are not a supported dtype
|
||||
:raises ValueError: If any numeric value is in an invalid range
|
||||
"""
|
||||
|
||||
_assert(
|
||||
(isinstance(data, str) and vector_dtype is not None)
|
||||
or isinstance(data, np.ndarray),
|
||||
"vector_dtype is required if data is a str representing a path to the vector bin file",
|
||||
)
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_positive_uint32(graph_degree, "graph_degree")
|
||||
_assert(search_memory_maximum > 0, "search_memory_maximum must be larger than 0")
|
||||
_assert(build_memory_maximum > 0, "build_memory_maximum must be larger than 0")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_nonnegative_uint32(pq_disk_bytes, "pq_disk_bytes")
|
||||
_assert(index_prefix != "", "index_prefix cannot be an empty string")
|
||||
|
||||
index_path = Path(index_directory)
|
||||
_assert(
|
||||
index_path.exists() and index_path.is_dir(),
|
||||
"index_directory must both exist and be a directory",
|
||||
)
|
||||
|
||||
vector_bin_path, vector_dtype_actual = _valid_path_and_dtype(
|
||||
data, vector_dtype, index_prefix
|
||||
)
|
||||
|
||||
if vector_dtype_actual == np.single:
|
||||
_builder = _native_dap.build_disk_float_index
|
||||
elif vector_dtype_actual == np.ubyte:
|
||||
_builder = _native_dap.build_disk_uint8_index
|
||||
else:
|
||||
_builder = _native_dap.build_disk_int8_index
|
||||
|
||||
_builder(
|
||||
metric=dap_metric,
|
||||
data_file_path=vector_bin_path,
|
||||
index_prefix_path=os.path.join(index_directory, index_prefix),
|
||||
complexity=complexity,
|
||||
graph_degree=graph_degree,
|
||||
final_index_ram_limit=search_memory_maximum,
|
||||
indexing_ram_budget=build_memory_maximum,
|
||||
num_threads=num_threads,
|
||||
pq_disk_bytes=pq_disk_bytes,
|
||||
)
|
||||
|
||||
|
||||
def build_memory_index(
|
||||
data: Union[str, np.ndarray],
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
num_threads: int,
|
||||
alpha: float = defaults.ALPHA,
|
||||
use_pq_build: bool = defaults.USE_PQ_BUILD,
|
||||
num_pq_bytes: int = defaults.NUM_PQ_BYTES,
|
||||
use_opq: bool = defaults.USE_OPQ,
|
||||
vector_dtype: Optional[VectorDType] = None,
|
||||
label_file: str = "",
|
||||
universal_label: str = "",
|
||||
filter_complexity: int = defaults.FILTER_COMPLEXITY,
|
||||
index_prefix: str = "ann"
|
||||
):
|
||||
"""
|
||||
Builds a memory index and saves it to disk to be loaded into ``StaticMemoryIndex``.
|
||||
|
||||
:param data: Either a ``str`` representing a path to a DiskANN vector bin file, or a numpy.ndarray,
|
||||
of a supported dtype, in 2 dimensions. Note that vector_dtype must be provided if vector_path_or_np_array is a
|
||||
``str``
|
||||
:type data: Union[str, numpy.ndarray]
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param index_directory: The path on disk that the index will be created in.
|
||||
:type index_directory: str
|
||||
:param complexity: The size of queue to use when building the index for search. Values between 75 and 200 are
|
||||
typical. Larger values will take more time to build but result in indices that provide higher recall for
|
||||
the same search complexity. Use a value that is at least as large as R unless you are prepared to
|
||||
somewhat compromise on quality
|
||||
:type complexity: int
|
||||
:param graph_degree: The degree of the graph index, typically between 60 and 150. A larger maximum degree will
|
||||
result in larger indices and longer indexing times, but better search quality.
|
||||
:type graph_degree int
|
||||
:param num_threads: Number of threads to use when creating this index. 0 indicates we should use all available
|
||||
system threads.
|
||||
:type num_threads: int
|
||||
:param alpha:
|
||||
:param use_pq_build:
|
||||
:param num_pq_bytes:
|
||||
:param use_opq:
|
||||
:param vector_dtype: Required if the provided ``vector_path_or_np_array`` is of type ``str``, else we use the
|
||||
``vector_path_or_np_array.dtype`` if np array.
|
||||
:type vector_dtype: Optional[VectorDType], default is ``None``.
|
||||
:param label_file: Defaults to ""
|
||||
:type label_file: str
|
||||
:param universal_label: Defaults to ""
|
||||
:param filter_complexity: Complexity to use when using filters. Default is 0.
|
||||
:type filter_complexity: int
|
||||
:param index_prefix: The prefix to give your index files. Defaults to ``ann``.
|
||||
:type index_prefix: str, default="ann"
|
||||
:return:
|
||||
"""
|
||||
_assert(
|
||||
(isinstance(data, str) and vector_dtype is not None)
|
||||
or isinstance(data, np.ndarray),
|
||||
"vector_dtype is required if data is a str representing a path to the vector bin file",
|
||||
)
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_positive_uint32(graph_degree, "graph_degree")
|
||||
_assert(alpha >= 1, "alpha must be >= 1, and realistically should be kept between [1.0, 2.0)")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes")
|
||||
_assert_is_nonnegative_uint32(filter_complexity, "filter_complexity")
|
||||
_assert(index_prefix != "", "index_prefix cannot be an empty string")
|
||||
|
||||
index_path = Path(index_directory)
|
||||
_assert(
|
||||
index_path.exists() and index_path.is_dir(),
|
||||
"index_directory must both exist and be a directory",
|
||||
)
|
||||
|
||||
vector_bin_path, vector_dtype_actual = _valid_path_and_dtype(
|
||||
data, vector_dtype, index_directory
|
||||
)
|
||||
|
||||
if vector_dtype_actual == np.single:
|
||||
_builder = _native_dap.build_in_memory_float_index
|
||||
elif vector_dtype_actual == np.ubyte:
|
||||
_builder = _native_dap.build_in_memory_uint8_index
|
||||
else:
|
||||
_builder = _native_dap.build_in_memory_int8_index
|
||||
|
||||
_builder(
|
||||
metric=dap_metric,
|
||||
data_file_path=vector_bin_path,
|
||||
index_output_path=os.path.join(index_directory, index_prefix),
|
||||
complexity=complexity,
|
||||
graph_degree=graph_degree,
|
||||
alpha=alpha,
|
||||
num_threads=num_threads,
|
||||
use_pq_build=use_pq_build,
|
||||
num_pq_bytes=num_pq_bytes,
|
||||
use_opq=use_opq,
|
||||
label_file=label_file,
|
||||
universal_label=universal_label,
|
||||
filter_complexity=filter_complexity
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import BinaryIO, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._common import VectorDType
|
||||
|
||||
def numpy_to_diskann_file(vectors: np.ndarray, file_handler: BinaryIO): ...
|
||||
@overload
|
||||
def build_disk_index(
|
||||
data: str,
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
search_memory_maximum: float,
|
||||
build_memory_maximum: float,
|
||||
num_threads: int,
|
||||
pq_disk_bytes: int,
|
||||
vector_dtype: VectorDType,
|
||||
index_prefix: str,
|
||||
): ...
|
||||
@overload
|
||||
def build_disk_index(
|
||||
data: np.ndarray,
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
search_memory_maximum: float,
|
||||
build_memory_maximum: float,
|
||||
num_threads: int,
|
||||
pq_disk_bytes: int,
|
||||
index_prefix: str,
|
||||
): ...
|
||||
@overload
|
||||
def build_memory_index(
|
||||
data: np.ndarray,
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
alpha: float,
|
||||
num_threads: int,
|
||||
use_pq_build: bool,
|
||||
num_pq_bytes: int,
|
||||
use_opq: bool,
|
||||
label_file: str,
|
||||
universal_label: str,
|
||||
filter_complexity: int,
|
||||
index_prefix: str,
|
||||
): ...
|
||||
@overload
|
||||
def build_memory_index(
|
||||
data: str,
|
||||
metric: Literal["l2", "mips"],
|
||||
index_directory: str,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
alpha: float,
|
||||
num_threads: int,
|
||||
use_pq_build: bool,
|
||||
num_pq_bytes: int,
|
||||
use_opq: bool,
|
||||
vector_dtype: VectorDType,
|
||||
label_file: str,
|
||||
universal_label: str,
|
||||
filter_complexity: int,
|
||||
index_prefix: str,
|
||||
): ...
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
|
||||
__ALL__ = ["VectorDType"]
|
||||
|
||||
_VALID_DTYPES = [np.single, np.float32, np.byte, np.int8, np.ubyte, np.uint8]
|
||||
|
||||
VectorDType = TypeVar(
|
||||
"VectorDType",
|
||||
Type[np.single],
|
||||
Type[np.float32],
|
||||
Type[np.ubyte],
|
||||
Type[np.uint8],
|
||||
Type[np.byte],
|
||||
Type[np.int8],
|
||||
)
|
||||
|
||||
|
||||
def _assert(statement_eval: bool, message: str):
|
||||
if not statement_eval:
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
def _get_valid_metric(metric: str) -> _native_dap.Metric:
|
||||
if not isinstance(metric, str):
|
||||
raise ValueError("metric must be a string")
|
||||
if metric.lower() == "l2":
|
||||
return _native_dap.L2
|
||||
elif metric.lower() == "mips":
|
||||
return _native_dap.INNER_PRODUCT
|
||||
else:
|
||||
raise ValueError("metric must be one of 'l2' or 'mips'")
|
||||
|
||||
|
||||
def _assert_dtype(vectors: np.dtype, name: str):
|
||||
_assert(
|
||||
vectors in _VALID_DTYPES,
|
||||
name
|
||||
+ " must be of one of type {(np.single, np.float32), (np.byte, np.int8), (np.ubyte, np.uint8)}",
|
||||
)
|
||||
|
||||
|
||||
def _assert_2d(vectors: np.ndarray, name: str):
|
||||
_assert(len(vectors.shape) == 2, f"{name} must be 2d numpy array")
|
||||
|
||||
|
||||
__MAX_UINT_VAL = 4_294_967_295
|
||||
|
||||
|
||||
def _assert_is_positive_uint32(test_value: int, parameter: str):
|
||||
_assert(
|
||||
0 < test_value < __MAX_UINT_VAL,
|
||||
f"{parameter} must be a positive integer in the uint32 range",
|
||||
)
|
||||
|
||||
|
||||
def _assert_is_nonnegative_uint32(test_value: int, parameter: str):
|
||||
_assert(
|
||||
-1 < test_value < __MAX_UINT_VAL,
|
||||
f"{parameter} must be a non-negative integer in the uint32 range",
|
||||
)
|
||||
|
||||
|
||||
def _assert_existing_directory(path: str, parameter: str):
|
||||
_path = Path(path)
|
||||
_assert(
|
||||
_path.exists() and _path.is_dir(), f"{parameter} must be an existing directory"
|
||||
)
|
||||
|
||||
|
||||
def _assert_existing_file(path: str, parameter: str):
|
||||
_path = Path(path)
|
||||
_assert(_path.exists() and _path.is_file(), f"{parameter} must be an existing file")
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
from ._common import (
|
||||
VectorDType,
|
||||
_assert,
|
||||
_assert_2d,
|
||||
_assert_dtype,
|
||||
_assert_is_nonnegative_uint32,
|
||||
_assert_is_positive_uint32,
|
||||
_get_valid_metric,
|
||||
)
|
||||
|
||||
__ALL__ = ["DiskIndex"]
|
||||
|
||||
|
||||
class DiskIndex:
|
||||
def __init__(
|
||||
self,
|
||||
metric: Literal["l2", "mips"],
|
||||
vector_dtype: VectorDType,
|
||||
index_directory: str,
|
||||
num_threads: int,
|
||||
num_nodes_to_cache: int,
|
||||
cache_mechanism: int = 1,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
The diskannpy.DiskIndex represents our python API into the DiskANN Product Quantization Flash Index library.
|
||||
|
||||
This class is responsible for searching a DiskANN disk index.
|
||||
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param vector_dtype: The vector dtype this index will be exposing.
|
||||
:type vector_dtype: Type[numpy.single], Type[numpy.byte], Type[numpy.ubyte]
|
||||
:param index_directory: Path on disk where the disk index is stored
|
||||
:type index_directory: str
|
||||
:param num_threads: Number of threads used to load the index (>= 0)
|
||||
:type num_threads: int
|
||||
:param num_nodes_to_cache: Number of nodes to cache in memory (> -1)
|
||||
:type num_nodes_to_cache: int
|
||||
:param cache_mechanism: 1 -> use the generated sample_data.bin file for
|
||||
the index to initialize a set of cached nodes, up to ``num_nodes_to_cache``, 2 -> ready the cache for up to
|
||||
``num_nodes_to_cache``, but do not initialize it with any nodes. Any other value disables node caching.
|
||||
:param index_prefix: A shared prefix that all files in this index will use. Default is "ann".
|
||||
:type index_prefix: str
|
||||
:raises ValueError: If metric is not a valid metric
|
||||
:raises ValueError: If vector dtype is not a supported dtype
|
||||
:raises ValueError: If num_threads or num_nodes_to_cache is an invalid range.
|
||||
"""
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
_assert_dtype(vector_dtype, "vector_dtype")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_nonnegative_uint32(num_nodes_to_cache, "num_nodes_to_cache")
|
||||
index_path = Path(index_directory)
|
||||
_assert(
|
||||
index_path.exists() and index_path.is_dir(),
|
||||
"index_directory must both exist and be a directory",
|
||||
)
|
||||
|
||||
self._vector_dtype = vector_dtype
|
||||
if vector_dtype == np.single:
|
||||
_index = _native_dap.DiskFloatIndex
|
||||
elif vector_dtype == np.ubyte:
|
||||
_index = _native_dap.DiskUInt8Index
|
||||
else:
|
||||
_index = _native_dap.DiskInt8Index
|
||||
self._index = _index(
|
||||
metric=dap_metric,
|
||||
index_path_prefix=os.path.join(index_directory, index_prefix),
|
||||
num_threads=num_threads,
|
||||
num_nodes_to_cache=num_nodes_to_cache,
|
||||
cache_mechanism=cache_mechanism,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: np.ndarray, k_neighbors: int, complexity: int, beam_width: int = 2
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index by a single query vector in a 1d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
:param query: 1d numpy array of the same dimensionality and dtype of the index.
|
||||
:type query: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:param beam_width: The beamwidth to be used for search. This is the maximum number of IO requests each query
|
||||
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
|
||||
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
|
||||
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
|
||||
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
|
||||
involve some tuning overhead.
|
||||
:type beam_width: int
|
||||
:return: Returns a tuple of 1-d numpy ndarrays; the first including the indices of the approximate nearest
|
||||
neighbors, the second their distances. These are aligned arrays.
|
||||
"""
|
||||
_assert(len(query.shape) == 1, "query vector must be 1-d")
|
||||
_assert_dtype(query.dtype, "query.dtype")
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_positive_uint32(beam_width, "beam_width")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
|
||||
return self._index.search(
|
||||
query=query,
|
||||
knn=k_neighbors,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
)
|
||||
|
||||
def batch_search(
|
||||
self,
|
||||
queries: np.ndarray,
|
||||
k_neighbors: int,
|
||||
complexity: int,
|
||||
num_threads: int,
|
||||
beam_width: int = 2,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index for many query vectors in a 2d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
This search is parallelized and far more efficient than searching for each vector individually.
|
||||
|
||||
:param queries: 2d numpy array, with column dimensionality matching the index and row dimensionality being the
|
||||
number of queries intended to search for in parallel. Dtype must match dtype of the index.
|
||||
:type queries: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:param num_threads: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
|
||||
:type num_threads: int
|
||||
:param beam_width: The beamwidth to be used for search. This is the maximum number of IO requests each query
|
||||
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
|
||||
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
|
||||
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
|
||||
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
|
||||
involve some tuning overhead.
|
||||
:type beam_width: int
|
||||
:return: Returns a tuple of 2-d numpy ndarrays; each row corresponds to the query vector in the same index,
|
||||
and elements in row corresponding from 1..k_neighbors approximate nearest neighbors. The second ndarray
|
||||
contains the distances, of the same form: row index will match query index, column index refers to
|
||||
1..k_neighbors distance. These are aligned arrays.
|
||||
"""
|
||||
_assert_2d(queries, "queries")
|
||||
_assert(
|
||||
queries.dtype == self._vector_dtype,
|
||||
f"DiskIndex was built expecting a dtype of {self._vector_dtype}, but the query vectors are of dtype "
|
||||
f"{queries.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_positive_uint32(beam_width, "beam_width")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
|
||||
num_queries, dim = queries.shape
|
||||
return self._index.batch_search(
|
||||
queries=queries,
|
||||
num_queries=num_queries,
|
||||
knn=k_neighbors,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
num_threads=num_threads,
|
||||
)
|
|
@ -0,0 +1,295 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
from ._common import (
|
||||
VectorDType,
|
||||
_assert,
|
||||
_assert_2d,
|
||||
_assert_dtype,
|
||||
_assert_existing_directory,
|
||||
_assert_is_nonnegative_uint32,
|
||||
_assert_is_positive_uint32,
|
||||
_get_valid_metric,
|
||||
)
|
||||
from ._diskannpy import defaults
|
||||
|
||||
__ALL__ = ["DynamicMemoryIndex"]
|
||||
|
||||
|
||||
class DynamicMemoryIndex:
|
||||
def __init__(
|
||||
self,
|
||||
metric: Literal["l2", "mips"],
|
||||
vector_dtype: VectorDType,
|
||||
dim: int,
|
||||
max_points: int,
|
||||
complexity: int,
|
||||
graph_degree: int,
|
||||
saturate_graph: bool = defaults.SATURATE_GRAPH,
|
||||
max_occlusion_size: int = defaults.MAX_OCCLUSION_SIZE,
|
||||
alpha: float = defaults.ALPHA,
|
||||
num_threads: int = defaults.NUM_THREADS,
|
||||
filter_complexity: int = defaults.FILTER_COMPLEXITY,
|
||||
num_frozen_points: int = defaults.NUM_FROZEN_POINTS_DYNAMIC,
|
||||
initial_search_complexity: int = 0,
|
||||
search_threads: int = 0,
|
||||
concurrent_consolidation: bool = True,
|
||||
):
|
||||
"""
|
||||
The diskannpy.DynamicMemoryIndex represents our python API into a dynamic DiskANN InMemory Index library.
|
||||
|
||||
This dynamic index is unlike the DiskIndex and StaticMemoryIndex, in that after loading it you can continue
|
||||
to insert and delete vectors.
|
||||
|
||||
Deletions are completed lazily, until the user executes `DynamicMemoryIndex.consolidate_deletes()`
|
||||
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param vector_dtype: The vector dtype this index will be exposing.
|
||||
:type vector_dtype: Type[numpy.single], Type[numpy.byte], Type[numpy.ubyte]
|
||||
:param dim: The vector dimensionality of this index. All new vectors inserted must be the same dimensionality.
|
||||
:type dim: int
|
||||
:param max_points: Capacity of the data store for future insertions
|
||||
:type max_points: int
|
||||
:param graph_degree: The degree of the graph index, typically between 60 and 150. A larger maximum degree will
|
||||
result in larger indices and longer indexing times, but better search quality.
|
||||
:type graph_degree: int
|
||||
:param saturate_graph:
|
||||
:type saturate_graph: bool
|
||||
:param max_occlusion_size:
|
||||
:type max_occlusion_size: int
|
||||
:param alpha:
|
||||
:type alpha: float
|
||||
:param num_threads:
|
||||
:type num_threads: int
|
||||
:param filter_complexity:
|
||||
:type filter_complexity: int
|
||||
:param num_frozen_points:
|
||||
:type num_frozen_points: int
|
||||
:param initial_search_complexity: The working scratch memory allocated is predicated off of
|
||||
initial_search_complexity * search_threads. If a larger list_size * num_threads value is
|
||||
ultimately provided by the individual action executed in `batch_query` than provided in this constructor,
|
||||
the scratch space is extended. If a smaller list_size * num_threads is provided by the action than the
|
||||
constructor, the pre-allocated scratch space is used as-is.
|
||||
:type initial_search_complexity: int
|
||||
:param search_threads: Should be set to the most common batch_query num_threads size. The working
|
||||
scratch memory allocated is predicated off of initial_search_list_size * initial_search_threads. If a
|
||||
larger list_size * num_threads value is ultimately provided by the individual action executed in
|
||||
`batch_query` than provided in this constructor, the scratch space is extended. If a smaller
|
||||
list_size * num_threads is provided by the action than the constructor, the pre-allocated scratch space
|
||||
is used as-is.
|
||||
:type search_threads: int
|
||||
:param concurrent_consolidation:
|
||||
:type concurrent_consolidation: bool
|
||||
"""
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
_assert_dtype(vector_dtype, "vector_dtype")
|
||||
self._vector_dtype = vector_dtype
|
||||
|
||||
_assert_is_positive_uint32(dim, "dim")
|
||||
_assert_is_positive_uint32(max_points, "max_points")
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_positive_uint32(graph_degree, "graph_degree")
|
||||
_assert(alpha >= 1, "alpha must be >= 1, and realistically should be kept between [1.0, 2.0)")
|
||||
_assert_is_nonnegative_uint32(max_occlusion_size, "max_occlusion_size")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_nonnegative_uint32(filter_complexity, "filter_complexity")
|
||||
_assert_is_nonnegative_uint32(num_frozen_points, "num_frozen_points")
|
||||
_assert_is_nonnegative_uint32(
|
||||
initial_search_complexity, "initial_search_complexity"
|
||||
)
|
||||
_assert_is_nonnegative_uint32(search_threads, "search_threads")
|
||||
|
||||
self._index_path = ""
|
||||
|
||||
self._dims = dim
|
||||
|
||||
if vector_dtype == np.single:
|
||||
_index = _native_dap.DynamicMemoryFloatIndex
|
||||
elif vector_dtype == np.ubyte:
|
||||
_index = _native_dap.DynamicMemoryUInt8Index
|
||||
else:
|
||||
_index = _native_dap.DynamicMemoryInt8Index
|
||||
self._index = _index(
|
||||
metric=dap_metric,
|
||||
dim=dim,
|
||||
max_points=max_points,
|
||||
complexity=complexity,
|
||||
graph_degree=graph_degree,
|
||||
saturate_graph=saturate_graph,
|
||||
max_occlusion_size=max_occlusion_size,
|
||||
alpha=alpha,
|
||||
num_threads=num_threads,
|
||||
filter_complexity=filter_complexity,
|
||||
num_frozen_points=num_frozen_points,
|
||||
initial_search_complexity=initial_search_complexity,
|
||||
search_threads=search_threads,
|
||||
concurrent_consolidation=concurrent_consolidation,
|
||||
index_path=self._index_path
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: np.ndarray, k_neighbors: int, complexity: int
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index by a single query vector in a 1d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
:param query: 1d numpy array of the same dimensionality and dtype of the index.
|
||||
:type query: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:return: Returns a tuple of 1-d numpy ndarrays; the first including the indices of the approximate nearest
|
||||
neighbors, the second their distances. These are aligned arrays.
|
||||
"""
|
||||
_assert(len(query.shape) == 1, "query vector must be 1-d")
|
||||
_assert(
|
||||
query.dtype == self._vector_dtype,
|
||||
f"DynamicMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the query vector is of dtype "
|
||||
f"{query.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_nonnegative_uint32(complexity, "complexity")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
return self._index.search(query=query, knn=k_neighbors, complexity=complexity)
|
||||
|
||||
def batch_search(
|
||||
self, queries: np.ndarray, k_neighbors: int, complexity: int, num_threads: int
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index for many query vectors in a 2d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
This search is parallelized and far more efficient than searching for each vector individually.
|
||||
|
||||
:param queries: 2d numpy array, with column dimensionality matching the index and row dimensionality being the
|
||||
number of queries intended to search for in parallel. Dtype must match dtype of the index.
|
||||
:type queries: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:param num_threads: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
|
||||
:type num_threads: int
|
||||
:return: Returns a tuple of 2-d numpy ndarrays; each row corresponds to the query vector in the same index,
|
||||
and elements in row corresponding from 1..k_neighbors approximate nearest neighbors. The second ndarray
|
||||
contains the distances, of the same form: row index will match query index, column index refers to
|
||||
1..k_neighbors distance. These are aligned arrays.
|
||||
"""
|
||||
_assert_2d(queries, "queries")
|
||||
_assert(
|
||||
queries.dtype == self._vector_dtype,
|
||||
f"StaticMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the query vectors are of dtype "
|
||||
f"{queries.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
|
||||
num_queries, dim = queries.shape
|
||||
return self._index.batch_search(
|
||||
queries=queries,
|
||||
num_queries=num_queries,
|
||||
knn=k_neighbors,
|
||||
complexity=complexity,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def save(self, save_path: str, compact_before_save: bool = False):
|
||||
"""
|
||||
Saves this index to file.
|
||||
:param save_path: The path to save these index files to.
|
||||
:type save_path: str
|
||||
:param compact_before_save:
|
||||
"""
|
||||
if save_path == "" and self._index_path == "":
|
||||
raise ValueError(
|
||||
"save_path cannot be empty if index_path is not set to a valid path in the constructor"
|
||||
)
|
||||
self._index.save(save_path=save_path, compact_before_save=compact_before_save)
|
||||
|
||||
def insert(self, vector: np.ndarray, vector_id: int):
|
||||
"""
|
||||
Inserts a single vector into the index with the provided vector_id.
|
||||
:param vector: The vector to insert. Note that dtype must match.
|
||||
:type vector: np.ndarray
|
||||
:param vector_id: The vector_id to use for this vector.
|
||||
"""
|
||||
_assert(len(vector.shape) == 1, "insert vector must be 1-d")
|
||||
_assert(
|
||||
vector.dtype == self._vector_dtype,
|
||||
f"DynamicMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the insert vector is of dtype "
|
||||
f"{vector.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(vector_id, "vector_id")
|
||||
return self._index.insert(vector, vector_id)
|
||||
|
||||
def batch_insert(
|
||||
self, vectors: np.ndarray, vector_ids: np.ndarray, num_threads: int = 0
|
||||
):
|
||||
"""
|
||||
:param vectors: The 2d numpy array of vectors to insert.
|
||||
:type vectors: np.ndarray
|
||||
:param vector_ids: The 1d array of vector ids to use. This array must have the same number of elements as
|
||||
the vectors array has rows. The dtype of vector_ids must be ``np.uintc`` (or any alias that is your
|
||||
platform's equivalent)
|
||||
:param num_threads: Number of threads to use when inserting into this index. (>= 0), 0 = num_threads in system
|
||||
:type num_threads: int
|
||||
"""
|
||||
_assert(len(vectors.shape) == 2, "vectors must be a 2-d array")
|
||||
_assert(
|
||||
vectors.dtype == self._vector_dtype,
|
||||
f"DynamicMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the insert vector is of dtype "
|
||||
f"{vectors.dtype}",
|
||||
)
|
||||
_assert(
|
||||
vectors.shape[0] == vector_ids.shape[0], "#vectors must be equal to #ids"
|
||||
)
|
||||
_assert(vector_ids.dtype == np.uintc, "vector_ids must have a dtype of np.uintc (32 bit, unsigned integer)")
|
||||
return self._index.batch_insert(
|
||||
vectors, vector_ids, vector_ids.shape[0], num_threads
|
||||
)
|
||||
|
||||
def mark_deleted(self, vector_id: int):
|
||||
"""
|
||||
Mark vector for deletion. This is a soft delete that won't return the vector id in any results, but does not
|
||||
remove it from the underlying index files or memory structure. To execute a hard delete, call this method and
|
||||
then call the much more expensive ``consolidate_delete`` method on this index.
|
||||
:param vector_id: The vector id to delete. Must be a uint32.
|
||||
:type vector_id: int
|
||||
"""
|
||||
_assert_is_positive_uint32(vector_id, "vector_id")
|
||||
self._index.mark_deleted(vector_id)
|
||||
|
||||
def consolidate_delete(self):
|
||||
"""
|
||||
This method actually restructures the DiskANN index to remove the items that have been marked for deletion.
|
||||
"""
|
||||
self._index.consolidate_delete()
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
from ._common import (
|
||||
_VALID_DTYPES,
|
||||
VectorDType,
|
||||
_assert,
|
||||
_assert_is_nonnegative_uint32,
|
||||
_assert_is_positive_uint32,
|
||||
_get_valid_metric,
|
||||
_assert_existing_directory,
|
||||
_assert_existing_file,
|
||||
)
|
||||
|
||||
__ALL__ = ["StaticMemoryIndex"]
|
||||
|
||||
|
||||
class StaticMemoryIndex:
|
||||
def __init__(
|
||||
self,
|
||||
metric: Literal["l2", "mips"],
|
||||
vector_dtype: VectorDType,
|
||||
data_path: str,
|
||||
index_directory: str,
|
||||
num_threads: int,
|
||||
initial_search_complexity: int,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
The diskannpy.StaticMemoryIndex represents our python API into a static DiskANN InMemory Index library.
|
||||
|
||||
This static index is treated exactly like the DiskIndex, in that it can only be loaded and searched.
|
||||
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param vector_dtype: The vector dtype this index will be exposing.
|
||||
:type vector_dtype: Type[numpy.single], Type[numpy.byte], Type[numpy.ubyte]
|
||||
:param data_path: The path to the vector bin file that created this index. Note that if you use a numpy
|
||||
array to build the index, you will still need to save this array as well via the
|
||||
``diskannpy.numpy_to_diskann_file`` and provide the path to it here.
|
||||
:type data_path: str
|
||||
:param index_directory: The directory the index files reside in
|
||||
:type index_directory: str
|
||||
:param initial_search_complexity: A positive integer that tunes how much work should be completed in the
|
||||
conduct of a search. This can be overridden on a per search basis, but this initial value allows us
|
||||
to pre-allocate a search scratch space. It is suggested that you set this value to the P95 of your
|
||||
search complexity values.
|
||||
:type initial_search_complexity: int
|
||||
:param index_prefix: A shared prefix that all files in this index will use. Default is "ann".
|
||||
:type index_prefix: str
|
||||
"""
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
_assert(
|
||||
vector_dtype in _VALID_DTYPES,
|
||||
f"vector_dtype {vector_dtype} is not in list of valid dtypes supported: {_VALID_DTYPES}",
|
||||
)
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
_assert_is_positive_uint32(
|
||||
initial_search_complexity, "initial_search_complexity"
|
||||
)
|
||||
_assert_existing_file(data_path, "data_path")
|
||||
_assert_existing_directory(index_directory, "index_directory")
|
||||
|
||||
_assert(index_prefix != "", "index_prefix cannot be an empty string")
|
||||
|
||||
self._vector_dtype = vector_dtype
|
||||
if vector_dtype == np.single:
|
||||
_index = _native_dap.StaticMemoryFloatIndex
|
||||
elif vector_dtype == np.ubyte:
|
||||
_index = _native_dap.StaticMemoryUInt8Index
|
||||
else:
|
||||
_index = _native_dap.StaticMemoryInt8Index
|
||||
self._index = _index(
|
||||
metric=dap_metric,
|
||||
data_path=data_path,
|
||||
index_path=os.path.join(index_directory, index_prefix),
|
||||
num_threads=num_threads,
|
||||
initial_search_complexity=initial_search_complexity,
|
||||
)
|
||||
|
||||
def search(self, query: np.ndarray, k_neighbors: int, complexity: int):
|
||||
"""
|
||||
Searches the static in memory index by a single query vector in a 1d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
:param query: 1d numpy array of the same dimensionality and dtype of the index.
|
||||
:type query: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:param beam_width: The beamwidth to be used for search. This is the maximum number of IO requests each query
|
||||
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
|
||||
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
|
||||
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
|
||||
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
|
||||
involve some tuning overhead.
|
||||
:type beam_width: int
|
||||
:return: Returns a tuple of 1-d numpy ndarrays; the first including the indices of the approximate nearest
|
||||
neighbors, the second their distances. These are aligned arrays.
|
||||
"""
|
||||
_assert(len(query.shape) == 1, "query vector must be 1-d")
|
||||
_assert(
|
||||
query.dtype == self._vector_dtype,
|
||||
f"StaticMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the query vector is of dtype "
|
||||
f"{query.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_nonnegative_uint32(complexity, "complexity")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
return self._index.search(query=query, knn=k_neighbors, complexity=complexity)
|
||||
|
||||
def batch_search(
|
||||
self, queries: np.ndarray, k_neighbors: int, complexity: int, num_threads: int
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the static, in memory index for many query vectors in a 2d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
This search is parallelized and far more efficient than searching for each vector individually.
|
||||
|
||||
:param queries: 2d numpy array, with column dimensionality matching the index and row dimensionality being the
|
||||
number of queries intended to search for in parallel. Dtype must match dtype of the index.
|
||||
:type queries: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param complexity: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type complexity: int
|
||||
:param num_threads: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
|
||||
:type num_threads: int
|
||||
:return: Returns a tuple of 2-d numpy ndarrays; each row corresponds to the query vector in the same index,
|
||||
and elements in row corresponding from 1..k_neighbors approximate nearest neighbors. The second ndarray
|
||||
contains the distances, of the same form: row index will match query index, column index refers to
|
||||
1..k_neighbors distance. These are aligned arrays.
|
||||
"""
|
||||
_assert(len(queries.shape) == 2, "queries must must be 2-d np array")
|
||||
_assert(
|
||||
queries.dtype == self._vector_dtype,
|
||||
f"StaticMemoryIndex was built expecting a dtype of {self._vector_dtype}, but the query vectors are of dtype "
|
||||
f"{queries.dtype}",
|
||||
)
|
||||
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
|
||||
_assert_is_positive_uint32(complexity, "complexity")
|
||||
_assert_is_nonnegative_uint32(num_threads, "num_threads")
|
||||
|
||||
if k_neighbors > complexity:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
|
||||
)
|
||||
complexity = k_neighbors
|
||||
|
||||
num_queries, dim = queries.shape
|
||||
return self._index.batch_search(
|
||||
queries=queries,
|
||||
num_queries=num_queries,
|
||||
knn=k_neighbors,
|
||||
complexity=complexity,
|
||||
num_threads=num_threads,
|
||||
)
|
|
@ -1,428 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import BinaryIO, Literal, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _diskannpy as _native_dap
|
||||
|
||||
__ALL__ = [
|
||||
"build_disk_index_from_vector_file",
|
||||
"build_disk_index_from_vectors",
|
||||
"numpy_to_diskann_file",
|
||||
"VectorDType",
|
||||
"DiskIndex",
|
||||
]
|
||||
|
||||
|
||||
_VALID_DTYPES = [np.single, np.byte, np.ubyte]
|
||||
|
||||
_DTYPE_TO_NATIVE_INDEX = {
|
||||
np.single: _native_dap.DiskANNFloatIndex,
|
||||
np.ubyte: _native_dap.DiskANNUInt8Index,
|
||||
np.byte: _native_dap.DiskANNInt8Index,
|
||||
}
|
||||
|
||||
_DTYPE_TO_NATIVE_INMEM_DYNAMIC_INDEX = {
|
||||
np.single: _native_dap.DiskANNDynamicInMemFloatIndex,
|
||||
np.ubyte: _native_dap.DiskANNDynamicInMemUint8Index,
|
||||
np.byte: _native_dap.DiskANNDynamicInMemInt8Index,
|
||||
}
|
||||
|
||||
_DTYPE_TO_NATIVE_INMEM_STATIC_INDEX = {
|
||||
np.single: _native_dap.DiskANNStaticInMemFloatIndex,
|
||||
np.ubyte: _native_dap.DiskANNStaticInMemUint8Index,
|
||||
np.byte: _native_dap.DiskANNStaticInMemInt8Index,
|
||||
}
|
||||
|
||||
|
||||
VectorDType = TypeVar("VectorDType", Type[np.single], Type[np.ubyte], Type[np.byte])
|
||||
|
||||
|
||||
def _get_valid_metric(metric: str) -> _native_dap.Metric:
|
||||
if not isinstance(metric, str):
|
||||
raise ValueError("metric must be a string")
|
||||
if metric.lower() == "l2":
|
||||
return _native_dap.L2
|
||||
elif metric.lower() == "mips":
|
||||
return _native_dap.INNER_PRODUCT
|
||||
else:
|
||||
raise ValueError("metric must be one of 'l2' or 'mips'")
|
||||
|
||||
|
||||
def _validate_dtype(vectors: np.ndarray):
|
||||
if vectors.dtype not in _VALID_DTYPES:
|
||||
raise ValueError(
|
||||
f"vectors provided had dtype {vectors.dtype}, but must be single precision float "
|
||||
f"(numpy.single), unsigned 8bit integer (numpy.ubyte), or signed 8bit integer (numpy.byte)."
|
||||
)
|
||||
|
||||
|
||||
def _validate_shape(vectors: np.ndarray):
|
||||
if len(vectors.shape) != 2:
|
||||
raise ValueError("vectors must be 2d numpy array")
|
||||
|
||||
|
||||
def _numpy_to_diskann_file(
|
||||
vectors: np.ndarray,
|
||||
file_handler: BinaryIO,
|
||||
):
|
||||
_validate_shape(vectors)
|
||||
_validate_dtype(vectors)
|
||||
_ = file_handler.write(np.array(vectors.shape, dtype=np.int32).tobytes())
|
||||
_ = file_handler.write(vectors.tobytes())
|
||||
|
||||
|
||||
def numpy_to_diskann_file(vectors: np.ndarray, output_path: Union[str, BinaryIO]):
|
||||
"""
|
||||
Utility function that writes a DiskANN binary vector formatted file to the location of your choosing.
|
||||
|
||||
:param vectors: A 2d array of dtype ``numpy.single``, ``numpy.ubyte``, or ``numpy.byte``
|
||||
:type vectors: numpy.ndarray, dtype in set {numpy.single, numpy.ubyte, numpy.byte}
|
||||
:param output_path: Where to write the file. If a string is provided, a binary writer will be opened at that
|
||||
location. Otherwise it is presumed ``output_path`` is a BinaryIO file handler and will write to it.
|
||||
:type output_path: Union[str, io.BinaryIO]
|
||||
:raises ValueError: If vectors are the wrong shape or an unsupported dtype
|
||||
:raises ValueError: If output_path is not a str or ``io.BinaryIO``
|
||||
"""
|
||||
if isinstance(output_path, BinaryIO):
|
||||
_numpy_to_diskann_file(vectors, output_path)
|
||||
elif isinstance(output_path, str):
|
||||
with open(output_path, "wb") as binary_out:
|
||||
_numpy_to_diskann_file(vectors, binary_out)
|
||||
else:
|
||||
raise ValueError(
|
||||
"output_path must be either a str or an open binary file handler (e.g. `handler = open('my_file_path', 'wb')`)"
|
||||
)
|
||||
|
||||
|
||||
def build_disk_index_from_vector_file(
|
||||
vector_bin_file: str,
|
||||
metric: Literal["l2", "mips"],
|
||||
vector_dtype: VectorDType,
|
||||
index_path: str,
|
||||
max_degree: int,
|
||||
list_size: int,
|
||||
search_memory_maximum: float,
|
||||
build_memory_maximum: float,
|
||||
num_threads: int,
|
||||
pq_disk_bytes: int,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
Builds a DiskANN disk index based on a provided DiskANN formatted binary file path.
|
||||
|
||||
:param vector_bin_file: Must be a binary file formatted in the expected DiskANN file format.
|
||||
Use ``diskannpy.numpy_to_diskann_file`` to create it.
|
||||
:type vector_bin_file: str
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param vector_dtype: The vector dtype this index will be exposing.
|
||||
:type vector_dtype: Type[numpy.single], Type[numpy.byte], Type[numpy.ubyte]
|
||||
:param index_path: The path on disk that the index will be created in.
|
||||
:type index_path: str
|
||||
:param max_degree: The degree of the graph index, typically between 60 and 50. A larger maximum degree will
|
||||
result in larger indices and longer indexing times, but better search quality.
|
||||
:type max_degree: int
|
||||
:param list_size: The size of queue to use when building the index for search. Values between 75 and 200 are
|
||||
typical. Larger values will take more time to build but result in indices that provide higher recall for
|
||||
the same search complexity. Use a value that is at least as large as R unless you are prepared to
|
||||
somewhat compromise on quality
|
||||
:type list_size: int
|
||||
:param search_memory_maximum: Build index with the expectation that the search will use at most
|
||||
``search_memory_maximum``
|
||||
:type search_memory_maximum: float
|
||||
:param build_memory_maximum: Build index using at most ``build_memory_maximum``
|
||||
:type build_memory_maximum: float
|
||||
:param num_threads: Number of threads to use when creating this index.
|
||||
:type num_threads: int
|
||||
:param pq_disk_bytes: Use 0 to store uncompressed data on SSD. This allows the index to asymptote to 100%
|
||||
recall. If your vectors are too large to store in SSD, this parameter provides the option to compress the
|
||||
vectors using PQ for storing on SSD. This will trade off recall. You would also want this to be greater
|
||||
than the number of bytes used for the PQ compressed data stored in-memory. Default is ``0``.
|
||||
:type pq_disk_bytes: int (default = 0)
|
||||
:param index_prefix: The prefix to give your index files. Defaults to ``ann``.
|
||||
:type index_prefix: str, default="ann"
|
||||
:raises ValueError: If any numeric parameter is in an invalid range.
|
||||
"""
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
if vector_dtype not in _VALID_DTYPES:
|
||||
raise ValueError(
|
||||
f"vector_dtype {vector_dtype} is not in list of valid dtypes supported: {_VALID_DTYPES}"
|
||||
)
|
||||
if list_size <= 0:
|
||||
raise ValueError("list_size must be a positive integer")
|
||||
if max_degree <= 0:
|
||||
raise ValueError("max_degree must be a positive integer")
|
||||
if search_memory_maximum <= 0:
|
||||
raise ValueError("search_memory_maximum must be larger than 0")
|
||||
if build_memory_maximum <= 0:
|
||||
raise ValueError("build_memory_maximum must be larger than 0")
|
||||
if num_threads < 0:
|
||||
raise ValueError("num_threads must be a nonnegative integer")
|
||||
if pq_disk_bytes < 0:
|
||||
raise ValueError("pq_disk_bytes must be nonnegative integer")
|
||||
|
||||
index = _DTYPE_TO_NATIVE_INDEX[vector_dtype](dap_metric)
|
||||
index.build(
|
||||
data_file_path=vector_bin_file,
|
||||
index_prefix_path=os.path.join(index_path, index_prefix),
|
||||
R=max_degree,
|
||||
L=list_size,
|
||||
final_index_ram_limit=search_memory_maximum,
|
||||
indexing_ram_limit=build_memory_maximum,
|
||||
num_threads=num_threads,
|
||||
pq_disk_bytes=pq_disk_bytes,
|
||||
)
|
||||
|
||||
|
||||
def build_disk_index_from_vectors(
|
||||
vectors: np.ndarray,
|
||||
metric: Literal["l2", "mips"],
|
||||
index_path: str,
|
||||
max_degree: int,
|
||||
list_size: int,
|
||||
search_memory_maximum: float,
|
||||
build_memory_maximum: float,
|
||||
num_threads: int,
|
||||
pq_disk_bytes: int,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
This function is a convenience function for first converting the provided numpy 2-d array into the binary format
|
||||
expected by the DiskANN library, and then using that to generate the index as per
|
||||
``DiskIndex.build_from_vector_file()``. After completion, this temporary file is deleted.
|
||||
|
||||
:param vectors: A numpy.ndarray, of a supported dtype, in 2 dimensions
|
||||
:type vectors: numpy.ndarray
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param index_path: The path on disk that the index will be created in.
|
||||
:type index_path: str
|
||||
:param max_degree: The degree of the graph index, typically between 60 and 50. A larger maximum degree will
|
||||
result in larger indices and longer indexing times, but better search quality.
|
||||
:type max_degree: int
|
||||
:param list_size: The size of queue to use when building the index for search. Values between 75 and 200 are
|
||||
typical. Larger values will take more time to build but result in indices that provide higher recall for
|
||||
the same search complexity. Use a value that is at least as large as R unless you are prepared to
|
||||
somewhat compromise on quality
|
||||
:type list_size: int
|
||||
:param search_memory_maximum: Build index with the expectation that the search will use at most
|
||||
``search_memory_maximum``
|
||||
:type search_memory_maximum: float
|
||||
:param build_memory_maximum: Build index using at most ``build_memory_maximum``
|
||||
:type build_memory_maximum: float
|
||||
:param num_threads: Number of threads to use when creating this index.
|
||||
:type num_threads: int
|
||||
:param pq_disk_bytes: Use 0 to store uncompressed data on SSD. This allows the index to asymptote to 100%
|
||||
recall. If your vectors are too large to store in SSD, this parameter provides the option to compress the
|
||||
vectors using PQ for storing on SSD. This will trade off recall. You would also want this to be greater
|
||||
than the number of bytes used for the PQ compressed data stored in-memory. Default is ``0``.
|
||||
:type pq_disk_bytes: int (default = 0)
|
||||
:param index_prefix: The prefix to give your index files. Defaults to ``ann``.
|
||||
:type index_prefix: str, default="ann"
|
||||
:raises ValueError: If vectors are not 2d numpy array or are not a supported dtype
|
||||
:raises ValueError: If any numeric value is in an invalid range
|
||||
"""
|
||||
_validate_dtype(vectors)
|
||||
_validate_shape(vectors)
|
||||
|
||||
_temp_work_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
temp_vector_bin_path = os.path.join(_temp_work_dir, "vectors.bin")
|
||||
with open(os.path.join(_temp_work_dir, "vectors.bin"), "wb") as temp_vector_bin:
|
||||
numpy_to_diskann_file(vectors, temp_vector_bin)
|
||||
build_disk_index_from_vector_file(
|
||||
vector_bin_file=temp_vector_bin_path,
|
||||
metric=metric,
|
||||
vector_dtype=vectors.dtype,
|
||||
index_path=index_path,
|
||||
max_degree=max_degree,
|
||||
list_size=list_size,
|
||||
search_memory_maximum=search_memory_maximum,
|
||||
build_memory_maximum=build_memory_maximum,
|
||||
num_threads=num_threads,
|
||||
pq_disk_bytes=pq_disk_bytes,
|
||||
index_prefix=index_prefix,
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(_temp_work_dir)
|
||||
|
||||
|
||||
class DiskIndex:
|
||||
def __init__(
|
||||
self,
|
||||
metric: Literal["l2", "mips"],
|
||||
vector_dtype: VectorDType,
|
||||
index_path: str,
|
||||
num_threads: int,
|
||||
num_nodes_to_cache: int,
|
||||
index_prefix: str = "ann",
|
||||
):
|
||||
"""
|
||||
The diskannpy.DiskIndex represents our python API into the DiskANN Product Quantization Flash Index library.
|
||||
|
||||
This class is responsible for searching a DiskANN disk index.
|
||||
|
||||
:param metric: One of {"l2", "mips"}. L2 is supported for all 3 vector dtypes, but MIPS is only
|
||||
available for single point floating numbers (numpy.single)
|
||||
:type metric: str
|
||||
:param vector_dtype: The vector dtype this index will be exposing.
|
||||
:type vector_dtype: Type[numpy.single], Type[numpy.byte], Type[numpy.ubyte]
|
||||
:param index_path: Path on disk where the disk index is stored
|
||||
:type index_path: str
|
||||
:param num_threads: Number of threads used to load the index (>= 0)
|
||||
:type num_threads: int
|
||||
:param num_nodes_to_cache: Number of nodes to cache in memory (> -1)
|
||||
:type num_nodes_to_cache: int
|
||||
:param index_prefix: A shared prefix that all files in this index will use. Default is "ann".
|
||||
:type index_prefix: str
|
||||
:raises ValueError: If metric is not a valid metric
|
||||
:raises ValueError: If vector dtype is not a supported dtype
|
||||
:raises ValueError: If num_threads or num_nodes_to_cache is an invalid range.
|
||||
"""
|
||||
dap_metric = _get_valid_metric(metric)
|
||||
if vector_dtype not in _VALID_DTYPES:
|
||||
raise ValueError(
|
||||
f"vector_dtype {vector_dtype} is not in list of valid dtypes supported: {_VALID_DTYPES}"
|
||||
)
|
||||
if num_threads < 0:
|
||||
raise ValueError("num_threads must be a non-negative integer")
|
||||
if num_nodes_to_cache < 0:
|
||||
raise ValueError("num_nodes_to_cache must be a non-negative integer")
|
||||
self._vector_dtype = vector_dtype
|
||||
self._index = _DTYPE_TO_NATIVE_INDEX[vector_dtype](dap_metric)
|
||||
self._index.load_index(
|
||||
index_path_prefix=os.path.join(index_path, index_prefix),
|
||||
num_threads=num_threads,
|
||||
num_nodes_to_cache=num_nodes_to_cache,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: np.ndarray, k_neighbors: int, list_size: int, beam_width: int = 2
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index by a single query vector in a 1d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
:param query: 1d numpy array of the same dimensionality and dtype of the index.
|
||||
:type query: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param list_size: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type list_size: int
|
||||
:param beam_width: The beamwidth to be used for search. This is the maximum number of IO requests each query
|
||||
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
|
||||
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
|
||||
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
|
||||
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
|
||||
involve some tuning overhead.
|
||||
:type beam_width: int
|
||||
:return: Returns a tuple of 1-d numpy ndarrays; the first including the indices of the approximate nearest
|
||||
neighbors, the second their distances. These are aligned arrays.
|
||||
"""
|
||||
if len(query.shape) != 1:
|
||||
raise ValueError("query vector must be 1-d")
|
||||
if query.dtype != self._vector_dtype:
|
||||
raise ValueError(
|
||||
f"DiskIndex was built expecting a dtype of {self._vector_dtype}, but the query vector is "
|
||||
f"of dtype {query.dtype}"
|
||||
)
|
||||
if k_neighbors <= 0:
|
||||
raise ValueError("k_neighbors must be a positive integer")
|
||||
if list_size <= 0:
|
||||
raise ValueError("list_size must be a positive integer")
|
||||
if beam_width <= 0:
|
||||
raise ValueError("beam_width must be a positive integer")
|
||||
|
||||
if k_neighbors > list_size:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={list_size} was smaller. Increasing {list_size} to {k_neighbors}"
|
||||
)
|
||||
list_size = k_neighbors
|
||||
return self._index.search(
|
||||
query=query,
|
||||
knn=k_neighbors,
|
||||
l_search=list_size,
|
||||
beam_width=beam_width,
|
||||
)
|
||||
|
||||
def batch_search(
|
||||
self,
|
||||
queries: np.ndarray,
|
||||
k_neighbors: int,
|
||||
list_size: int,
|
||||
num_threads: int,
|
||||
beam_width: int = 2,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Searches the disk index for many query vectors in a 2d numpy array.
|
||||
|
||||
numpy array dtype must match index.
|
||||
|
||||
This search is parallelized and far more efficient than searching for each vector individually.
|
||||
|
||||
:param queries: 2d numpy array, with column dimensionality matching the index and row dimensionality being the
|
||||
number of queries intended to search for in parallel. Dtype must match dtype of the index.
|
||||
:type queries: numpy.ndarray
|
||||
:param k_neighbors: Number of neighbors to be returned. If query vector exists in index, it almost definitely
|
||||
will be returned as well, so adjust your ``k_neighbors`` as appropriate. (> 0)
|
||||
:type k_neighbors: int
|
||||
:param list_size: Size of list to use while searching. List size increases accuracy at the cost of latency. Must
|
||||
be at least k_neighbors in size.
|
||||
:type list_size: int
|
||||
:param num_threads: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
|
||||
:type num_threads: int
|
||||
:param beam_width: The beamwidth to be used for search. This is the maximum number of IO requests each query
|
||||
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
|
||||
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
|
||||
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
|
||||
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
|
||||
involve some tuning overhead.
|
||||
:type beam_width: int
|
||||
:return: Returns a tuple of 2-d numpy ndarrays; each row corresponds to the query vector in the same index,
|
||||
and elements in row corresponding from 1..k_neighbors approximate nearest neighbors. The second ndarray
|
||||
contains the distances, of the same form: row index will match query index, column index refers to
|
||||
1..k_neighbors distance. These are aligned arrays.
|
||||
"""
|
||||
if len(queries.shape) != 2:
|
||||
raise ValueError("queries must must be 2-d np array")
|
||||
if queries.dtype != self._vector_dtype:
|
||||
raise ValueError(
|
||||
f"DiskIndex was built expecting a dtype of {self._vector_dtype}, but the query vectors "
|
||||
f"are of dtype {queries.dtype}"
|
||||
)
|
||||
if k_neighbors <= 0:
|
||||
raise ValueError("k_neighbors must be a positive integer")
|
||||
if list_size <= 0:
|
||||
raise ValueError("list_size must be a positive integer")
|
||||
if num_threads < 0:
|
||||
raise ValueError("num_threads must be a nonnegative integer")
|
||||
if beam_width <= 0:
|
||||
raise ValueError("beam_width must be a positive integer")
|
||||
|
||||
if k_neighbors > list_size:
|
||||
warnings.warn(
|
||||
f"k_neighbors={k_neighbors} asked for, but list_size={list_size} was smaller. Increasing {list_size} to {k_neighbors}"
|
||||
)
|
||||
list_size = k_neighbors
|
||||
|
||||
num_queries, dim = queries.shape
|
||||
return self._index.batch_search(
|
||||
queries=queries,
|
||||
num_queries=num_queries,
|
||||
knn=k_neighbors,
|
||||
l_search=list_size,
|
||||
beam_width=beam_width,
|
||||
num_threads=num_threads,
|
||||
)
|
|
@ -21,8 +21,9 @@
|
|||
#include "disk_utils.h"
|
||||
#include "index.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "utils.h"
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<unsigned>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<uint32_t>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<float>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<int8_t>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
|
||||
|
@ -30,64 +31,28 @@ PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
|
|||
namespace py = pybind11;
|
||||
using namespace diskann;
|
||||
|
||||
template <class T> struct DiskANNIndex
|
||||
#ifdef _WINDOWS
|
||||
typedef WindowsAlignedFileReader PlatformSpecificAlignedFileReader;
|
||||
#else
|
||||
typedef LinuxAlignedFileReader PlatformSpecificAlignedFileReader;
|
||||
#endif
|
||||
|
||||
template <class T> struct DiskIndex
|
||||
{
|
||||
PQFlashIndex<T> *pq_flash_index;
|
||||
PQFlashIndex<T> *_pq_flash_index;
|
||||
std::shared_ptr<AlignedFileReader> reader;
|
||||
|
||||
DiskANNIndex(diskann::Metric metric)
|
||||
DiskIndex(const diskann::Metric metric, const std::string &index_path_prefix, const uint32_t num_threads,
|
||||
const size_t num_nodes_to_cache, const uint32_t cache_mechanism)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
reader = std::make_shared<WindowsAlignedFileReader>();
|
||||
#else
|
||||
reader = std::make_shared<LinuxAlignedFileReader>();
|
||||
#endif
|
||||
pq_flash_index = new PQFlashIndex<T>(reader, metric);
|
||||
}
|
||||
|
||||
~DiskANNIndex()
|
||||
{
|
||||
delete pq_flash_index;
|
||||
}
|
||||
|
||||
auto get_metric()
|
||||
{
|
||||
return pq_flash_index->get_metric();
|
||||
}
|
||||
|
||||
void cache_bfs_levels(size_t num_nodes_to_cache)
|
||||
{
|
||||
std::vector<uint32_t> node_list;
|
||||
pq_flash_index->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
pq_flash_index->load_cache_list(node_list);
|
||||
}
|
||||
|
||||
void cache_sample_paths(size_t num_nodes_to_cache, const std::string &warmup_query_file, uint32_t num_threads)
|
||||
{
|
||||
if (!file_exists(warmup_query_file))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
pq_flash_index->generate_cache_list_from_sample_queries(warmup_query_file, 15, 4, num_nodes_to_cache,
|
||||
num_threads, node_list);
|
||||
pq_flash_index->load_cache_list(node_list);
|
||||
}
|
||||
|
||||
int load_index(const std::string &index_path_prefix, const int num_threads, const size_t num_nodes_to_cache,
|
||||
int cache_mechanism)
|
||||
{
|
||||
int load_success = pq_flash_index->load(num_threads, index_path_prefix.c_str());
|
||||
reader = std::make_shared<PlatformSpecificAlignedFileReader>();
|
||||
_pq_flash_index = new PQFlashIndex<T>(reader, metric);
|
||||
int load_success = _pq_flash_index->load(num_threads, index_path_prefix.c_str());
|
||||
if (load_success != 0)
|
||||
{
|
||||
throw std::runtime_error("load_index failed.");
|
||||
throw std::runtime_error("index load failed.");
|
||||
}
|
||||
if (cache_mechanism == 0)
|
||||
{
|
||||
// Nothing to do
|
||||
}
|
||||
else if (cache_mechanism == 1)
|
||||
if (cache_mechanism == 1)
|
||||
{
|
||||
std::string sample_file = index_path_prefix + std::string("_sample_data.bin");
|
||||
cache_sample_paths(num_nodes_to_cache, sample_file, num_threads);
|
||||
|
@ -96,21 +61,46 @@ template <class T> struct DiskANNIndex
|
|||
{
|
||||
cache_bfs_levels(num_nodes_to_cache);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
~DiskIndex()
|
||||
{
|
||||
delete _pq_flash_index;
|
||||
}
|
||||
|
||||
void cache_bfs_levels(const size_t num_nodes_to_cache)
|
||||
{
|
||||
std::vector<uint32_t> node_list;
|
||||
_pq_flash_index->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
_pq_flash_index->load_cache_list(node_list);
|
||||
}
|
||||
|
||||
void cache_sample_paths(const size_t num_nodes_to_cache, const std::string &warmup_query_file,
|
||||
const uint32_t num_threads)
|
||||
{
|
||||
if (!file_exists(warmup_query_file))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
_pq_flash_index->generate_cache_list_from_sample_queries(warmup_query_file, 15, 4, num_nodes_to_cache,
|
||||
num_threads, node_list);
|
||||
_pq_flash_index->load_cache_list(node_list);
|
||||
}
|
||||
|
||||
auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
|
||||
const uint64_t l_search, const uint64_t beam_width)
|
||||
const uint64_t complexity, const uint64_t beam_width)
|
||||
{
|
||||
py::array_t<unsigned> ids(knn);
|
||||
py::array_t<uint32_t> ids(knn);
|
||||
py::array_t<float> dists(knn);
|
||||
|
||||
std::vector<unsigned> u32_ids(knn);
|
||||
std::vector<uint32_t> u32_ids(knn);
|
||||
std::vector<uint64_t> u64_ids(knn);
|
||||
QueryStats stats;
|
||||
|
||||
pq_flash_index->cached_beam_search(query.data(), knn, l_search, u64_ids.data(), dists.mutable_data(),
|
||||
beam_width, false, &stats);
|
||||
_pq_flash_index->cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(),
|
||||
beam_width, false, &stats);
|
||||
|
||||
auto r = ids.mutable_unchecked<1>();
|
||||
for (uint64_t i = 0; i < knn; ++i)
|
||||
|
@ -120,9 +110,9 @@ template <class T> struct DiskANNIndex
|
|||
}
|
||||
|
||||
auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
|
||||
const uint64_t knn, const uint64_t l_search, const uint64_t beam_width, const int num_threads)
|
||||
const uint64_t knn, const uint64_t complexity, const uint64_t beam_width, const int num_threads)
|
||||
{
|
||||
py::array_t<unsigned> ids({num_queries, knn});
|
||||
py::array_t<uint32_t> ids({num_queries, knn});
|
||||
py::array_t<float> dists({num_queries, knn});
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
@ -132,14 +122,14 @@ template <class T> struct DiskANNIndex
|
|||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)num_queries; i++)
|
||||
{
|
||||
pq_flash_index->cached_beam_search(queries.data(i), knn, l_search, u64_ids.data() + i * knn,
|
||||
dists.mutable_data(i), beam_width);
|
||||
_pq_flash_index->cached_beam_search(queries.data(i), knn, complexity, u64_ids.data() + i * knn,
|
||||
dists.mutable_data(i), beam_width);
|
||||
}
|
||||
|
||||
auto r = ids.mutable_unchecked();
|
||||
for (uint64_t i = 0; i < num_queries; ++i)
|
||||
for (uint64_t j = 0; j < knn; ++j)
|
||||
r(i, j) = (unsigned)u64_ids[i * knn + j];
|
||||
r(i, j) = (uint32_t)u64_ids[i * knn + j];
|
||||
|
||||
return std::make_pair(ids, dists);
|
||||
}
|
||||
|
@ -151,25 +141,46 @@ typedef uint32_t filterT;
|
|||
template <class T> struct DynamicInMemIndex
|
||||
{
|
||||
Index<T, IdT, filterT> *_index;
|
||||
const IndexWriteParameters write_params;
|
||||
IndexWriteParameters _write_params;
|
||||
const std::string &_index_path;
|
||||
|
||||
DynamicInMemIndex(Metric m, const size_t dim, const size_t max_points, const IndexWriteParameters &index_parameters,
|
||||
const uint32_t initial_search_list_size, const uint32_t search_threads,
|
||||
const bool concurrent_consolidate)
|
||||
: write_params(index_parameters)
|
||||
DynamicInMemIndex(const Metric m, const size_t dim, const size_t max_points, const uint32_t complexity,
|
||||
const uint32_t graph_degree, const bool saturate_graph, const uint32_t max_occlusion_size,
|
||||
const float alpha, const uint32_t num_threads, const uint32_t filter_complexity,
|
||||
const uint32_t num_frozen_points, const uint32_t initial_search_complexity,
|
||||
const uint32_t initial_search_threads, const bool concurrent_consolidation,
|
||||
const std::string &index_path = "")
|
||||
: _write_params(IndexWriteParametersBuilder(complexity, graph_degree)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_max_occlusion_size(max_occlusion_size)
|
||||
.with_alpha(alpha)
|
||||
.with_num_threads(num_threads)
|
||||
.with_filter_list_size(filter_complexity)
|
||||
.with_num_frozen_points(num_frozen_points)
|
||||
.build()),
|
||||
_index_path(index_path)
|
||||
{
|
||||
const uint32_t _initial_search_complexity =
|
||||
initial_search_complexity != 0 ? initial_search_complexity : complexity;
|
||||
const uint32_t _initial_search_threads =
|
||||
initial_search_threads != 0 ? initial_search_threads : omp_get_num_threads();
|
||||
|
||||
_index = new Index<T>(m, dim, max_points,
|
||||
true, // dynamic_index
|
||||
index_parameters, // used for insert
|
||||
initial_search_list_size, // used to prepare the scratch space for
|
||||
// searching. can / may be expanded if the
|
||||
// search asks for a larger L.
|
||||
search_threads, // also used for the scratch space
|
||||
true, // enable_tags
|
||||
concurrent_consolidate,
|
||||
true, // dynamic_index
|
||||
_write_params, // used for insert
|
||||
_initial_search_complexity, // used to prepare the scratch space for searching. can / may
|
||||
// be expanded if the search asks for a larger L.
|
||||
_initial_search_threads, // also used for the scratch space
|
||||
true, // enable_tags
|
||||
concurrent_consolidation,
|
||||
false, // pq_dist_build
|
||||
0, // num_pq_chunks
|
||||
false); // use_opq = false
|
||||
if (!index_path.empty())
|
||||
{
|
||||
_index->load(index_path.c_str(), _write_params.num_threads, complexity);
|
||||
}
|
||||
_index->enable_delete();
|
||||
}
|
||||
|
||||
~DynamicInMemIndex()
|
||||
|
@ -182,34 +193,68 @@ template <class T> struct DynamicInMemIndex
|
|||
return _index->insert_point(vector.data(), id);
|
||||
}
|
||||
|
||||
auto batch_insert(py::array_t<T, py::array::c_style | py::array::forcecast> &vectors,
|
||||
py::array_t<IdT, py::array::c_style | py::array::forcecast> &ids, const size_t num_inserts,
|
||||
const int num_threads = 0)
|
||||
{
|
||||
if (num_threads == 0)
|
||||
omp_set_num_threads(omp_get_num_procs());
|
||||
else
|
||||
omp_set_num_threads(num_threads);
|
||||
py::array_t<int> insert_retvals(num_inserts);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (size_t i = 0; i < num_inserts; i++)
|
||||
{
|
||||
insert_retvals.mutable_data()[i] = _index->insert_point(vectors.data(i), *(ids.data(i)));
|
||||
}
|
||||
|
||||
return insert_retvals;
|
||||
}
|
||||
|
||||
int mark_deleted(const IdT id)
|
||||
{
|
||||
return _index->lazy_delete(id);
|
||||
}
|
||||
|
||||
void save(const std::string &save_path = "", const bool compact_before_save = false)
|
||||
{
|
||||
const std::string path = !save_path.empty() ? save_path : _index_path;
|
||||
if (path.empty())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"A save_path must be provided if a starting index was not provided in the DynamicMemoryIndex "
|
||||
"constructor via the index_path parameter");
|
||||
}
|
||||
_index->save(path.c_str(), compact_before_save);
|
||||
}
|
||||
|
||||
auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
|
||||
const uint64_t l_search)
|
||||
const uint64_t complexity)
|
||||
{
|
||||
py::array_t<IdT> ids(knn);
|
||||
py::array_t<float> dists(knn);
|
||||
std::vector<T *> empty_vector;
|
||||
_index->search_with_tags(query.data(), knn, l_search, ids.mutable_data(), dists.mutable_data(), empty_vector);
|
||||
_index->search_with_tags(query.data(), knn, complexity, ids.mutable_data(), dists.mutable_data(), empty_vector);
|
||||
return std::make_pair(ids, dists);
|
||||
}
|
||||
|
||||
auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
|
||||
const uint64_t knn, const uint64_t l_search, const int num_threads)
|
||||
const uint64_t knn, const uint64_t complexity, const int num_threads)
|
||||
{
|
||||
py::array_t<unsigned> ids({num_queries, knn});
|
||||
py::array_t<float> dists({num_queries, knn});
|
||||
std::vector<T *> empty_vector;
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
if (num_threads == 0)
|
||||
omp_set_num_threads(omp_get_num_procs());
|
||||
else
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)num_queries; i++)
|
||||
{
|
||||
_index->search_with_tags(queries.data(i), knn, l_search, ids.mutable_data(i), dists.mutable_data(i),
|
||||
_index->search_with_tags(queries.data(i), knn, complexity, ids.mutable_data(i), dists.mutable_data(i),
|
||||
empty_vector);
|
||||
}
|
||||
|
||||
|
@ -218,7 +263,7 @@ template <class T> struct DynamicInMemIndex
|
|||
|
||||
auto consolidate_delete()
|
||||
{
|
||||
return _index->consolidate_deletes(write_params);
|
||||
_index->consolidate_deletes(_write_params);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -226,8 +271,15 @@ template <class T> struct StaticInMemIndex
|
|||
{
|
||||
Index<T, IdT, filterT> *_index;
|
||||
|
||||
StaticInMemIndex(Metric m, const std::string &data_path, IndexWriteParameters &index_parameters)
|
||||
StaticInMemIndex(const Metric m, const std::string &data_path, const std::string &index_prefix,
|
||||
const uint32_t num_threads, const uint32_t initial_search_complexity)
|
||||
{
|
||||
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads();
|
||||
if (initial_search_complexity == 0)
|
||||
{
|
||||
throw std::runtime_error("initial_search_complexity must be a positive uint32_t");
|
||||
}
|
||||
|
||||
size_t ndims, npoints;
|
||||
diskann::get_bin_metadata(data_path, npoints, ndims);
|
||||
_index = new Index<T>(m, ndims, npoints,
|
||||
|
@ -238,7 +290,7 @@ template <class T> struct StaticInMemIndex
|
|||
0, // num_pq_chunks
|
||||
false, // use_opq = false
|
||||
0); // num_frozen_pts = 0
|
||||
_index->build(data_path.c_str(), npoints, index_parameters);
|
||||
_index->load(index_prefix.c_str(), _num_threads, initial_search_complexity);
|
||||
}
|
||||
|
||||
~StaticInMemIndex()
|
||||
|
@ -247,34 +299,162 @@ template <class T> struct StaticInMemIndex
|
|||
}
|
||||
|
||||
auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
|
||||
const uint64_t l_search)
|
||||
const uint64_t complexity)
|
||||
{
|
||||
py::array_t<IdT> ids(knn);
|
||||
py::array_t<float> dists(knn);
|
||||
std::vector<T *> empty_vector;
|
||||
_index->search(query.data(), knn, l_search, ids.mutable_data(), dists.mutable_data());
|
||||
_index->search(query.data(), knn, complexity, ids.mutable_data(), dists.mutable_data());
|
||||
return std::make_pair(ids, dists);
|
||||
}
|
||||
|
||||
auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
|
||||
const uint64_t knn, const uint64_t l_search, const int num_threads)
|
||||
const uint64_t knn, const uint64_t complexity, const int num_threads)
|
||||
{
|
||||
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads();
|
||||
py::array_t<unsigned> ids({num_queries, knn});
|
||||
py::array_t<float> dists({num_queries, knn});
|
||||
std::vector<T *> empty_vector;
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
omp_set_num_threads(_num_threads);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)num_queries; i++)
|
||||
{
|
||||
_index->search(queries.data(i), knn, l_search, ids.mutable_data(i), dists.mutable_data(i));
|
||||
_index->search(queries.data(i), knn, complexity, ids.mutable_data(i), dists.mutable_data(i));
|
||||
}
|
||||
|
||||
return std::make_pair(ids, dists);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void build_disk_index2(const diskann::Metric metric, const std::string &data_file_path,
|
||||
const std::string &index_prefix_path, const uint32_t complexity, const uint32_t graph_degree,
|
||||
const double final_index_ram_limit, const double indexing_ram_budget, const uint32_t num_threads,
|
||||
const uint32_t pq_disk_bytes)
|
||||
{
|
||||
std::string params = std::to_string(graph_degree) + " " + std::to_string(complexity) + " " +
|
||||
std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) + " " +
|
||||
std::to_string(num_threads);
|
||||
if (pq_disk_bytes > 0)
|
||||
params = params + " " + std::to_string(pq_disk_bytes);
|
||||
diskann::build_disk_index<T>(data_file_path.c_str(), index_prefix_path.c_str(), params.c_str(), metric);
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = IdT, typename LabelT = filterT>
|
||||
void build_in_memory_index(const diskann::Metric &metric, const std::string &vector_bin_path,
|
||||
const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity,
|
||||
const float alpha, const uint32_t num_threads, const bool use_pq_build,
|
||||
const size_t num_pq_bytes, const bool use_opq, const std::string &label_file,
|
||||
const std::string &universal_label, const uint32_t filter_complexity,
|
||||
const bool use_tags = false)
|
||||
{
|
||||
diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree)
|
||||
.with_filter_list_size(filter_complexity)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
size_t data_num, data_dim;
|
||||
diskann::get_bin_metadata(vector_bin_path, data_num, data_dim);
|
||||
diskann::Index<T, TagT, LabelT> index(metric, data_dim, data_num, false, use_tags, false, use_pq_build,
|
||||
num_pq_bytes, use_opq);
|
||||
if (label_file == "")
|
||||
{
|
||||
index.build(vector_bin_path.c_str(), data_num, index_build_params);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string labels_file_to_use = index_output_path + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt";
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
if (universal_label != "")
|
||||
{
|
||||
filterT unv_label_as_num = 0;
|
||||
index.set_universal_label(unv_label_as_num);
|
||||
}
|
||||
index.build_filtered_index(index_output_path.c_str(), labels_file_to_use, data_num, index_build_params);
|
||||
}
|
||||
index.save(index_output_path.c_str());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void add_variant(py::module_ &m, const std::string &build_name, const std::string &class_name)
|
||||
{
|
||||
const std::string build_disk_name = "build_disk_" + build_name + "_index";
|
||||
m.def(build_disk_name.c_str(), &build_disk_index2<T>, py::arg("metric"), py::arg("data_file_path"),
|
||||
py::arg("index_prefix_path"), py::arg("complexity"), py::arg("graph_degree"),
|
||||
py::arg("final_index_ram_limit"), py::arg("indexing_ram_budget"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes"));
|
||||
|
||||
const std::string build_in_memory_name = "build_in_memory_" + build_name + "_index";
|
||||
m.def(build_in_memory_name.c_str(), &build_in_memory_index<T>, py::arg("metric"), py::arg("data_file_path"),
|
||||
py::arg("index_output_path"), py::arg("graph_degree"), py::arg("complexity"), py::arg("alpha"),
|
||||
py::arg("num_threads"), py::arg("use_pq_build"), py::arg("num_pq_bytes"), py::arg("use_opq"),
|
||||
py::arg("label_file") = "", py::arg("universal_label") = "", py::arg("filter_complexity") = 0,
|
||||
py::arg("use_tags") = false);
|
||||
|
||||
const std::string static_index = "StaticMemory" + class_name + "Index";
|
||||
py::class_<StaticInMemIndex<T>>(m, static_index.c_str())
|
||||
.def(py::init([](const diskann::Metric metric, const std::string &data_path, const std::string &index_path,
|
||||
const uint32_t num_threads, const uint32_t initial_search_complexity) {
|
||||
return std::unique_ptr<StaticInMemIndex<T>>(
|
||||
new StaticInMemIndex<T>(metric, data_path, index_path, num_threads, initial_search_complexity));
|
||||
}),
|
||||
py::arg("metric"), py::arg("data_path"), py::arg("index_path"), py::arg("num_threads"),
|
||||
py::arg("initial_search_complexity"))
|
||||
.def("search", &StaticInMemIndex<T>::search, py::arg("query"), py::arg("knn"), py::arg("complexity"))
|
||||
.def("batch_search", &StaticInMemIndex<T>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("complexity"), py::arg("num_threads"));
|
||||
|
||||
const std::string dynamic_index = "DynamicMemory" + class_name + "Index";
|
||||
py::class_<DynamicInMemIndex<T>>(m, dynamic_index.c_str())
|
||||
.def(py::init([](const diskann::Metric metric, const size_t dim, const size_t max_points,
|
||||
const uint32_t complexity, const uint32_t graph_degree, const bool saturate_graph,
|
||||
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads,
|
||||
const uint32_t filter_complexity, const uint32_t num_frozen_points,
|
||||
const uint32_t initial_search_complexity, const uint32_t search_threads,
|
||||
const bool concurrent_consolidation, const std::string &index_path) {
|
||||
return std::unique_ptr<DynamicInMemIndex<T>>(new DynamicInMemIndex<T>(
|
||||
metric, dim, max_points, complexity, graph_degree, saturate_graph, max_occlusion_size, alpha,
|
||||
num_threads, filter_complexity, num_frozen_points, initial_search_complexity, search_threads,
|
||||
concurrent_consolidation, index_path));
|
||||
}),
|
||||
py::arg("metric"), py::arg("dim"), py::arg("max_points"), py::arg("complexity"), py::arg("graph_degree"),
|
||||
py::arg("saturate_graph") = diskann::defaults::SATURATE_GRAPH,
|
||||
py::arg("max_occlusion_size") = diskann::defaults::MAX_OCCLUSION_SIZE,
|
||||
py::arg("alpha") = diskann::defaults::ALPHA, py::arg("num_threads") = diskann::defaults::NUM_THREADS,
|
||||
py::arg("filter_complexity") = diskann::defaults::FILTER_LIST_SIZE,
|
||||
py::arg("num_frozen_points") = diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC,
|
||||
py::arg("initial_search_complexity") = 0, py::arg("search_threads") = 0,
|
||||
py::arg("concurrent_consolidation") = true, py::arg("index_path") = "")
|
||||
.def("search", &DynamicInMemIndex<T>::search, py::arg("query"), py::arg("knn"), py::arg("complexity"))
|
||||
.def("batch_search", &DynamicInMemIndex<T>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("complexity"), py::arg("num_threads"))
|
||||
.def("batch_insert", &DynamicInMemIndex<T>::batch_insert, py::arg("vectors"), py::arg("ids"),
|
||||
py::arg("num_inserts"), py::arg("num_threads"))
|
||||
.def("save", &DynamicInMemIndex<T>::save, py::arg("save_path") = "", py::arg("compact_before_save") = false)
|
||||
.def("insert", &DynamicInMemIndex<T>::insert, py::arg("vector"), py::arg("id"))
|
||||
.def("mark_deleted", &DynamicInMemIndex<T>::mark_deleted, py::arg("id"))
|
||||
.def("consolidate_delete", &DynamicInMemIndex<T>::consolidate_delete);
|
||||
|
||||
const std::string disk_name = "Disk" + class_name + "Index";
|
||||
py::class_<DiskIndex<T>>(m, disk_name.c_str())
|
||||
.def(py::init([](const diskann::Metric metric, const std::string &index_path_prefix, const uint32_t num_threads,
|
||||
const size_t num_nodes_to_cache, const uint32_t cache_mechanism) {
|
||||
return std::unique_ptr<DiskIndex<T>>(
|
||||
new DiskIndex<T>(metric, index_path_prefix, num_threads, num_nodes_to_cache, cache_mechanism));
|
||||
}),
|
||||
py::arg("metric"), py::arg("index_path_prefix"), py::arg("num_threads"), py::arg("num_nodes_to_cache"),
|
||||
py::arg("cache_mechanism") = 1)
|
||||
.def("cache_bfs_levels", &DiskIndex<T>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
|
||||
.def("search", &DiskIndex<T>::search, py::arg("query"), py::arg("knn"), py::arg("complexity"),
|
||||
py::arg("beam_width"))
|
||||
.def("batch_search", &DiskIndex<T>::batch_search, py::arg("queries"), py::arg("num_queries"), py::arg("knn"),
|
||||
py::arg("complexity"), py::arg("beam_width"), py::arg("num_threads"));
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_diskannpy, m)
|
||||
{
|
||||
m.doc() = "DiskANN Python Bindings";
|
||||
|
@ -284,164 +464,33 @@ PYBIND11_MODULE(_diskannpy, m)
|
|||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
|
||||
// let's re-export our defaults
|
||||
py::module_ default_values = m.def_submodule(
|
||||
"defaults",
|
||||
"A collection of the default values used for common diskann operations. `GRAPH_DEGREE` and `COMPLEXITY` are not"
|
||||
" set as defaults, but some semi-reasonable default values are selected for your convenience. We urge you to "
|
||||
"investigate their meaning and adjust them for your use cases.");
|
||||
|
||||
default_values.attr("ALPHA") = diskann::defaults::ALPHA;
|
||||
default_values.attr("NUM_THREADS") = diskann::defaults::NUM_THREADS;
|
||||
default_values.attr("MAX_OCCLUSION_SIZE") = diskann::defaults::MAX_OCCLUSION_SIZE;
|
||||
default_values.attr("FILTER_COMPLEXITY") = diskann::defaults::FILTER_LIST_SIZE;
|
||||
default_values.attr("NUM_FROZEN_POINTS_STATIC") = diskann::defaults::NUM_FROZEN_POINTS_STATIC;
|
||||
default_values.attr("NUM_FROZEN_POINTS_DYNAMIC") = diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC;
|
||||
default_values.attr("SATURATE_GRAPH") = diskann::defaults::SATURATE_GRAPH;
|
||||
default_values.attr("GRAPH_DEGREE") = diskann::defaults::MAX_DEGREE;
|
||||
default_values.attr("COMPLEXITY") = diskann::defaults::BUILD_LIST_SIZE;
|
||||
default_values.attr("PQ_DISK_BYTES") = (uint32_t)0;
|
||||
default_values.attr("USE_PQ_BUILD") = false;
|
||||
default_values.attr("NUM_PQ_BYTES") = (uint32_t)0;
|
||||
default_values.attr("USE_OPQ") = false;
|
||||
|
||||
add_variant<float>(m, "float", "Float");
|
||||
add_variant<uint8_t>(m, "uint8", "UInt8");
|
||||
add_variant<int8_t>(m, "int8", "Int8");
|
||||
|
||||
py::enum_<Metric>(m, "Metric")
|
||||
.value("L2", Metric::L2)
|
||||
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
|
||||
.export_values();
|
||||
|
||||
py::class_<StaticInMemIndex<float>>(m, "DiskANNStaticInMemFloatIndex")
|
||||
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
|
||||
return std::unique_ptr<StaticInMemIndex<float>>(
|
||||
new StaticInMemIndex<float>(metric, data_path, index_parameters));
|
||||
}))
|
||||
.def("search", &StaticInMemIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &StaticInMemIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));
|
||||
|
||||
py::class_<StaticInMemIndex<int8_t>>(m, "DiskANNStaticInMemInt8Index")
|
||||
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
|
||||
return std::unique_ptr<StaticInMemIndex<int8_t>>(
|
||||
new StaticInMemIndex<int8_t>(metric, data_path, index_parameters));
|
||||
}))
|
||||
.def("search", &StaticInMemIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &StaticInMemIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));
|
||||
|
||||
py::class_<StaticInMemIndex<uint8_t>>(m, "DiskANNStaticInMemUint8Index")
|
||||
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
|
||||
return std::unique_ptr<StaticInMemIndex<uint8_t>>(
|
||||
new StaticInMemIndex<uint8_t>(metric, data_path, index_parameters));
|
||||
}))
|
||||
.def("search", &StaticInMemIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &StaticInMemIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));
|
||||
|
||||
py::class_<DynamicInMemIndex<float>>(m, "DiskANNDynamicInMemFloatIndex")
|
||||
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
|
||||
const IndexWriteParameters &index_parameters, const uint32_t initial_search_list_size,
|
||||
const uint32_t search_threads, const bool concurrent_consolidate) {
|
||||
return std::unique_ptr<DynamicInMemIndex<float>>(
|
||||
new DynamicInMemIndex<float>(metric, dim, max_points, index_parameters, initial_search_list_size,
|
||||
search_threads, concurrent_consolidate));
|
||||
}))
|
||||
.def("search", &DynamicInMemIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &DynamicInMemIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
|
||||
.def("insert", &DynamicInMemIndex<float>::insert, py::arg("vector"), py::arg("id"))
|
||||
.def("mark_deleted", &DynamicInMemIndex<float>::mark_deleted, py::arg("id"))
|
||||
.def("consolidate_delete", &DynamicInMemIndex<float>::consolidate_delete);
|
||||
|
||||
py::class_<DynamicInMemIndex<int8_t>>(m, "DiskANNDynamicInMemInt8Index")
|
||||
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
|
||||
const IndexWriteParameters &index_parameters, const uint32_t initial_search_list_size,
|
||||
const uint32_t search_threads, const bool concurrent_consolidate) {
|
||||
return std::unique_ptr<DynamicInMemIndex<int8_t>>(
|
||||
new DynamicInMemIndex<int8_t>(metric, dim, max_points, index_parameters, initial_search_list_size,
|
||||
search_threads, concurrent_consolidate));
|
||||
}))
|
||||
.def("search", &DynamicInMemIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &DynamicInMemIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
|
||||
.def("insert", &DynamicInMemIndex<int8_t>::insert, py::arg("vector"), py::arg("id"))
|
||||
.def("mark_deleted", &DynamicInMemIndex<int8_t>::mark_deleted, py::arg("id"))
|
||||
.def("consolidate_delete", &DynamicInMemIndex<int8_t>::consolidate_delete);
|
||||
|
||||
py::class_<DynamicInMemIndex<uint8_t>>(m, "DiskANNDynamicInMemUint8Index")
|
||||
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
|
||||
const IndexWriteParameters &index_parameters, const uint32_t initial_search_list_size,
|
||||
const uint32_t search_threads, const bool concurrent_consolidate) {
|
||||
return std::unique_ptr<DynamicInMemIndex<uint8_t>>(
|
||||
new DynamicInMemIndex<uint8_t>(metric, dim, max_points, index_parameters, initial_search_list_size,
|
||||
search_threads, concurrent_consolidate));
|
||||
}))
|
||||
.def("search", &DynamicInMemIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
|
||||
.def("batch_search", &DynamicInMemIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
|
||||
.def("insert", &DynamicInMemIndex<uint8_t>::insert, py::arg("vector"), py::arg("id"))
|
||||
.def("mark_deleted", &DynamicInMemIndex<uint8_t>::mark_deleted, py::arg("id"))
|
||||
.def("consolidate_delete", &DynamicInMemIndex<uint8_t>::consolidate_delete);
|
||||
|
||||
py::class_<DiskANNIndex<float>>(m, "DiskANNFloatIndex")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<float>>(new DiskANNIndex<float>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<float>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
|
||||
.def("load_index", &DiskANNIndex<float>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
|
||||
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
|
||||
.def("search", &DiskANNIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
|
||||
py::arg("beam_width"))
|
||||
.def("batch_search", &DiskANNIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<float> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
|
||||
unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads,
|
||||
unsigned pq_disk_bytes) {
|
||||
std::string params = std::to_string(R) + " " + std::to_string(L) + " " +
|
||||
std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) +
|
||||
" " + std::to_string(num_threads);
|
||||
if (pq_disk_bytes > 0)
|
||||
{
|
||||
params = params + " " + std::to_string(pq_disk_bytes);
|
||||
}
|
||||
diskann::build_disk_index<float>(data_file_path, index_prefix_path, params.c_str(), self.get_metric());
|
||||
},
|
||||
py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"),
|
||||
py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes") = 0);
|
||||
|
||||
py::class_<DiskANNIndex<int8_t>>(m, "DiskANNInt8Index")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<int8_t>>(new DiskANNIndex<int8_t>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<int8_t>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
|
||||
.def("load_index", &DiskANNIndex<int8_t>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
|
||||
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
|
||||
.def("search", &DiskANNIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
|
||||
py::arg("beam_width"))
|
||||
.def("batch_search", &DiskANNIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<int8_t> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
|
||||
unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads,
|
||||
unsigned pq_disk_bytes) {
|
||||
std::string params = std::to_string(R) + " " + std::to_string(L) + " " +
|
||||
std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) +
|
||||
" " + std::to_string(num_threads);
|
||||
if (pq_disk_bytes > 0)
|
||||
params = params + " " + std::to_string(pq_disk_bytes);
|
||||
diskann::build_disk_index<int8_t>(data_file_path, index_prefix_path, params.c_str(), self.get_metric());
|
||||
},
|
||||
py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"),
|
||||
py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes") = 0);
|
||||
|
||||
py::class_<DiskANNIndex<uint8_t>>(m, "DiskANNUInt8Index")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<uint8_t>>(new DiskANNIndex<uint8_t>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<uint8_t>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
|
||||
.def("load_index", &DiskANNIndex<uint8_t>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
|
||||
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
|
||||
.def("search", &DiskANNIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
|
||||
py::arg("beam_width"))
|
||||
.def("batch_search", &DiskANNIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
|
||||
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<uint8_t> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
|
||||
unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads,
|
||||
unsigned pq_disk_bytes) {
|
||||
std::string params = std::to_string(R) + " " + std::to_string(L) + " " +
|
||||
std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) +
|
||||
" " + std::to_string(num_threads);
|
||||
if (pq_disk_bytes > 0)
|
||||
params = params + " " + std::to_string(pq_disk_bytes);
|
||||
diskann::build_disk_index<uint8_t>(data_file_path, index_prefix_path, params.c_str(),
|
||||
self.get_metric());
|
||||
},
|
||||
py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"),
|
||||
py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes") = 0);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .create_test_data import (random_vectors, vectors_as_temp_file,
|
||||
write_vectors)
|
||||
from .build_memory_index import build_random_vectors_and_memory_index
|
||||
from .create_test_data import random_vectors, vectors_as_temp_file, write_vectors
|
||||
from .recall import calculate_recall
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
|
||||
from .create_test_data import random_vectors
|
||||
|
||||
|
||||
def build_random_vectors_and_memory_index(
|
||||
dtype,
|
||||
metric,
|
||||
with_tags: bool = False,
|
||||
index_prefix: str ="ann",
|
||||
seed: int = 12345
|
||||
):
|
||||
query_vectors: np.ndarray = random_vectors(1000, 10, dtype=dtype, seed=seed)
|
||||
index_vectors: np.ndarray = random_vectors(10000, 10, dtype=dtype, seed=seed)
|
||||
ann_dir = mkdtemp()
|
||||
|
||||
if with_tags:
|
||||
rng = np.random.default_rng(seed)
|
||||
tags = np.arange(start=1, stop=10001, dtype=np.uint32)
|
||||
rng.shuffle(tags)
|
||||
else:
|
||||
tags = None
|
||||
|
||||
dap.build_memory_index(
|
||||
data=index_vectors,
|
||||
metric=metric,
|
||||
index_directory=ann_dir,
|
||||
graph_degree=16,
|
||||
complexity=32,
|
||||
alpha=1.2,
|
||||
num_threads=0,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=8,
|
||||
use_opq=False,
|
||||
filter_complexity=32,
|
||||
index_prefix=index_prefix
|
||||
)
|
||||
return (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
os.path.join(ann_dir, "vectors.bin"),
|
||||
tags
|
||||
)
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import unittest
|
||||
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestBuildDiskIndex(unittest.TestCase):
|
||||
def test_valid_shape(self):
|
||||
rng = np.random.default_rng(12345)
|
||||
rando = rng.random((1000, 100, 5), dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index(
|
||||
data=rando,
|
||||
metric="l2",
|
||||
index_directory="test",
|
||||
complexity=5,
|
||||
graph_degree=5,
|
||||
search_memory_maximum=0.01,
|
||||
build_memory_maximum=0.01,
|
||||
num_threads=1,
|
||||
pq_disk_bytes=0,
|
||||
)
|
||||
|
||||
rando = rng.random(1000, dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index(
|
||||
data=rando,
|
||||
metric="l2",
|
||||
index_directory="test",
|
||||
complexity=5,
|
||||
graph_degree=5,
|
||||
search_memory_maximum=0.01,
|
||||
build_memory_maximum=0.01,
|
||||
num_threads=1,
|
||||
pq_disk_bytes=0,
|
||||
)
|
||||
|
||||
def test_value_ranges_build(self):
|
||||
good_ranges = {
|
||||
"vector_dtype": np.single,
|
||||
"metric": "l2",
|
||||
"graph_degree": 5,
|
||||
"complexity": 5,
|
||||
"search_memory_maximum": 0.01,
|
||||
"build_memory_maximum": 0.01,
|
||||
"num_threads": 1,
|
||||
"pq_disk_bytes": 0,
|
||||
}
|
||||
bad_ranges = {
|
||||
"vector_dtype": np.float64,
|
||||
"metric": "soups this time",
|
||||
"graph_degree": -1,
|
||||
"complexity": -1,
|
||||
"search_memory_maximum": 0,
|
||||
"build_memory_maximum": 0,
|
||||
"num_threads": -1,
|
||||
"pq_disk_bytes": -1,
|
||||
}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest(
|
||||
f"testing bad value key: {bad_value_key} with bad value: {bad_ranges[bad_value_key]}"
|
||||
):
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index(data="test", index_directory="test", **kwargs)
|
||||
|
||||
|
||||
class TestBuildMemoryIndex(unittest.TestCase):
|
||||
def test_valid_shape(self):
|
||||
rng = np.random.default_rng(12345)
|
||||
rando = rng.random((1000, 100, 5), dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_memory_index(
|
||||
data=rando,
|
||||
metric="l2",
|
||||
index_directory="test",
|
||||
complexity=5,
|
||||
graph_degree=5,
|
||||
alpha=1.2,
|
||||
num_threads=1,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=0,
|
||||
use_opq=False,
|
||||
)
|
||||
|
||||
rando = rng.random(1000, dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_memory_index(
|
||||
data=rando,
|
||||
metric="l2",
|
||||
index_directory="test",
|
||||
complexity=5,
|
||||
graph_degree=5,
|
||||
alpha=1.2,
|
||||
num_threads=1,
|
||||
use_pq_build=False,
|
||||
num_pq_bytes=0,
|
||||
use_opq=False,
|
||||
)
|
||||
|
||||
def test_value_ranges_build(self):
|
||||
good_ranges = {
|
||||
"vector_dtype": np.single,
|
||||
"metric": "l2",
|
||||
"graph_degree": 5,
|
||||
"complexity": 5,
|
||||
"alpha": 1.2,
|
||||
"num_threads": 1,
|
||||
"num_pq_bytes": 0,
|
||||
}
|
||||
bad_ranges = {
|
||||
"vector_dtype": np.float64,
|
||||
"metric": "soups this time",
|
||||
"graph_degree": -1,
|
||||
"complexity": -1,
|
||||
"alpha": -1.2,
|
||||
"num_threads": 1,
|
||||
"num_pq_bytes": -60,
|
||||
}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest(
|
||||
f"testing bad value key: {bad_value_key} with bad value: {bad_ranges[bad_value_key]}"
|
||||
):
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_memory_index(
|
||||
data="test",
|
||||
index_directory="test",
|
||||
use_pq_build=True,
|
||||
use_opq=False,
|
||||
**kwargs,
|
||||
)
|
|
@ -7,57 +7,8 @@ from tempfile import mkdtemp
|
|||
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
from fixtures import calculate_recall, random_vectors, vectors_as_temp_file
|
||||
|
||||
|
||||
class TestBuildIndex(unittest.TestCase):
|
||||
def test_valid_shape(self):
|
||||
rng = np.random.default_rng(12345)
|
||||
rando = rng.random((1000, 100, 5), dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index_from_vectors(
|
||||
rando, "l2", "test", 5, 5, 0.01, 0.01, 1, 0
|
||||
)
|
||||
|
||||
rando = rng.random(1000, dtype=np.single)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index_from_vectors(
|
||||
rando, "l2", "test", 5, 5, 0.01, 0.01, 1, 0
|
||||
)
|
||||
|
||||
def test_value_ranges_build(self):
|
||||
good_ranges = {
|
||||
"vector_dtype": np.single,
|
||||
"metric": "l2",
|
||||
"max_degree": 5,
|
||||
"list_size": 5,
|
||||
"search_memory_maximum": 0.01,
|
||||
"build_memory_maximum": 0.01,
|
||||
"num_threads": 1,
|
||||
"pq_disk_bytes": 0,
|
||||
}
|
||||
bad_ranges = {
|
||||
"vector_dtype": np.float64,
|
||||
"metric": "soups this time",
|
||||
"max_degree": -1,
|
||||
"list_size": -1,
|
||||
"search_memory_maximum": 0,
|
||||
"build_memory_maximum": 0,
|
||||
"num_threads": -1,
|
||||
"pq_disk_bytes": -1,
|
||||
}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest(
|
||||
f"testing bad value key: {bad_value_key} with bad value: {bad_ranges[bad_value_key]}"
|
||||
):
|
||||
with self.assertRaises(ValueError):
|
||||
dap.build_disk_index_from_vector_file(
|
||||
vector_bin_file="test", index_path="test", **kwargs
|
||||
)
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
|
||||
def _build_random_vectors_and_index(dtype, metric):
|
||||
|
@ -65,13 +16,13 @@ def _build_random_vectors_and_index(dtype, metric):
|
|||
index_vectors = random_vectors(10000, 10, dtype=dtype)
|
||||
with vectors_as_temp_file(index_vectors) as vector_temp:
|
||||
ann_dir = mkdtemp()
|
||||
dap.build_disk_index_from_vector_file(
|
||||
vector_bin_file=vector_temp,
|
||||
dap.build_disk_index(
|
||||
data=vector_temp,
|
||||
metric=metric,
|
||||
vector_dtype=dtype,
|
||||
index_path=ann_dir,
|
||||
max_degree=16,
|
||||
list_size=32,
|
||||
index_directory=ann_dir,
|
||||
graph_degree=16,
|
||||
complexity=32,
|
||||
search_memory_maximum=0.00003,
|
||||
build_memory_maximum=1,
|
||||
num_threads=1,
|
||||
|
@ -105,7 +56,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
index = dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
@ -114,7 +65,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
diskann_neighbors, diskann_distances = index.batch_search(
|
||||
query_vectors,
|
||||
k_neighbors=k,
|
||||
list_size=5,
|
||||
complexity=5,
|
||||
beam_width=2,
|
||||
num_threads=16,
|
||||
)
|
||||
|
@ -124,10 +75,11 @@ class TestDiskIndex(unittest.TestCase):
|
|||
)
|
||||
knn.fit(index_vectors)
|
||||
knn_distances, knn_indices = knn.kneighbors(query_vectors)
|
||||
recall = calculate_recall(diskann_neighbors, knn_indices, k)
|
||||
self.assertTrue(
|
||||
calculate_recall(diskann_neighbors, knn_indices, k) > 0.70,
|
||||
"Recall was not over 0.7",
|
||||
)
|
||||
recall > 0.70,
|
||||
f"Recall [{recall}] was not over 0.7",
|
||||
)
|
||||
|
||||
def test_single(self):
|
||||
for metric, dtype, query_vectors, index_vectors, ann_dir in self._test_matrix:
|
||||
|
@ -135,14 +87,14 @@ class TestDiskIndex(unittest.TestCase):
|
|||
index = dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
||||
k = 5
|
||||
ids, dists = index.search(
|
||||
query_vectors[0], k_neighbors=k, list_size=5, beam_width=2
|
||||
query_vectors[0], k_neighbors=k, complexity=5, beam_width=2
|
||||
)
|
||||
self.assertEqual(ids.shape[0], k)
|
||||
self.assertEqual(dists.shape[0], k)
|
||||
|
@ -153,7 +105,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
dap.DiskIndex(
|
||||
metric="sandwich",
|
||||
vector_dtype=np.single,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
@ -161,28 +113,28 @@ class TestDiskIndex(unittest.TestCase):
|
|||
dap.DiskIndex(
|
||||
metric=None,
|
||||
vector_dtype=np.single,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
dap.DiskIndex(
|
||||
metric="mips",
|
||||
vector_dtype=np.single,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
dap.DiskIndex(
|
||||
metric="MiPs",
|
||||
vector_dtype=np.single,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
@ -194,7 +146,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
index = dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=aliases[dtype],
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
@ -206,14 +158,14 @@ class TestDiskIndex(unittest.TestCase):
|
|||
dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=invalid_vector_dtype,
|
||||
index_path=ann_dir,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
||||
def test_value_ranges_search(self):
|
||||
good_ranges = {"list_size": 5, "k_neighbors": 10, "beam_width": 2}
|
||||
bad_ranges = {"list_size": -1, "k_neighbors": 0, "beam_width": 0}
|
||||
good_ranges = {"complexity": 5, "k_neighbors": 10, "beam_width": 2}
|
||||
bad_ranges = {"complexity": -1, "k_neighbors": 0, "beam_width": 0}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
|
@ -222,7 +174,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
index = dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
index_path=self._example_ann_dir,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
@ -230,13 +182,13 @@ class TestDiskIndex(unittest.TestCase):
|
|||
|
||||
def test_value_ranges_batch_search(self):
|
||||
good_ranges = {
|
||||
"list_size": 5,
|
||||
"complexity": 5,
|
||||
"k_neighbors": 10,
|
||||
"beam_width": 2,
|
||||
"num_threads": 5,
|
||||
}
|
||||
bad_ranges = {
|
||||
"list_size": 0,
|
||||
"complexity": 0,
|
||||
"k_neighbors": 0,
|
||||
"beam_width": -1,
|
||||
"num_threads": -1,
|
||||
|
@ -249,7 +201,7 @@ class TestDiskIndex(unittest.TestCase):
|
|||
index = dap.DiskIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
index_path=self._example_ann_dir,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
num_nodes_to_cache=10,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
from fixtures import build_random_vectors_and_memory_index
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
|
||||
def _calculate_recall(
|
||||
result_set_tags: np.ndarray,
|
||||
original_indices_to_tags: np.ndarray,
|
||||
truth_set_indices: np.ndarray,
|
||||
recall_at: int = 5
|
||||
) -> float:
|
||||
|
||||
found = 0
|
||||
for i in range(0, result_set_tags.shape[0]):
|
||||
result_set_set = set(result_set_tags[i][0:recall_at])
|
||||
truth_set_set = set()
|
||||
for knn_index in truth_set_indices[i][0:recall_at]:
|
||||
truth_set_set.add(original_indices_to_tags[knn_index]) # mapped into our tag number instead
|
||||
found += len(result_set_set.intersection(truth_set_set))
|
||||
return found / (result_set_tags.shape[0] * recall_at)
|
||||
|
||||
|
||||
class TestDynamicMemoryIndex(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls._test_matrix = [
|
||||
build_random_vectors_and_memory_index(np.single, "l2", with_tags=True),
|
||||
build_random_vectors_and_memory_index(np.ubyte, "l2", with_tags=True),
|
||||
build_random_vectors_and_memory_index(np.byte, "l2", with_tags=True),
|
||||
]
|
||||
cls._example_ann_dir = cls._test_matrix[0][4]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
for test in cls._test_matrix:
|
||||
try:
|
||||
ann_dir = test[4]
|
||||
shutil.rmtree(ann_dir, ignore_errors=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_recall_and_batch(self):
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
generated_tags
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.DynamicMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
index.batch_insert(vectors=index_vectors, vector_ids=generated_tags)
|
||||
|
||||
k = 5
|
||||
diskann_neighbors, diskann_distances = index.batch_search(
|
||||
query_vectors,
|
||||
k_neighbors=k,
|
||||
complexity=5,
|
||||
num_threads=16,
|
||||
)
|
||||
if metric == "l2":
|
||||
knn = NearestNeighbors(
|
||||
n_neighbors=100, algorithm="auto", metric="l2"
|
||||
)
|
||||
knn.fit(index_vectors)
|
||||
knn_distances, knn_indices = knn.kneighbors(query_vectors)
|
||||
recall = _calculate_recall(diskann_neighbors, generated_tags, knn_indices, k)
|
||||
self.assertTrue(
|
||||
recall > 0.70,
|
||||
f"Recall [{recall}] was not over 0.7",
|
||||
)
|
||||
|
||||
def test_single(self):
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
generated_tags
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.DynamicMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
index.batch_insert(vectors=index_vectors, vector_ids=generated_tags)
|
||||
|
||||
k = 5
|
||||
ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5)
|
||||
self.assertEqual(ids.shape[0], k)
|
||||
self.assertEqual(dists.shape[0], k)
|
||||
|
||||
def test_valid_metric(self):
|
||||
ann_dir = self._example_ann_dir
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
with self.assertRaises(ValueError):
|
||||
dap.DynamicMemoryIndex(
|
||||
metric="sandwich",
|
||||
vector_dtype=np.single,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.DynamicMemoryIndex(
|
||||
metric=None,
|
||||
vector_dtype=np.single,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
dap.DynamicMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
dap.DynamicMemoryIndex(
|
||||
metric="mips",
|
||||
vector_dtype=np.single,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
dap.DynamicMemoryIndex(
|
||||
metric="MiPs",
|
||||
vector_dtype=np.single,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
|
||||
def test_valid_vector_dtype(self):
|
||||
aliases = {np.single: np.float32, np.byte: np.int8, np.ubyte: np.uint8}
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
generated_tags
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.DynamicMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=aliases[dtype],
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
|
||||
invalid = [np.double, np.float64, np.ulonglong, np.float16]
|
||||
for invalid_vector_dtype in invalid:
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
dap.DynamicMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=invalid_vector_dtype,
|
||||
dim=10,
|
||||
max_points=11_000,
|
||||
complexity=64,
|
||||
graph_degree=32,
|
||||
num_threads=16,
|
||||
)
|
||||
|
||||
def test_value_ranges_ctor(self):
|
||||
(
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
generated_tags
|
||||
) = build_random_vectors_and_memory_index(np.single, "l2", with_tags=True, index_prefix="not_ann")
|
||||
good_ranges = {
|
||||
"metric": "l2",
|
||||
"vector_dtype": np.single,
|
||||
"dim": 10,
|
||||
"max_points": 11_000,
|
||||
"complexity": 64,
|
||||
"graph_degree": 32,
|
||||
"max_occlusion_size": 10,
|
||||
"alpha": 1.2,
|
||||
"num_threads": 16,
|
||||
"filter_complexity": 10,
|
||||
"num_frozen_points": 10,
|
||||
"initial_search_complexity": 32,
|
||||
"search_threads": 0
|
||||
}
|
||||
|
||||
bad_ranges = {
|
||||
"metric": "l200000",
|
||||
"vector_dtype": np.double,
|
||||
"dim": -1,
|
||||
"max_points": -1,
|
||||
"complexity": 0,
|
||||
"graph_degree": 0,
|
||||
"max_occlusion_size": -1,
|
||||
"alpha": -1,
|
||||
"num_threads": -1,
|
||||
"filter_complexity": -1,
|
||||
"num_frozen_points": -1,
|
||||
"initial_search_complexity": -1,
|
||||
"search_threads": -1,
|
||||
}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError, msg=f"expected to fail with parameter {bad_value_key}={bad_ranges[bad_value_key]}"):
|
||||
index = dap.DynamicMemoryIndex(saturate_graph=False, **kwargs)
|
||||
|
||||
def test_value_ranges_search(self):
|
||||
good_ranges = {"complexity": 5, "k_neighbors": 10}
|
||||
bad_ranges = {"complexity": -1, "k_neighbors": 0}
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
index.search(query=np.array([], dtype=np.single), **kwargs)
|
||||
|
||||
def test_value_ranges_batch_search(self):
|
||||
good_ranges = {
|
||||
"complexity": 5,
|
||||
"k_neighbors": 10,
|
||||
"num_threads": 5,
|
||||
}
|
||||
bad_ranges = {
|
||||
"complexity": 0,
|
||||
"k_neighbors": 0,
|
||||
"num_threads": -1,
|
||||
}
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
index.batch_search(
|
||||
queries=np.array([[]], dtype=np.single), **kwargs
|
||||
)
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import diskannpy as dap
|
||||
import numpy as np
|
||||
from fixtures import build_random_vectors_and_memory_index, calculate_recall
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
|
||||
class TestStaticMemoryIndex(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls._test_matrix = [
|
||||
build_random_vectors_and_memory_index(np.single, "l2"),
|
||||
build_random_vectors_and_memory_index(np.ubyte, "l2"),
|
||||
build_random_vectors_and_memory_index(np.byte, "l2"),
|
||||
]
|
||||
cls._example_ann_dir = cls._test_matrix[0][4]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
for test in cls._test_matrix:
|
||||
try:
|
||||
ann_dir = test[4]
|
||||
shutil.rmtree(ann_dir, ignore_errors=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_recall_and_batch(self):
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
_
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
data_path=os.path.join(ann_dir, "vectors.bin"),
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
|
||||
k = 5
|
||||
diskann_neighbors, diskann_distances = index.batch_search(
|
||||
query_vectors,
|
||||
k_neighbors=k,
|
||||
complexity=5,
|
||||
num_threads=16,
|
||||
)
|
||||
if metric == "l2":
|
||||
knn = NearestNeighbors(
|
||||
n_neighbors=100, algorithm="auto", metric="l2"
|
||||
)
|
||||
knn.fit(index_vectors)
|
||||
knn_distances, knn_indices = knn.kneighbors(query_vectors)
|
||||
recall = calculate_recall(diskann_neighbors, knn_indices, k)
|
||||
self.assertTrue(
|
||||
recall > 0.70,
|
||||
f"Recall [{recall}] was not over 0.7",
|
||||
)
|
||||
|
||||
def test_single(self):
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
_
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=dtype,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
|
||||
k = 5
|
||||
ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5)
|
||||
self.assertEqual(ids.shape[0], k)
|
||||
self.assertEqual(dists.shape[0], k)
|
||||
|
||||
def test_valid_metric(self):
|
||||
ann_dir = self._example_ann_dir
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
with self.assertRaises(ValueError):
|
||||
dap.StaticMemoryIndex(
|
||||
metric="sandwich",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
dap.StaticMemoryIndex(
|
||||
metric=None,
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
dap.StaticMemoryIndex(
|
||||
metric="mips",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
dap.StaticMemoryIndex(
|
||||
metric="MiPs",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
|
||||
def test_valid_vector_dtype(self):
|
||||
aliases = {np.single: np.float32, np.byte: np.int8, np.ubyte: np.uint8}
|
||||
for (
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
_
|
||||
) in self._test_matrix:
|
||||
with self.subTest():
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=aliases[dtype],
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
ann_dir = self._example_ann_dir
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
invalid = [np.double, np.float64, np.ulonglong, np.float16]
|
||||
for invalid_vector_dtype in invalid:
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=invalid_vector_dtype,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
|
||||
def test_value_ranges_ctor(self):
|
||||
(
|
||||
metric,
|
||||
dtype,
|
||||
query_vectors,
|
||||
index_vectors,
|
||||
ann_dir,
|
||||
vector_bin_file,
|
||||
_
|
||||
) = build_random_vectors_and_memory_index(np.single, "l2", "not_ann")
|
||||
good_ranges = {
|
||||
"metric": "l2",
|
||||
"vector_dtype": np.single,
|
||||
"data_path": vector_bin_file,
|
||||
"index_directory": ann_dir,
|
||||
"num_threads": 16,
|
||||
"initial_search_complexity": 32,
|
||||
"index_prefix": "not_ann",
|
||||
}
|
||||
|
||||
bad_ranges = {
|
||||
"metric": "l200000",
|
||||
"vector_dtype": np.double,
|
||||
"data_path": "I do not exist.bin",
|
||||
"index_directory": "sandwiches",
|
||||
"num_threads": -100,
|
||||
"initial_search_complexity": 0,
|
||||
"index_prefix": "",
|
||||
}
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
index = dap.StaticMemoryIndex(**kwargs)
|
||||
|
||||
def test_value_ranges_search(self):
|
||||
good_ranges = {"complexity": 5, "k_neighbors": 10}
|
||||
bad_ranges = {"complexity": -1, "k_neighbors": 0}
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
index.search(query=np.array([], dtype=np.single), **kwargs)
|
||||
|
||||
def test_value_ranges_batch_search(self):
|
||||
good_ranges = {
|
||||
"complexity": 5,
|
||||
"k_neighbors": 10,
|
||||
"num_threads": 5,
|
||||
}
|
||||
bad_ranges = {
|
||||
"complexity": 0,
|
||||
"k_neighbors": 0,
|
||||
"num_threads": -1,
|
||||
}
|
||||
vector_bin_file = self._test_matrix[0][5]
|
||||
for bad_value_key in good_ranges.keys():
|
||||
kwargs = good_ranges.copy()
|
||||
kwargs[bad_value_key] = bad_ranges[bad_value_key]
|
||||
with self.subTest():
|
||||
with self.assertRaises(ValueError):
|
||||
index = dap.StaticMemoryIndex(
|
||||
metric="l2",
|
||||
vector_dtype=np.single,
|
||||
data_path=vector_bin_file,
|
||||
index_directory=self._example_ann_dir,
|
||||
num_threads=16,
|
||||
initial_search_complexity=32,
|
||||
)
|
||||
index.batch_search(
|
||||
queries=np.array([[]], dtype=np.single), **kwargs
|
||||
)
|
|
@ -1344,7 +1344,7 @@ void Index<T, TagT, LabelT>::inter_insert(uint32_t n, std::vector<uint32_t> &pru
|
|||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void Index<T, TagT, LabelT>::link(IndexWriteParameters ¶meters)
|
||||
void Index<T, TagT, LabelT>::link(const IndexWriteParameters ¶meters)
|
||||
{
|
||||
uint32_t num_threads = parameters.num_threads;
|
||||
if (num_threads != 0)
|
||||
|
@ -1577,7 +1577,8 @@ void Index<T, TagT, LabelT>::set_start_points_at_random(T radius, uint32_t rando
|
|||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void Index<T, TagT, LabelT>::build_with_data_populated(IndexWriteParameters ¶meters, const std::vector<TagT> &tags)
|
||||
void Index<T, TagT, LabelT>::build_with_data_populated(const IndexWriteParameters ¶meters,
|
||||
const std::vector<TagT> &tags)
|
||||
{
|
||||
diskann::cout << "Starting index build with " << _nd << " points... " << std::endl;
|
||||
|
||||
|
@ -1633,8 +1634,8 @@ void Index<T, TagT, LabelT>::build_with_data_populated(IndexWriteParameters &par
|
|||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_load, IndexWriteParameters ¶meters,
|
||||
const std::vector<TagT> &tags)
|
||||
void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_load,
|
||||
const IndexWriteParameters ¶meters, const std::vector<TagT> &tags)
|
||||
{
|
||||
if (num_points_to_load == 0)
|
||||
{
|
||||
|
@ -1670,7 +1671,7 @@ void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_loa
|
|||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points_to_load,
|
||||
IndexWriteParameters ¶meters, const std::vector<TagT> &tags)
|
||||
const IndexWriteParameters ¶meters, const std::vector<TagT> &tags)
|
||||
{
|
||||
std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
|
||||
if (num_points_to_load == 0)
|
||||
|
@ -1760,7 +1761,7 @@ void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points
|
|||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points_to_load,
|
||||
IndexWriteParameters ¶meters, const char *tag_filename)
|
||||
const IndexWriteParameters ¶meters, const char *tag_filename)
|
||||
{
|
||||
std::vector<TagT> tags;
|
||||
|
||||
|
|
|
@ -145,7 +145,6 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
|
|||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(500) // C = 500
|
||||
.with_alpha(alpha)
|
||||
.with_num_rounds(1)
|
||||
.with_num_threads(thread_count)
|
||||
.with_num_frozen_points(num_start_pts)
|
||||
.build();
|
||||
|
@ -367,9 +366,11 @@ int main(int argc, char **argv)
|
|||
desc.add_options()("start_deletes_after", po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
|
||||
desc.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
desc.add_options()("num_start_points", po::value<uint32_t>(&num_start_pts)->default_value(0),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
desc.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
|
|
|
@ -177,7 +177,6 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
|
|||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_rounds(1)
|
||||
.with_num_threads(insert_threads)
|
||||
.with_num_frozen_points(num_start_pts)
|
||||
.build();
|
||||
|
@ -186,7 +185,6 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
|
|||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_rounds(1)
|
||||
.with_num_threads(consolidate_threads)
|
||||
.build();
|
||||
|
||||
|
@ -320,9 +318,11 @@ int main(int argc, char **argv)
|
|||
"the window while deleting the same number from the left");
|
||||
desc.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
desc.add_options()("num_start_points", po::value<uint32_t>(&num_start_pts)->default_value(0),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
desc.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
|
|
Загрузка…
Ссылка в новой задаче