diskann code
|
@ -0,0 +1,14 @@
|
|||
# Set the default behavior, in case people don't have core.autocrlf set.
|
||||
* text=auto
|
||||
|
||||
# Explicitly declare text files you want to always be normalized and converted
|
||||
# to native line endings on checkout.
|
||||
*.c text
|
||||
*.h text
|
||||
|
||||
# Declare files that will always have CRLF line endings on checkout.
|
||||
*.sln text eol=crlf
|
||||
|
||||
# Denote all files that are truly binary and should not be modified.
|
||||
*.png binary
|
||||
*.jpg binary
|
|
@ -0,0 +1,10 @@
|
|||
build/*
|
||||
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
|
||||
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
|
||||
/vcproj/test_recall/test_recall.vcxproj.user
|
||||
/.vs
|
||||
/data/SIFT1M
|
||||
/out/build/x64-Debug
|
||||
/src/#index_nsg.cpp#
|
||||
cscope*
|
||||
|
|
@ -1,17 +1,102 @@
|
|||
cmake_minimum_required(VERSION 2.8)
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
if(MSVC)
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
else()
|
||||
cmake_minimum_required(VERSION 2.8)
|
||||
endif()
|
||||
set(CMAKE_STANDARD 14)
|
||||
|
||||
|
||||
#set(CMAKE_USER_MAKE_RULES_OVERRIDE "${CMAKE_CURRENT_LIST_DIR}/CompilerOptions.cmake")
|
||||
if(MSVC)
|
||||
set(CMAKE_CXX_COMPILER $ENV{VCToolsInstallDir}/bin/Hostx64/x64/cl.exe)
|
||||
set(CMAKE_CXX_LINK_EXECUTABLE $ENV{VCToolsInstallDir}/bin/Hostx64/x64/link.exe)
|
||||
else()
|
||||
set(CMAKE_CXX_COMPILER g++)
|
||||
endif()
|
||||
|
||||
project(diskann)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include/dll)
|
||||
|
||||
project(efanna2e)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
#OpenMP
|
||||
find_package(OpenMP)
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
else()
|
||||
message(FATAL_ERROR "no OpenMP supprot")
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
|
||||
add_definitions (-std=c++11 -O3 -lboost -march=native -Wall -DINFO)
|
||||
|
||||
function(checkEnvAndSetLocalVar env_var msg local_var)
|
||||
if (NOT EXISTS "$ENV{${env_var}}" )
|
||||
message (FATAL_ERROR ${msg})
|
||||
else()
|
||||
if ($ENV{${env_var}} MATCHES "\\$" OR $ENV{${env_var}} MATCHES "/$" )
|
||||
set(${local_var} $ENV{${env_var}} PARENT_SCOPE)
|
||||
else()
|
||||
message(STATUS "Appending trailing backslash to ${env_var}")
|
||||
set(${local_var} "$ENV{${env_var}}\\" PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
|
||||
#MKL Config
|
||||
if (MSVC)
|
||||
checkEnvAndSetLocalVar("INTEL_ROOT" "Please install Intel MKL libraries and set the env variable INTEL_ROOT to the intel software directory. Should be similar to: C:\\Program Files (x86)\\IntelSWTools\\compilers_and_libraries\\windows\\. " "INTEL_ROOT")
|
||||
set(MKL_ROOT ${INTEL_ROOT}/mkl)
|
||||
add_compile_options(/arch:AVX2 /Qpar)
|
||||
link_libraries("${INTEL_ROOT}/mkl/lib/intel64/mkl_core_dll.lib" "${INTEL_ROOT}/mkl/lib/intel64/mkl_rt.lib" "${INTEL_ROOT}/mkl/lib/intel64/mkl_intel_thread_dll.lib" "${INTEL_ROOT}/compiler/lib/intel64/libiomp5md.lib" "${INTEL_ROOT}/mkl/lib/intel64/mkl_intel_ilp64_dll.lib" "${INTEL_ROOT}/mkl/lib/intel64/mkl_sequential_dll.lib")
|
||||
|
||||
else()
|
||||
set(INTEL_ROOT /opt/intel/compilers_and_libraries/linux)
|
||||
set(MKL_ROOT ${INTEL_ROOT}/mkl)
|
||||
add_compile_options(-m64 -Wl,--no-as-needed)
|
||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
|
||||
link_directories(${INTEL_ROOT}/lib/intel64 ${MKL_ROOT}/lib/intel64)
|
||||
endif()
|
||||
|
||||
add_definitions(-DMKL_ILP64)
|
||||
include_directories(include ${INTEL_ROOT}/include ${MKL_ROOT}/include)
|
||||
|
||||
|
||||
#Main compiler/linker settings
|
||||
if(MSVC)
|
||||
#language options
|
||||
add_compile_options(/permissive- /openmp:experimental /Zc:wchar_t /Zc:twoPhase- /Zc:forScope /Zc:inline /WX- /std:c++14 /Gd /W3 /MP /Zi /FC /nologo /diagnostics:classic)
|
||||
#code generation options
|
||||
add_compile_options(/Qpar /fp:fast /Zp8 /fp:except- /EHsc /GS- /Gm- /Gy )
|
||||
#optimization options
|
||||
add_compile_options(/Ot /Oy /Oi)
|
||||
#path options
|
||||
#add_compile_options(/Fdx64/Release/vc141.pdb /Fox64/Release/)
|
||||
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
|
||||
|
||||
set(CMAKE_SHARED_LIBRARY_CXX_LINK_FLAGS "/MANIFEST /MACHINE:X64 /DEBUG:FULL /LTCG:incremental /NXCOMPAT /DYNAMICBASE /OPT:REF /SUBSYSTEM:CONSOLE /MANIFESTUAC:\"level='asInvoker' uiAccess='false'\"")
|
||||
set(CMAKE_EXECUTABLE_CXX_LINK_FLAGS "/MANIFEST /MACHINE:X64 /DEBUG:FULL /LTCG:incremental /NXCOMPAT /DYNAMICBASE /OPT:REF /SUBSYSTEM:CONSOLE /MANIFESTUAC:\"level='asInvoker' uiAccess='false'\"")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_DEBUG")
|
||||
set(CMAKE_SHARED_LIBRARY_CXX_LINK_FLAGS_DEBUG "${CMAKE_SHARED_LIBRARY_CXX_LINK_FLAGS_DEBUG} /DEBUG")
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
else()
|
||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
||||
# set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -O0 -fsanitize=address -fsanitize=leak -fsanitize=undefined")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -march=native -mtune=native -ftree-vectorize")
|
||||
add_compile_options(-march=native -Wall -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_ACCELERATED_PQ -DUSE_AVX2)
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(tests)
|
||||
add_subdirectory(tests/utils)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
@ -0,0 +1,9 @@
|
|||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
|
@ -0,0 +1,21 @@
|
|||
if(MSVC)
|
||||
#changing default target to X64
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_EXE_LINKER_FLAGS_INIT "${CMAKE_EXE_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_MODULE_LINKER_FLAGS_INIT "${CMAKE_MODULE_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_SHARED_LINKER_FLAGS_INIT "${CMAKE_SHARED_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_STATIC_LINKER_FLAGS_INIT "${CMAKE_STATIC_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_EXE_LINKER_FLAGS_INIT "${CMAKE_EXE_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_MODULE_LINKER_FLAGS_INIT "${CMAKE_MODULE_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_SHARED_LINKER_FLAGS_INIT "${CMAKE_SHARED_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "/[M|m][A|a][C|c][H|h][I|i][N|n][E|e]:[X|x]86" "/MACHINE:X64" CMAKE_STATIC_LINKER_FLAGS_INIT "${CMAKE_STATIC_LINKER_FLAGS_INIT}")
|
||||
string(REGEX REPLACE "Debug" "Release" CMAKE_BUILD_TYPE_INIT "${CMAKE_BUILD_TYPE_INIT}")
|
||||
endif()
|
||||
|
||||
|
||||
get_cmake_property(_varNames VARIABLES)
|
||||
list (REMOVE_DUPLICATES _varNames)
|
||||
list (SORT _varNames)
|
||||
foreach (_varName ${_varNames})
|
||||
message(STATUS "${_varName}=${${_varName}}")
|
||||
endforeach()
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
FROM ubuntu:16.04
|
||||
MAINTAINER Changxu Wang <wang_changxu@zju.edu.cn>
|
||||
|
||||
RUN apt-get update -y
|
||||
RUN apt-get install -y g++ cmake libboost-dev libgoogle-perftools-dev
|
||||
|
||||
COPY . /opt/nsg
|
||||
|
||||
WORKDIR /opt/nsg
|
||||
|
||||
RUN mkdir -p build && cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release .. && \
|
||||
make -j $(nproc)
|
|
@ -0,0 +1,23 @@
|
|||
DiskANN
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
|
@ -0,0 +1,23 @@
|
|||
This algorithms builds upon [code for NSG](https://github.com/ZJULearning/nsg), commit: 335e8e, licensed under the following terms.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Cong Fu, Changxu Wang, Deng Cai
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
215
README.md
|
@ -1,162 +1,117 @@
|
|||
NSG : Navigating Spread-out Graph For Approximate Nearest Neighbor Search
|
||||
======
|
||||
NSG is a graph-based approximate nearest neighbor search (ANNS) algorithm. It provides a flexible and efficient solution for the metric-free large-scale ANNS on dense real vectors. It implements the algorithm of our paper, [Fast Approximate Nearest Neighbor Search With Navigating Spread-out Graphs.](https://arxiv.org/abs/1707.00143)
|
||||
NSG has been intergrated into the search engine of Taobao (Alibaba Group) for billion scale ANNS in E-commerce scenario.
|
||||
# DiskANN
|
||||
|
||||
Benchmark data set
|
||||
------
|
||||
* [SIFT1M and GIST1M](http://corpus-texmex.irisa.fr/)
|
||||
* Synthetic data set: RAND4M and GAUSS5M
|
||||
* RAND4M: 4 million 128-dimension vectors sampled from a uniform distribution of [-1, 1].
|
||||
* GAUSS5M: 5 million 128-dimension vectors sampled from a gaussion ditribution N(0,3).
|
||||
The goal of the project is to build scalable, performant and cost-effective approximate nearest neighbor search algorithms.
|
||||
The initial release has the in-memory version of the [DiskANN paper](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf) published in NeurIPS 2019. The SSD based index will be released later.
|
||||
This code reuses and builds upon some of the [code for NSG](https://github.com/ZJULearning/nsg) algoritm.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
See [guidelines](CONTRIBUTING.md) for contributing to this project.
|
||||
|
||||
|
||||
ANNS performance
|
||||
------
|
||||
|
||||
**Compared Algorithms:**
|
||||
##Linux build:
|
||||
|
||||
Graph-based ANNS algorithms:
|
||||
* [kGraph](http://www.kgraph.org)
|
||||
* [FANNG](https://pdfs.semanticscholar.org/9ea6/5687a21c869fce7ecf17ca25ffcadbf77d69.pdf) : FANNG: Fast Approximate Nearest Neighbour Graphs
|
||||
* [HNSW:code](https://github.com/searchivarius/nmslib), [paper](https://arxiv.org/abs/1603.09320) : Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs
|
||||
* [DPG:code](https://github.com/DBWangGroupUNSW/nns_benchmark), [paper](https://arxiv.org/abs/1610.02455) : Approximate Nearest Neighbor Search on High Dimensional Data --- Experiments, Analyses, and Improvement (v1.0)
|
||||
* [Efanna:code](https://github.com/fc731097343/efanna), [paper](https://arxiv.org/abs/1609.07228) : EFANNA : An Extremely Fast Approximate Nearest Neighbor Search Algorithm Based on kNN Graph
|
||||
* NSG-naive: a designed based-line, please refer to [our paper](https://arxiv.org/abs/1707.00143).
|
||||
* NSG: This project, please refer to [our paper](https://arxiv.org/abs/1707.00143).
|
||||
Install the following packages through apt-get, and Intel MKL either by downloading the installer or using [apt](https://software.intel.com/en-us/articles/installing-intel-free-libs-and-python-apt-repo) (we tested with build 2019.4-070).
|
||||
```
|
||||
sudo apt install cmake g++ libaio-dev libgoogle-perftools-dev clang-format-4.0
|
||||
```
|
||||
|
||||
Other popular ANNS algorithms
|
||||
* [flann](http://www.cs.ubc.ca/research/flann/)
|
||||
* [FALCONN](https://github.com/FALCONN-LIB/FALCONN)
|
||||
* [Annoy](https://github.com/spotify/annoy)
|
||||
* [Faiss](https://github.com/facebookresearch/faiss)
|
||||
Build
|
||||
```
|
||||
mkdir build && cd build && cmake .. && make -j
|
||||
```
|
||||
|
||||
The performance was tested without parallelism.
|
||||
NSG achieved the **best** search performance among all the compared algorithms on all the four datasets.
|
||||
Among all the ***graph-based algorithms***, NSG has ***the smallest index size*** and ***the best search performance***.
|
||||
##Windows build:
|
||||
|
||||
The Windows version has been tested with the Enterprise editions of Visual Studio 2017 and Visual Studio 2019. It should work with the Community and Professional editions as well without any changes.
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* Install CMAKE (v3.15.2 or later) from https://cmake.org
|
||||
* Install MKL from https://software.intel.com/en-us/mkl
|
||||
|
||||
* Environment variables:
|
||||
* Set a new System environment variable, called INTEL_ROOT to the "windows" folder under your MKL installation
|
||||
(For instance, if your install folder is "C:\Program Files (x86)\IntelSWtools", set INTEL_ROOT to "C:\Program Files (x86)\IntelSWtools\compilers_and_libraries\windows")
|
||||
|
||||
**Build steps:**
|
||||
- Open a new command prompt window
|
||||
- Create a "build" directory under diskann
|
||||
- Change to the "build" directory and run
|
||||
```
|
||||
<full-path-to-cmake>\cmake -G "Visual Studio 16 2019" -B. -A x64 ..
|
||||
```
|
||||
OR
|
||||
```
|
||||
<full-path-to-cmake>\cmake -G "Visual Studio 15 2017" -B. -A x64 ..
|
||||
```
|
||||
|
||||
**Note: Since VS comes with its own (older) version of cmake, you have to specify the full path to cmake to ensure that the right version is used.**
|
||||
- This will create a “diskann” solution file in the "build" directory
|
||||
- Open the "diskann" solution and build the “diskann” project.
|
||||
- Then build all the other binaries using the ALL_BUILD project that is part of the solution
|
||||
- Generated binaries are stored in the diskann/x64/Debug or diskann/x64/Release directories.
|
||||
|
||||
To build from command line, change to the "build" directory and use msbuild to first build the "diskann" project. And then build the entire solution, as shown below.
|
||||
```
|
||||
msbuild src\dll\diskann.vcxproj
|
||||
msbuild diskann.sln
|
||||
```
|
||||
Check msbuild docs for additional options including choosing between debug and release builds.
|
||||
|
||||
|
||||
**SIFT1M-100NN-All-Algorithms**
|
||||
##Usage:
|
||||
|
||||
![SIFT1M-100NN-All-Algorithms](figures/siftall.png)
|
||||
|
||||
**SIFT1M-100NN-Graphs-Only**
|
||||
**Usage for in-memory indices**
|
||||
================================
|
||||
|
||||
![SIFT1M-100NN-Graphs-Only](figures/sift_graph.png)
|
||||
To generate index, use the `tests/build_memory_index` program.
|
||||
--------------------------------------------------------------
|
||||
|
||||
**GIST1M-100NN-All-Algorithms**
|
||||
```
|
||||
./tests/build_memory_index [data_type<int8/uint8/float>] [data_file.bin] [output_index_file] [R] [L] [alpha] [num_threads_to_use]
|
||||
```
|
||||
|
||||
![GIST1M-100NN-All-Algorithms](figures/gistall.png)
|
||||
The arguments are as follows:
|
||||
|
||||
**GIST1M-100NN-Graphs-Only**
|
||||
(i) data_type: same as (i) above in building disk index.
|
||||
|
||||
![GIST1M-100NN-Graphs-Only](figures/gist_graph.png)
|
||||
(ii) data_file: same as (ii) above in building disk index, the input data file in .bin format of type int8/uint8/float.
|
||||
|
||||
**RAND4M-100NN-All-Algorithms**
|
||||
(iii) output_index_file: memory index will be saved here.
|
||||
|
||||
![RAND4M-100NN-All-Algorithms](figures/randall.png)
|
||||
(iv) R: max degree of index: larger is typically better, range (50-150). Preferrably ensure that L is at least R.
|
||||
|
||||
**RAND4M-100NN-Graphs-Only**
|
||||
(v) L: candidate_list_size for building index, larger is better (typical range: 75 to 200)
|
||||
|
||||
![RAND4M-100NN-Graphs-Only](figures/rand_graph.png)
|
||||
(vi) alpha: float value which determines how dense our overall graph will be, and diameter will be log of n base alpha (roughly). Typical values are between 1 to 1.5. 1 will yield sparsest graph, 1.5 will yield denser graphs.
|
||||
|
||||
**GAUSS5M-100NN-All-Algorithms**
|
||||
(vii) number of threads to use: indexing uses specified number of threads.
|
||||
|
||||
![GAUSS5M-100NN-All-Algorithms](figures/gaussall.png)
|
||||
|
||||
**GAUSS5M-100NN-Graphs-Only**
|
||||
To search the generated index, use the `tests/search_memory_index` program:
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
![GAUSS5M-100NN-Graphs-Only](figures/gauss_graph.png)
|
||||
```
|
||||
./tests/search_memory_index [index_type<float/int8/uint8>] [data_file.bin] [memory_index_path] [query_file.bin] [truthset.bin (use "null" for none)] [K] [result_output_prefix] [L1] [L2] etc.
|
||||
```
|
||||
|
||||
How to use
|
||||
------
|
||||
1. Compile
|
||||
Prerequisite : openmp, cmake, boost
|
||||
Compile:
|
||||
a) Go to the root directory of faiss, it's under the directory of extern_libraries aside of ours.
|
||||
b)
|
||||
|
||||
$ cd nsg/
|
||||
$ cmake .
|
||||
$ make
|
||||
|
||||
|
||||
2. Usage
|
||||
The main interfaces and classes have its respective test codes under directory tests/
|
||||
Temporarilly several essential functions have been implemented. To use my algorithm, you should first build an index. It takes several steps as below:
|
||||
|
||||
**a) Build a kNN graph**
|
||||
|
||||
You can use [efanna\_graph](https://github.com/ZJULearning/efanna\_graph) to build the kNN graph, or you can build the kNN graph by yourself.
|
||||
|
||||
**b)Convert a kNN graph to a NSG**
|
||||
For example:
|
||||
```
|
||||
$ cd tests/
|
||||
$ ./test_nsg_index data_path nn_graph_path L R save_graph_file
|
||||
```
|
||||
**data\_path** is the path of the origin data.
|
||||
**nn\_graph\_path** is the path of the pre-built kNN graph.
|
||||
**L** controls the quality of the NSG, the larger the better, L > R.
|
||||
**R** controls the index size of the graph, the best R is related to the intrinsic dimension of the dataset.
|
||||
|
||||
**c) Use NSG for search**
|
||||
For example:
|
||||
```
|
||||
$ cd tests/
|
||||
$ ./test_nsg_optimized_search data_path query_path nsg_path search_L search_K result_path
|
||||
```
|
||||
**data\_path** is the path of the origin data.
|
||||
**query\_path** is the path of the query data.
|
||||
**nsg\_path** is the path of the pre-built NSG.
|
||||
**search\_L** controls the quality of the search results, the larger the better but slower. The **search_L** cannot be samller than the **search_K**
|
||||
**search\_K** controls the number of neighbors we want to find.
|
||||
The arguments are as follows:
|
||||
|
||||
For now, we only provide interface for search for only one query at a time, and test the performance with single thread.
|
||||
There is another program in tests folder which is test_nsg_search. The parameters of test_nsg_search are exactly same as test_nsg_optimized_search.
|
||||
test_nsg_search is slower than test_nsg_optimized_search but requires less memory. In the situations memory consumption is extremely important, one can use test_nsg_search instead of test_nsg_optimized_search.
|
||||
(i) data type: same as (i) above in building index.
|
||||
|
||||
Note
|
||||
------
|
||||
**The data\_align()** function we provided is essential for the correctness of our procedure, because we use SIMD instructions for acceleration of numerical computing such as AVX and SSE2.
|
||||
You should use it to ensure your data elements (feature) is aligned with 8 or 16 int or float.
|
||||
For example, if your features are of dimension 70, then it should be extend to dimension 72. And the last 2 dimension should be filled with 0 to ensure the correctness of the distance computing. And this is what data\_align() does.
|
||||
(ii) memory_index_path: enter path of index built (argument (iii) above in building memory index).
|
||||
|
||||
Only data-type int32 and float32 are supported for now.
|
||||
(iii) query_bin: search on these queries, same format as data file (ii) above. The query file must be the same type as specified in (i).
|
||||
|
||||
Input of NSG
|
||||
------
|
||||
Because there is no unified format for input data, users may need to write input function to read your own data. You may imitate the input function in our sample code in the *tests/* directory to load the data.
|
||||
(iv) Truthset file. Must be in the following format: n, the number of queries (4 bytes) followed by d, the number of ground truth elements per query (4 bytes), followed by n*d entries per query representing the d closest IDs per query in integer format, followed by n*d entries representing the corresponding distances (float). Total file size is 8 + 4*n*d + 4*n*d. The groundtruth file, if not available, can be calculated using our program, tests/utils/compute_groundtruth.
|
||||
|
||||
Output of NSG
|
||||
------
|
||||
The output format of the search results follows the same format of the **fvecs** in [SIFT1M](http://corpus-texmex.irisa.fr/)
|
||||
(v) K: search for recall@k, meaning accuracy of retrieving top-k nearest neighbors.
|
||||
|
||||
Parameters to get the index in Fig. 4/5 in [our paper](https://arxiv.org/abs/1707.00143). (We use [efanna_graph](https://github.com/ZJULearning/efanna_graph) to build the kNN graph)
|
||||
------
|
||||
|
||||
$efanna_graph/tests/test_nndescent sift.fvecs sift.50nngraph 50 70 8 10 100
|
||||
$nsg/tests/test_nsg_index sift.fvecs sift.50nngraph 90 40 sift.nsg
|
||||
$efanna_graph/tests/test_nndescent gist.fvecs gist.100nngraph 100 120 10 15 100
|
||||
$nsg/tests/test_nsg_index gist.fvecs gist.100nngraph 150 70 gist.nsg
|
||||
|
||||
For RAND4M and GAUSS5M, we build the kNN graph with Faiss for efficiency.
|
||||
Here, we use nn-descent to build the kNN Graph. If it cannot a good-quality graph (accuracy > 90%), you may turn to other solutions, such as Faiss or Efanna.
|
||||
|
||||
|
||||
$nsg/tests/test_nsg_index rand4m.fvecs rand4m.200nngraph 400 200 rand4m.nsg
|
||||
$nsg/tests/test_nsg_index gauss5m.fvecs gauss5m.200nngraph 500 200 gauss5m.nsg
|
||||
|
||||
|
||||
Performance on Taobao E-commerce data
|
||||
------
|
||||
**Environments:**
|
||||
Xeon E5-2630.
|
||||
**Single thread test:**
|
||||
Dataset: 10,000,000 128-dimension vectors.
|
||||
Latency: 1ms (average) on 10,000 query.
|
||||
**Distributed search test:**
|
||||
Dataset: 45,000,000 128-dimension vectors.
|
||||
Distribute: randomly divide the dataset into 12 subsets and build 12 NSGs. Search in parallel and merge results.
|
||||
Latency: 1ms (average) on 10,000 query.
|
||||
(vi) result output prefix: will search and store the computed results in the files with specified prefix in bin format.
|
||||
|
||||
(vii, viii, ...) various search_list sizes to perform search with. Larger will result in slower latencies, but higher accuracies. Must be atleast the recall@ value in (vi).
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
@ -0,0 +1,16 @@
|
|||
(1) Building disk index by merging shards currently needs a lot of working disk space. For each shard, we need to store the index; then we merge them all using merge_shards program to get a merged memory index, and then we create a disk index using this merged memory index. We should have a direct program which adds a new shard to the final disk index (and also perhaps parallelize this code to speed it up).
|
||||
|
||||
(2) Add the uint8 distance AVX2.
|
||||
|
||||
(3) Do a thorough walk through of code to detect possible crashes due to corner cases
|
||||
|
||||
(4) Connect up index if disconnected by just linking start points or rebuilding an index on excluded points
|
||||
|
||||
(5) Eliminate parameters all together? It is giving warnings in linux compiler.
|
||||
|
||||
(6) Add indexing capability with cosine similarity
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: Google
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: true
|
||||
AlignEscapedNewlinesLeft: true
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: None
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterClass: true
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: true
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IndentCaseLabels: true
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: All
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Right
|
||||
ReflowComments: true
|
||||
SortIncludes: false
|
||||
SpaceAfterCStyleCast: true
|
||||
SpaceAfterTemplateKeyword: false
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 4
|
||||
UseTab: Never
|
||||
...
|
Двоичные данные
figures/gauss_graph.png
До Ширина: | Высота: | Размер: 241 KiB |
Двоичные данные
figures/gaussall.png
До Ширина: | Высота: | Размер: 277 KiB |
Двоичные данные
figures/gist_graph.png
До Ширина: | Высота: | Размер: 233 KiB |
Двоичные данные
figures/gistall.png
До Ширина: | Высота: | Размер: 301 KiB |
Двоичные данные
figures/rand_graph.png
До Ширина: | Высота: | Размер: 252 KiB |
Двоичные данные
figures/randall.png
До Ширина: | Высота: | Размер: 250 KiB |
Двоичные данные
figures/sift_graph.png
До Ширина: | Высота: | Размер: 218 KiB |
Двоичные данные
figures/siftall.png
До Ширина: | Высота: | Размер: 285 KiB |
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define MAX_IO_DEPTH 128
|
||||
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <fcntl.h>
|
||||
#include <libaio.h>
|
||||
#include <unistd.h>
|
||||
typedef io_context_t IOContext;
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#include <minwinbase.h>
|
||||
|
||||
#ifndef USE_BING_INFRA
|
||||
typedef struct {
|
||||
HANDLE fhandle = NULL;
|
||||
HANDLE iocp = NULL;
|
||||
std::vector<OVERLAPPED> reqs;
|
||||
} IOContext;
|
||||
#else
|
||||
#include "IDiskPriorityIO.h"
|
||||
#include <atomic>
|
||||
// TODO: Caller code is very callous about copying IOContext objects
|
||||
// all over the place. MUST verify that it won't cause leaks/logical
|
||||
// errors.
|
||||
// Because of such callous copying, we have to use ptr->atomic instead
|
||||
// of atomic, as atomic is not copyable.
|
||||
struct IOContext {
|
||||
enum Status { READ_WAIT = 0, READ_SUCCESS, READ_FAILED, PROCESS_COMPLETE };
|
||||
|
||||
std::shared_ptr<ANNIndex::IDiskPriorityIO> m_pDiskIO = nullptr;
|
||||
std::shared_ptr<std::vector<ANNIndex::AsyncReadRequest>> m_pRequests;
|
||||
std::shared_ptr<std::vector<Status>> m_pRequestsStatus;
|
||||
|
||||
IOContext()
|
||||
: m_pRequestsStatus(new std::vector<Status>()),
|
||||
m_pRequests(new std::vector<ANNIndex::AsyncReadRequest>()) {
|
||||
(*m_pRequestsStatus).reserve(MAX_IO_DEPTH);
|
||||
(*m_pRequests).reserve(MAX_IO_DEPTH);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#include <malloc.h>
|
||||
#include <cstdio>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include "tsl/robin_map.h"
|
||||
#include "utils.h"
|
||||
|
||||
// NOTE :: all 3 fields must be 512-aligned
|
||||
struct AlignedRead {
|
||||
uint64_t offset; // where to read from
|
||||
uint64_t len; // how much to read
|
||||
void* buf; // where to read into
|
||||
|
||||
AlignedRead() : offset(0), len(0), buf(nullptr) {
|
||||
}
|
||||
|
||||
AlignedRead(uint64_t offset, uint64_t len, void* buf)
|
||||
: offset(offset), len(len), buf(buf) {
|
||||
assert(IS_512_ALIGNED(offset));
|
||||
assert(IS_512_ALIGNED(len));
|
||||
assert(IS_512_ALIGNED(buf));
|
||||
// assert(malloc_usable_size(buf) >= len);
|
||||
}
|
||||
};
|
||||
|
||||
class AlignedFileReader {
|
||||
protected:
|
||||
tsl::robin_map<std::thread::id, IOContext> ctx_map;
|
||||
std::mutex ctx_mut;
|
||||
|
||||
public:
|
||||
// returns the thread-specific context
|
||||
// returns (io_context_t)(-1) if thread is not registered
|
||||
virtual IOContext& get_ctx() = 0;
|
||||
|
||||
virtual ~AlignedFileReader(){};
|
||||
|
||||
// register thread-id for a context
|
||||
virtual void register_thread() = 0;
|
||||
// de-register thread-id for a context
|
||||
virtual void deregister_thread() = 0;
|
||||
|
||||
// Open & close ops
|
||||
// Blocking calls
|
||||
virtual void open(const std::string& fname) = 0;
|
||||
virtual void close() = 0;
|
||||
|
||||
// process batch of aligned requests in parallel
|
||||
// NOTE :: blocking call
|
||||
virtual void read(std::vector<AlignedRead>& read_reqs, IOContext& ctx,
|
||||
bool async = false) = 0;
|
||||
};
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "windows_customizations.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#define __FUNCSIG__ __PRETTY_FUNCTION__
|
||||
#endif
|
||||
|
||||
namespace diskann {
|
||||
class ANNException {
|
||||
public:
|
||||
DISKANN_DLLEXPORT ANNException(const std::string& message, int errorCode);
|
||||
DISKANN_DLLEXPORT ANNException(const std::string& message, int errorCode,
|
||||
const std::string& funcSig,
|
||||
const std::string& fileName,
|
||||
unsigned int lineNum);
|
||||
|
||||
DISKANN_DLLEXPORT std::string message() const;
|
||||
|
||||
private:
|
||||
int _errorCode;
|
||||
std::string _message;
|
||||
std::string _funcSig;
|
||||
std::string _fileName;
|
||||
unsigned int _lineNum;
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <fcntl.h>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#ifdef __APPLE__
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <Windows.h>
|
||||
typedef HANDLE FileHandle;
|
||||
#else
|
||||
#include <unistd.h>
|
||||
typedef int FileHandle;
|
||||
#endif
|
||||
|
||||
#include "cached_io.h"
|
||||
#include "common_includes.h"
|
||||
#include "utils.h"
|
||||
#include "windows_customizations.h"
|
||||
|
||||
namespace diskann {
|
||||
const size_t TRAINING_SET_SIZE = 1500000;
|
||||
const double SPACE_FOR_CACHED_NODES_IN_GB = 0.25;
|
||||
const double THRESHOLD_FOR_CACHING_IN_GB = 1.0;
|
||||
const uint32_t NUM_NODES_TO_CACHE = 250000;
|
||||
const uint32_t WARMUP_L = 20;
|
||||
|
||||
template<typename T>
|
||||
class PQFlashIndex;
|
||||
|
||||
DISKANN_DLLEXPORT double calculate_recall(
|
||||
unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs,
|
||||
unsigned *our_results, unsigned dim_or, unsigned recall_at);
|
||||
|
||||
DISKANN_DLLEXPORT void read_idmap(const std::string & fname,
|
||||
std::vector<unsigned> &ivecs);
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT T *load_warmup(MemoryMappedFiles &files,
|
||||
const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim,
|
||||
uint64_t warmup_aligned_dim);
|
||||
#else
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim,
|
||||
uint64_t warmup_aligned_dim);
|
||||
#endif
|
||||
|
||||
DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix,
|
||||
const std::string &vamana_suffix,
|
||||
const std::string &idmaps_prefix,
|
||||
const std::string &idmaps_suffix,
|
||||
const _u64 nshards, unsigned max_degree,
|
||||
const std::string &output_vamana,
|
||||
const std::string &medoids_file);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT int build_merged_vamana_index(
|
||||
std::string base_file, diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_file,
|
||||
std::string centroids_file);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT uint32_t optimize_beamwidth(
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> &_pFlashIndex, T *tuning_sample,
|
||||
_u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
|
||||
uint32_t nthreads, uint32_t start_bw = 2);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT bool build_disk_index(const char * dataFilePath,
|
||||
const char * indexFilePath,
|
||||
const char * indexBuildParameters,
|
||||
diskann::Metric _compareMetric);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file,
|
||||
const std::string mem_index_file,
|
||||
const std::string output_file);
|
||||
|
||||
} // namespace diskann
|
|
@ -0,0 +1,167 @@
|
|||
#pragma once
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "logger.h"
|
||||
#include "ann_exception.h"
|
||||
|
||||
// sequential cached reads
|
||||
class cached_ifstream {
|
||||
public:
|
||||
cached_ifstream() {
|
||||
}
|
||||
cached_ifstream(const std::string& filename, uint64_t cacheSize)
|
||||
: cache_size(cacheSize), cur_off(0) {
|
||||
this->open(filename, cache_size);
|
||||
}
|
||||
~cached_ifstream() {
|
||||
delete[] cache_buf;
|
||||
reader.close();
|
||||
}
|
||||
|
||||
void open(const std::string& filename, uint64_t cacheSize) {
|
||||
this->cur_off = 0;
|
||||
reader.open(filename, std::ios::binary | std::ios::ate);
|
||||
fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
assert(reader.is_open());
|
||||
assert(cacheSize > 0);
|
||||
cacheSize = (std::min)(cacheSize, fsize);
|
||||
this->cache_size = cacheSize;
|
||||
cache_buf = new char[cacheSize];
|
||||
reader.read(cache_buf, cacheSize);
|
||||
diskann::cout << "Opened: " << filename.c_str() << ", size: " << fsize
|
||||
<< ", cache_size: " << cacheSize << std::endl;
|
||||
}
|
||||
|
||||
size_t get_file_size() {
|
||||
return fsize;
|
||||
}
|
||||
void read(char* read_buf, uint64_t n_bytes) {
|
||||
assert(cache_buf != nullptr);
|
||||
assert(read_buf != nullptr);
|
||||
if (n_bytes <= (cache_size - cur_off)) {
|
||||
// case 1: cache contains all data
|
||||
memcpy(read_buf, cache_buf + cur_off, n_bytes);
|
||||
cur_off += n_bytes;
|
||||
} else {
|
||||
// case 2: cache contains some data
|
||||
uint64_t cached_bytes = cache_size - cur_off;
|
||||
if (n_bytes - cached_bytes > fsize - reader.tellg()) {
|
||||
std::stringstream stream;
|
||||
stream << "Reading beyond end of file" << std::endl;
|
||||
stream << "n_bytes: " << n_bytes << " cached_bytes: " << cached_bytes
|
||||
<< " fsize: " << fsize << " current pos:" << reader.tellg()
|
||||
<< std::endl;
|
||||
diskann::cout << stream.str() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
memcpy(read_buf, cache_buf + cur_off, cached_bytes);
|
||||
|
||||
// go to disk and fetch more data
|
||||
reader.read(read_buf + cached_bytes, n_bytes - cached_bytes);
|
||||
// reset cur off
|
||||
cur_off = cache_size;
|
||||
|
||||
uint64_t size_left = fsize - reader.tellg();
|
||||
|
||||
if (size_left >= cache_size) {
|
||||
reader.read(cache_buf, cache_size);
|
||||
cur_off = 0;
|
||||
}
|
||||
|
||||
// note that if size_left < cache_size, then cur_off = cache_size, so
|
||||
// subsequent reads will all be directly from file
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// underlying ifstream
|
||||
std::ifstream reader;
|
||||
// # bytes to cache in one shot read
|
||||
uint64_t cache_size = 0;
|
||||
// underlying buf for cache
|
||||
char* cache_buf = nullptr;
|
||||
// offset into cache_buf for cur_pos
|
||||
uint64_t cur_off = 0;
|
||||
// file size
|
||||
uint64_t fsize = 0;
|
||||
};
|
||||
|
||||
// sequential cached writes
|
||||
class cached_ofstream {
|
||||
public:
|
||||
cached_ofstream(const std::string& filename, uint64_t cache_size)
|
||||
: cache_size(cache_size), cur_off(0) {
|
||||
writer.open(filename, std::ios::binary);
|
||||
assert(writer.is_open());
|
||||
assert(cache_size > 0);
|
||||
cache_buf = new char[cache_size];
|
||||
diskann::cout << "Opened: " << filename.c_str()
|
||||
<< ", cache_size: " << cache_size << std::endl;
|
||||
}
|
||||
|
||||
~cached_ofstream() {
|
||||
// dump any remaining data in memory
|
||||
if (cur_off > 0) {
|
||||
this->flush_cache();
|
||||
}
|
||||
|
||||
delete[] cache_buf;
|
||||
writer.close();
|
||||
diskann::cout << "Finished writing " << fsize << "B" << std::endl;
|
||||
}
|
||||
|
||||
size_t get_file_size() {
|
||||
return fsize;
|
||||
}
|
||||
// writes n_bytes from write_buf to the underlying ofstream/cache
|
||||
void write(char* write_buf, uint64_t n_bytes) {
|
||||
assert(cache_buf != nullptr);
|
||||
if (n_bytes <= (cache_size - cur_off)) {
|
||||
// case 1: cache can take all data
|
||||
memcpy(cache_buf + cur_off, write_buf, n_bytes);
|
||||
cur_off += n_bytes;
|
||||
} else {
|
||||
// case 2: cache cant take all data
|
||||
// go to disk and write existing cache data
|
||||
writer.write(cache_buf, cur_off);
|
||||
fsize += cur_off;
|
||||
// write the new data to disk
|
||||
writer.write(write_buf, n_bytes);
|
||||
fsize += n_bytes;
|
||||
// memset all cache data and reset cur_off
|
||||
memset(cache_buf, 0, cache_size);
|
||||
cur_off = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void flush_cache() {
|
||||
assert(cache_buf != nullptr);
|
||||
writer.write(cache_buf, cur_off);
|
||||
fsize += cur_off;
|
||||
memset(cache_buf, 0, cache_size);
|
||||
cur_off = 0;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
flush_cache();
|
||||
writer.seekp(0);
|
||||
}
|
||||
|
||||
private:
|
||||
// underlying ofstream
|
||||
std::ofstream writer;
|
||||
// # bytes to cache for one shot write
|
||||
uint64_t cache_size = 0;
|
||||
// underlying buf for cache
|
||||
char* cache_buf = nullptr;
|
||||
// offset into cache_buf for cur_pos
|
||||
uint64_t cur_off = 0;
|
||||
|
||||
// file size
|
||||
uint64_t fsize = 0;
|
||||
};
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <string.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <vector>
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace diskann {
|
||||
|
||||
template<typename T>
|
||||
class ConcurrentQueue {
|
||||
typedef std::chrono::microseconds chrono_us_t;
|
||||
typedef std::unique_lock<std::mutex> mutex_locker;
|
||||
|
||||
std::queue<T> q;
|
||||
std::mutex mut;
|
||||
std::mutex push_mut;
|
||||
std::mutex pop_mut;
|
||||
std::condition_variable push_cv;
|
||||
std::condition_variable pop_cv;
|
||||
T null_T;
|
||||
|
||||
public:
|
||||
ConcurrentQueue() {
|
||||
}
|
||||
|
||||
ConcurrentQueue(T nullT) {
|
||||
this->null_T = nullT;
|
||||
}
|
||||
|
||||
~ConcurrentQueue() {
|
||||
this->push_cv.notify_all();
|
||||
this->pop_cv.notify_all();
|
||||
}
|
||||
|
||||
// queue stats
|
||||
uint64_t size() {
|
||||
mutex_locker lk(this->mut);
|
||||
uint64_t ret = q.size();
|
||||
lk.unlock();
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool empty() {
|
||||
return (this->size() == 0);
|
||||
}
|
||||
|
||||
// PUSH BACK
|
||||
void push(T& new_val) {
|
||||
mutex_locker lk(this->mut);
|
||||
this->q.push(new_val);
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
template<class Iterator>
|
||||
void insert(Iterator iter_begin, Iterator iter_end) {
|
||||
mutex_locker lk(this->mut);
|
||||
for (Iterator it = iter_begin; it != iter_end; it++) {
|
||||
this->q.push(*it);
|
||||
}
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
// POP FRONT
|
||||
T pop() {
|
||||
mutex_locker lk(this->mut);
|
||||
if (this->q.empty()) {
|
||||
lk.unlock();
|
||||
return this->null_T;
|
||||
} else {
|
||||
T ret = this->q.front();
|
||||
this->q.pop();
|
||||
// diskann::cout << "thread_id: " << std::this_thread::get_id() << ",
|
||||
// ctx: "
|
||||
// << ret.ctx << "\n";
|
||||
lk.unlock();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
// register for notifications
|
||||
void wait_for_push_notify(chrono_us_t wait_time = chrono_us_t{10}) {
|
||||
mutex_locker lk(this->push_mut);
|
||||
this->push_cv.wait_for(lk, wait_time);
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
void wait_for_pop_notify(chrono_us_t wait_time = chrono_us_t{10}) {
|
||||
mutex_locker lk(this->pop_mut);
|
||||
this->pop_cv.wait_for(lk, wait_time);
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
// just notify functions
|
||||
void push_notify_one() {
|
||||
this->push_cv.notify_one();
|
||||
}
|
||||
void push_notify_all() {
|
||||
this->push_cv.notify_all();
|
||||
}
|
||||
void pop_notify_one() {
|
||||
this->pop_cv.notify_one();
|
||||
}
|
||||
void pop_notify_all() {
|
||||
this->pop_cv.notify_all();
|
||||
}
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
namespace diskann {
|
||||
template<typename T>
|
||||
inline float compute_l2_norm(const T* vector, uint64_t ndims) {
|
||||
float norm = 0.0f;
|
||||
for (uint64_t i = 0; i < ndims; i++) {
|
||||
norm += vector[i] * vector[i];
|
||||
}
|
||||
return std::sqrt(norm);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline float compute_cosine_similarity(const T* left, const T* right,
|
||||
uint64_t ndims) {
|
||||
float left_norm = compute_l2_norm<T>(left, ndims);
|
||||
float right_norm = compute_l2_norm<T>(right, ndims);
|
||||
float dot = 0.0f;
|
||||
for (uint64_t i = 0; i < ndims; i++) {
|
||||
dot += left[i] * right[i];
|
||||
}
|
||||
float cos_sim = dot / (left_norm * right_norm);
|
||||
return cos_sim;
|
||||
}
|
||||
|
||||
inline std::vector<float> compute_cosine_similarity_batch(
|
||||
const float* query, const unsigned* indices, const float* all_data,
|
||||
const unsigned ndims, const unsigned npts) {
|
||||
std::vector<float> cos_dists;
|
||||
cos_dists.reserve(npts);
|
||||
|
||||
for (size_t i = 0; i < npts; i++) {
|
||||
const float* point = all_data + (size_t)(indices[i]) * (size_t)(ndims);
|
||||
cos_dists.push_back(
|
||||
compute_cosine_similarity<float>(point, query, ndims));
|
||||
}
|
||||
return cos_dists;
|
||||
}
|
||||
} // namespace diskann
|
|
@ -0,0 +1,329 @@
|
|||
#pragma once
|
||||
|
||||
#include <utils.h>
|
||||
#ifdef _WINDOWS
|
||||
#include <immintrin.h>
|
||||
#include <smmintrin.h>
|
||||
#include <tmmintrin.h>
|
||||
#include <intrin.h>
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#include <cosine_similarity.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace {
|
||||
static inline __m128 _mm_mulhi_epi8(__m128i X) {
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
|
||||
return _mm_cvtepi32_ps(
|
||||
_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_mulhi_epi8_shift32(__m128i X) {
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
X = _mm_srli_epi64(X, 32);
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
|
||||
return _mm_cvtepi32_ps(
|
||||
_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
|
||||
}
|
||||
static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y) {
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
|
||||
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
return _mm_cvtepi32_ps(
|
||||
_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
static inline __m128 _mm_mul_epi8(__m128i X) {
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
|
||||
return _mm_cvtepi32_ps(
|
||||
_mm_add_epi32(_mm_madd_epi16(xlo, xlo), _mm_madd_epi16(xhi, xhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_mul32_pi8(__m128i X, __m128i Y) {
|
||||
__m128i xlo = _mm_cvtepi8_epi16(X), ylo = _mm_cvtepi8_epi16(Y);
|
||||
return _mm_cvtepi32_ps(
|
||||
_mm_unpacklo_epi32(_mm_madd_epi16(xlo, ylo), _mm_setzero_si128()));
|
||||
}
|
||||
|
||||
static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y) {
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
|
||||
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
|
||||
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
|
||||
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
|
||||
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo),
|
||||
_mm256_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
|
||||
static inline __m256 _mm256_mul32_pi8(__m128i X, __m128i Y) {
|
||||
__m256i xlo = _mm256_cvtepi8_epi16(X), ylo = _mm256_cvtepi8_epi16(Y);
|
||||
return _mm256_blend_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(xlo, ylo)),
|
||||
_mm256_setzero_ps(), 252);
|
||||
}
|
||||
|
||||
static inline float _mm256_reduce_add_ps(__m256 x) {
|
||||
/* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
|
||||
const __m128 x128 =
|
||||
_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
||||
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
|
||||
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
||||
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
|
||||
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
/* Conversion to float is a no-op on x86-64 */
|
||||
return _mm_cvtss_f32(x32);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace diskann {
|
||||
// enum Metric { L2 = 0, INNER_PRODUCT = 1, FAST_L2 = 2, PQ = 3 };
|
||||
template<typename T>
|
||||
class Distance {
|
||||
public:
|
||||
virtual float compare(const T *a, const T *b, unsigned length) const = 0;
|
||||
virtual ~Distance() {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class DistanceCosine : public Distance<T> {
|
||||
float compare(const T *a, const T *b, unsigned length) const {
|
||||
return diskann::compute_cosine_similarity<T>(a, b, length);
|
||||
}
|
||||
};
|
||||
|
||||
class DistanceL2Int8 : public Distance<int8_t> {
|
||||
public:
|
||||
float compare(const int8_t *a, const int8_t *b, unsigned size) const {
|
||||
int32_t result = 0;
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#ifdef USE_AVX2
|
||||
__m256 r = _mm256_setzero_ps();
|
||||
char * pX = (char *) a, *pY = (char *) b;
|
||||
while (size >= 32) {
|
||||
__m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *) pX),
|
||||
_mm256_loadu_si256((__m256i *) pY));
|
||||
r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1));
|
||||
pX += 32;
|
||||
pY += 32;
|
||||
size -= 32;
|
||||
}
|
||||
while (size > 0) {
|
||||
__m128i r2 = _mm_subs_epi8(_mm_loadu_si128((__m128i *) pX),
|
||||
_mm_loadu_si128((__m128i *) pY));
|
||||
r = _mm256_add_ps(r, _mm256_mul32_pi8(r2, r2));
|
||||
pX += 4;
|
||||
pY += 4;
|
||||
size -= 4;
|
||||
}
|
||||
r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r);
|
||||
return r.m256_f32[0] + r.m256_f32[4];
|
||||
#else
|
||||
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
|
||||
for (_s32 i = 0; i < (_s32) size; i++) {
|
||||
result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) *
|
||||
((int32_t)((int16_t) a[i] - (int16_t) b[i]));
|
||||
}
|
||||
return (float) result;
|
||||
#endif
|
||||
#else
|
||||
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
|
||||
for (_s32 i = 0; i < (_s32) size; i++) {
|
||||
result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) *
|
||||
((int32_t)((int16_t) a[i] - (int16_t) b[i]));
|
||||
}
|
||||
return (float) result;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class DistanceL2UInt8 : public Distance<uint8_t> {
|
||||
public:
|
||||
float compare(const uint8_t *a, const uint8_t *b, unsigned size) const {
|
||||
uint32_t result = 0;
|
||||
#ifndef _WINDOWS
|
||||
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
|
||||
#endif
|
||||
for (_s32 i = 0; i < (_s32) size; i++) {
|
||||
result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) *
|
||||
((int32_t)((int16_t) a[i] - (int16_t) b[i]));
|
||||
}
|
||||
return (float) result;
|
||||
}
|
||||
};
|
||||
|
||||
class DistanceL2 : public Distance<float> {
|
||||
public:
|
||||
#ifndef _WINDOWS
|
||||
float compare(const float *a, const float *b, unsigned size) const
|
||||
__attribute__((hot)) {
|
||||
a = (const float *) __builtin_assume_aligned(a, 32);
|
||||
b = (const float *) __builtin_assume_aligned(b, 32);
|
||||
#else
|
||||
float compare(const float *a, const float *b, unsigned size) const {
|
||||
#endif
|
||||
|
||||
float result = 0;
|
||||
#ifdef USE_AVX2
|
||||
// assume size is divisible by 8
|
||||
_u16 niters = size / 8;
|
||||
__m256 sum = _mm256_setzero_ps();
|
||||
for (_u16 j = 0; j < niters; j++) {
|
||||
// scope is a[8j:8j+7], b[8j:8j+7]
|
||||
// load a_vec
|
||||
if (j < (niters - 1)) {
|
||||
_mm_prefetch((char *) (a + 8 * (j + 1)), _MM_HINT_T0);
|
||||
_mm_prefetch((char *) (b + 8 * (j + 1)), _MM_HINT_T0);
|
||||
}
|
||||
__m256 a_vec = _mm256_load_ps(a + 8 * j);
|
||||
// load b_vec
|
||||
__m256 b_vec = _mm256_load_ps(b + 8 * j);
|
||||
// a_vec - b_vec
|
||||
__m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec);
|
||||
/*
|
||||
// (a_vec - b_vec)**2
|
||||
__m256 tmp_vec2 = _mm256_mul_ps(tmp_vec, tmp_vec);
|
||||
// accumulate sum
|
||||
sum = _mm256_add_ps(sum, tmp_vec2);
|
||||
*/
|
||||
// sum = (tmp_vec**2) + sum
|
||||
sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum);
|
||||
}
|
||||
|
||||
// horizontal add sum
|
||||
result = _mm256_reduce_add_ps(sum);
|
||||
#else
|
||||
#ifndef _WINDOWS
|
||||
#pragma omp simd reduction(+ : result) aligned(a, b : 32)
|
||||
#endif
|
||||
for (_s32 i = 0; i < (_s32) size; i++) {
|
||||
result += (a[i] - b[i]) * (a[i] - b[i]);
|
||||
}
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Gopal. Slow implementations of the distance functions to get diskann to
|
||||
// work in v14 machines that do not have AVX2 support. Performance here is not
|
||||
// a concern, so we are using the simplest possible implementation.
|
||||
template<typename T>
|
||||
class SlowDistanceL2Int : public Distance<T> {
|
||||
virtual float compare(const T *a, const T *b, unsigned length) const {
|
||||
uint32_t result = 0;
|
||||
for (_u32 i = 0; i < length; i++) {
|
||||
result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) *
|
||||
((int32_t)((int16_t) a[i] - (int16_t) b[i]));
|
||||
}
|
||||
return (float) result;
|
||||
}
|
||||
};
|
||||
|
||||
class SlowDistanceL2Float : public Distance<float> {
|
||||
virtual float compare(const float *a, const float *b,
|
||||
unsigned length) const {
|
||||
float result = 0.0f;
|
||||
for (_u32 i = 0; i < length; i++) {
|
||||
result += (a[i] - b[i]) * (a[i] - b[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class AVXDistanceL2Int8 : public Distance<int8_t> {
|
||||
public:
|
||||
virtual float compare(const int8_t *a, const int8_t *b,
|
||||
unsigned int length) const {
|
||||
#ifndef _WINDOWS
|
||||
std::cout << "AVX only supported in Windows build.";
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
__m128 r = _mm_setzero_ps();
|
||||
__m128i r1;
|
||||
while (length >= 16) {
|
||||
r1 = _mm_subs_epi8(_mm_load_si128((__m128i *) a),
|
||||
_mm_load_si128((__m128i *) b));
|
||||
r = _mm_add_ps(r, _mm_mul_epi8(r1));
|
||||
a += 16;
|
||||
b += 16;
|
||||
length -= 16;
|
||||
}
|
||||
r = _mm_hadd_ps(_mm_hadd_ps(r, r), r);
|
||||
float res = r.m128_f32[0];
|
||||
|
||||
if (length >= 8) {
|
||||
__m128 r2 = _mm_setzero_ps();
|
||||
__m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 8)),
|
||||
_mm_load_si128((__m128i *) (b - 8)));
|
||||
r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3));
|
||||
a += 8;
|
||||
b += 8;
|
||||
length -= 8;
|
||||
r2 = _mm_hadd_ps(_mm_hadd_ps(r2, r2), r2);
|
||||
res += r2.m128_f32[0];
|
||||
}
|
||||
|
||||
if (length >= 4) {
|
||||
__m128 r2 = _mm_setzero_ps();
|
||||
__m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 12)),
|
||||
_mm_load_si128((__m128i *) (b - 12)));
|
||||
r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3));
|
||||
res += r2.m128_f32[0] + r2.m128_f32[1];
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class AVXDistanceL2Float : public Distance<float> {
|
||||
public:
|
||||
virtual float compare(const float *a, const float *b,
|
||||
unsigned int length) const {
|
||||
#ifndef _WINDOWS
|
||||
std::cout << "AVX only supported in Windows build.";
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
__m128 diff, v1, v2;
|
||||
__m128 sum = _mm_set1_ps(0);
|
||||
|
||||
while (length >= 4) {
|
||||
v1 = _mm_loadu_ps(a);
|
||||
a += 4;
|
||||
v2 = _mm_loadu_ps(b);
|
||||
b += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
length -= 4;
|
||||
}
|
||||
|
||||
return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] +
|
||||
sum.m128_f32[3];
|
||||
}
|
||||
#endif
|
||||
};
|
||||
} // namespace diskann
|
|
@ -1,328 +0,0 @@
|
|||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_DISTANCE_H
|
||||
#define EFANNA2E_DISTANCE_H
|
||||
|
||||
#include <x86intrin.h>
|
||||
#include <iostream>
|
||||
namespace efanna2e{
|
||||
enum Metric{
|
||||
L2 = 0,
|
||||
INNER_PRODUCT = 1,
|
||||
FAST_L2 = 2,
|
||||
PQ = 3
|
||||
};
|
||||
class Distance {
|
||||
public:
|
||||
virtual float compare(const float* a, const float* b, unsigned length) const = 0;
|
||||
virtual ~Distance() {}
|
||||
};
|
||||
|
||||
class DistanceL2 : public Distance{
|
||||
public:
|
||||
float compare(const float* a, const float* b, unsigned size) const {
|
||||
float result = 0;
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __AVX__
|
||||
|
||||
#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm256_loadu_ps(addr1);\
|
||||
tmp2 = _mm256_loadu_ps(addr2);\
|
||||
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
__m256 r0, r1;
|
||||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[8] __attribute__ ((aligned (32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if(DR){AVX_L2SQR(e_l, e_r, sum, l0, r0);}
|
||||
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
AVX_L2SQR(l, r, sum, l0, r0);
|
||||
AVX_L2SQR(l + 8, r + 8, sum, l1, r1);
|
||||
}
|
||||
_mm256_storeu_ps(unpack, sum);
|
||||
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
|
||||
|
||||
#else
|
||||
#ifdef __SSE2__
|
||||
#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm_load_ps(addr1);\
|
||||
tmp2 = _mm_load_ps(addr2);\
|
||||
tmp1 = _mm_sub_ps(tmp1, tmp2); \
|
||||
tmp1 = _mm_mul_ps(tmp1, tmp1); \
|
||||
dest = _mm_add_ps(dest, tmp1);
|
||||
|
||||
__m128 sum;
|
||||
__m128 l0, l1, l2, l3;
|
||||
__m128 r0, r1, r2, r3;
|
||||
unsigned D = (size + 3) & ~3U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[4] __attribute__ ((aligned (16))) = {0, 0, 0, 0};
|
||||
|
||||
sum = _mm_load_ps(unpack);
|
||||
switch (DR) {
|
||||
case 12:
|
||||
SSE_L2SQR(e_l+8, e_r+8, sum, l2, r2);
|
||||
case 8:
|
||||
SSE_L2SQR(e_l+4, e_r+4, sum, l1, r1);
|
||||
case 4:
|
||||
SSE_L2SQR(e_l, e_r, sum, l0, r0);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
SSE_L2SQR(l, r, sum, l0, r0);
|
||||
SSE_L2SQR(l + 4, r + 4, sum, l1, r1);
|
||||
SSE_L2SQR(l + 8, r + 8, sum, l2, r2);
|
||||
SSE_L2SQR(l + 12, r + 12, sum, l3, r3);
|
||||
}
|
||||
_mm_storeu_ps(unpack, sum);
|
||||
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
|
||||
|
||||
//nomal distance
|
||||
#else
|
||||
|
||||
float diff0, diff1, diff2, diff3;
|
||||
const float* last = a + size;
|
||||
const float* unroll_group = last - 3;
|
||||
|
||||
/* Process 4 items with each loop for efficiency. */
|
||||
while (a < unroll_group) {
|
||||
diff0 = a[0] - b[0];
|
||||
diff1 = a[1] - b[1];
|
||||
diff2 = a[2] - b[2];
|
||||
diff3 = a[3] - b[3];
|
||||
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
|
||||
while (a < last) {
|
||||
diff0 = *a++ - *b++;
|
||||
result += diff0 * diff0;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class DistanceInnerProduct : public Distance{
|
||||
public:
|
||||
float compare(const float* a, const float* b, unsigned size) const {
|
||||
float result = 0;
|
||||
#ifdef __GNUC__
|
||||
#ifdef __AVX__
|
||||
#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm256_loadu_ps(addr1);\
|
||||
tmp2 = _mm256_loadu_ps(addr2);\
|
||||
tmp1 = _mm256_mul_ps(tmp1, tmp2); \
|
||||
dest = _mm256_add_ps(dest, tmp1);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
__m256 r0, r1;
|
||||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[8] __attribute__ ((aligned (32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if(DR){AVX_DOT(e_l, e_r, sum, l0, r0);}
|
||||
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
AVX_DOT(l, r, sum, l0, r0);
|
||||
AVX_DOT(l + 8, r + 8, sum, l1, r1);
|
||||
}
|
||||
_mm256_storeu_ps(unpack, sum);
|
||||
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
|
||||
|
||||
#else
|
||||
#ifdef __SSE2__
|
||||
#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \
|
||||
tmp1 = _mm128_loadu_ps(addr1);\
|
||||
tmp2 = _mm128_loadu_ps(addr2);\
|
||||
tmp1 = _mm128_mul_ps(tmp1, tmp2); \
|
||||
dest = _mm128_add_ps(dest, tmp1);
|
||||
__m128 sum;
|
||||
__m128 l0, l1, l2, l3;
|
||||
__m128 r0, r1, r2, r3;
|
||||
unsigned D = (size + 3) & ~3U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *r = b;
|
||||
const float *e_l = l + DD;
|
||||
const float *e_r = r + DD;
|
||||
float unpack[4] __attribute__ ((aligned (16))) = {0, 0, 0, 0};
|
||||
|
||||
sum = _mm_load_ps(unpack);
|
||||
switch (DR) {
|
||||
case 12:
|
||||
SSE_DOT(e_l+8, e_r+8, sum, l2, r2);
|
||||
case 8:
|
||||
SSE_DOT(e_l+4, e_r+4, sum, l1, r1);
|
||||
case 4:
|
||||
SSE_DOT(e_l, e_r, sum, l0, r0);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
|
||||
SSE_DOT(l, r, sum, l0, r0);
|
||||
SSE_DOT(l + 4, r + 4, sum, l1, r1);
|
||||
SSE_DOT(l + 8, r + 8, sum, l2, r2);
|
||||
SSE_DOT(l + 12, r + 12, sum, l3, r3);
|
||||
}
|
||||
_mm_storeu_ps(unpack, sum);
|
||||
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
|
||||
#else
|
||||
|
||||
float dot0, dot1, dot2, dot3;
|
||||
const float* last = a + size;
|
||||
const float* unroll_group = last - 3;
|
||||
|
||||
/* Process 4 items with each loop for efficiency. */
|
||||
while (a < unroll_group) {
|
||||
dot0 = a[0] * b[0];
|
||||
dot1 = a[1] * b[1];
|
||||
dot2 = a[2] * b[2];
|
||||
dot3 = a[3] * b[3];
|
||||
result += dot0 + dot1 + dot2 + dot3;
|
||||
a += 4;
|
||||
b += 4;
|
||||
}
|
||||
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
|
||||
while (a < last) {
|
||||
result += *a++ * *b++;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
};
|
||||
class DistanceFastL2 : public DistanceInnerProduct{
|
||||
public:
|
||||
float norm(const float* a, unsigned size) const{
|
||||
float result = 0;
|
||||
#ifdef __GNUC__
|
||||
#ifdef __AVX__
|
||||
#define AVX_L2NORM(addr, dest, tmp) \
|
||||
tmp = _mm256_loadu_ps(addr); \
|
||||
tmp = _mm256_mul_ps(tmp, tmp); \
|
||||
dest = _mm256_add_ps(dest, tmp);
|
||||
|
||||
__m256 sum;
|
||||
__m256 l0, l1;
|
||||
unsigned D = (size + 7) & ~7U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *e_l = l + DD;
|
||||
float unpack[8] __attribute__ ((aligned (32))) = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
sum = _mm256_loadu_ps(unpack);
|
||||
if(DR){AVX_L2NORM(e_l, sum, l0);}
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16) {
|
||||
AVX_L2NORM(l, sum, l0);
|
||||
AVX_L2NORM(l + 8, sum, l1);
|
||||
}
|
||||
_mm256_storeu_ps(unpack, sum);
|
||||
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
|
||||
#else
|
||||
#ifdef __SSE2__
|
||||
#define SSE_L2NORM(addr, dest, tmp) \
|
||||
tmp = _mm128_loadu_ps(addr); \
|
||||
tmp = _mm128_mul_ps(tmp, tmp); \
|
||||
dest = _mm128_add_ps(dest, tmp);
|
||||
|
||||
__m128 sum;
|
||||
__m128 l0, l1, l2, l3;
|
||||
unsigned D = (size + 3) & ~3U;
|
||||
unsigned DR = D % 16;
|
||||
unsigned DD = D - DR;
|
||||
const float *l = a;
|
||||
const float *e_l = l + DD;
|
||||
float unpack[4] __attribute__ ((aligned (16))) = {0, 0, 0, 0};
|
||||
|
||||
sum = _mm_load_ps(unpack);
|
||||
switch (DR) {
|
||||
case 12:
|
||||
SSE_L2NORM(e_l+8, sum, l2);
|
||||
case 8:
|
||||
SSE_L2NORM(e_l+4, sum, l1);
|
||||
case 4:
|
||||
SSE_L2NORM(e_l, sum, l0);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
for (unsigned i = 0; i < DD; i += 16, l += 16) {
|
||||
SSE_L2NORM(l, sum, l0);
|
||||
SSE_L2NORM(l + 4, sum, l1);
|
||||
SSE_L2NORM(l + 8, sum, l2);
|
||||
SSE_L2NORM(l + 12, sum, l3);
|
||||
}
|
||||
_mm_storeu_ps(unpack, sum);
|
||||
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
|
||||
#else
|
||||
float dot0, dot1, dot2, dot3;
|
||||
const float* last = a + size;
|
||||
const float* unroll_group = last - 3;
|
||||
|
||||
/* Process 4 items with each loop for efficiency. */
|
||||
while (a < unroll_group) {
|
||||
dot0 = a[0] * a[0];
|
||||
dot1 = a[1] * a[1];
|
||||
dot2 = a[2] * a[2];
|
||||
dot3 = a[3] * a[3];
|
||||
result += dot0 + dot1 + dot2 + dot3;
|
||||
a += 4;
|
||||
}
|
||||
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
|
||||
while (a < last) {
|
||||
result += (*a) * (*a);
|
||||
a++;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
using DistanceInnerProduct::compare;
|
||||
float compare(const float* a, const float* b, float norm, unsigned size) const {//not implement
|
||||
float result = -2 * DistanceInnerProduct::compare(a, b, size);
|
||||
result += norm;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif //EFANNA2E_DISTANCE_H
|
|
@ -1,21 +0,0 @@
|
|||
//
|
||||
// Copyright (c) 2017 ZJULearning. All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the MIT license.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_EXCEPTIONS_H
|
||||
#define EFANNA2E_EXCEPTIONS_H
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
namespace efanna2e {
|
||||
|
||||
class NotImplementedException : public std::logic_error {
|
||||
public:
|
||||
NotImplementedException() : std::logic_error("Function not yet implemented.") {}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_EXCEPTIONS_H
|
|
@ -1,56 +0,0 @@
|
|||
//
|
||||
// Copyright (c) 2017 ZJULearning. All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the MIT license.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_INDEX_H
|
||||
#define EFANNA2E_INDEX_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include "distance.h"
|
||||
#include "parameters.h"
|
||||
|
||||
namespace efanna2e {
|
||||
|
||||
class Index {
|
||||
public:
|
||||
explicit Index(const size_t dimension, const size_t n, Metric metric);
|
||||
|
||||
|
||||
virtual ~Index();
|
||||
|
||||
virtual void Build(size_t n, const float *data, const Parameters ¶meters) = 0;
|
||||
|
||||
virtual void Search(
|
||||
const float *query,
|
||||
const float *x,
|
||||
size_t k,
|
||||
const Parameters ¶meters,
|
||||
unsigned *indices) = 0;
|
||||
|
||||
virtual void Save(const char *filename) = 0;
|
||||
|
||||
virtual void Load(const char *filename) = 0;
|
||||
|
||||
inline bool HasBuilt() const { return has_built; }
|
||||
|
||||
inline size_t GetDimension() const { return dimension_; };
|
||||
|
||||
inline size_t GetSizeOfDataset() const { return nd_; }
|
||||
|
||||
inline const float *GetDataset() const { return data_; }
|
||||
protected:
|
||||
const size_t dimension_;
|
||||
const float *data_;
|
||||
size_t nd_;
|
||||
bool has_built;
|
||||
Distance* distance_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_INDEX_H
|
|
@ -1,78 +0,0 @@
|
|||
#ifndef EFANNA2E_INDEX_NSG_H
|
||||
#define EFANNA2E_INDEX_NSG_H
|
||||
|
||||
#include "util.h"
|
||||
#include "parameters.h"
|
||||
#include "neighbor.h"
|
||||
#include "index.h"
|
||||
#include <cassert>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include <stack>
|
||||
|
||||
namespace efanna2e {
|
||||
|
||||
class IndexNSG : public Index {
|
||||
public:
|
||||
explicit IndexNSG(const size_t dimension, const size_t n, Metric m, Index *initializer);
|
||||
|
||||
|
||||
virtual ~IndexNSG();
|
||||
|
||||
virtual void Save(const char *filename)override;
|
||||
virtual void Load(const char *filename)override;
|
||||
|
||||
|
||||
virtual void Build(size_t n, const float *data, const Parameters ¶meters) override;
|
||||
|
||||
virtual void Search(
|
||||
const float *query,
|
||||
const float *x,
|
||||
size_t k,
|
||||
const Parameters ¶meters,
|
||||
unsigned *indices) override;
|
||||
void SearchWithOptGraph(
|
||||
const float *query,
|
||||
size_t K,
|
||||
const Parameters ¶meters,
|
||||
unsigned *indices);
|
||||
void OptimizeGraph(float* data);
|
||||
|
||||
protected:
|
||||
typedef std::vector<std::vector<unsigned > > CompactGraph;
|
||||
typedef std::vector<LockNeighbor > LockGraph;
|
||||
typedef std::vector<nhood> KNNGraph;
|
||||
|
||||
CompactGraph final_graph_;
|
||||
|
||||
Index *initializer_;
|
||||
void init_graph(const Parameters ¶meters);
|
||||
void get_neighbors(
|
||||
const float *query,
|
||||
const Parameters ¶meter,
|
||||
std::vector<Neighbor> &retset,
|
||||
std::vector<Neighbor> &fullset);
|
||||
void add_cnn(unsigned des, Neighbor p, unsigned range, LockGraph& cut_graph_);
|
||||
void sync_prune(unsigned q, std::vector<Neighbor>& pool, const Parameters ¶meter, LockGraph& cut_graph_);
|
||||
void Link(const Parameters ¶meters, LockGraph& cut_graph_);
|
||||
void Load_nn_graph(const char *filename);
|
||||
void tree_grow(const Parameters ¶meter);
|
||||
void DFS(boost::dynamic_bitset<> &flag, unsigned root, unsigned &cnt);
|
||||
void findroot(boost::dynamic_bitset<> &flag, unsigned &root, const Parameters ¶meter);
|
||||
|
||||
|
||||
private:
|
||||
unsigned width;
|
||||
unsigned ep_;
|
||||
std::vector<std::mutex> locks;
|
||||
char* opt_graph_;
|
||||
size_t node_size;
|
||||
size_t data_len;
|
||||
size_t neighbor_len;
|
||||
KNNGraph nnd_graph;
|
||||
};
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_INDEX_NSG_H
|
|
@ -1,124 +0,0 @@
|
|||
//
|
||||
// Copyright (c) 2017 ZJULearning. All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the MIT license.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_GRAPH_H
|
||||
#define EFANNA2E_GRAPH_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
namespace efanna2e {
|
||||
|
||||
struct Neighbor {
|
||||
unsigned id;
|
||||
float distance;
|
||||
bool flag;
|
||||
|
||||
Neighbor() = default;
|
||||
Neighbor(unsigned id, float distance, bool f) : id{id}, distance{distance}, flag(f) {}
|
||||
|
||||
inline bool operator<(const Neighbor &other) const {
|
||||
return distance < other.distance;
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::lock_guard<std::mutex> LockGuard;
|
||||
struct nhood{
|
||||
std::mutex lock;
|
||||
std::vector<Neighbor> pool;
|
||||
unsigned M;
|
||||
|
||||
std::vector<unsigned> nn_old;
|
||||
std::vector<unsigned> nn_new;
|
||||
std::vector<unsigned> rnn_old;
|
||||
std::vector<unsigned> rnn_new;
|
||||
|
||||
nhood(){}
|
||||
nhood(unsigned l, unsigned s, std::mt19937 &rng, unsigned N){
|
||||
M = s;
|
||||
nn_new.resize(s * 2);
|
||||
GenRandom(rng, &nn_new[0], (unsigned)nn_new.size(), N);
|
||||
nn_new.reserve(s * 2);
|
||||
pool.reserve(l);
|
||||
}
|
||||
|
||||
nhood(const nhood &other){
|
||||
M = other.M;
|
||||
std::copy(other.nn_new.begin(), other.nn_new.end(), std::back_inserter(nn_new));
|
||||
nn_new.reserve(other.nn_new.capacity());
|
||||
pool.reserve(other.pool.capacity());
|
||||
}
|
||||
void insert (unsigned id, float dist) {
|
||||
LockGuard guard(lock);
|
||||
if (dist > pool.front().distance) return;
|
||||
for(unsigned i=0; i<pool.size(); i++){
|
||||
if(id == pool[i].id)return;
|
||||
}
|
||||
if(pool.size() < pool.capacity()){
|
||||
pool.push_back(Neighbor(id, dist, true));
|
||||
std::push_heap(pool.begin(), pool.end());
|
||||
}else{
|
||||
std::pop_heap(pool.begin(), pool.end());
|
||||
pool[pool.size()-1] = Neighbor(id, dist, true);
|
||||
std::push_heap(pool.begin(), pool.end());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename C>
|
||||
void join (C callback) const {
|
||||
for (unsigned const i: nn_new) {
|
||||
for (unsigned const j: nn_new) {
|
||||
if (i < j) {
|
||||
callback(i, j);
|
||||
}
|
||||
}
|
||||
for (unsigned j: nn_old) {
|
||||
callback(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct LockNeighbor{
|
||||
std::mutex lock;
|
||||
std::vector<Neighbor> pool;
|
||||
};
|
||||
|
||||
static inline int InsertIntoPool (Neighbor *addr, unsigned K, Neighbor nn) {
|
||||
// find the location to insert
|
||||
int left=0,right=K-1;
|
||||
if(addr[left].distance>nn.distance){
|
||||
memmove((char *)&addr[left+1], &addr[left],K * sizeof(Neighbor));
|
||||
addr[left] = nn;
|
||||
return left;
|
||||
}
|
||||
if(addr[right].distance<nn.distance){
|
||||
addr[K] = nn;
|
||||
return K;
|
||||
}
|
||||
while(left<right-1){
|
||||
int mid=(left+right)/2;
|
||||
if(addr[mid].distance>nn.distance)right=mid;
|
||||
else left=mid;
|
||||
}
|
||||
//check equal ID
|
||||
|
||||
while (left > 0){
|
||||
if (addr[left].distance < nn.distance) break;
|
||||
if (addr[left].id == nn.id) return K + 1;
|
||||
left--;
|
||||
}
|
||||
if(addr[left].id == nn.id||addr[right].id==nn.id)return K+1;
|
||||
memmove((char *)&addr[right+1], &addr[right],(K-right) * sizeof(Neighbor));
|
||||
addr[right]=nn;
|
||||
return right;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_GRAPH_H
|
|
@ -1,61 +0,0 @@
|
|||
//
|
||||
// Copyright (c) 2017 ZJULearning. All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the MIT license.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_PARAMETERS_H
|
||||
#define EFANNA2E_PARAMETERS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <typeinfo>
|
||||
namespace efanna2e {
|
||||
|
||||
class Parameters {
|
||||
public:
|
||||
template<typename ParamType>
|
||||
inline void Set(const std::string &name, const ParamType &value) {
|
||||
std::stringstream sstream;
|
||||
sstream << value;
|
||||
params[name] = sstream.str();
|
||||
}
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType Get(const std::string &name) const {
|
||||
auto item = params.find(name);
|
||||
if (item == params.end()) {
|
||||
throw std::invalid_argument("Invalid parameter name.");
|
||||
} else {
|
||||
return ConvertStrToValue<ParamType>(item->second);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType Get(const std::string &name, const ParamType &default_value) {
|
||||
try {
|
||||
return Get<ParamType>(name);
|
||||
} catch (std::invalid_argument e) {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> params;
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType ConvertStrToValue(const std::string &str) const {
|
||||
std::stringstream sstream(str);
|
||||
ParamType value;
|
||||
if (!(sstream >> value) || !sstream.eof()) {
|
||||
std::stringstream err;
|
||||
err << "Failed to convert value '" << str << "' to type: " << typeid(value).name();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_PARAMETERS_H
|
|
@ -1,71 +0,0 @@
|
|||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#ifndef EFANNA2E_UTIL_H
|
||||
#define EFANNA2E_UTIL_H
|
||||
#include <random>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#ifdef __APPLE__
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif
|
||||
namespace efanna2e {
|
||||
|
||||
static void GenRandom(std::mt19937 &rng, unsigned *addr, unsigned size, unsigned N) {
|
||||
for (unsigned i = 0; i < size; ++i) {
|
||||
addr[i] = rng() % (N - size);
|
||||
}
|
||||
std::sort(addr, addr + size);
|
||||
for (unsigned i = 1; i < size; ++i) {
|
||||
if (addr[i] <= addr[i - 1]) {
|
||||
addr[i] = addr[i - 1] + 1;
|
||||
}
|
||||
}
|
||||
unsigned off = rng() % N;
|
||||
for (unsigned i = 0; i < size; ++i) {
|
||||
addr[i] = (addr[i] + off) % N;
|
||||
}
|
||||
}
|
||||
|
||||
inline float* data_align(float* data_ori, unsigned point_num, unsigned& dim){
|
||||
#ifdef __GNUC__
|
||||
#ifdef __AVX__
|
||||
#define DATA_ALIGN_FACTOR 8
|
||||
#else
|
||||
#ifdef __SSE2__
|
||||
#define DATA_ALIGN_FACTOR 4
|
||||
#else
|
||||
#define DATA_ALIGN_FACTOR 1
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
std::cout << "align with : "<<DATA_ALIGN_FACTOR << std::endl;
|
||||
float* data_new=0;
|
||||
unsigned new_dim = (dim + DATA_ALIGN_FACTOR - 1) / DATA_ALIGN_FACTOR * DATA_ALIGN_FACTOR;
|
||||
std::cout << "align to new dim: "<<new_dim << std::endl;
|
||||
#ifdef __APPLE__
|
||||
data_new = new float[new_dim * point_num];
|
||||
#else
|
||||
data_new = (float*)memalign(DATA_ALIGN_FACTOR * 4, point_num * new_dim * sizeof(float));
|
||||
#endif
|
||||
|
||||
for(unsigned i=0; i<point_num; i++){
|
||||
memcpy(data_new + i * new_dim, data_ori + i * dim, dim * sizeof(float));
|
||||
memset(data_new + i * new_dim + dim, 0, (new_dim - dim) * sizeof(float));
|
||||
}
|
||||
dim = new_dim;
|
||||
#ifdef __APPLE__
|
||||
delete[] data_ori;
|
||||
#else
|
||||
free(data_ori);
|
||||
#endif
|
||||
return data_new;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif //EFANNA2E_UTIL_H
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <stdexcept>
|
||||
|
||||
namespace diskann {
|
||||
|
||||
class NotImplementedException : public std::logic_error {
|
||||
public:
|
||||
NotImplementedException()
|
||||
: std::logic_error("Function not yet implemented.") {
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tsl/robin_set.h"
|
||||
|
||||
#include "distance.h"
|
||||
#include "neighbor.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
#include "windows_customizations.h"
|
||||
|
||||
#define SLACK_FACTOR 1.3
|
||||
|
||||
#define ESTIMATE_RAM_USAGE(size, dim, datasize, degree) \
|
||||
(1.30 * (((double) size * dim) * datasize + \
|
||||
((double) size * degree) * sizeof(unsigned) * SLACK_FACTOR))
|
||||
|
||||
namespace diskann {
|
||||
template<typename T, typename TagT = int>
|
||||
class Index {
|
||||
public:
|
||||
DISKANN_DLLEXPORT Index(Metric m, const char *filename,
|
||||
const size_t max_points = 0, const size_t nd = 0,
|
||||
const size_t num_frozen_pts = 0,
|
||||
const bool enable_tags = false,
|
||||
const bool store_data = true,
|
||||
const bool support_eager_delete = false);
|
||||
DISKANN_DLLEXPORT ~Index();
|
||||
|
||||
// checks if data is consolidated, saves graph, metadata and associated
|
||||
// tags.
|
||||
DISKANN_DLLEXPORT void save(const char *filename);
|
||||
DISKANN_DLLEXPORT void load(const char *filename,
|
||||
const bool load_tags = false,
|
||||
const char *tag_filename = NULL);
|
||||
// generates one or more frozen points that will never get deleted from the
|
||||
// graph
|
||||
DISKANN_DLLEXPORT int generate_random_frozen_points(
|
||||
const char *filename = NULL);
|
||||
|
||||
DISKANN_DLLEXPORT void build(
|
||||
Parameters & parameters,
|
||||
const std::vector<TagT> &tags = std::vector<TagT>());
|
||||
|
||||
// Gopal. Added search overload that takes L as parameter, so that we
|
||||
// can customize L on a per-query basis without tampering with "Parameters"
|
||||
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query,
|
||||
const size_t K,
|
||||
const unsigned L,
|
||||
unsigned *indices);
|
||||
|
||||
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(
|
||||
const T *query, const uint64_t K, const unsigned L,
|
||||
std::vector<unsigned> init_ids, uint64_t *indices, float *distances);
|
||||
|
||||
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_tags(
|
||||
const T *query, const size_t K, const unsigned L, TagT *tags,
|
||||
unsigned frozen_pts, unsigned *indices_buffer = NULL);
|
||||
|
||||
// repositions frozen points to the end of _data - if they have been moved
|
||||
// during deletion
|
||||
DISKANN_DLLEXPORT void readjust_data(unsigned _num_frozen_pts);
|
||||
|
||||
/* insertions possible only when id corresponding to tag does not already
|
||||
* exist in the graph */
|
||||
DISKANN_DLLEXPORT int insert_point(const T * point,
|
||||
const Parameters & parameter,
|
||||
std::vector<Neighbor> & pool,
|
||||
std::vector<Neighbor> & tmp,
|
||||
tsl::robin_set<unsigned> & visited,
|
||||
std::vector<SimpleNeighbor> &cut_graph,
|
||||
const TagT tag);
|
||||
|
||||
// call before triggering deleteions - sets important flags required for
|
||||
// deletion related operations
|
||||
DISKANN_DLLEXPORT int enable_delete();
|
||||
|
||||
// call after all delete requests have been served, checks if deletions were
|
||||
// executed correctly, rearranges metadata in case of lazy deletes
|
||||
DISKANN_DLLEXPORT int disable_delete(const Parameters ¶meters,
|
||||
const bool consolidate = false);
|
||||
|
||||
// Record deleted point now and restructure graph later. Return -1 if tag
|
||||
// not found, 0 if OK. Do not call if _eager_delete was called earlier and
|
||||
// data was not consolidated
|
||||
DISKANN_DLLEXPORT int delete_point(const TagT tag);
|
||||
|
||||
// Delete point from graph and restructure it immediately. Do not call if
|
||||
// _lazy_delete was called earlier and data was not consolidated
|
||||
DISKANN_DLLEXPORT int eager_delete(const TagT tag,
|
||||
const Parameters ¶meters);
|
||||
|
||||
/* Internals of the library */
|
||||
protected:
|
||||
typedef std::vector<SimpleNeighbor> vecNgh;
|
||||
typedef std::vector<std::vector<unsigned>> CompactGraph;
|
||||
CompactGraph _final_graph;
|
||||
CompactGraph _in_graph;
|
||||
|
||||
// determines navigating node of the graph by calculating medoid of data
|
||||
unsigned calculate_entry_point();
|
||||
// called only when _eager_delete is to be supported
|
||||
void update_in_graph();
|
||||
|
||||
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(
|
||||
const T *node_coords, const unsigned Lindex,
|
||||
const std::vector<unsigned> &init_ids,
|
||||
std::vector<Neighbor> & expanded_nodes_info,
|
||||
tsl::robin_set<unsigned> & expanded_nodes_ids,
|
||||
std::vector<Neighbor> & best_L_nodes);
|
||||
|
||||
void get_expanded_nodes(const size_t node, const unsigned Lindex,
|
||||
std::vector<unsigned> init_ids,
|
||||
std::vector<Neighbor> & expanded_nodes_info,
|
||||
tsl::robin_set<unsigned> &expanded_nodes_ids);
|
||||
|
||||
void inter_insert(unsigned n, std::vector<unsigned> &pruned_list,
|
||||
const Parameters ¶meter, bool update_in_graph);
|
||||
|
||||
void prune_neighbors(const unsigned location, std::vector<Neighbor> &pool,
|
||||
const Parameters & parameter,
|
||||
std::vector<unsigned> &pruned_list);
|
||||
|
||||
void occlude_list(std::vector<Neighbor> &pool, const unsigned location,
|
||||
const float alpha, const unsigned degree,
|
||||
const unsigned maxc, std::vector<Neighbor> &result);
|
||||
|
||||
void occlude_list(std::vector<Neighbor> &pool, const unsigned location,
|
||||
const float alpha, const unsigned degree,
|
||||
const unsigned maxc, std::vector<Neighbor> &result,
|
||||
std::vector<float> &occlude_factor);
|
||||
|
||||
void batch_inter_insert(unsigned n,
|
||||
const std::vector<unsigned> &pruned_list,
|
||||
const Parameters & parameter,
|
||||
std::vector<unsigned> & need_to_sync);
|
||||
|
||||
void link(Parameters ¶meters);
|
||||
|
||||
// WARNING: Do not call reserve_location() without acquiring change_lock_
|
||||
unsigned reserve_location();
|
||||
|
||||
// get new location corresponding to each undeleted tag after deletions
|
||||
std::vector<unsigned> get_new_location(unsigned &active);
|
||||
|
||||
// renumber nodes, update tag and location maps and compact the graph, mode
|
||||
// = _consolidated_order in case of lazy deletion and _compacted_order in
|
||||
// case of eager deletion
|
||||
void compact_data(std::vector<unsigned> new_location, unsigned active,
|
||||
bool &mode);
|
||||
|
||||
// WARNING: Do not call consolidate_deletes without acquiring change_lock_
|
||||
// Returns number of live points left after consolidation
|
||||
size_t consolidate_deletes(const Parameters ¶meters);
|
||||
|
||||
private:
|
||||
size_t _dim;
|
||||
size_t _aligned_dim;
|
||||
T * _data;
|
||||
size_t _nd; // number of active points i.e. existing in the graph
|
||||
size_t _max_points; // total number of points in given data set
|
||||
size_t _num_frozen_pts;
|
||||
bool _has_built;
|
||||
Distance<T> *_distance;
|
||||
unsigned _width;
|
||||
unsigned _ep;
|
||||
bool _saturate_graph = false;
|
||||
std::vector<std::mutex> _locks; // Per node lock, cardinality=max_points_
|
||||
|
||||
char * _opt_graph;
|
||||
size_t _node_size;
|
||||
size_t _data_len;
|
||||
size_t _neighbor_len;
|
||||
|
||||
bool _can_delete;
|
||||
bool _eager_done; // true if eager deletions have been made
|
||||
bool _lazy_done; // true if lazy deletions have been made
|
||||
bool _compacted_order; // true if after eager deletions, data has been
|
||||
// consolidated
|
||||
bool _enable_tags;
|
||||
bool _consolidated_order; // true if after lazy deletions, data has been
|
||||
// consolidated
|
||||
bool _support_eager_delete; //_support_eager_delete = activates extra data
|
||||
// structures and functions required for eager
|
||||
// deletion
|
||||
bool _store_data;
|
||||
|
||||
std::unordered_map<TagT, unsigned> _tag_to_location;
|
||||
std::unordered_map<unsigned, TagT> _location_to_tag;
|
||||
|
||||
tsl::robin_set<unsigned> _delete_set;
|
||||
tsl::robin_set<unsigned> _empty_slots;
|
||||
|
||||
std::mutex _change_lock; // Allow only 1 thread to insert/delete
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#ifndef _WINDOWS
|
||||
|
||||
#include "aligned_file_reader.h"
|
||||
|
||||
class LinuxAlignedFileReader : public AlignedFileReader {
|
||||
private:
|
||||
uint64_t file_sz;
|
||||
FileHandle file_desc;
|
||||
io_context_t bad_ctx = (io_context_t) -1;
|
||||
|
||||
public:
|
||||
LinuxAlignedFileReader();
|
||||
~LinuxAlignedFileReader();
|
||||
|
||||
IOContext &get_ctx();
|
||||
|
||||
// register thread-id for a context
|
||||
void register_thread();
|
||||
|
||||
// de-register thread-id for a context
|
||||
void deregister_thread();
|
||||
|
||||
// Open & close ops
|
||||
// Blocking calls
|
||||
void open(const std::string &fname);
|
||||
void close();
|
||||
|
||||
// process batch of aligned requests in parallel
|
||||
// NOTE :: blocking call
|
||||
void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx,
|
||||
bool async = false);
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "windows_customizations.h"
|
||||
|
||||
namespace diskann {
|
||||
#if defined(DISKANN_DLL)
|
||||
extern std::basic_ostream<char> cout;
|
||||
extern std::basic_ostream<char> cerr;
|
||||
#else
|
||||
DISKANN_DLLIMPORT extern std::basic_ostream<char> cout;
|
||||
DISKANN_DLLIMPORT extern std::basic_ostream<char> cerr;
|
||||
#endif
|
||||
|
||||
} // namespace diskann
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <mutex>
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
#include "IANNIndex.h"
|
||||
#include "ANNLogging.h"
|
||||
#endif
|
||||
|
||||
#include "ann_exception.h"
|
||||
|
||||
#ifndef EXEC_ENV_OLS
|
||||
namespace ANNIndex {
|
||||
enum LogLevel {
|
||||
LL_Debug = 0,
|
||||
LL_Info,
|
||||
LL_Status,
|
||||
LL_Warning,
|
||||
LL_Error,
|
||||
LL_Assert,
|
||||
LL_Count
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace diskann {
|
||||
class ANNStreamBuf : public std::basic_streambuf<char> {
|
||||
public:
|
||||
DISKANN_DLLEXPORT explicit ANNStreamBuf(FILE* fp);
|
||||
DISKANN_DLLEXPORT ~ANNStreamBuf();
|
||||
|
||||
DISKANN_DLLEXPORT bool is_open() const {
|
||||
return true; // because stdout and stderr are always open.
|
||||
}
|
||||
DISKANN_DLLEXPORT void close();
|
||||
DISKANN_DLLEXPORT virtual int underflow();
|
||||
DISKANN_DLLEXPORT virtual int overflow(int c);
|
||||
DISKANN_DLLEXPORT virtual int sync();
|
||||
|
||||
private:
|
||||
FILE* _fp;
|
||||
char* _buf;
|
||||
int _bufIndex;
|
||||
std::mutex _mutex;
|
||||
ANNIndex::LogLevel _logLevel;
|
||||
|
||||
int flush();
|
||||
void logImpl(char* str, int numchars);
|
||||
|
||||
// Why the two buffer-sizes? If we are running normally, we are basically
|
||||
// interacting with a character output system, so we short-circuit the
|
||||
// output process by keeping an empty buffer and writing each character
|
||||
// to stdout/stderr. But if we are running in OLS, we have to take all
|
||||
// the text that is written to diskann::cout/diskann:cerr, consolidate it
|
||||
// and push it out in one-shot, because the OLS infra does not give us
|
||||
// character based output. Therefore, we use a larger buffer that is large
|
||||
// enough to store the longest message, and continuously add characters
|
||||
// to it. When the calling code outputs a std::endl or std::flush, sync()
|
||||
// will be called and will output a log level, component name, and the text
|
||||
// that has been collected. (sync() is also called if the buffer is full, so
|
||||
// overflows/missing text are not a concern).
|
||||
// This implies calling code _must_ either print std::endl or std::flush
|
||||
// to ensure that the message is written immediately.
|
||||
#ifdef EXEC_ENV_OLS
|
||||
static const int BUFFER_SIZE = 1024;
|
||||
#else
|
||||
static const int BUFFER_SIZE = 0;
|
||||
#endif
|
||||
|
||||
ANNStreamBuf(const ANNStreamBuf&);
|
||||
ANNStreamBuf& operator=(const ANNStreamBuf&);
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <mkl.h>
|
||||
|
||||
#include "common_includes.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace math_utils {
|
||||
|
||||
float calc_distance(float* vec_1, float* vec_2, size_t dim);
|
||||
|
||||
// compute l2-squared norms of data stored in row major num_points * dim,
|
||||
// needs
|
||||
// to be pre-allocated
|
||||
void compute_vecs_l2sq(float* vecs_l2sq, float* data, const size_t num_points,
|
||||
const size_t dim);
|
||||
|
||||
void rotate_data_randomly(float* data, size_t num_points, size_t dim,
|
||||
float* rot_mat, float*& new_mat,
|
||||
bool transpose_rot = false);
|
||||
|
||||
// calculate closest center to data of num_points * dim (row major)
|
||||
// centers is num_centers * dim (row major)
|
||||
// data_l2sq has pre-computed squared norms of data
|
||||
// centers_l2sq has pre-computed squared norms of centers
|
||||
// pre-allocated center_index will contain id of k nearest centers
|
||||
// pre-allocated dist_matrix shound be num_points * num_centers and contain
|
||||
// squared distances
|
||||
|
||||
// Ideally used only by compute_closest_centers
|
||||
void compute_closest_centers_in_block(
|
||||
const float* const data, const size_t num_points, const size_t dim,
|
||||
const float* const centers, const size_t num_centers,
|
||||
const float* const docs_l2sq, const float* const centers_l2sq,
|
||||
uint32_t* center_index, float* const dist_matrix, size_t k = 1);
|
||||
|
||||
// Given data in num_points * new_dim row major
|
||||
// Pivots stored in full_pivot_data as k * new_dim row major
|
||||
// Calculate the closest pivot for each point and store it in vector
|
||||
// closest_centers_ivf (which needs to be allocated outside)
|
||||
// Additionally, if inverted index is not null (and pre-allocated), it will
|
||||
// return inverted index for each center Additionally, if pts_norms_squared is
|
||||
// not null, then it will assume that point norms are pre-computed and use
|
||||
// those
|
||||
// values
|
||||
|
||||
void compute_closest_centers(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers, size_t k,
|
||||
uint32_t* closest_centers_ivf,
|
||||
std::vector<size_t>* inverted_index = NULL,
|
||||
float* pts_norms_squared = NULL);
|
||||
|
||||
// if to_subtract is 1, will subtract nearest center from each row. Else will
|
||||
// add. Output will be in data_load iself.
|
||||
// Nearest centers need to be provided in closst_centers.
|
||||
|
||||
void process_residuals(float* data_load, size_t num_points, size_t dim,
|
||||
float* cur_pivot_data, size_t num_centers,
|
||||
uint32_t* closest_centers, bool to_subtract);
|
||||
|
||||
} // namespace math_utils
|
||||
|
||||
namespace kmeans {
|
||||
|
||||
// run Lloyds one iteration
|
||||
// Given data in row major num_points * dim, and centers in row major
|
||||
// num_centers * dim
|
||||
// And squared lengths of data points, output the closest center to each data
|
||||
// point, update centers, and also return inverted index.
|
||||
// If closest_centers == NULL, will allocate memory and return.
|
||||
// Similarly, if closest_docs == NULL, will allocate memory and return.
|
||||
|
||||
float lloyds_iter(float* data, size_t num_points, size_t dim, float* centers,
|
||||
size_t num_centers, float* docs_l2sq,
|
||||
std::vector<size_t>* closest_docs,
|
||||
uint32_t*& closest_center);
|
||||
|
||||
// Run Lloyds until max_reps or stopping criterion
|
||||
// If you pass NULL for closest_docs and closest_center, it will NOT return
|
||||
// the results, else it will assume appriate allocation as closest_docs = new
|
||||
// vector<size_t> [num_centers], and closest_center = new size_t[num_points]
|
||||
// Final centers are output in centers as row major num_centers * dim
|
||||
//
|
||||
float run_lloyds(float* data, size_t num_points, size_t dim, float* centers,
|
||||
const size_t num_centers, const size_t max_reps,
|
||||
std::vector<size_t>* closest_docs, uint32_t* closest_center);
|
||||
|
||||
// assumes already memory allocated for pivot_data as new
|
||||
// float[num_centers*dim] and select randomly num_centers points as pivots
|
||||
void selecting_pivots(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers);
|
||||
|
||||
void kmeanspp_selecting_pivots(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers);
|
||||
} // namespace kmeans
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
#include <string>
|
||||
|
||||
namespace diskann {
|
||||
class MemoryMapper {
|
||||
private:
|
||||
#ifndef _WINDOWS
|
||||
int _fd;
|
||||
#else
|
||||
HANDLE _bareFile;
|
||||
HANDLE _fd;
|
||||
|
||||
#endif
|
||||
char* _buf;
|
||||
size_t _fileSize;
|
||||
const char* _fileName;
|
||||
|
||||
public:
|
||||
MemoryMapper(const char* filename);
|
||||
MemoryMapper(const std::string& filename);
|
||||
|
||||
char* getBuf();
|
||||
size_t getFileSize();
|
||||
|
||||
~MemoryMapper();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include "utils.h"
|
||||
|
||||
namespace diskann {
|
||||
|
||||
struct Neighbor {
|
||||
unsigned id;
|
||||
float distance;
|
||||
bool flag;
|
||||
|
||||
Neighbor() = default;
|
||||
Neighbor(unsigned id, float distance, bool f)
|
||||
: id{id}, distance{distance}, flag(f) {
|
||||
}
|
||||
|
||||
inline bool operator<(const Neighbor &other) const {
|
||||
return distance < other.distance;
|
||||
}
|
||||
inline bool operator==(const Neighbor &other) const {
|
||||
return (id == other.id);
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::lock_guard<std::mutex> LockGuard;
|
||||
struct nhood {
|
||||
std::mutex lock;
|
||||
std::vector<Neighbor> pool;
|
||||
unsigned M;
|
||||
|
||||
std::vector<unsigned> nn_old;
|
||||
std::vector<unsigned> nn_new;
|
||||
std::vector<unsigned> rnn_old;
|
||||
std::vector<unsigned> rnn_new;
|
||||
|
||||
nhood() {
|
||||
}
|
||||
nhood(unsigned l, unsigned s, std::mt19937 &rng, unsigned N) {
|
||||
M = s;
|
||||
nn_new.resize(s * 2);
|
||||
GenRandom(rng, &nn_new[0], (unsigned) nn_new.size(), N);
|
||||
nn_new.reserve(s * 2);
|
||||
pool.reserve(l);
|
||||
}
|
||||
|
||||
nhood(const nhood &other) {
|
||||
M = other.M;
|
||||
std::copy(other.nn_new.begin(), other.nn_new.end(),
|
||||
std::back_inserter(nn_new));
|
||||
nn_new.reserve(other.nn_new.capacity());
|
||||
pool.reserve(other.pool.capacity());
|
||||
}
|
||||
void insert(unsigned id, float dist) {
|
||||
LockGuard guard(lock);
|
||||
if (dist > pool.front().distance)
|
||||
return;
|
||||
for (unsigned i = 0; i < pool.size(); i++) {
|
||||
if (id == pool[i].id)
|
||||
return;
|
||||
}
|
||||
if (pool.size() < pool.capacity()) {
|
||||
pool.push_back(Neighbor(id, dist, true));
|
||||
std::push_heap(pool.begin(), pool.end());
|
||||
} else {
|
||||
std::pop_heap(pool.begin(), pool.end());
|
||||
pool[pool.size() - 1] = Neighbor(id, dist, true);
|
||||
std::push_heap(pool.begin(), pool.end());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename C>
|
||||
void join(C callback) const {
|
||||
for (unsigned const i : nn_new) {
|
||||
for (unsigned const j : nn_new) {
|
||||
if (i < j) {
|
||||
callback(i, j);
|
||||
}
|
||||
}
|
||||
for (unsigned j : nn_old) {
|
||||
callback(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct SimpleNeighbor {
|
||||
unsigned id;
|
||||
float distance;
|
||||
|
||||
SimpleNeighbor() = default;
|
||||
SimpleNeighbor(unsigned id, float distance) : id(id), distance(distance) {
|
||||
}
|
||||
|
||||
inline bool operator<(const SimpleNeighbor &other) const {
|
||||
return distance < other.distance;
|
||||
}
|
||||
|
||||
inline bool operator==(const SimpleNeighbor &other) const {
|
||||
return id == other.id;
|
||||
}
|
||||
};
|
||||
struct SimpleNeighbors {
|
||||
std::vector<SimpleNeighbor> pool;
|
||||
};
|
||||
|
||||
static inline unsigned InsertIntoPool(Neighbor *addr, unsigned K,
|
||||
Neighbor nn) {
|
||||
// find the location to insert
|
||||
unsigned left = 0, right = K - 1;
|
||||
if (addr[left].distance > nn.distance) {
|
||||
memmove((char *) &addr[left + 1], &addr[left], K * sizeof(Neighbor));
|
||||
addr[left] = nn;
|
||||
return left;
|
||||
}
|
||||
if (addr[right].distance < nn.distance) {
|
||||
addr[K] = nn;
|
||||
return K;
|
||||
}
|
||||
while (right > 1 && left < right - 1) {
|
||||
unsigned mid = (left + right) / 2;
|
||||
if (addr[mid].distance > nn.distance)
|
||||
right = mid;
|
||||
else
|
||||
left = mid;
|
||||
}
|
||||
// check equal ID
|
||||
|
||||
while (left > 0) {
|
||||
if (addr[left].distance < nn.distance)
|
||||
break;
|
||||
if (addr[left].id == nn.id)
|
||||
return K + 1;
|
||||
left--;
|
||||
}
|
||||
if (addr[left].id == nn.id || addr[right].id == nn.id)
|
||||
return K + 1;
|
||||
memmove((char *) &addr[right + 1], &addr[right],
|
||||
(K - right) * sizeof(Neighbor));
|
||||
addr[right] = nn;
|
||||
return right;
|
||||
}
|
||||
} // namespace diskann
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
#include <sstream>
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace diskann {
|
||||
|
||||
class Parameters {
|
||||
public:
|
||||
Parameters() {
|
||||
int *p = new int;
|
||||
*p = 0;
|
||||
params["num_threads"] = p;
|
||||
}
|
||||
|
||||
template<typename ParamType>
|
||||
inline void Set(const std::string &name, const ParamType &value) {
|
||||
// ParamType *ptr = (ParamType *) malloc(sizeof(ParamType));
|
||||
ParamType *ptr = new ParamType;
|
||||
*ptr = value;
|
||||
params[name] = (void *) ptr;
|
||||
}
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType Get(const std::string &name) const {
|
||||
auto item = params.find(name);
|
||||
if (item == params.end()) {
|
||||
throw std::invalid_argument("Invalid parameter name.");
|
||||
} else {
|
||||
// return ConvertStrToValue<ParamType>(item->second);
|
||||
if (item->second == nullptr) {
|
||||
throw std::invalid_argument(std::string("Parameter ") + name +
|
||||
" has value null.");
|
||||
} else {
|
||||
return *(static_cast<ParamType *>(item->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType Get(const std::string &name,
|
||||
const ParamType & default_value) {
|
||||
try {
|
||||
return Get<ParamType>(name);
|
||||
} catch (std::invalid_argument e) {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
~Parameters() {
|
||||
for (auto iter = params.begin(); iter != params.end(); iter++) {
|
||||
if (iter->second != nullptr)
|
||||
free(iter->second);
|
||||
// delete iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, void *> params;
|
||||
|
||||
Parameters(const Parameters &);
|
||||
Parameters &operator=(const Parameters &);
|
||||
|
||||
template<typename ParamType>
|
||||
inline ParamType ConvertStrToValue(const std::string &str) const {
|
||||
std::stringstream sstream(str);
|
||||
ParamType value;
|
||||
if (!(sstream >> value) || !sstream.eof()) {
|
||||
std::stringstream err;
|
||||
err << "Failed to convert value '" << str
|
||||
<< "' to type: " << typeid(value).name();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "neighbor.h"
|
||||
#include "parameters.h"
|
||||
#include "tsl/robin_set.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include "windows_customizations.h"
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const std::string base_file,
|
||||
const std::string output_prefix, double sampling_rate);
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const std::string data_file, double p_val,
|
||||
float *&sampled_data, size_t &slice_size, size_t &ndims);
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const T *inputdata, size_t npts, size_t ndims,
|
||||
double p_val, float *&sampled_data, size_t &slice_size);
|
||||
|
||||
template<typename T>
|
||||
int estimate_cluster_sizes(const std::string data_file, float *pivots,
|
||||
const size_t num_centers, const size_t dim,
|
||||
const size_t k_base,
|
||||
std::vector<size_t> &cluster_sizes);
|
||||
|
||||
template<typename T>
|
||||
int shard_data_into_clusters(const std::string data_file, float *pivots,
|
||||
const size_t num_centers, const size_t dim,
|
||||
const size_t k_base, std::string prefix_path);
|
||||
|
||||
template<typename T>
|
||||
int partition(const std::string data_file, const float sampling_rate,
|
||||
size_t num_centers, size_t max_k_means_reps,
|
||||
const std::string prefix_path, size_t k_base);
|
||||
|
||||
template<typename T>
|
||||
int partition_with_ram_budget(const std::string data_file,
|
||||
const double sampling_rate, double ram_budget,
|
||||
size_t graph_degree,
|
||||
const std::string prefix_path, size_t k_base);
|
||||
|
||||
DISKANN_DLLEXPORT int generate_pq_pivots(const float *train_data,
|
||||
size_t num_train, unsigned dim,
|
||||
unsigned num_centers,
|
||||
unsigned num_pq_chunks,
|
||||
unsigned max_k_means_reps,
|
||||
std::string pq_pivots_path);
|
||||
|
||||
template<typename T>
|
||||
int generate_pq_data_from_pivots(const std::string data_file,
|
||||
unsigned num_centers, unsigned num_pq_chunks,
|
||||
std::string pq_pivots_path,
|
||||
std::string pq_compressed_vectors_path);
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#ifdef _WINDOWS
|
||||
#include <numeric>
|
||||
#endif
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "distance.h"
|
||||
#include "parameters.h"
|
||||
|
||||
namespace diskann {
|
||||
struct QueryStats {
|
||||
double total_us = 0; // total time to process query in micros
|
||||
double n_4k = 0; // # of 4kB reads
|
||||
double n_8k = 0; // # of 8kB reads
|
||||
double n_12k = 0; // # of 12kB reads
|
||||
double n_ios = 0; // total # of IOs issued
|
||||
double read_size = 0; // total # of bytes read
|
||||
double io_us = 0; // total time spent in IO
|
||||
double cpu_us = 0; // total time spent in CPU
|
||||
double n_cmps_saved = 0; // # cmps saved
|
||||
double n_cmps = 0; // # cmps
|
||||
double n_cache_hits = 0; // # cache_hits
|
||||
double n_hops = 0; // # search hops
|
||||
};
|
||||
|
||||
inline double get_percentile_stats(
|
||||
QueryStats *stats, uint64_t len, float percentile,
|
||||
const std::function<double(const QueryStats &)> &member_fn) {
|
||||
std::vector<double> vals(len);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
vals[i] = member_fn(stats[i]);
|
||||
}
|
||||
|
||||
std::sort(
|
||||
vals.begin(), vals.end(),
|
||||
[](const double &left, const double &right) { return left < right; });
|
||||
|
||||
auto retval = vals[(uint64_t)(percentile * len)];
|
||||
vals.clear();
|
||||
return retval;
|
||||
}
|
||||
|
||||
inline double get_mean_stats(
|
||||
QueryStats *stats, uint64_t len,
|
||||
const std::function<double(const QueryStats &)> &member_fn) {
|
||||
double avg = 0;
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
avg += member_fn(stats[i]);
|
||||
}
|
||||
return avg / len;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace diskann {
|
||||
template<typename T>
|
||||
class FixedChunkPQTable {
|
||||
// data_dim = n_chunks * chunk_size;
|
||||
float* tables =
|
||||
nullptr; // pq_tables = float* [[2^8 * [chunk_size]] * n_chunks]
|
||||
// _u64 n_chunks; // n_chunks = # of chunks ndims is split into
|
||||
// _u64 chunk_size; // chunk_size = chunk size of each dimension chunk
|
||||
_u64 ndims; // ndims = chunk_size * n_chunks
|
||||
_u64 n_chunks;
|
||||
_u32* chunk_offsets = nullptr;
|
||||
_u32* rearrangement = nullptr;
|
||||
float* centroid = nullptr;
|
||||
float* tables_T = nullptr; // same as pq_tables, but col-major
|
||||
public:
|
||||
FixedChunkPQTable() {
|
||||
}
|
||||
|
||||
virtual ~FixedChunkPQTable() {
|
||||
#ifndef EXEC_ENV_OLS
|
||||
if (tables != nullptr)
|
||||
delete[] tables;
|
||||
if (tables_T != nullptr)
|
||||
delete[] tables_T;
|
||||
if (rearrangement != nullptr)
|
||||
delete[] rearrangement;
|
||||
if (chunk_offsets != nullptr)
|
||||
delete[] chunk_offsets;
|
||||
if (centroid != nullptr)
|
||||
delete[] centroid;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
void load_pq_centroid_bin(MemoryMappedFiles& files,
|
||||
const char* pq_table_file, size_t num_chunks){
|
||||
#else
|
||||
void load_pq_centroid_bin(const char* pq_table_file, size_t num_chunks) {
|
||||
#endif
|
||||
std::string rearrangement_file = std::string(pq_table_file) +
|
||||
"_rearrangement_perm.bin";
|
||||
std::string chunk_offset_file =
|
||||
std::string(pq_table_file) + "_chunk_offsets.bin";
|
||||
std::string centroid_file = std::string(pq_table_file) + "_centroid.bin";
|
||||
|
||||
// bin structure: [256][ndims][ndims(float)]
|
||||
uint64_t numr, numc;
|
||||
size_t npts_u64, ndims_u64;
|
||||
#ifdef EXEC_ENV_OLS
|
||||
diskann::load_bin<float>(files, pq_table_file, tables, npts_u64, ndims_u64);
|
||||
#else
|
||||
diskann::load_bin<float>(pq_table_file, tables, npts_u64, ndims_u64);
|
||||
#endif
|
||||
this->ndims = ndims_u64;
|
||||
|
||||
if (file_exists(chunk_offset_file)) {
|
||||
#ifdef EXEC_ENV_OLS
|
||||
diskann::load_bin<_u32>(files, rearrangement_file, rearrangement, numr,
|
||||
numc);
|
||||
#else
|
||||
diskann::load_bin<_u32>(rearrangement_file, rearrangement, numr, numc);
|
||||
#endif
|
||||
if (numr != ndims_u64 || numc != 1) {
|
||||
diskann::cerr << "Error loading rearrangement file" << std::endl;
|
||||
throw diskann::ANNException("Error loading rearrangement file", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
diskann::load_bin<_u32>(files, chunk_offset_file, chunk_offsets, numr,
|
||||
numc);
|
||||
#else
|
||||
diskann::load_bin<_u32>(chunk_offset_file, chunk_offsets, numr, numc);
|
||||
#endif
|
||||
if (numc != 1 || numr != num_chunks + 1) {
|
||||
diskann::cerr << "Error loading chunk offsets file. numc: " << numc
|
||||
<< " (should be 1). numr: " << numr << " (should be "
|
||||
<< num_chunks + 1 << ")" << std::endl;
|
||||
throw diskann::ANNException("Error loading chunk offsets file", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
this->n_chunks = numr - 1;
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
diskann::load_bin<float>(files, centroid_file, centroid, numr, numc);
|
||||
#else
|
||||
diskann::load_bin<float>(centroid_file, centroid, numr, numc);
|
||||
#endif
|
||||
if (numc != 1 || numr != ndims_u64) {
|
||||
diskann::cerr << "Error loading centroid file" << std::endl;
|
||||
throw diskann::ANNException("Error loading centroid file", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
} else {
|
||||
this->n_chunks = num_chunks;
|
||||
rearrangement = new uint32_t[ndims];
|
||||
|
||||
uint64_t chunk_size = DIV_ROUND_UP(ndims, num_chunks);
|
||||
for (uint32_t d = 0; d < ndims; d++)
|
||||
rearrangement[d] = d;
|
||||
chunk_offsets = new uint32_t[num_chunks + 1];
|
||||
for (uint32_t d = 0; d <= num_chunks; d++)
|
||||
chunk_offsets[d] = (_u32)(std::min)(ndims, d * chunk_size);
|
||||
centroid = new float[ndims];
|
||||
std::memset(centroid, 0, ndims * sizeof(float));
|
||||
}
|
||||
|
||||
diskann::cout << "PQ Pivots: #ctrs: " << npts_u64
|
||||
<< ", #dims: " << ndims_u64 << ", #chunks: " << n_chunks
|
||||
<< std::endl;
|
||||
// assert((_u64) ndims_u32 == n_chunks * chunk_size);
|
||||
// alloc and compute transpose
|
||||
tables_T = new float[256 * ndims_u64];
|
||||
for (_u64 i = 0; i < 256; i++) {
|
||||
for (_u64 j = 0; j < ndims_u64; j++) {
|
||||
tables_T[j * 256 + i] = tables[i * ndims_u64 + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
populate_chunk_distances(const T* query_vec, float* dist_vec) {
|
||||
memset(dist_vec, 0, 256 * n_chunks * sizeof(float));
|
||||
// chunk wise distance computation
|
||||
for (_u64 chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// sum (q-c)^2 for the dimensions associated with this chunk
|
||||
float* chunk_dists = dist_vec + (256 * chunk);
|
||||
for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) {
|
||||
_u64 permuted_dim_in_query = rearrangement[j];
|
||||
const float* centers_dim_vec = tables_T + (256 * j);
|
||||
for (_u64 idx = 0; idx < 256; idx++) {
|
||||
// Gopal. Fixing crash in v14 machines.
|
||||
// float diff = centers_dim_vec[idx] -
|
||||
// ((float) query_vec[permuted_dim_in_query] -
|
||||
// centroid[permuted_dim_in_query]);
|
||||
// chunk_dists[idx] += (diff * diff);
|
||||
double diff =
|
||||
centers_dim_vec[idx] - (query_vec[permuted_dim_in_query] -
|
||||
centroid[permuted_dim_in_query]);
|
||||
chunk_dists[idx] += (float) (diff * diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace diskann
|
|
@ -0,0 +1,22 @@
|
|||
#include <chrono>
|
||||
|
||||
namespace diskann {
|
||||
class Timer {
|
||||
typedef std::chrono::high_resolution_clock _clock;
|
||||
std::chrono::time_point<_clock> check_point;
|
||||
|
||||
public:
|
||||
Timer() : check_point(_clock::now()) {
|
||||
}
|
||||
|
||||
void reset() {
|
||||
check_point = _clock::now();
|
||||
}
|
||||
|
||||
long long elapsed() const {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
_clock::now() - check_point)
|
||||
.count();
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,330 @@
|
|||
/**
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2017 Tessil
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
#ifndef TSL_ROBIN_GROWTH_POLICY_H
|
||||
#define TSL_ROBIN_GROWTH_POLICY_H
|
||||
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <ratio>
|
||||
#include <stdexcept>
|
||||
|
||||
|
||||
#ifndef tsl_assert
|
||||
# ifdef TSL_DEBUG
|
||||
# define tsl_assert(expr) assert(expr)
|
||||
# else
|
||||
# define tsl_assert(expr) (static_cast<void>(0))
|
||||
# endif
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate.
|
||||
*/
|
||||
#ifndef TSL_THROW_OR_TERMINATE
|
||||
# if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS)
|
||||
# define TSL_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
|
||||
# else
|
||||
# ifdef NDEBUG
|
||||
# define TSL_THROW_OR_TERMINATE(ex, msg) std::terminate()
|
||||
# else
|
||||
# include <cstdio>
|
||||
# define TSL_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0)
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
|
||||
#ifndef TSL_LIKELY
|
||||
# if defined(__GNUC__) || defined(__clang__)
|
||||
# define TSL_LIKELY(exp) (__builtin_expect(!!(exp), true))
|
||||
# else
|
||||
# define TSL_LIKELY(exp) (exp)
|
||||
# endif
|
||||
#endif
|
||||
|
||||
|
||||
namespace tsl {
|
||||
namespace rh {
|
||||
|
||||
/**
|
||||
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
|
||||
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
|
||||
*
|
||||
* GrowthFactor must be a power of two >= 2.
|
||||
*/
|
||||
template<std::size_t GrowthFactor>
|
||||
class power_of_two_growth_policy {
|
||||
public:
|
||||
/**
|
||||
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
|
||||
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
|
||||
*
|
||||
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
|
||||
* bucket_for_hash must always return 0 in this case.
|
||||
*/
|
||||
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
|
||||
if(min_bucket_count_in_out > max_bucket_count()) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
if(min_bucket_count_in_out > 0) {
|
||||
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
|
||||
m_mask = min_bucket_count_in_out - 1;
|
||||
}
|
||||
else {
|
||||
m_mask = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the bucket [0, bucket_count()) to which the hash belongs.
|
||||
* If bucket_count() is 0, it must always return 0.
|
||||
*/
|
||||
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
|
||||
return hash & m_mask;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the number of buckets that should be used on next growth.
|
||||
*/
|
||||
std::size_t next_bucket_count() const {
|
||||
if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
return (m_mask + 1) * GrowthFactor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the maximum number of buckets supported by the policy.
|
||||
*/
|
||||
std::size_t max_bucket_count() const {
|
||||
// Largest power of two.
|
||||
return ((std::numeric_limits<std::size_t>::max)() / 2) + 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the growth policy as if it was created with a bucket count of 0.
|
||||
* After a clear, the policy must always return 0 when bucket_for_hash is called.
|
||||
*/
|
||||
void clear() noexcept {
|
||||
m_mask = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::size_t round_up_to_power_of_two(std::size_t value) {
|
||||
if(is_power_of_two(value)) {
|
||||
return value;
|
||||
}
|
||||
|
||||
if(value == 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
--value;
|
||||
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
|
||||
value |= value >> i;
|
||||
}
|
||||
|
||||
return value + 1;
|
||||
}
|
||||
|
||||
static constexpr bool is_power_of_two(std::size_t value) {
|
||||
return value != 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
|
||||
|
||||
std::size_t m_mask;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
|
||||
* to a bucket. Slower but it can be useful if you want a slower growth.
|
||||
*/
|
||||
template<class GrowthFactor = std::ratio<3, 2>>
|
||||
class mod_growth_policy {
|
||||
public:
|
||||
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
|
||||
if(min_bucket_count_in_out > max_bucket_count()) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
if(min_bucket_count_in_out > 0) {
|
||||
m_mod = min_bucket_count_in_out;
|
||||
}
|
||||
else {
|
||||
m_mod = 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
|
||||
return hash % m_mod;
|
||||
}
|
||||
|
||||
std::size_t next_bucket_count() const {
|
||||
if(m_mod == max_bucket_count()) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
|
||||
if(!std::isnormal(next_bucket_count)) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
if(next_bucket_count > double(max_bucket_count())) {
|
||||
return max_bucket_count();
|
||||
}
|
||||
else {
|
||||
return std::size_t(next_bucket_count);
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t max_bucket_count() const {
|
||||
return MAX_BUCKET_COUNT;
|
||||
}
|
||||
|
||||
void clear() noexcept {
|
||||
m_mod = 1;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
|
||||
static const std::size_t MAX_BUCKET_COUNT =
|
||||
std::size_t(double(
|
||||
(std::numeric_limits<std::size_t>::max)() / REHASH_SIZE_MULTIPLICATION_FACTOR
|
||||
));
|
||||
|
||||
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
|
||||
|
||||
std::size_t m_mod;
|
||||
};
|
||||
|
||||
|
||||
|
||||
namespace detail {
|
||||
|
||||
static constexpr const std::array<std::size_t, 40> PRIMES = {{
|
||||
1ul, 5ul, 17ul, 29ul, 37ul, 53ul, 67ul, 79ul, 97ul, 131ul, 193ul, 257ul, 389ul, 521ul, 769ul, 1031ul,
|
||||
1543ul, 2053ul, 3079ul, 6151ul, 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, 393241ul, 786433ul,
|
||||
1572869ul, 3145739ul, 6291469ul, 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
|
||||
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul
|
||||
}};
|
||||
|
||||
template<unsigned int IPrime>
|
||||
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
|
||||
|
||||
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
|
||||
// compiler can optimize the modulo code better with a constant known at the compilation.
|
||||
static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {{
|
||||
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
|
||||
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
|
||||
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
|
||||
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>
|
||||
}};
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in
|
||||
* general but will probably distribute the values around better in the buckets with a poor hash function.
|
||||
*
|
||||
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
|
||||
*
|
||||
* With a switch the code would look like:
|
||||
* \code
|
||||
* switch(iprime) { // iprime is the current prime of the hash table
|
||||
* case 0: hash % 5ul;
|
||||
* break;
|
||||
* case 1: hash % 17ul;
|
||||
* break;
|
||||
* case 2: hash % 29ul;
|
||||
* break;
|
||||
* ...
|
||||
* }
|
||||
* \endcode
|
||||
*
|
||||
* Due to the constant variable in the modulo the compiler is able to optimize the operation
|
||||
* by a series of multiplications, substractions and shifts.
|
||||
*
|
||||
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
|
||||
*/
|
||||
class prime_growth_policy {
|
||||
public:
|
||||
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
|
||||
auto it_prime = std::lower_bound(detail::PRIMES.begin(),
|
||||
detail::PRIMES.end(), min_bucket_count_in_out);
|
||||
if(it_prime == detail::PRIMES.end()) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
|
||||
if(min_bucket_count_in_out > 0) {
|
||||
min_bucket_count_in_out = *it_prime;
|
||||
}
|
||||
else {
|
||||
min_bucket_count_in_out = 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
|
||||
return detail::MOD_PRIME[m_iprime](hash);
|
||||
}
|
||||
|
||||
std::size_t next_bucket_count() const {
|
||||
if(m_iprime + 1 >= detail::PRIMES.size()) {
|
||||
TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
|
||||
}
|
||||
|
||||
return detail::PRIMES[m_iprime + 1];
|
||||
}
|
||||
|
||||
std::size_t max_bucket_count() const {
|
||||
return detail::PRIMES.back();
|
||||
}
|
||||
|
||||
void clear() noexcept {
|
||||
m_iprime = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned int m_iprime;
|
||||
|
||||
static_assert((std::numeric_limits<decltype(m_iprime)>::max)() >= detail::PRIMES.size(),
|
||||
"The type of m_iprime is not big enough.");
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,668 @@
|
|||
/**
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2017 Tessil
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
#ifndef TSL_ROBIN_MAP_H
|
||||
#define TSL_ROBIN_MAP_H
|
||||
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "robin_hash.h"
|
||||
|
||||
|
||||
namespace tsl {
|
||||
|
||||
|
||||
/**
|
||||
* Implementation of a hash map using open-adressing and the robin hood hashing algorithm with backward shift deletion.
|
||||
*
|
||||
* For operations modifying the hash map (insert, erase, rehash, ...), the strong exception guarantee
|
||||
* is only guaranteed when the expression `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
|
||||
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true, otherwise if an exception
|
||||
* is thrown during the swap or the move, the hash map may end up in a undefined state. Per the standard
|
||||
* a `Key` or `T` with a noexcept copy constructor and no move constructor also satisfies the
|
||||
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and will thus guarantee the
|
||||
* strong exception for the map).
|
||||
*
|
||||
* When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
|
||||
* the performance during lookups if the `KeyEqual` function takes time (if it engenders a cache-miss for example)
|
||||
* as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
|
||||
* as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
|
||||
* When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
|
||||
* `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
|
||||
* sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
|
||||
* used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
|
||||
* not be used on lookups unless `StoreHash` is true).
|
||||
*
|
||||
* `GrowthPolicy` defines how the map grows and consequently how a hash value is mapped to a bucket.
|
||||
* By default the map uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
|
||||
* to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
|
||||
* Other growth policies are available and you may define your own growth policy,
|
||||
* check `tsl::rh::power_of_two_growth_policy` for the interface.
|
||||
*
|
||||
* If the destructor of `Key` or `T` throws an exception, the behaviour of the class is undefined.
|
||||
*
|
||||
* Iterators invalidation:
|
||||
* - clear, operator=, reserve, rehash: always invalidate the iterators.
|
||||
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
|
||||
* - erase: always invalidate the iterators.
|
||||
*/
|
||||
template<class Key,
|
||||
class T,
|
||||
class Hash = std::hash<Key>,
|
||||
class KeyEqual = std::equal_to<Key>,
|
||||
class Allocator = std::allocator<std::pair<Key, T>>,
|
||||
bool StoreHash = false,
|
||||
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
|
||||
class robin_map {
|
||||
private:
|
||||
template<typename U>
|
||||
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
|
||||
|
||||
class KeySelect {
|
||||
public:
|
||||
using key_type = Key;
|
||||
|
||||
const key_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
|
||||
return key_value.first;
|
||||
}
|
||||
|
||||
key_type& operator()(std::pair<Key, T>& key_value) noexcept {
|
||||
return key_value.first;
|
||||
}
|
||||
};
|
||||
|
||||
class ValueSelect {
|
||||
public:
|
||||
using value_type = T;
|
||||
|
||||
const value_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
|
||||
return key_value.second;
|
||||
}
|
||||
|
||||
value_type& operator()(std::pair<Key, T>& key_value) noexcept {
|
||||
return key_value.second;
|
||||
}
|
||||
};
|
||||
|
||||
using ht = detail_robin_hash::robin_hash<std::pair<Key, T>, KeySelect, ValueSelect,
|
||||
Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
|
||||
|
||||
public:
|
||||
using key_type = typename ht::key_type;
|
||||
using mapped_type = T;
|
||||
using value_type = typename ht::value_type;
|
||||
using size_type = typename ht::size_type;
|
||||
using difference_type = typename ht::difference_type;
|
||||
using hasher = typename ht::hasher;
|
||||
using key_equal = typename ht::key_equal;
|
||||
using allocator_type = typename ht::allocator_type;
|
||||
using reference = typename ht::reference;
|
||||
using const_reference = typename ht::const_reference;
|
||||
using pointer = typename ht::pointer;
|
||||
using const_pointer = typename ht::const_pointer;
|
||||
using iterator = typename ht::iterator;
|
||||
using const_iterator = typename ht::const_iterator;
|
||||
|
||||
|
||||
public:
|
||||
/*
|
||||
* Constructors
|
||||
*/
|
||||
robin_map(): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {
|
||||
}
|
||||
|
||||
explicit robin_map(size_type bucket_count,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()):
|
||||
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map(size_type bucket_count,
|
||||
const Allocator& alloc): robin_map(bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map(size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc): robin_map(bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
explicit robin_map(const Allocator& alloc): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_map(InputIt first, InputIt last,
|
||||
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()): robin_map(bucket_count, hash, equal, alloc)
|
||||
{
|
||||
insert(first, last);
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_map(InputIt first, InputIt last,
|
||||
size_type bucket_count,
|
||||
const Allocator& alloc): robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_map(InputIt first, InputIt last,
|
||||
size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc): robin_map(first, last, bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map(std::initializer_list<value_type> init,
|
||||
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()):
|
||||
robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map(std::initializer_list<value_type> init,
|
||||
size_type bucket_count,
|
||||
const Allocator& alloc):
|
||||
robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map(std::initializer_list<value_type> init,
|
||||
size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc):
|
||||
robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_map& operator=(std::initializer_list<value_type> ilist) {
|
||||
m_ht.clear();
|
||||
|
||||
m_ht.reserve(ilist.size());
|
||||
m_ht.insert(ilist.begin(), ilist.end());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
allocator_type get_allocator() const { return m_ht.get_allocator(); }
|
||||
|
||||
|
||||
/*
|
||||
* Iterators
|
||||
*/
|
||||
iterator begin() noexcept { return m_ht.begin(); }
|
||||
const_iterator begin() const noexcept { return m_ht.begin(); }
|
||||
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
|
||||
|
||||
iterator end() noexcept { return m_ht.end(); }
|
||||
const_iterator end() const noexcept { return m_ht.end(); }
|
||||
const_iterator cend() const noexcept { return m_ht.cend(); }
|
||||
|
||||
|
||||
/*
|
||||
* Capacity
|
||||
*/
|
||||
bool empty() const noexcept { return m_ht.empty(); }
|
||||
size_type size() const noexcept { return m_ht.size(); }
|
||||
size_type max_size() const noexcept { return m_ht.max_size(); }
|
||||
|
||||
/*
|
||||
* Modifiers
|
||||
*/
|
||||
void clear() noexcept { m_ht.clear(); }
|
||||
|
||||
|
||||
|
||||
std::pair<iterator, bool> insert(const value_type& value) {
|
||||
return m_ht.insert(value);
|
||||
}
|
||||
|
||||
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
|
||||
std::pair<iterator, bool> insert(P&& value) {
|
||||
return m_ht.emplace(std::forward<P>(value));
|
||||
}
|
||||
|
||||
std::pair<iterator, bool> insert(value_type&& value) {
|
||||
return m_ht.insert(std::move(value));
|
||||
}
|
||||
|
||||
|
||||
iterator insert(const_iterator hint, const value_type& value) {
|
||||
return m_ht.insert(hint, value);
|
||||
}
|
||||
|
||||
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
|
||||
iterator insert(const_iterator hint, P&& value) {
|
||||
return m_ht.emplace_hint(hint, std::forward<P>(value));
|
||||
}
|
||||
|
||||
iterator insert(const_iterator hint, value_type&& value) {
|
||||
return m_ht.insert(hint, std::move(value));
|
||||
}
|
||||
|
||||
|
||||
template<class InputIt>
|
||||
void insert(InputIt first, InputIt last) {
|
||||
m_ht.insert(first, last);
|
||||
}
|
||||
|
||||
void insert(std::initializer_list<value_type> ilist) {
|
||||
m_ht.insert(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<class M>
|
||||
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
|
||||
return m_ht.insert_or_assign(k, std::forward<M>(obj));
|
||||
}
|
||||
|
||||
template<class M>
|
||||
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
|
||||
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
|
||||
}
|
||||
|
||||
template<class M>
|
||||
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
|
||||
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
|
||||
}
|
||||
|
||||
template<class M>
|
||||
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
|
||||
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
|
||||
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
|
||||
*
|
||||
* Mainly here for compatibility with the std::unordered_map interface.
|
||||
*/
|
||||
template<class... Args>
|
||||
std::pair<iterator, bool> emplace(Args&&... args) {
|
||||
return m_ht.emplace(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
|
||||
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
|
||||
*
|
||||
* Mainly here for compatibility with the std::unordered_map interface.
|
||||
*/
|
||||
template<class... Args>
|
||||
iterator emplace_hint(const_iterator hint, Args&&... args) {
|
||||
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<class... Args>
|
||||
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
|
||||
return m_ht.try_emplace(k, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
|
||||
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
|
||||
return m_ht.try_emplace(hint, k, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
|
||||
return m_ht.try_emplace(hint, std::move(k), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
iterator erase(iterator pos) { return m_ht.erase(pos); }
|
||||
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
|
||||
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
|
||||
size_type erase(const key_type& key) { return m_ht.erase(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
|
||||
*/
|
||||
size_type erase(const key_type& key, std::size_t precalculated_hash) {
|
||||
return m_ht.erase(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type erase(const K& key) { return m_ht.erase(key); }
|
||||
|
||||
/**
|
||||
* @copydoc erase(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type erase(const K& key, std::size_t precalculated_hash) {
|
||||
return m_ht.erase(key, precalculated_hash);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void swap(robin_map& other) { other.m_ht.swap(m_ht); }
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* Lookup
|
||||
*/
|
||||
T& at(const Key& key) { return m_ht.at(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
|
||||
|
||||
|
||||
const T& at(const Key& key) const { return m_ht.at(key); }
|
||||
|
||||
/**
|
||||
* @copydoc at(const Key& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
|
||||
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
T& at(const K& key) { return m_ht.at(key); }
|
||||
|
||||
/**
|
||||
* @copydoc at(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
|
||||
|
||||
|
||||
/**
|
||||
* @copydoc at(const K& key)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const T& at(const K& key) const { return m_ht.at(key); }
|
||||
|
||||
/**
|
||||
* @copydoc at(const K& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
|
||||
|
||||
|
||||
|
||||
|
||||
T& operator[](const Key& key) { return m_ht[key]; }
|
||||
T& operator[](Key&& key) { return m_ht[std::move(key)]; }
|
||||
|
||||
|
||||
|
||||
|
||||
size_type count(const Key& key) const { return m_ht.count(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
size_type count(const Key& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.count(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type count(const K& key) const { return m_ht.count(key); }
|
||||
|
||||
/**
|
||||
* @copydoc count(const K& key) const
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
|
||||
|
||||
|
||||
|
||||
|
||||
iterator find(const Key& key) { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
const_iterator find(const Key& key) const { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const Key& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
const_iterator find(const Key& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.find(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
iterator find(const K& key) { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const_iterator find(const K& key) const { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const_iterator find(const K& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.find(key, precalculated_hash);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
|
||||
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* Bucket interface
|
||||
*/
|
||||
size_type bucket_count() const { return m_ht.bucket_count(); }
|
||||
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
|
||||
|
||||
|
||||
/*
|
||||
* Hash policy
|
||||
*/
|
||||
float load_factor() const { return m_ht.load_factor(); }
|
||||
float max_load_factor() const { return m_ht.max_load_factor(); }
|
||||
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
|
||||
|
||||
void rehash(size_type count) { m_ht.rehash(count); }
|
||||
void reserve(size_type count) { m_ht.reserve(count); }
|
||||
|
||||
|
||||
/*
|
||||
* Observers
|
||||
*/
|
||||
hasher hash_function() const { return m_ht.hash_function(); }
|
||||
key_equal key_eq() const { return m_ht.key_eq(); }
|
||||
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
|
||||
/**
|
||||
* Convert a const_iterator to an iterator.
|
||||
*/
|
||||
iterator mutable_iterator(const_iterator pos) {
|
||||
return m_ht.mutable_iterator(pos);
|
||||
}
|
||||
|
||||
friend bool operator==(const robin_map& lhs, const robin_map& rhs) {
|
||||
if(lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(const auto& element_lhs: lhs) {
|
||||
const auto it_element_rhs = rhs.find(element_lhs.first);
|
||||
if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
friend bool operator!=(const robin_map& lhs, const robin_map& rhs) {
|
||||
return !operator==(lhs, rhs);
|
||||
}
|
||||
|
||||
friend void swap(robin_map& lhs, robin_map& rhs) {
|
||||
lhs.swap(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
ht m_ht;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
|
||||
*/
|
||||
template<class Key,
|
||||
class T,
|
||||
class Hash = std::hash<Key>,
|
||||
class KeyEqual = std::equal_to<Key>,
|
||||
class Allocator = std::allocator<std::pair<Key, T>>,
|
||||
bool StoreHash = false>
|
||||
using robin_pg_map = robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
|
||||
|
||||
} // end namespace tsl
|
||||
|
||||
#endif
|
|
@ -0,0 +1,535 @@
|
|||
/**
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2017 Tessil
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
#ifndef TSL_ROBIN_SET_H
|
||||
#define TSL_ROBIN_SET_H
|
||||
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "robin_hash.h"
|
||||
|
||||
|
||||
namespace tsl {
|
||||
|
||||
|
||||
/**
|
||||
* Implementation of a hash set using open-adressing and the robin hood hashing algorithm with backward shift deletion.
|
||||
*
|
||||
* For operations modifying the hash set (insert, erase, rehash, ...), the strong exception guarantee
|
||||
* is only guaranteed when the expression `std::is_nothrow_swappable<Key>::value &&
|
||||
* std::is_nothrow_move_constructible<Key>::value` is true, otherwise if an exception
|
||||
* is thrown during the swap or the move, the hash set may end up in a undefined state. Per the standard
|
||||
* a `Key` with a noexcept copy constructor and no move constructor also satisfies the
|
||||
* `std::is_nothrow_move_constructible<Key>::value` criterion (and will thus guarantee the
|
||||
* strong exception for the set).
|
||||
*
|
||||
* When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
|
||||
* the performance during lookups if the `KeyEqual` function takes time (or engenders a cache-miss for example)
|
||||
* as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
|
||||
* as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
|
||||
* When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
|
||||
* `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
|
||||
* sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
|
||||
* used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
|
||||
* not be used on lookups unless `StoreHash` is true).
|
||||
*
|
||||
* `GrowthPolicy` defines how the set grows and consequently how a hash value is mapped to a bucket.
|
||||
* By default the set uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
|
||||
* to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo.
|
||||
* Other growth policies are available and you may define your own growth policy,
|
||||
* check `tsl::rh::power_of_two_growth_policy` for the interface.
|
||||
*
|
||||
* If the destructor of `Key` throws an exception, the behaviour of the class is undefined.
|
||||
*
|
||||
* Iterators invalidation:
|
||||
* - clear, operator=, reserve, rehash: always invalidate the iterators.
|
||||
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
|
||||
* - erase: always invalidate the iterators.
|
||||
*/
|
||||
template<class Key,
|
||||
class Hash = std::hash<Key>,
|
||||
class KeyEqual = std::equal_to<Key>,
|
||||
class Allocator = std::allocator<Key>,
|
||||
bool StoreHash = false,
|
||||
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
|
||||
class robin_set {
|
||||
private:
|
||||
template<typename U>
|
||||
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
|
||||
|
||||
class KeySelect {
|
||||
public:
|
||||
using key_type = Key;
|
||||
|
||||
const key_type& operator()(const Key& key) const noexcept {
|
||||
return key;
|
||||
}
|
||||
|
||||
key_type& operator()(Key& key) noexcept {
|
||||
return key;
|
||||
}
|
||||
};
|
||||
|
||||
using ht = detail_robin_hash::robin_hash<Key, KeySelect, void,
|
||||
Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
|
||||
|
||||
public:
|
||||
using key_type = typename ht::key_type;
|
||||
using value_type = typename ht::value_type;
|
||||
using size_type = typename ht::size_type;
|
||||
using difference_type = typename ht::difference_type;
|
||||
using hasher = typename ht::hasher;
|
||||
using key_equal = typename ht::key_equal;
|
||||
using allocator_type = typename ht::allocator_type;
|
||||
using reference = typename ht::reference;
|
||||
using const_reference = typename ht::const_reference;
|
||||
using pointer = typename ht::pointer;
|
||||
using const_pointer = typename ht::const_pointer;
|
||||
using iterator = typename ht::iterator;
|
||||
using const_iterator = typename ht::const_iterator;
|
||||
|
||||
|
||||
/*
|
||||
* Constructors
|
||||
*/
|
||||
robin_set(): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {
|
||||
}
|
||||
|
||||
explicit robin_set(size_type bucket_count,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()):
|
||||
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
|
||||
{
|
||||
}
|
||||
|
||||
robin_set(size_type bucket_count,
|
||||
const Allocator& alloc): robin_set(bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_set(size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc): robin_set(bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
explicit robin_set(const Allocator& alloc): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_set(InputIt first, InputIt last,
|
||||
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()): robin_set(bucket_count, hash, equal, alloc)
|
||||
{
|
||||
insert(first, last);
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_set(InputIt first, InputIt last,
|
||||
size_type bucket_count,
|
||||
const Allocator& alloc): robin_set(first, last, bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
robin_set(InputIt first, InputIt last,
|
||||
size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc): robin_set(first, last, bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_set(std::initializer_list<value_type> init,
|
||||
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
|
||||
const Hash& hash = Hash(),
|
||||
const KeyEqual& equal = KeyEqual(),
|
||||
const Allocator& alloc = Allocator()):
|
||||
robin_set(init.begin(), init.end(), bucket_count, hash, equal, alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_set(std::initializer_list<value_type> init,
|
||||
size_type bucket_count,
|
||||
const Allocator& alloc):
|
||||
robin_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
robin_set(std::initializer_list<value_type> init,
|
||||
size_type bucket_count,
|
||||
const Hash& hash,
|
||||
const Allocator& alloc):
|
||||
robin_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
robin_set& operator=(std::initializer_list<value_type> ilist) {
|
||||
m_ht.clear();
|
||||
|
||||
m_ht.reserve(ilist.size());
|
||||
m_ht.insert(ilist.begin(), ilist.end());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
allocator_type get_allocator() const { return m_ht.get_allocator(); }
|
||||
|
||||
|
||||
/*
|
||||
* Iterators
|
||||
*/
|
||||
iterator begin() noexcept { return m_ht.begin(); }
|
||||
const_iterator begin() const noexcept { return m_ht.begin(); }
|
||||
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
|
||||
|
||||
iterator end() noexcept { return m_ht.end(); }
|
||||
const_iterator end() const noexcept { return m_ht.end(); }
|
||||
const_iterator cend() const noexcept { return m_ht.cend(); }
|
||||
|
||||
|
||||
/*
|
||||
* Capacity
|
||||
*/
|
||||
bool empty() const noexcept { return m_ht.empty(); }
|
||||
size_type size() const noexcept { return m_ht.size(); }
|
||||
size_type max_size() const noexcept { return m_ht.max_size(); }
|
||||
|
||||
/*
|
||||
* Modifiers
|
||||
*/
|
||||
void clear() noexcept { m_ht.clear(); }
|
||||
|
||||
|
||||
|
||||
|
||||
std::pair<iterator, bool> insert(const value_type& value) {
|
||||
return m_ht.insert(value);
|
||||
}
|
||||
|
||||
std::pair<iterator, bool> insert(value_type&& value) {
|
||||
return m_ht.insert(std::move(value));
|
||||
}
|
||||
|
||||
iterator insert(const_iterator hint, const value_type& value) {
|
||||
return m_ht.insert(hint, value);
|
||||
}
|
||||
|
||||
iterator insert(const_iterator hint, value_type&& value) {
|
||||
return m_ht.insert(hint, std::move(value));
|
||||
}
|
||||
|
||||
template<class InputIt>
|
||||
void insert(InputIt first, InputIt last) {
|
||||
m_ht.insert(first, last);
|
||||
}
|
||||
|
||||
void insert(std::initializer_list<value_type> ilist) {
|
||||
m_ht.insert(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
|
||||
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
|
||||
*
|
||||
* Mainly here for compatibility with the std::unordered_map interface.
|
||||
*/
|
||||
template<class... Args>
|
||||
std::pair<iterator, bool> emplace(Args&&... args) {
|
||||
return m_ht.emplace(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
|
||||
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
|
||||
*
|
||||
* Mainly here for compatibility with the std::unordered_map interface.
|
||||
*/
|
||||
template<class... Args>
|
||||
iterator emplace_hint(const_iterator hint, Args&&... args) {
|
||||
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
iterator erase(iterator pos) { return m_ht.erase(pos); }
|
||||
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
|
||||
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
|
||||
size_type erase(const key_type& key) { return m_ht.erase(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
|
||||
*/
|
||||
size_type erase(const key_type& key, std::size_t precalculated_hash) {
|
||||
return m_ht.erase(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type erase(const K& key) { return m_ht.erase(key); }
|
||||
|
||||
/**
|
||||
* @copydoc erase(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type erase(const K& key, std::size_t precalculated_hash) {
|
||||
return m_ht.erase(key, precalculated_hash);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void swap(robin_set& other) { other.m_ht.swap(m_ht); }
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* Lookup
|
||||
*/
|
||||
size_type count(const Key& key) const { return m_ht.count(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type count(const K& key) const { return m_ht.count(key); }
|
||||
|
||||
/**
|
||||
* @copydoc count(const K& key) const
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
|
||||
|
||||
|
||||
|
||||
|
||||
iterator find(const Key& key) { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
const_iterator find(const Key& key) const { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const Key& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
iterator find(const K& key) { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const_iterator find(const K& key) const { return m_ht.find(key); }
|
||||
|
||||
/**
|
||||
* @copydoc find(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
|
||||
|
||||
|
||||
|
||||
|
||||
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
|
||||
* If so, K must be hashable and comparable to Key.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key)
|
||||
*
|
||||
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
|
||||
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
|
||||
|
||||
/**
|
||||
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
|
||||
*/
|
||||
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
|
||||
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
|
||||
return m_ht.equal_range(key, precalculated_hash);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* Bucket interface
|
||||
*/
|
||||
size_type bucket_count() const { return m_ht.bucket_count(); }
|
||||
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
|
||||
|
||||
|
||||
/*
|
||||
* Hash policy
|
||||
*/
|
||||
float load_factor() const { return m_ht.load_factor(); }
|
||||
float max_load_factor() const { return m_ht.max_load_factor(); }
|
||||
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
|
||||
|
||||
void rehash(size_type count) { m_ht.rehash(count); }
|
||||
void reserve(size_type count) { m_ht.reserve(count); }
|
||||
|
||||
|
||||
/*
|
||||
* Observers
|
||||
*/
|
||||
hasher hash_function() const { return m_ht.hash_function(); }
|
||||
key_equal key_eq() const { return m_ht.key_eq(); }
|
||||
|
||||
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
|
||||
/**
|
||||
* Convert a const_iterator to an iterator.
|
||||
*/
|
||||
iterator mutable_iterator(const_iterator pos) {
|
||||
return m_ht.mutable_iterator(pos);
|
||||
}
|
||||
|
||||
friend bool operator==(const robin_set& lhs, const robin_set& rhs) {
|
||||
if(lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(const auto& element_lhs: lhs) {
|
||||
const auto it_element_rhs = rhs.find(element_lhs);
|
||||
if(it_element_rhs == rhs.cend()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
friend bool operator!=(const robin_set& lhs, const robin_set& rhs) {
|
||||
return !operator==(lhs, rhs);
|
||||
}
|
||||
|
||||
friend void swap(robin_set& lhs, robin_set& rhs) {
|
||||
lhs.swap(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
ht m_ht;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Same as `tsl::robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
|
||||
*/
|
||||
template<class Key,
|
||||
class Hash = std::hash<Key>,
|
||||
class KeyEqual = std::equal_to<Key>,
|
||||
class Allocator = std::allocator<Key>,
|
||||
bool StoreHash = false>
|
||||
using robin_pg_set = robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
|
||||
|
||||
} // end namespace tsl
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,538 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <fcntl.h>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#ifdef __APPLE__
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <Windows.h>
|
||||
typedef HANDLE FileHandle;
|
||||
#else
|
||||
#include <unistd.h>
|
||||
typedef int FileHandle;
|
||||
#endif
|
||||
|
||||
#include "logger.h"
|
||||
#include "cached_io.h"
|
||||
#include "common_includes.h"
|
||||
#include "windows_customizations.h"
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
#include "content_buf.h"
|
||||
#include "memory_mapped_files.h"
|
||||
#endif
|
||||
|
||||
// taken from
|
||||
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
|
||||
// round up X to the nearest multiple of Y
|
||||
#define ROUND_UP(X, Y) \
|
||||
((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y))
|
||||
|
||||
#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0))
|
||||
|
||||
// round down X to the nearest multiple of Y
|
||||
#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y))
|
||||
|
||||
// alignment tests
|
||||
#define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0)
|
||||
#define IS_512_ALIGNED(X) IS_ALIGNED(X, 512)
|
||||
#define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096)
|
||||
|
||||
typedef uint64_t _u64;
|
||||
typedef int64_t _s64;
|
||||
typedef uint32_t _u32;
|
||||
typedef int32_t _s32;
|
||||
typedef uint16_t _u16;
|
||||
typedef int16_t _s16;
|
||||
typedef uint8_t _u8;
|
||||
typedef int8_t _s8;
|
||||
|
||||
namespace diskann {
|
||||
static const size_t MAX_SIZE_OF_STREAMBUF = 2LL * 1024 * 1024 * 1024;
|
||||
|
||||
enum Metric { L2 = 0, INNER_PRODUCT = 1, FAST_L2 = 2, PQ = 3 };
|
||||
|
||||
inline void alloc_aligned(void** ptr, size_t size, size_t align) {
|
||||
*ptr = nullptr;
|
||||
assert(IS_ALIGNED(size, align));
|
||||
#ifndef _WINDOWS
|
||||
*ptr = ::aligned_alloc(align, size);
|
||||
#else
|
||||
*ptr = ::_aligned_malloc(size, align); // note the swapped arguments!
|
||||
#endif
|
||||
assert(*ptr != nullptr);
|
||||
}
|
||||
|
||||
inline void aligned_free(void* ptr) {
|
||||
// Gopal. Must have a check here if the pointer was actually allocated by
|
||||
// _alloc_aligned
|
||||
if (ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
#ifndef _WINDOWS
|
||||
free(ptr);
|
||||
#else
|
||||
::_aligned_free(ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void GenRandom(std::mt19937& rng, unsigned* addr, unsigned size,
|
||||
unsigned N) {
|
||||
for (unsigned i = 0; i < size; ++i) {
|
||||
addr[i] = rng() % (N - size);
|
||||
}
|
||||
|
||||
std::sort(addr, addr + size);
|
||||
for (unsigned i = 1; i < size; ++i) {
|
||||
if (addr[i] <= addr[i - 1]) {
|
||||
addr[i] = addr[i - 1] + 1;
|
||||
}
|
||||
}
|
||||
unsigned off = rng() % N;
|
||||
for (unsigned i = 0; i < size; ++i) {
|
||||
addr[i] = (addr[i] + off) % N;
|
||||
}
|
||||
}
|
||||
|
||||
// get_bin_metadata functions START
|
||||
inline void get_bin_metadata_impl(std::basic_istream<char>& reader,
|
||||
size_t& nrows, size_t& ncols) {
|
||||
int nrows_32, ncols_32;
|
||||
reader.read((char*) &nrows_32, sizeof(int));
|
||||
reader.read((char*) &ncols_32, sizeof(int));
|
||||
nrows = nrows_32;
|
||||
ncols = ncols_32;
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
inline void get_bin_metadata(MemoryMappedFiles& files,
|
||||
const std::string& bin_file, size_t& nrows,
|
||||
size_t& ncols) {
|
||||
diskann::cout << "Getting metadata for file: " << bin_file << std::endl;
|
||||
auto fc = files.getContent(bin_file);
|
||||
auto cb = ContentBuf((char*) fc._content, fc._size);
|
||||
std::basic_istream<char> reader(&cb);
|
||||
get_bin_metadata_impl(reader, nrows, ncols);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void get_bin_metadata(const std::string& bin_file, size_t& nrows,
|
||||
size_t& ncols) {
|
||||
std::ifstream reader(bin_file.c_str(), std::ios::binary);
|
||||
get_bin_metadata_impl(reader, nrows, ncols);
|
||||
}
|
||||
// get_bin_metadata functions END
|
||||
|
||||
template<typename T>
|
||||
inline std::string getValues(T* data, size_t num) {
|
||||
std::stringstream stream;
|
||||
stream << "[";
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
stream << std::to_string(data[i]) << ",";
|
||||
}
|
||||
stream << "]" << std::endl;
|
||||
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
// load_bin functions START
|
||||
template<typename T>
|
||||
inline void load_bin_impl(std::basic_istream<char>& reader,
|
||||
size_t actual_file_size, T*& data, size_t& npts,
|
||||
size_t& dim) {
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char*) &npts_i32, sizeof(int));
|
||||
reader.read((char*) &dim_i32, sizeof(int));
|
||||
npts = (unsigned) npts_i32;
|
||||
dim = (unsigned) dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..."
|
||||
<< std::endl;
|
||||
|
||||
size_t expected_actual_file_size =
|
||||
npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size) {
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size
|
||||
<< " while expected size is " << expected_actual_file_size
|
||||
<< " npts = " << npts << " dim = " << dim
|
||||
<< " size of <T>= " << sizeof(T) << std::endl;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
|
||||
data = new T[npts * dim];
|
||||
reader.read((char*) data, npts * dim * sizeof(T));
|
||||
|
||||
// diskann::cout << "Last bytes: "
|
||||
// << getValues<T>(data + (npts - 2) * dim, dim);
|
||||
// diskann::cout << "Finished reading bin file." << std::endl;
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
inline void load_bin(MemoryMappedFiles& files, const std::string& bin_file,
|
||||
T*& data, size_t& npts, size_t& dim) {
|
||||
diskann::cout << "Reading bin file " << bin_file.c_str() << " ..."
|
||||
<< std::endl;
|
||||
|
||||
auto fc = files.getContent(bin_file);
|
||||
|
||||
uint32_t t_npts, t_dim;
|
||||
uint32_t* contentAsIntPtr = (uint32_t*) (fc._content);
|
||||
t_npts = *(contentAsIntPtr);
|
||||
t_dim = *(contentAsIntPtr + 1);
|
||||
|
||||
npts = t_npts;
|
||||
dim = t_dim;
|
||||
|
||||
auto actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != fc._size) {
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << fc._size
|
||||
<< " while expected size is " << actual_file_size
|
||||
<< " npts = " << npts << " dim = " << dim
|
||||
<< " size of <T>= " << sizeof(T) << std::endl;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
|
||||
data =
|
||||
(T*) ((char*) fc._content + 2 * sizeof(uint32_t)); // No need to copy!
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
inline void load_bin(const std::string& bin_file, T*& data, size_t& npts,
|
||||
size_t& dim) {
|
||||
// OLS
|
||||
//_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
// cached_ifstream reader(bin_file, read_blk_size);
|
||||
// size_t actual_file_size = reader.get_file_size();
|
||||
// END OLS
|
||||
diskann::cout << "Reading bin file " << bin_file.c_str() << " ..."
|
||||
<< std::endl;
|
||||
std::ifstream reader(bin_file, std::ios::binary | std::ios::ate);
|
||||
uint64_t fsize = reader.tellg();
|
||||
reader.seekg(0);
|
||||
|
||||
load_bin_impl<T>(reader, fsize, data, npts, dim);
|
||||
}
|
||||
// load_bin functions END
|
||||
|
||||
inline void load_truthset(const std::string& bin_file, uint32_t*& ids,
|
||||
float*& dists, size_t& npts, size_t& dim) {
|
||||
_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..."
|
||||
<< std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char*) &npts_i32, sizeof(int));
|
||||
reader.read((char*) &dim_i32, sizeof(int));
|
||||
npts = (unsigned) npts_i32;
|
||||
dim = (unsigned) dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..."
|
||||
<< std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists =
|
||||
2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids =
|
||||
npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1) {
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size
|
||||
<< ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char*) ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1) {
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char*) dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
inline void load_bin(MemoryMappedFiles& files, const std::string& bin_file,
|
||||
std::unique_ptr<T[]>& data, size_t& npts, size_t& dim) {
|
||||
T* ptr;
|
||||
load_bin<T>(files, bin_file, ptr, npts, dim);
|
||||
data.reset(ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
inline void load_bin(const std::string& bin_file, std::unique_ptr<T[]>& data,
|
||||
size_t& npts, size_t& dim) {
|
||||
T* ptr;
|
||||
load_bin<T>(bin_file, ptr, npts, dim);
|
||||
data.reset(ptr);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void save_bin(const std::string& filename, T* data, size_t npts,
|
||||
size_t ndims) {
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
diskann::cout << "Writing bin: " << filename.c_str() << std::endl;
|
||||
int npts_i32 = (int) npts, ndims_i32 = (int) ndims;
|
||||
writer.write((char*) &npts_i32, sizeof(int));
|
||||
writer.write((char*) &ndims_i32, sizeof(int));
|
||||
diskann::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
// data = new T[npts_u64 * ndims_u64];
|
||||
writer.write((char*) data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
diskann::cout << "Finished writing bin." << std::endl;
|
||||
}
|
||||
|
||||
// load_aligned_bin functions START
|
||||
|
||||
template<typename T>
|
||||
inline void load_aligned_bin_impl(std::basic_istream<char>& reader,
|
||||
size_t actual_file_size, T*& data,
|
||||
size_t& npts, size_t& dim,
|
||||
size_t& rounded_dim) {
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char*) &npts_i32, sizeof(int));
|
||||
reader.read((char*) &dim_i32, sizeof(int));
|
||||
npts = (unsigned) npts_i32;
|
||||
dim = (unsigned) dim_i32;
|
||||
|
||||
size_t expected_actual_file_size =
|
||||
npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size) {
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size
|
||||
<< " while expected size is " << expected_actual_file_size
|
||||
<< " npts = " << npts << " dim = " << dim
|
||||
<< " size of <T>= " << sizeof(T) << std::endl;
|
||||
diskann::cout << stream.str() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
rounded_dim = ROUND_UP(dim, 8);
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim
|
||||
<< ", aligned_dim = " << rounded_dim << "..." << std::flush;
|
||||
size_t allocSize = npts * rounded_dim * sizeof(T);
|
||||
diskann::cout << "allocating aligned memory, " << allocSize << " bytes..."
|
||||
<< std::flush;
|
||||
alloc_aligned(((void**) &data), allocSize, 8 * sizeof(T));
|
||||
diskann::cout << "done. Copying data..." << std::flush;
|
||||
|
||||
for (size_t i = 0; i < npts; i++) {
|
||||
reader.read((char*) (data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
diskann::cout << " done." << std::endl;
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
inline void load_aligned_bin(MemoryMappedFiles& files,
|
||||
const std::string& bin_file, T*& data,
|
||||
size_t& npts, size_t& dim, size_t& rounded_dim) {
|
||||
diskann::cout << "Reading bin file " << bin_file << " ..." << std::flush;
|
||||
FileContent fc = files.getContent(bin_file);
|
||||
ContentBuf buf((char*) fc._content, fc._size);
|
||||
std::basic_istream<char> reader(&buf);
|
||||
|
||||
size_t actual_file_size = fc._size;
|
||||
load_aligned_bin_impl(reader, actual_file_size, data, npts, dim,
|
||||
rounded_dim);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
inline void load_aligned_bin(const std::string& bin_file, T*& data,
|
||||
size_t& npts, size_t& dim, size_t& rounded_dim) {
|
||||
diskann::cout << "Reading bin file " << bin_file << " ..." << std::flush;
|
||||
// START OLS
|
||||
//_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
// cached_ifstream reader(bin_file, read_blk_size);
|
||||
// size_t actual_file_size = reader.get_file_size();
|
||||
// END OLS
|
||||
|
||||
std::ifstream reader(bin_file, std::ios::binary | std::ios::ate);
|
||||
uint64_t fsize = reader.tellg();
|
||||
reader.seekg(0);
|
||||
|
||||
load_aligned_bin_impl(reader, fsize, data, npts, dim, rounded_dim);
|
||||
}
|
||||
|
||||
template<typename InType, typename OutType>
|
||||
void convert_types(const InType* srcmat, OutType* destmat, size_t npts,
|
||||
size_t dim) {
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t i = 0; i < (_s64) npts; i++) {
|
||||
for (uint64_t j = 0; j < dim; j++) {
|
||||
destmat[i * dim + j] = (OutType) srcmat[i * dim + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// plain saves data as npts X ndims array into filename
|
||||
template<typename T>
|
||||
void save_Tvecs(const char* filename, T* data, size_t npts, size_t ndims) {
|
||||
std::string fname(filename);
|
||||
|
||||
// create cached ofstream with 64MB cache
|
||||
cached_ofstream writer(fname, 64 * 1048576);
|
||||
|
||||
unsigned dims_u32 = (unsigned) ndims;
|
||||
|
||||
// start writing
|
||||
for (uint64_t i = 0; i < npts; i++) {
|
||||
// write dims in u32
|
||||
writer.write((char*) &dims_u32, sizeof(unsigned));
|
||||
|
||||
// get cur point in data
|
||||
T* cur_pt = data + i * ndims;
|
||||
writer.write((char*) cur_pt, ndims * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE :: good efficiency when total_vec_size is integral multiple of 64
|
||||
inline void prefetch_vector(const char* vec, size_t vecsize) {
|
||||
size_t max_prefetch_size = (vecsize / 64) * 64;
|
||||
for (size_t d = 0; d < max_prefetch_size; d += 64)
|
||||
_mm_prefetch((const char*) vec + d, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
// NOTE :: good efficiency when total_vec_size is integral multiple of 64
|
||||
inline void prefetch_vector_l2(const char* vec, size_t vecsize) {
|
||||
size_t max_prefetch_size = (vecsize / 64) * 64;
|
||||
for (size_t d = 0; d < max_prefetch_size; d += 64)
|
||||
_mm_prefetch((const char*) vec + d, _MM_HINT_T1);
|
||||
}
|
||||
}; // namespace diskann
|
||||
|
||||
struct PivotContainer {
|
||||
PivotContainer() = default;
|
||||
|
||||
PivotContainer(size_t pivo_id, float pivo_dist)
|
||||
: piv_id{pivo_id}, piv_dist{pivo_dist} {
|
||||
}
|
||||
|
||||
bool operator<(const PivotContainer& p) const {
|
||||
return p.piv_dist < piv_dist;
|
||||
}
|
||||
|
||||
bool operator>(const PivotContainer& p) const {
|
||||
return p.piv_dist > piv_dist;
|
||||
}
|
||||
|
||||
size_t piv_id;
|
||||
float piv_dist;
|
||||
};
|
||||
|
||||
inline bool file_exists(const std::string& name) {
|
||||
struct stat buffer;
|
||||
auto val = stat(name.c_str(), &buffer);
|
||||
diskann::cout << " Stat(" << name.c_str() << ") returned: " << val
|
||||
<< std::endl;
|
||||
return (val == 0);
|
||||
}
|
||||
|
||||
inline _u64 get_file_size(const std::string& fname) {
|
||||
std::ifstream reader(fname, std::ios::binary | std::ios::ate);
|
||||
if (!reader.fail() && reader.is_open()) {
|
||||
_u64 end_pos = reader.tellg();
|
||||
diskann::cout << " Tellg: " << reader.tellg() << " as u64: " << end_pos
|
||||
<< std::endl;
|
||||
reader.close();
|
||||
return end_pos;
|
||||
} else {
|
||||
diskann::cout << "Could not open file: " << fname << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool validate_file_size(const std::string& name) {
|
||||
std::ifstream in(std::string(name), std::ios::binary);
|
||||
in.seekg(0, in.end);
|
||||
size_t actual_file_size = in.tellg();
|
||||
in.seekg(0, in.beg);
|
||||
size_t expected_file_size;
|
||||
in.read((char*) &expected_file_size, sizeof(uint64_t));
|
||||
if (actual_file_size != expected_file_size) {
|
||||
diskann::cout << "Error loading" << name << ". Expected "
|
||||
"size (metadata): "
|
||||
<< expected_file_size
|
||||
<< ", actual file size : " << actual_file_size
|
||||
<< ". Exitting." << std::endl;
|
||||
in.close();
|
||||
return false;
|
||||
}
|
||||
in.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <intrin.h>
|
||||
#include <Psapi.h>
|
||||
|
||||
inline void printProcessMemory(const char* message) {
|
||||
PROCESS_MEMORY_COUNTERS counters;
|
||||
HANDLE h = GetCurrentProcess();
|
||||
GetProcessMemoryInfo(h, &counters, sizeof(counters));
|
||||
diskann::cout << message << " [Peaking Working Set size: "
|
||||
<< counters.PeakWorkingSetSize * 1.0 / (1024 * 1024 * 1024)
|
||||
<< "GB Working set size: "
|
||||
<< counters.WorkingSetSize * 1.0 / (1024 * 1024 * 1024)
|
||||
<< "GB Private bytes "
|
||||
<< counters.PagefileUsage * 1.0 / (1024 * 1024 * 1024) << "GB]"
|
||||
<< std::endl;
|
||||
}
|
||||
#else
|
||||
|
||||
// need to check and change this
|
||||
inline bool avx2Supported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void printProcessMemory(const char* message) {
|
||||
}
|
||||
#endif
|
||||
|
||||
extern bool AvxSupportedCPU;
|
||||
extern bool Avx2SupportedCPU;
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
#include <Windows.h>
|
||||
#include <fcntl.h>
|
||||
#include <malloc.h>
|
||||
#include <minwinbase.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include "aligned_file_reader.h"
|
||||
#include "tsl/robin_map.h"
|
||||
#include "utils.h"
|
||||
|
||||
class WindowsAlignedFileReader : public AlignedFileReader {
|
||||
private:
|
||||
std::wstring m_filename;
|
||||
|
||||
protected:
|
||||
// virtual IOContext createContext();
|
||||
|
||||
public:
|
||||
WindowsAlignedFileReader(){};
|
||||
virtual ~WindowsAlignedFileReader(){};
|
||||
|
||||
// Open & close ops
|
||||
// Blocking calls
|
||||
virtual void open(const std::string &fname);
|
||||
virtual void close();
|
||||
|
||||
virtual void register_thread();
|
||||
virtual void deregister_thread() {
|
||||
}
|
||||
virtual IOContext &get_ctx();
|
||||
|
||||
// process batch of aligned requests in parallel
|
||||
// NOTE :: blocking call for the calling thread, but can thread-safe
|
||||
virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx);
|
||||
};
|
||||
#endif // USE_BING_INFRA
|
||||
#endif //_WINDOWS
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#define DISKANN_DLLEXPORT __declspec(dllexport)
|
||||
#define DISKANN_DLLIMPORT __declspec(dllimport)
|
||||
#else
|
||||
#define DISKANN_DLLEXPORT
|
||||
#define DISKANN_DLLIMPORT
|
||||
#endif
|
|
@ -1,8 +1,16 @@
|
|||
set(CMAKE_CXX_STANDARD 11)
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
file(GLOB_RECURSE CPP_SOURCES *.cpp)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
add_library(${PROJECT_NAME} ${CPP_SOURCES})
|
||||
add_library(${PROJECT_NAME}_s STATIC ${CPP_SOURCES})
|
||||
|
||||
#install()
|
||||
if(MSVC)
|
||||
add_subdirectory(dll)
|
||||
else()
|
||||
#file(GLOB CPP_SOURCES *.cpp)
|
||||
set(CPP_SOURCES ann_exception.cpp aux_utils.cpp index.cpp
|
||||
linux_aligned_file_reader.cpp math_utils.cpp memory_mapper.cpp
|
||||
partition_and_pq.cpp logger.cpp utils.cpp)
|
||||
add_library(${PROJECT_NAME} ${CPP_SOURCES})
|
||||
add_library(${PROJECT_NAME}_s STATIC ${CPP_SOURCES})
|
||||
endif()
|
||||
install()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "ann_exception.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace diskann {
|
||||
ANNException::ANNException(const std::string& message, int errorCode)
|
||||
: _errorCode(errorCode), _message(message), _funcSig(""), _fileName(""),
|
||||
_lineNum(0) {
|
||||
}
|
||||
|
||||
ANNException::ANNException(const std::string& message, int errorCode,
|
||||
const std::string& funcSig,
|
||||
const std::string& fileName, unsigned lineNum)
|
||||
: ANNException(message, errorCode) {
|
||||
_funcSig = funcSig;
|
||||
_fileName = fileName;
|
||||
_lineNum = lineNum;
|
||||
}
|
||||
|
||||
std::string ANNException::message() const {
|
||||
std::stringstream sstream;
|
||||
|
||||
sstream << "Exception: " << _message;
|
||||
if (_funcSig != "")
|
||||
sstream << ". occurred at: " << _funcSig;
|
||||
if (_fileName != "" && _lineNum != 0)
|
||||
sstream << " defined in file: " << _fileName << " at line: " << _lineNum;
|
||||
if (_errorCode != -1)
|
||||
sstream << ". OS error code: " << std::hex << _errorCode;
|
||||
|
||||
return sstream.str();
|
||||
}
|
||||
|
||||
} // namespace diskann
|
|
@ -0,0 +1,825 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "logger.h"
|
||||
#include "aux_utils.h"
|
||||
#include "cached_io.h"
|
||||
#include "index.h"
|
||||
#include "mkl.h"
|
||||
#include "omp.h"
|
||||
#include "partition_and_pq.h"
|
||||
#include "percentile_stats.h"
|
||||
//#include "pq_flash_index.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace diskann {
|
||||
|
||||
double get_memory_budget(const std::string &mem_budget_str) {
|
||||
double mem_ram_budget = atof(mem_budget_str.c_str());
|
||||
double final_index_ram_limit = mem_ram_budget;
|
||||
if (mem_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB >
|
||||
THRESHOLD_FOR_CACHING_IN_GB) { // slack for space used by cached
|
||||
// nodes
|
||||
final_index_ram_limit = mem_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB;
|
||||
}
|
||||
return final_index_ram_limit * 1024 * 1024 * 1024;
|
||||
}
|
||||
|
||||
double calculate_recall(unsigned num_queries, unsigned *gold_std,
|
||||
float *gs_dist, unsigned dim_gs,
|
||||
unsigned *our_results, unsigned dim_or,
|
||||
unsigned recall_at) {
|
||||
double total_recall = 0;
|
||||
std::set<unsigned> gt, res;
|
||||
|
||||
for (size_t i = 0; i < num_queries; i++) {
|
||||
gt.clear();
|
||||
res.clear();
|
||||
unsigned *gt_vec = gold_std + dim_gs * i;
|
||||
unsigned *res_vec = our_results + dim_or * i;
|
||||
size_t tie_breaker = recall_at;
|
||||
if (gs_dist != nullptr) {
|
||||
tie_breaker = recall_at - 1;
|
||||
float *gt_dist_vec = gs_dist + dim_gs * i;
|
||||
while (tie_breaker < dim_gs &&
|
||||
gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1])
|
||||
tie_breaker++;
|
||||
}
|
||||
|
||||
gt.insert(gt_vec, gt_vec + tie_breaker);
|
||||
res.insert(res_vec, res_vec + recall_at);
|
||||
unsigned cur_recall = 0;
|
||||
for (auto &v : gt) {
|
||||
if (res.find(v) != res.end()) {
|
||||
cur_recall++;
|
||||
}
|
||||
}
|
||||
total_recall += cur_recall;
|
||||
}
|
||||
return total_recall / (num_queries) * (100.0 / recall_at);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T *generateRandomWarmup(uint64_t warmup_num, uint64_t warmup_dim,
|
||||
uint64_t warmup_aligned_dim) {
|
||||
T *warmup = nullptr;
|
||||
warmup_num = 100000;
|
||||
diskann::cout << "Generating random warmup file with dim " << warmup_dim
|
||||
<< " and aligned dim " << warmup_aligned_dim << std::flush;
|
||||
diskann::alloc_aligned(((void **) &warmup),
|
||||
warmup_num * warmup_aligned_dim * sizeof(T),
|
||||
8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++) {
|
||||
for (uint32_t d = 0; d < warmup_dim; d++) {
|
||||
warmup[i * warmup_aligned_dim + d] = (T) dis(gen);
|
||||
}
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
return warmup;
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim,
|
||||
uint64_t warmup_aligned_dim) {
|
||||
T * warmup = nullptr;
|
||||
uint64_t file_dim, file_aligned_dim;
|
||||
|
||||
if (files.fileExists(cache_warmup_file)) {
|
||||
diskann::load_aligned_bin<T>(files, cache_warmup_file, warmup, warmup_num,
|
||||
file_dim, file_aligned_dim);
|
||||
if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) {
|
||||
std::stringstream stream;
|
||||
stream << "Mismatched dimensions in sample file. file_dim = "
|
||||
<< file_dim << " file_aligned_dim: " << file_aligned_dim
|
||||
<< " index_dim: " << warmup_dim
|
||||
<< " index_aligned_dim: " << warmup_aligned_dim << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1);
|
||||
}
|
||||
} else {
|
||||
warmup =
|
||||
generateRandomWarmup<T>(warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
return warmup;
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num,
|
||||
uint64_t warmup_dim, uint64_t warmup_aligned_dim) {
|
||||
T * warmup = nullptr;
|
||||
uint64_t file_dim, file_aligned_dim;
|
||||
|
||||
if (file_exists(cache_warmup_file)) {
|
||||
diskann::load_aligned_bin<T>(cache_warmup_file, warmup, warmup_num,
|
||||
file_dim, file_aligned_dim);
|
||||
if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) {
|
||||
std::stringstream stream;
|
||||
stream << "Mismatched dimensions in sample file. file_dim = "
|
||||
<< file_dim << " file_aligned_dim: " << file_aligned_dim
|
||||
<< " index_dim: " << warmup_dim
|
||||
<< " index_aligned_dim: " << warmup_aligned_dim << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1);
|
||||
}
|
||||
} else {
|
||||
warmup =
|
||||
generateRandomWarmup<T>(warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
return warmup;
|
||||
}
|
||||
|
||||
/***************************************************
|
||||
Support for Merging Many Vamana Indices
|
||||
***************************************************/
|
||||
|
||||
void read_idmap(const std::string &fname, std::vector<unsigned> &ivecs) {
|
||||
uint32_t npts32, dim;
|
||||
size_t actual_file_size = get_file_size(fname);
|
||||
std::ifstream reader(fname.c_str(), std::ios::binary);
|
||||
reader.read((char *) &npts32, sizeof(uint32_t));
|
||||
reader.read((char *) &dim, sizeof(uint32_t));
|
||||
if (dim != 1 ||
|
||||
actual_file_size !=
|
||||
((size_t) npts32) * sizeof(uint32_t) + 2 * sizeof(uint32_t)) {
|
||||
std::stringstream stream;
|
||||
stream << "Error reading idmap file. Check if the file is bin file with "
|
||||
"1 dimensional data. Actual: "
|
||||
<< actual_file_size
|
||||
<< ", expected: " << (size_t) npts32 + 2 * sizeof(uint32_t)
|
||||
<< std::endl;
|
||||
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
ivecs.resize(npts32);
|
||||
reader.read((char *) ivecs.data(), ((size_t) npts32) * sizeof(uint32_t));
|
||||
reader.close();
|
||||
}
|
||||
|
||||
int merge_shards(const std::string &vamana_prefix,
|
||||
const std::string &vamana_suffix,
|
||||
const std::string &idmaps_prefix,
|
||||
const std::string &idmaps_suffix, const _u64 nshards,
|
||||
unsigned max_degree, const std::string &output_vamana,
|
||||
const std::string &medoids_file) {
|
||||
// Read ID maps
|
||||
std::vector<std::string> vamana_names(nshards);
|
||||
std::vector<std::vector<unsigned>> idmaps(nshards);
|
||||
for (_u64 shard = 0; shard < nshards; shard++) {
|
||||
vamana_names[shard] =
|
||||
vamana_prefix + std::to_string(shard) + vamana_suffix;
|
||||
read_idmap(idmaps_prefix + std::to_string(shard) + idmaps_suffix,
|
||||
idmaps[shard]);
|
||||
}
|
||||
|
||||
// find max node id
|
||||
_u64 nnodes = 0;
|
||||
_u64 nelems = 0;
|
||||
for (auto &idmap : idmaps) {
|
||||
for (auto &id : idmap) {
|
||||
nnodes = std::max(nnodes, (_u64) id);
|
||||
}
|
||||
nelems += idmap.size();
|
||||
}
|
||||
nnodes++;
|
||||
diskann::cout << "# nodes: " << nnodes << ", max. degree: " << max_degree
|
||||
<< std::endl;
|
||||
|
||||
// compute inverse map: node -> shards
|
||||
std::vector<std::pair<unsigned, unsigned>> node_shard;
|
||||
node_shard.reserve(nelems);
|
||||
for (_u64 shard = 0; shard < nshards; shard++) {
|
||||
diskann::cout << "Creating inverse map -- shard #" << shard << std::endl;
|
||||
for (_u64 idx = 0; idx < idmaps[shard].size(); idx++) {
|
||||
_u64 node_id = idmaps[shard][idx];
|
||||
node_shard.push_back(std::make_pair((_u32) node_id, (_u32) shard));
|
||||
}
|
||||
}
|
||||
std::sort(node_shard.begin(), node_shard.end(), [](const auto &left,
|
||||
const auto &right) {
|
||||
return left.first < right.first ||
|
||||
(left.first == right.first && left.second < right.second);
|
||||
});
|
||||
diskann::cout << "Finished computing node -> shards map" << std::endl;
|
||||
|
||||
// create cached vamana readers
|
||||
std::vector<cached_ifstream> vamana_readers(nshards);
|
||||
for (_u64 i = 0; i < nshards; i++) {
|
||||
vamana_readers[i].open(vamana_names[i], 1024 * 1048576);
|
||||
size_t actual_file_size = get_file_size(vamana_names[i]);
|
||||
size_t expected_file_size;
|
||||
vamana_readers[i].read((char *) &expected_file_size, sizeof(uint64_t));
|
||||
if (actual_file_size != expected_file_size) {
|
||||
std::stringstream stream;
|
||||
stream << "Error in Vamana Index file " << vamana_names[i]
|
||||
<< " Actual file size: " << actual_file_size
|
||||
<< " does not match expected file size: " << expected_file_size
|
||||
<< std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
size_t merged_index_size = 16;
|
||||
// create cached vamana writers
|
||||
cached_ofstream diskann_writer(output_vamana, 1024 * 1048576);
|
||||
diskann_writer.write((char *) &merged_index_size, sizeof(uint64_t));
|
||||
|
||||
unsigned output_width = max_degree;
|
||||
unsigned max_input_width = 0;
|
||||
// read width from each vamana to advance buffer by sizeof(unsigned) bytes
|
||||
for (auto &reader : vamana_readers) {
|
||||
unsigned input_width;
|
||||
reader.read((char *) &input_width, sizeof(unsigned));
|
||||
max_input_width =
|
||||
input_width > max_input_width ? input_width : max_input_width;
|
||||
}
|
||||
|
||||
diskann::cout << "Max input width: " << max_input_width
|
||||
<< ", output width: " << output_width << std::endl;
|
||||
|
||||
diskann_writer.write((char *) &output_width, sizeof(unsigned));
|
||||
std::ofstream medoid_writer(medoids_file.c_str(), std::ios::binary);
|
||||
_u32 nshards_u32 = (_u32) nshards;
|
||||
_u32 one_val = 1;
|
||||
medoid_writer.write((char *) &nshards_u32, sizeof(uint32_t));
|
||||
medoid_writer.write((char *) &one_val, sizeof(uint32_t));
|
||||
|
||||
for (_u64 shard = 0; shard < nshards; shard++) {
|
||||
unsigned medoid;
|
||||
// read medoid
|
||||
vamana_readers[shard].read((char *) &medoid, sizeof(unsigned));
|
||||
// rename medoid
|
||||
medoid = idmaps[shard][medoid];
|
||||
|
||||
medoid_writer.write((char *) &medoid, sizeof(uint32_t));
|
||||
// write renamed medoid
|
||||
if (shard == (nshards - 1)) //--> uncomment if running hierarchical
|
||||
diskann_writer.write((char *) &medoid, sizeof(unsigned));
|
||||
}
|
||||
medoid_writer.close();
|
||||
|
||||
diskann::cout << "Starting merge" << std::endl;
|
||||
|
||||
// Gopal. random_shuffle() is deprecated.
|
||||
std::random_device rng;
|
||||
std::mt19937 urng(rng());
|
||||
|
||||
std::vector<bool> nhood_set(nnodes, 0);
|
||||
std::vector<unsigned> final_nhood;
|
||||
|
||||
unsigned nnbrs = 0, shard_nnbrs = 0;
|
||||
unsigned cur_id = 0;
|
||||
for (const auto &id_shard : node_shard) {
|
||||
unsigned node_id = id_shard.first;
|
||||
unsigned shard_id = id_shard.second;
|
||||
if (cur_id < node_id) {
|
||||
// Gopal. random_shuffle() is deprecated.
|
||||
std::shuffle(final_nhood.begin(), final_nhood.end(), urng);
|
||||
nnbrs =
|
||||
(unsigned) (std::min)(final_nhood.size(), (uint64_t) max_degree);
|
||||
// write into merged ofstream
|
||||
diskann_writer.write((char *) &nnbrs, sizeof(unsigned));
|
||||
diskann_writer.write((char *) final_nhood.data(),
|
||||
nnbrs * sizeof(unsigned));
|
||||
merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned));
|
||||
if (cur_id % 499999 == 1) {
|
||||
diskann::cout << "." << std::flush;
|
||||
}
|
||||
cur_id = node_id;
|
||||
nnbrs = 0;
|
||||
for (auto &p : final_nhood)
|
||||
nhood_set[p] = 0;
|
||||
final_nhood.clear();
|
||||
}
|
||||
// read from shard_id ifstream
|
||||
vamana_readers[shard_id].read((char *) &shard_nnbrs, sizeof(unsigned));
|
||||
std::vector<unsigned> shard_nhood(shard_nnbrs);
|
||||
vamana_readers[shard_id].read((char *) shard_nhood.data(),
|
||||
shard_nnbrs * sizeof(unsigned));
|
||||
|
||||
// rename nodes
|
||||
for (_u64 j = 0; j < shard_nnbrs; j++) {
|
||||
if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) {
|
||||
nhood_set[idmaps[shard_id][shard_nhood[j]]] = 1;
|
||||
final_nhood.emplace_back(idmaps[shard_id][shard_nhood[j]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gopal. random_shuffle() is deprecated.
|
||||
std::shuffle(final_nhood.begin(), final_nhood.end(), urng);
|
||||
nnbrs = (unsigned) (std::min)(final_nhood.size(), (uint64_t) max_degree);
|
||||
// write into merged ofstream
|
||||
diskann_writer.write((char *) &nnbrs, sizeof(unsigned));
|
||||
diskann_writer.write((char *) final_nhood.data(), nnbrs * sizeof(unsigned));
|
||||
merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned));
|
||||
for (auto &p : final_nhood)
|
||||
nhood_set[p] = 0;
|
||||
final_nhood.clear();
|
||||
|
||||
diskann::cout << "Expected size: " << merged_index_size << std::endl;
|
||||
|
||||
diskann_writer.reset();
|
||||
diskann_writer.write((char *) &merged_index_size, sizeof(uint64_t));
|
||||
|
||||
diskann::cout << "Finished merge" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int build_merged_vamana_index(std::string base_file,
|
||||
diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate,
|
||||
double ram_budget, std::string mem_index_path,
|
||||
std::string medoids_file,
|
||||
std::string centroids_file) {
|
||||
size_t base_num, base_dim;
|
||||
diskann::get_bin_metadata(base_file, base_num, base_dim);
|
||||
|
||||
double full_index_ram =
|
||||
ESTIMATE_RAM_USAGE(base_num, base_dim, sizeof(T), R);
|
||||
if (full_index_ram < ram_budget * 1024 * 1024 * 1024) {
|
||||
diskann::cout << "Full index fits in RAM, building in one shot"
|
||||
<< std::endl;
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("L", (unsigned) L);
|
||||
paras.Set<unsigned>("R", (unsigned) R);
|
||||
paras.Set<unsigned>("C", 750);
|
||||
paras.Set<float>("alpha", 2.0f);
|
||||
paras.Set<unsigned>("num_rnds", 2);
|
||||
paras.Set<bool>("saturate_graph", 1);
|
||||
paras.Set<std::string>("save_path", mem_index_path);
|
||||
|
||||
std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
|
||||
std::unique_ptr<diskann::Index<T>>(
|
||||
new diskann::Index<T>(_compareMetric, base_file.c_str()));
|
||||
_pvamanaIndex->build(paras);
|
||||
_pvamanaIndex->save(mem_index_path.c_str());
|
||||
std::remove(medoids_file.c_str());
|
||||
std::remove(centroids_file.c_str());
|
||||
return 0;
|
||||
}
|
||||
std::string merged_index_prefix = mem_index_path + "_tempFiles";
|
||||
int num_parts =
|
||||
partition_with_ram_budget<T>(base_file, sampling_rate, ram_budget,
|
||||
2 * R / 3, merged_index_prefix, 2);
|
||||
|
||||
std::string cur_centroid_filepath = merged_index_prefix + "_centroids.bin";
|
||||
std::rename(cur_centroid_filepath.c_str(), centroids_file.c_str());
|
||||
|
||||
for (int p = 0; p < num_parts; p++) {
|
||||
std::string shard_base_file =
|
||||
merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
|
||||
std::string shard_index_file =
|
||||
merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
|
||||
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>("R", (2 * (R / 3)));
|
||||
paras.Set<unsigned>("C", 750);
|
||||
paras.Set<float>("alpha", 2.0f);
|
||||
paras.Set<unsigned>("num_rnds", 2);
|
||||
paras.Set<bool>("saturate_graph", 1);
|
||||
paras.Set<std::string>("save_path", shard_index_file);
|
||||
|
||||
std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
|
||||
std::unique_ptr<diskann::Index<T>>(
|
||||
new diskann::Index<T>(_compareMetric, shard_base_file.c_str()));
|
||||
_pvamanaIndex->build(paras);
|
||||
_pvamanaIndex->save(shard_index_file.c_str());
|
||||
}
|
||||
|
||||
diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index",
|
||||
merged_index_prefix + "_subshard-", "_ids_uint32.bin",
|
||||
num_parts, R, mem_index_path, medoids_file);
|
||||
|
||||
// delete tempFiles
|
||||
for (int p = 0; p < num_parts; p++) {
|
||||
std::string shard_base_file =
|
||||
merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
|
||||
std::string shard_id_file = merged_index_prefix + "_subshard-" +
|
||||
std::to_string(p) + "_ids_uint32.bin";
|
||||
std::string shard_index_file =
|
||||
merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
|
||||
std::remove(shard_base_file.c_str());
|
||||
std::remove(shard_id_file.c_str());
|
||||
std::remove(shard_index_file.c_str());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// General purpose support for DiskANN interface
|
||||
//
|
||||
//
|
||||
|
||||
// optimizes the beamwidth to maximize QPS for a given L_search subject to
|
||||
// 99.9 latency not blowing up
|
||||
// template<typename T>
|
||||
// uint32_t optimize_beamwidth(
|
||||
// std::unique_ptr<diskann::PQFlashIndex<T>> &pFlashIndex, T
|
||||
// *tuning_sample,
|
||||
// _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
|
||||
// uint32_t nthreads, uint32_t start_bw) {
|
||||
// uint32_t cur_bw = start_bw;
|
||||
// double max_qps = 0;
|
||||
// uint32_t best_bw = start_bw;
|
||||
// bool stop_flag = false;
|
||||
|
||||
// while (!stop_flag) {
|
||||
// std::vector<uint64_t> tuning_sample_result_ids_64(tuning_sample_num,
|
||||
// 0);
|
||||
// std::vector<float> tuning_sample_result_dists(tuning_sample_num,
|
||||
// 0);
|
||||
// diskann::QueryStats * stats = new
|
||||
// diskann::QueryStats[tuning_sample_num];
|
||||
|
||||
// auto s = std::chrono::high_resolution_clock::now();
|
||||
// #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
|
||||
// for (_s64 i = 0; i < (int64_t) tuning_sample_num; i++) {
|
||||
// pFlashIndex->cached_beam_search(
|
||||
// tuning_sample + (i * tuning_sample_aligned_dim), 1, L,
|
||||
// tuning_sample_result_ids_64.data() + (i * 1),
|
||||
// tuning_sample_result_dists.data() + (i * 1), cur_bw, stats +
|
||||
// i);
|
||||
// }
|
||||
// auto e = std::chrono::high_resolution_clock::now();
|
||||
// std::chrono::duration<double> diff = e - s;
|
||||
// double qps = (1.0f * tuning_sample_num) / (1.0f * diff.count());
|
||||
|
||||
// double lat_999 = diskann::get_percentile_stats(
|
||||
// stats, tuning_sample_num, 0.999,
|
||||
// [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
// double mean_latency = diskann::get_mean_stats(
|
||||
// stats, tuning_sample_num,
|
||||
// [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
// // diskann::cout << "For bw: " << cur_bw << " qps: " << qps
|
||||
// // << " max_qps: " << max_qps << " mean_lat: " <<
|
||||
// mean_latency
|
||||
// // << " lat_999: " << lat_999 << std::endl;
|
||||
|
||||
// if (qps > max_qps && lat_999 < (15000) + mean_latency * 2) {
|
||||
// // if (qps > max_qps) {
|
||||
// max_qps = qps;
|
||||
// best_bw = cur_bw;
|
||||
// // diskann::cout<<"cur_bw: " << cur_bw <<", qps: " << qps
|
||||
// <<",
|
||||
// // mean_lat: " << mean_latency/1000<<", 99.9lat: " <<
|
||||
// // lat_999/1000<<std::endl;
|
||||
// cur_bw = (uint32_t)(std::ceil)((float) cur_bw * 1.1);
|
||||
// } else {
|
||||
// stop_flag = true;
|
||||
// // diskann::cout << "Stopping at bw: " << best_bw << " max_qps: "
|
||||
// <<
|
||||
// // max_qps
|
||||
// // << std::endl;
|
||||
// // diskann::cout<<"cur_bw: " << cur_bw <<", qps: " << qps
|
||||
// <<",
|
||||
// // mean_lat: " << mean_latency/1000<<", 99.9lat: " <<
|
||||
// // lat_999/1000<<std::endl;
|
||||
// }
|
||||
// if (cur_bw > 64)
|
||||
// stop_flag = true;
|
||||
|
||||
// delete[] stats;
|
||||
// }
|
||||
// return best_bw;
|
||||
// }
|
||||
|
||||
// template<typename T>
|
||||
// void create_disk_layout(const std::string base_file,
|
||||
// const std::string mem_index_file,
|
||||
// const std::string output_file) {
|
||||
// unsigned npts, ndims;
|
||||
|
||||
// // amount to read or write in one shot
|
||||
// _u64 read_blk_size = 64 * 1024 * 1024;
|
||||
// _u64 write_blk_size = read_blk_size;
|
||||
// cached_ifstream base_reader(base_file, read_blk_size);
|
||||
// base_reader.read((char *) &npts, sizeof(uint32_t));
|
||||
// base_reader.read((char *) &ndims, sizeof(uint32_t));
|
||||
|
||||
// size_t npts_64, ndims_64;
|
||||
// npts_64 = npts;
|
||||
// ndims_64 = ndims;
|
||||
|
||||
// // create cached reader + writer
|
||||
// size_t actual_file_size = get_file_size(mem_index_file);
|
||||
// cached_ifstream vamana_reader(mem_index_file, read_blk_size);
|
||||
// cached_ofstream diskann_writer(output_file, write_blk_size);
|
||||
|
||||
// // metadata: width, medoid
|
||||
// unsigned width_u32, medoid_u32;
|
||||
// size_t index_file_size;
|
||||
|
||||
// vamana_reader.read((char *) &index_file_size, sizeof(uint64_t));
|
||||
// if (index_file_size != actual_file_size) {
|
||||
// std::stringstream stream;
|
||||
// stream << "Vamana Index file size does not match expected size per "
|
||||
// "meta-data."
|
||||
// << " file size from file: " << index_file_size
|
||||
// << " actual file size: " << actual_file_size << std::endl;
|
||||
|
||||
// throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
// __LINE__);
|
||||
// }
|
||||
|
||||
// vamana_reader.read((char *) &width_u32, sizeof(unsigned));
|
||||
// vamana_reader.read((char *) &medoid_u32, sizeof(unsigned));
|
||||
|
||||
// // compute
|
||||
// _u64 medoid, max_node_len, nnodes_per_sector;
|
||||
// npts_64 = (_u64) npts;
|
||||
// medoid = (_u64) medoid_u32;
|
||||
// max_node_len =
|
||||
// (((_u64) width_u32 + 1) * sizeof(unsigned)) + (ndims_64 * sizeof(T));
|
||||
// nnodes_per_sector = SECTOR_LEN / max_node_len;
|
||||
|
||||
// diskann::cout << "medoid: " << medoid << "B" << std::endl;
|
||||
// diskann::cout << "max_node_len: " << max_node_len << "B" << std::endl;
|
||||
// diskann::cout << "nnodes_per_sector: " << nnodes_per_sector << "B"
|
||||
// << std::endl;
|
||||
|
||||
// // SECTOR_LEN buffer for each sector
|
||||
// std::unique_ptr<char[]> sector_buf =
|
||||
// std::make_unique<char[]>(SECTOR_LEN);
|
||||
// std::unique_ptr<char[]> node_buf =
|
||||
// std::make_unique<char[]>(max_node_len);
|
||||
// unsigned &nnbrs = *(unsigned *) (node_buf.get() + ndims_64 * sizeof(T));
|
||||
// unsigned *nhood_buf =
|
||||
// (unsigned *) (node_buf.get() + (ndims_64 * sizeof(T)) +
|
||||
// sizeof(unsigned));
|
||||
|
||||
// // number of sectors (1 for meta data)
|
||||
// _u64 n_sectors = ROUND_UP(npts_64, nnodes_per_sector) /
|
||||
// nnodes_per_sector;
|
||||
// _u64 disk_index_file_size = (n_sectors + 1) * SECTOR_LEN;
|
||||
// // write first sector with metadata
|
||||
// *(_u64 *) (sector_buf.get() + 0 * sizeof(_u64)) = disk_index_file_size;
|
||||
// *(_u64 *) (sector_buf.get() + 1 * sizeof(_u64)) = npts_64;
|
||||
// *(_u64 *) (sector_buf.get() + 2 * sizeof(_u64)) = medoid;
|
||||
// *(_u64 *) (sector_buf.get() + 3 * sizeof(_u64)) = max_node_len;
|
||||
// *(_u64 *) (sector_buf.get() + 4 * sizeof(_u64)) = nnodes_per_sector;
|
||||
// diskann_writer.write(sector_buf.get(), SECTOR_LEN);
|
||||
|
||||
// std::unique_ptr<T[]> cur_node_coords = std::make_unique<T[]>(ndims_64);
|
||||
// diskann::cout << "# sectors: " << n_sectors << std::endl;
|
||||
// _u64 cur_node_id = 0;
|
||||
// for (_u64 sector = 0; sector < n_sectors; sector++) {
|
||||
// if (sector % 100000 == 0) {
|
||||
// diskann::cout << "Sector #" << sector << "written" << std::endl;
|
||||
// }
|
||||
// memset(sector_buf.get(), 0, SECTOR_LEN);
|
||||
// for (_u64 sector_node_id = 0;
|
||||
// sector_node_id < nnodes_per_sector && cur_node_id < npts_64;
|
||||
// sector_node_id++) {
|
||||
// memset(node_buf.get(), 0, max_node_len);
|
||||
// // read cur node's nnbrs
|
||||
// vamana_reader.read((char *) &nnbrs, sizeof(unsigned));
|
||||
|
||||
// // sanity checks on nnbrs
|
||||
// assert(nnbrs > 0);
|
||||
// assert(nnbrs <= width_u32);
|
||||
|
||||
// // read node's nhood
|
||||
// vamana_reader.read((char *) nhood_buf, nnbrs * sizeof(unsigned));
|
||||
|
||||
// // write coords of node first
|
||||
// // T *node_coords = data + ((_u64) ndims_64 * cur_node_id);
|
||||
// base_reader.read((char *) cur_node_coords.get(), sizeof(T) *
|
||||
// ndims_64);
|
||||
// memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T));
|
||||
|
||||
// // write nnbrs
|
||||
// *(unsigned *) (node_buf.get() + ndims_64 * sizeof(T)) = nnbrs;
|
||||
|
||||
// // write nhood next
|
||||
// memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(unsigned),
|
||||
// nhood_buf, nnbrs * sizeof(unsigned));
|
||||
|
||||
// // get offset into sector_buf
|
||||
// char *sector_node_buf =
|
||||
// sector_buf.get() + (sector_node_id * max_node_len);
|
||||
|
||||
// // copy node buf into sector_node_buf
|
||||
// memcpy(sector_node_buf, node_buf.get(), max_node_len);
|
||||
// cur_node_id++;
|
||||
// }
|
||||
// // flush sector to disk
|
||||
// diskann_writer.write(sector_buf.get(), SECTOR_LEN);
|
||||
// }
|
||||
// diskann::cout << "Output file written." << std::endl;
|
||||
// }
|
||||
|
||||
// template<typename T>
|
||||
// bool build_disk_index(const char *dataFilePath, const char *indexFilePath,
|
||||
// const char * indexBuildParameters,
|
||||
// diskann::Metric _compareMetric) {
|
||||
// std::stringstream parser;
|
||||
// parser << std::string(indexBuildParameters);
|
||||
// std::string cur_param;
|
||||
// std::vector<std::string> param_list;
|
||||
// while (parser >> cur_param)
|
||||
// param_list.push_back(cur_param);
|
||||
|
||||
// if (param_list.size() != 5) {
|
||||
// diskann::cout
|
||||
// << "Correct usage of parameters is R (max degree) "
|
||||
// "L (indexing list size, better if >= R) B (RAM limit of final "
|
||||
// "index in "
|
||||
// "GB) M (memory limit while indexing) T (number of threads for "
|
||||
// "indexing)"
|
||||
// << std::endl;
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// std::string index_prefix_path(indexFilePath);
|
||||
// std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
// std::string pq_compressed_vectors_path =
|
||||
// index_prefix_path + "_pq_compressed.bin";
|
||||
// std::string mem_index_path = index_prefix_path + "_mem.index";
|
||||
// std::string disk_index_path = index_prefix_path + "_disk.index";
|
||||
// std::string medoids_path = disk_index_path + "_medoids.bin";
|
||||
// std::string centroids_path = disk_index_path + "_centroids.bin";
|
||||
// std::string sample_base_prefix = index_prefix_path + "_sample";
|
||||
|
||||
// unsigned R = (unsigned) atoi(param_list[0].c_str());
|
||||
// unsigned L = (unsigned) atoi(param_list[1].c_str());
|
||||
|
||||
// double final_index_ram_limit = get_memory_budget(param_list[2]);
|
||||
// if (final_index_ram_limit <= 0) {
|
||||
// std::cerr << "Insufficient memory budget (or string was not in right "
|
||||
// "format). Should be > 0."
|
||||
// << std::endl;
|
||||
// return false;
|
||||
// }
|
||||
// double indexing_ram_budget = (float) atof(param_list[3].c_str());
|
||||
// if (indexing_ram_budget <= 0) {
|
||||
// std::cerr << "Not building index. Please provide more RAM budget"
|
||||
// << std::endl;
|
||||
// return false;
|
||||
// }
|
||||
// _u32 num_threads = (_u32) atoi(param_list[4].c_str());
|
||||
|
||||
// if (num_threads != 0) {
|
||||
// omp_set_num_threads(num_threads);
|
||||
// mkl_set_num_threads(num_threads);
|
||||
// }
|
||||
|
||||
// diskann::cout << "Starting index build: R=" << R << " L=" << L
|
||||
// << " Query RAM budget: " << final_index_ram_limit
|
||||
// << " Indexing ram budget: " << indexing_ram_budget
|
||||
// << " T: " << num_threads << std::endl;
|
||||
|
||||
// auto s = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// size_t points_num, dim;
|
||||
|
||||
// diskann::get_bin_metadata(dataFilePath, points_num, dim);
|
||||
|
||||
// size_t num_pq_chunks =
|
||||
// (size_t)(std::floor)(_u64(final_index_ram_limit / points_num));
|
||||
|
||||
// num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks;
|
||||
// num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks;
|
||||
// num_pq_chunks =
|
||||
// num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks;
|
||||
|
||||
// diskann::cout << "Compressing " << dim << "-dimensional data into "
|
||||
// << num_pq_chunks << " bytes per vector." << std::endl;
|
||||
|
||||
// size_t train_size, train_dim;
|
||||
// float *train_data;
|
||||
|
||||
// double p_val = ((double) TRAINING_SET_SIZE / (double) points_num);
|
||||
// // generates random sample and sets it to train_data and updates
|
||||
// train_size
|
||||
// gen_random_slice<T>(dataFilePath, p_val, train_data, train_size,
|
||||
// train_dim);
|
||||
|
||||
// diskann::cout << "Training data loaded of size " << train_size <<
|
||||
// std::endl;
|
||||
|
||||
// generate_pq_pivots(train_data, train_size, (uint32_t) dim, 256,
|
||||
// (uint32_t) num_pq_chunks, 15, pq_pivots_path);
|
||||
// generate_pq_data_from_pivots<T>(dataFilePath, 256, (uint32_t)
|
||||
// num_pq_chunks,
|
||||
// pq_pivots_path,
|
||||
// pq_compressed_vectors_path);
|
||||
|
||||
// delete[] train_data;
|
||||
|
||||
// train_data = nullptr;
|
||||
|
||||
// diskann::build_merged_vamana_index<T>(
|
||||
// dataFilePath, _compareMetric, L, R, p_val, indexing_ram_budget,
|
||||
// mem_index_path, medoids_path, centroids_path);
|
||||
|
||||
// diskann::create_disk_layout<T>(dataFilePath, mem_index_path,
|
||||
// disk_index_path);
|
||||
|
||||
// double sample_sampling_rate = (150000.0 / points_num);
|
||||
// gen_random_slice<T>(dataFilePath, sample_base_prefix,
|
||||
// sample_sampling_rate);
|
||||
|
||||
// std::remove(mem_index_path.c_str());
|
||||
|
||||
// auto e =
|
||||
// std::chrono::high_resolution_clock::now();
|
||||
// std::chrono::duration<double> diff = e - s;
|
||||
// diskann::cout << "Indexing time: " << diff.count() << std::endl;
|
||||
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// template DISKANN_DLLEXPORT void create_disk_layout<int8_t>(
|
||||
// const std::string base_file, const std::string mem_index_file,
|
||||
// const std::string output_file);
|
||||
|
||||
// template DISKANN_DLLEXPORT void create_disk_layout<uint8_t>(
|
||||
// const std::string base_file, const std::string mem_index_file,
|
||||
// const std::string output_file);
|
||||
// template DISKANN_DLLEXPORT void create_disk_layout<float>(
|
||||
// const std::string base_file, const std::string mem_index_file,
|
||||
// const std::string output_file);
|
||||
|
||||
template DISKANN_DLLEXPORT int8_t *load_warmup<int8_t>(
|
||||
const std::string &cache_warmup_file, uint64_t &warmup_num,
|
||||
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
template DISKANN_DLLEXPORT uint8_t *load_warmup<uint8_t>(
|
||||
const std::string &cache_warmup_file, uint64_t &warmup_num,
|
||||
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
template DISKANN_DLLEXPORT float *load_warmup<float>(
|
||||
const std::string &cache_warmup_file, uint64_t &warmup_num,
|
||||
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template DISKANN_DLLEXPORT int8_t *load_warmup<int8_t>(
|
||||
MemoryMappedFiles &files, const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
template DISKANN_DLLEXPORT uint8_t *load_warmup<uint8_t>(
|
||||
MemoryMappedFiles &files, const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
template DISKANN_DLLEXPORT float *load_warmup<float>(
|
||||
MemoryMappedFiles &files, const std::string &cache_warmup_file,
|
||||
uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
#endif
|
||||
|
||||
// template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t>(
|
||||
// std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
|
||||
// int8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
// _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
// uint32_t start_bw);
|
||||
// template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t>(
|
||||
// std::unique_ptr<diskann::PQFlashIndex<uint8_t>> &pFlashIndex,
|
||||
// uint8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
// _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
// uint32_t start_bw);
|
||||
// template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float>(
|
||||
// std::unique_ptr<diskann::PQFlashIndex<float>> &pFlashIndex,
|
||||
// float *tuning_sample, _u64 tuning_sample_num,
|
||||
// _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
// uint32_t start_bw);
|
||||
|
||||
// template DISKANN_DLLEXPORT bool build_disk_index<int8_t>(
|
||||
// const char *dataFilePath, const char *indexFilePath,
|
||||
// const char *indexBuildParameters, diskann::Metric _compareMetric);
|
||||
// template DISKANN_DLLEXPORT bool build_disk_index<uint8_t>(
|
||||
// const char *dataFilePath, const char *indexFilePath,
|
||||
// const char *indexBuildParameters, diskann::Metric _compareMetric);
|
||||
// template DISKANN_DLLEXPORT bool build_disk_index<float>(
|
||||
// const char *dataFilePath, const char *indexFilePath,
|
||||
// const char *indexBuildParameters, diskann::Metric _compareMetric);
|
||||
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t>(
|
||||
std::string base_file, diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<float>(
|
||||
std::string base_file, diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t>(
|
||||
std::string base_file, diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file);
|
||||
}; // namespace diskann
|
|
@ -0,0 +1,54 @@
|
|||
add_library(${PROJECT_NAME} SHARED dllmain.cpp ../partition_and_pq.cpp ../logger.cpp ../utils.cpp ../index.cpp ../math_utils.cpp ../aux_utils.cpp ../ann_exception.cpp)
|
||||
if (MSVC)
|
||||
add_definitions(-D_USRDLL -D_WINDLL -DDISKANN_DLL)
|
||||
add_compile_options(/MD)
|
||||
target_link_options(${PROJECT_NAME} PRIVATE /DLL /MACHINE:X64 /DEBUG:FULL "/INCLUDE:_tcmalloc")
|
||||
target_link_options(${PROJECT_NAME} PRIVATE $<$<CONFIG:Debug>:/IMPLIB:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib>
|
||||
$<$<CONFIG:Release>:/IMPLIB:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib>
|
||||
)
|
||||
target_link_libraries(${PROJECT_NAME} debug ${PROJECT_SOURCE_DIR}/dependencies/windows/tcmalloc/libtcmalloc_minimal.lib)
|
||||
target_link_libraries(${PROJECT_NAME} optimized ${PROJECT_SOURCE_DIR}/dependencies/windows/tcmalloc/libtcmalloc_minimal.lib)
|
||||
|
||||
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/compiler/libiomp5md.dll "$<$<CONFIG:debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_avx.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_avx2.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_avx512.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_core.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_def.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_intel_thread.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${INTEL_ROOT}redist/intel64/mkl/mkl_rt.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
add_custom_command(TARGET
|
||||
${PROJECT_NAME}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/dependencies/windows/tcmalloc/libtcmalloc_minimal.dll "$<$<CONFIG:debug>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}\">$<$<CONFIG:release>:\"${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}\">" )
|
||||
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
// dllmain.cpp : Defines the entry point for the DLL application.
|
||||
#include <windows.h>
|
||||
|
||||
BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call,
|
||||
LPVOID lpReserved) {
|
||||
switch (ul_reason_for_call) {
|
||||
case DLL_PROCESS_ATTACH:
|
||||
case DLL_THREAD_ATTACH:
|
||||
case DLL_THREAD_DETACH:
|
||||
case DLL_PROCESS_DETACH:
|
||||
break;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
1702
src/index.cpp
|
@ -1,539 +0,0 @@
|
|||
#include <efanna2e/index_nsg.h>
|
||||
#include <efanna2e/exceptions.h>
|
||||
#include <efanna2e/parameters.h>
|
||||
#include <omp.h>
|
||||
#include <chrono>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include <bitset>
|
||||
#include <cmath>
|
||||
|
||||
namespace efanna2e {
|
||||
#define _CONTROL_NUM 100
|
||||
IndexNSG::IndexNSG(const size_t dimension, const size_t n, Metric m, Index *initializer) : Index(dimension, n, m),
|
||||
initializer_{initializer} {
|
||||
}
|
||||
|
||||
IndexNSG::~IndexNSG() {}
|
||||
|
||||
void IndexNSG::Save(const char *filename) {
|
||||
std::ofstream out(filename, std::ios::binary | std::ios::out);
|
||||
assert(final_graph_.size() == nd_);
|
||||
|
||||
out.write((char *) &width, sizeof(unsigned));
|
||||
out.write((char *) &ep_, sizeof(unsigned));
|
||||
for (unsigned i = 0; i < nd_; i++) {
|
||||
unsigned GK = (unsigned) final_graph_[i].size();
|
||||
out.write((char *) &GK, sizeof(unsigned));
|
||||
out.write((char *) final_graph_[i].data(), GK * sizeof(unsigned));
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
void IndexNSG::Load(const char *filename) {
|
||||
std::ifstream in(filename, std::ios::binary);
|
||||
in.read((char *) &width, sizeof(unsigned));
|
||||
in.read((char *) &ep_, sizeof(unsigned));
|
||||
//width=100;
|
||||
unsigned cc=0;
|
||||
while (!in.eof()) {
|
||||
unsigned k;
|
||||
in.read((char *) &k, sizeof(unsigned));
|
||||
if (in.eof())break;
|
||||
cc += k;
|
||||
std::vector<unsigned> tmp(k);
|
||||
in.read((char *) tmp.data(), k * sizeof(unsigned));
|
||||
final_graph_.push_back(tmp);
|
||||
}
|
||||
cc /= nd_;
|
||||
std::cout<<cc<<std::endl;
|
||||
}
|
||||
void IndexNSG::Load_nn_graph(const char *filename) {
|
||||
std::ifstream in(filename, std::ios::binary);
|
||||
unsigned k;
|
||||
in.read((char *) &k, sizeof(unsigned));
|
||||
in.seekg(0, std::ios::end);
|
||||
std::ios::pos_type ss = in.tellg();
|
||||
size_t fsize = (size_t) ss;
|
||||
size_t num = (unsigned) (fsize / (k + 1) / 4);
|
||||
in.seekg(0, std::ios::beg);
|
||||
|
||||
final_graph_.resize(num);
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
in.seekg(4, std::ios::cur);
|
||||
final_graph_[i].resize(k);
|
||||
in.read((char *) final_graph_[i].data(), k * sizeof(unsigned));
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
|
||||
void IndexNSG::get_neighbors(
|
||||
const float *query,
|
||||
const Parameters ¶meter,
|
||||
std::vector <Neighbor> &retset, std::vector <Neighbor> &fullset) {
|
||||
unsigned L = parameter.Get<unsigned>("L");
|
||||
|
||||
retset.resize(L + 1);
|
||||
std::vector<unsigned> init_ids(L);
|
||||
//initializer_->Search(query, nullptr, L, parameter, init_ids.data());
|
||||
|
||||
boost::dynamic_bitset<> flags{nd_, 0};
|
||||
L = 0;
|
||||
for(unsigned i=0; i < init_ids.size() && i < final_graph_[ep_].size(); i++){
|
||||
init_ids[i] = final_graph_[ep_][i];
|
||||
flags[init_ids[i]] = true;
|
||||
L++;
|
||||
}
|
||||
while(L < init_ids.size()){
|
||||
unsigned id = rand() % nd_;
|
||||
if(flags[id])continue;
|
||||
init_ids[L] = id;
|
||||
L++;
|
||||
flags[id] = true;
|
||||
}
|
||||
|
||||
L = 0;
|
||||
for (unsigned i = 0; i < init_ids.size(); i++) {
|
||||
unsigned id = init_ids[i];
|
||||
if(id >= nd_)continue;
|
||||
//std::cout<<id<<std::endl;
|
||||
float dist = distance_->compare(data_ + dimension_ * id, query, (unsigned) dimension_);
|
||||
retset[i] = Neighbor(id, dist, true);
|
||||
//flags[id] = 1;
|
||||
L++;
|
||||
}
|
||||
|
||||
std::sort(retset.begin(), retset.begin() + L);
|
||||
int k = 0;
|
||||
while (k < (int) L) {
|
||||
int nk = L;
|
||||
|
||||
if (retset[k].flag) {
|
||||
retset[k].flag = false;
|
||||
unsigned n = retset[k].id;
|
||||
|
||||
for (unsigned m = 0; m < final_graph_[n].size(); ++m) {
|
||||
unsigned id = final_graph_[n][m];
|
||||
if (flags[id])continue;
|
||||
flags[id] = 1;
|
||||
|
||||
float dist = distance_->compare(query, data_ + dimension_ * id, (unsigned) dimension_);
|
||||
Neighbor nn(id, dist, true);
|
||||
fullset.push_back(nn);
|
||||
if (dist >= retset[L - 1].distance)continue;
|
||||
int r = InsertIntoPool(retset.data(), L, nn);
|
||||
|
||||
if(L+1 < retset.size()) ++L;
|
||||
if (r < nk)nk = r;
|
||||
}
|
||||
|
||||
}
|
||||
if (nk <= k)k = nk;
|
||||
else ++k;
|
||||
}
|
||||
}
|
||||
|
||||
void IndexNSG::init_graph(const Parameters ¶meters) {
|
||||
float *center = new float[dimension_];
|
||||
for (unsigned j = 0; j < dimension_; j++)center[j] = 0;
|
||||
for (unsigned i = 0; i < nd_; i++) {
|
||||
for (unsigned j = 0; j < dimension_; j++) {
|
||||
center[j] += data_[i * dimension_ + j];
|
||||
}
|
||||
}
|
||||
for (unsigned j = 0; j < dimension_; j++) {
|
||||
center[j] /= nd_;
|
||||
}
|
||||
std::vector <Neighbor> tmp, pool;
|
||||
get_neighbors(center, parameters, tmp, pool);
|
||||
ep_ = tmp[0].id;
|
||||
}
|
||||
|
||||
void IndexNSG::add_cnn(unsigned des, Neighbor p, unsigned range, LockGraph &cut_graph_) {
|
||||
LockGuard guard(cut_graph_[des].lock);
|
||||
for (unsigned i = 0; i < cut_graph_[des].pool.size(); i++) {
|
||||
if (p.id == cut_graph_[des].pool[i].id)return;
|
||||
}
|
||||
cut_graph_[des].pool.push_back(p);
|
||||
if (cut_graph_[des].pool.size() > range) {
|
||||
std::vector <Neighbor> result;
|
||||
std::vector <Neighbor> &pool = cut_graph_[des].pool;
|
||||
unsigned start = 0;
|
||||
std::sort(pool.begin(), pool.end());
|
||||
result.push_back(pool[start]);
|
||||
|
||||
while (result.size() < range && (++start) < pool.size()) {
|
||||
auto &p = pool[start];
|
||||
bool occlude = false;
|
||||
for (unsigned t = 0; t < result.size(); t++) {
|
||||
if (p.id == result[t].id) {
|
||||
occlude = true;
|
||||
break;
|
||||
}
|
||||
float djk = distance_->compare(data_ + dimension_ * result[t].id, data_ + dimension_ * p.id, dimension_);
|
||||
if (djk < p.distance/* dik */) {
|
||||
occlude = true;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
if (!occlude)result.push_back(p);
|
||||
}
|
||||
pool.swap(result);
|
||||
}
|
||||
|
||||
}
|
||||
void IndexNSG::sync_prune(unsigned q,
|
||||
std::vector <Neighbor> &pool,
|
||||
const Parameters ¶meter,
|
||||
LockGraph &cut_graph_) {
|
||||
unsigned range = parameter.Get<unsigned>("R");
|
||||
width = range;
|
||||
unsigned start = 0;
|
||||
|
||||
boost::dynamic_bitset<> flags{nd_, 0};
|
||||
for (unsigned i = 0; i < pool.size(); i++)flags[pool[i].id] = 1;
|
||||
for (unsigned nn = 0; nn < final_graph_[q].size(); nn++) {
|
||||
unsigned id = final_graph_[q][nn];
|
||||
if (flags[id])continue;
|
||||
float dist = distance_->compare(data_ + dimension_ * q, data_ + dimension_ * id, dimension_);
|
||||
pool.push_back(Neighbor(id, dist, true));
|
||||
}
|
||||
|
||||
std::sort(pool.begin(), pool.end());
|
||||
std::vector <Neighbor> result;
|
||||
if(pool[start].id == q)start++;
|
||||
result.push_back(pool[start]);
|
||||
|
||||
while (result.size() < range && (++start) < pool.size()) {
|
||||
auto &p = pool[start];
|
||||
bool occlude = false;
|
||||
for (unsigned t = 0; t < result.size(); t++) {
|
||||
if (p.id == result[t].id) {
|
||||
occlude = true;
|
||||
break;
|
||||
}
|
||||
float djk = distance_->compare(data_ + dimension_ * result[t].id, data_ + dimension_ * p.id, dimension_);
|
||||
if (djk < p.distance/* dik */) {
|
||||
occlude = true;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
if (!occlude)result.push_back(p);
|
||||
}
|
||||
for (unsigned t = 0; t < result.size(); t++) {
|
||||
add_cnn(q, result[t], range, cut_graph_);
|
||||
add_cnn(result[t].id, Neighbor(q, result[t].distance, true), range, cut_graph_);
|
||||
}
|
||||
}
|
||||
|
||||
void IndexNSG::Link(const Parameters ¶meters, LockGraph &cut_graph_) {
|
||||
std::cout << " graph link" << std::endl;
|
||||
unsigned progress=0;
|
||||
unsigned percent = 10;
|
||||
unsigned step_size = nd_/percent;
|
||||
std::mutex progress_lock;
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
unsigned cnt = 0;
|
||||
#pragma omp for
|
||||
for (unsigned n = 0; n < nd_; ++n) {
|
||||
std::vector <Neighbor> pool, tmp;
|
||||
get_neighbors(data_ + dimension_ * n, parameters, tmp, pool);
|
||||
sync_prune(n, pool, parameters, cut_graph_);
|
||||
|
||||
cnt++;
|
||||
if(cnt % step_size == 0){
|
||||
LockGuard g(progress_lock);
|
||||
std::cout<<progress++ <<"/"<< percent << " completed" <<std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void IndexNSG::Build(size_t n, const float *data, const Parameters ¶meters) {
|
||||
std::string nn_graph_path = parameters.Get<std::string>("nn_graph_path");
|
||||
Load_nn_graph(nn_graph_path.c_str());
|
||||
data_ = data;
|
||||
init_graph(parameters);
|
||||
LockGraph cut_graph_(nd_);
|
||||
Link(parameters, cut_graph_);
|
||||
final_graph_.resize(nd_);
|
||||
unsigned max = 0, min = 1e6, avg = 0, cnt=0;
|
||||
for (unsigned i = 0; i < nd_; i++) {
|
||||
auto &pool = cut_graph_[i].pool;
|
||||
max = max < pool.size() ? pool.size() : max;
|
||||
min = min > pool.size() ? pool.size() : min;
|
||||
avg += pool.size();
|
||||
if(pool.size() < 2)cnt++;
|
||||
final_graph_[i].resize(pool.size());
|
||||
for (unsigned j = 0; j < pool.size(); j++) {
|
||||
final_graph_[i][j] = pool[j].id;
|
||||
}
|
||||
}
|
||||
avg /= nd_;
|
||||
std::cout << max << ":" << avg << ":" << min << ":" << cnt << "\n";
|
||||
tree_grow(parameters);
|
||||
has_built = true;
|
||||
}
|
||||
|
||||
void IndexNSG::Search(
|
||||
const float *query,
|
||||
const float *x,
|
||||
size_t K,
|
||||
const Parameters ¶meters,
|
||||
unsigned *indices) {
|
||||
const unsigned L = parameters.Get<unsigned>("L_search");
|
||||
data_ = x;
|
||||
std::vector <Neighbor> retset(L + 1);
|
||||
std::vector<unsigned> init_ids(L);
|
||||
boost::dynamic_bitset<> flags{nd_, 0};
|
||||
//std::mt19937 rng(rand());
|
||||
//GenRandom(rng, init_ids.data(), L, (unsigned) nd_);
|
||||
|
||||
unsigned tmp_l = 0;
|
||||
for(; tmp_l<L && tmp_l<final_graph_[ep_].size(); tmp_l++){
|
||||
init_ids[tmp_l] = final_graph_[ep_][tmp_l];
|
||||
flags[init_ids[tmp_l]] = true;
|
||||
}
|
||||
|
||||
while(tmp_l < L){
|
||||
unsigned id = rand() % nd_;
|
||||
if(flags[id])continue;
|
||||
flags[id] = true;
|
||||
init_ids[tmp_l] = id;
|
||||
tmp_l++;
|
||||
}
|
||||
|
||||
|
||||
for (unsigned i = 0; i < init_ids.size(); i++) {
|
||||
unsigned id = init_ids[i];
|
||||
float dist = distance_->compare(data_ + dimension_ * id, query, (unsigned) dimension_);
|
||||
retset[i] = Neighbor(id, dist, true);
|
||||
//flags[id] = true;
|
||||
}
|
||||
|
||||
std::sort(retset.begin(), retset.begin() + L);
|
||||
int k = 0;
|
||||
while (k < (int) L) {
|
||||
int nk = L;
|
||||
|
||||
if (retset[k].flag) {
|
||||
retset[k].flag = false;
|
||||
unsigned n = retset[k].id;
|
||||
|
||||
for (unsigned m = 0; m < final_graph_[n].size(); ++m) {
|
||||
unsigned id = final_graph_[n][m];
|
||||
if (flags[id])continue;
|
||||
flags[id] = 1;
|
||||
float dist = distance_->compare(query, data_ + dimension_ * id, (unsigned) dimension_);
|
||||
if (dist >= retset[L - 1].distance)continue;
|
||||
Neighbor nn(id, dist, true);
|
||||
int r = InsertIntoPool(retset.data(), L, nn);
|
||||
|
||||
if (r < nk)nk = r;
|
||||
}
|
||||
}
|
||||
if (nk <= k)k = nk;
|
||||
else ++k;
|
||||
}
|
||||
for (size_t i = 0; i < K; i++) {
|
||||
indices[i] = retset[i].id;
|
||||
}
|
||||
}
|
||||
|
||||
void IndexNSG::SearchWithOptGraph(
|
||||
const float *query,
|
||||
size_t K,
|
||||
const Parameters ¶meters,
|
||||
unsigned *indices){
|
||||
unsigned L = parameters.Get<unsigned>("L_search");
|
||||
unsigned P = parameters.Get<unsigned>("P_search");
|
||||
DistanceFastL2* dist_fast = (DistanceFastL2*)distance_;
|
||||
|
||||
P = P > K ? P : K;
|
||||
std::vector <Neighbor> retset(P + 1);
|
||||
std::vector<unsigned> init_ids(L);
|
||||
//std::mt19937 rng(rand());
|
||||
//GenRandom(rng, init_ids.data(), L, (unsigned) nd_);
|
||||
|
||||
boost::dynamic_bitset<> flags{nd_, 0};
|
||||
unsigned tmp_l = 0;
|
||||
unsigned *neighbors = (unsigned*)(opt_graph_ + node_size * ep_ + data_len);
|
||||
unsigned MaxM_ep = *neighbors;
|
||||
neighbors++;
|
||||
|
||||
for(; tmp_l < L && tmp_l < MaxM_ep; tmp_l++){
|
||||
init_ids[tmp_l] = neighbors[tmp_l];
|
||||
flags[init_ids[tmp_l]] = true;
|
||||
}
|
||||
|
||||
while(tmp_l < L){
|
||||
unsigned id = rand() % nd_;
|
||||
if(flags[id])continue;
|
||||
flags[id] = true;
|
||||
init_ids[tmp_l] = id;
|
||||
tmp_l++;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < init_ids.size(); i++){
|
||||
unsigned id = init_ids[i];
|
||||
if(id >= nd_)continue;
|
||||
_mm_prefetch(opt_graph_ + node_size * id, _MM_HINT_T0);
|
||||
}
|
||||
L = 0;
|
||||
for (unsigned i = 0; i < init_ids.size(); i++) {
|
||||
unsigned id = init_ids[i];
|
||||
if(id >= nd_)continue;
|
||||
float *x = (float*)(opt_graph_ + node_size * id);
|
||||
float norm_x = *x;x++;
|
||||
float dist = dist_fast->compare(x, query, norm_x, (unsigned) dimension_);
|
||||
retset[i] = Neighbor(id, dist, true);
|
||||
flags[id] = true;
|
||||
L++;
|
||||
}
|
||||
//std::cout<<L<<std::endl;
|
||||
|
||||
std::sort(retset.begin(), retset.begin() + L);
|
||||
int k = 0;
|
||||
while (k < (int) L) {
|
||||
int nk = L;
|
||||
|
||||
if (retset[k].flag) {
|
||||
retset[k].flag = false;
|
||||
unsigned n = retset[k].id;
|
||||
|
||||
_mm_prefetch(opt_graph_ + node_size * n + data_len, _MM_HINT_T0);
|
||||
unsigned *neighbors = (unsigned*)(opt_graph_ + node_size * n + data_len);
|
||||
unsigned MaxM = *neighbors;
|
||||
neighbors++;
|
||||
for(unsigned m=0; m<MaxM; ++m)
|
||||
_mm_prefetch(opt_graph_ + node_size * neighbors[m], _MM_HINT_T0);
|
||||
for (unsigned m = 0; m < MaxM; ++m) {
|
||||
unsigned id = neighbors[m];
|
||||
if (flags[id])continue;
|
||||
flags[id] = 1;
|
||||
float *data = (float*)(opt_graph_ + node_size * id);
|
||||
float norm = *data;data++;
|
||||
float dist = dist_fast->compare(query, data, norm, (unsigned) dimension_);
|
||||
if (dist >= retset[L - 1].distance)continue;
|
||||
Neighbor nn(id, dist, true);
|
||||
int r = InsertIntoPool(retset.data(), L, nn);
|
||||
|
||||
//if(L+1 < retset.size()) ++L;
|
||||
if (r < nk)nk = r;
|
||||
}
|
||||
|
||||
}
|
||||
if (nk <= k)k = nk;
|
||||
else ++k;
|
||||
}
|
||||
for (size_t i = 0; i < K; i++) {
|
||||
indices[i] = retset[i].id;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void IndexNSG::OptimizeGraph(float* data){//use after build or load
|
||||
|
||||
data_ = data;
|
||||
data_len = (dimension_ + 1) * sizeof(float);
|
||||
neighbor_len = (width + 1) * sizeof(unsigned);
|
||||
node_size = data_len + neighbor_len;
|
||||
opt_graph_ = (char*)malloc(node_size * nd_);
|
||||
DistanceFastL2* dist_fast = (DistanceFastL2*)distance_;
|
||||
for(unsigned i=0; i<nd_; i++){
|
||||
char* cur_node_offset = opt_graph_ + i * node_size;
|
||||
float cur_norm = dist_fast->norm(data_ + i * dimension_, dimension_);
|
||||
std::memcpy(cur_node_offset, &cur_norm, sizeof(float));
|
||||
std::memcpy(cur_node_offset + sizeof(float), data_ + i * dimension_, data_len-sizeof(float));
|
||||
|
||||
cur_node_offset += data_len;
|
||||
unsigned k = final_graph_[i].size();
|
||||
std::memcpy(cur_node_offset, &k, sizeof(unsigned));
|
||||
std::memcpy(cur_node_offset + sizeof(unsigned), final_graph_[i].data(), k * sizeof(unsigned));
|
||||
std::vector<unsigned>().swap(final_graph_[i]);
|
||||
}
|
||||
free(data);
|
||||
data_ = nullptr;
|
||||
CompactGraph().swap(final_graph_);
|
||||
}
|
||||
|
||||
void IndexNSG::DFS(boost::dynamic_bitset<> &flag, unsigned root, unsigned &cnt){
|
||||
unsigned tmp = root;
|
||||
std::stack<unsigned> s;
|
||||
s.push(root);
|
||||
if(!flag[root])cnt++;
|
||||
flag[root] = true;
|
||||
while(!s.empty()){
|
||||
|
||||
unsigned next = nd_ + 1;
|
||||
for(unsigned i=0; i<final_graph_[tmp].size(); i++){
|
||||
if(flag[final_graph_[tmp][i]] == false){
|
||||
next = final_graph_[tmp][i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
//std::cout << next <<":"<<cnt <<":"<<tmp <<":"<<s.size()<< '\n';
|
||||
if(next == (nd_ + 1)){
|
||||
s.pop();
|
||||
if(s.empty())break;
|
||||
tmp = s.top();
|
||||
continue;
|
||||
}
|
||||
tmp = next;
|
||||
flag[tmp] = true;s.push(tmp);cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
void IndexNSG::findroot(boost::dynamic_bitset<> &flag, unsigned &root, const Parameters ¶meter){
|
||||
unsigned id;
|
||||
for(unsigned i=0; i<nd_; i++){
|
||||
if(flag[i] == false){
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector <Neighbor> tmp, pool;
|
||||
get_neighbors(data_ + dimension_ * id, parameter, tmp, pool);
|
||||
std::sort(pool.begin(), pool.end());
|
||||
|
||||
unsigned found = 0;
|
||||
for(unsigned i=0; i<pool.size(); i++){
|
||||
if(flag[pool[i].id]){
|
||||
std::cout << pool[i].id << '\n';
|
||||
root = pool[i].id;
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(found == 0){
|
||||
while(true){
|
||||
unsigned rid = rand() % nd_;
|
||||
if(flag[rid]){
|
||||
root = rid;
|
||||
std::cout << root << '\n';
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
final_graph_[root].push_back(id);
|
||||
|
||||
}
|
||||
void IndexNSG::tree_grow(const Parameters ¶meter){
|
||||
unsigned root = ep_;
|
||||
boost::dynamic_bitset<> flags{nd_, 0};
|
||||
unsigned unlinked_cnt = 0;
|
||||
while(unlinked_cnt < nd_){
|
||||
DFS(flags, root, unlinked_cnt);
|
||||
std::cout << unlinked_cnt << '\n';
|
||||
if(unlinked_cnt >= nd_)break;
|
||||
findroot(flags, root, parameter);
|
||||
std::cout << "new root"<<":"<<root << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "linux_aligned_file_reader.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include "tsl/robin_map.h"
|
||||
#include "utils.h"
|
||||
#define MAX_EVENTS 1024
|
||||
|
||||
namespace {
|
||||
typedef struct io_event io_event_t;
|
||||
typedef struct iocb iocb_t;
|
||||
|
||||
void execute_io(io_context_t ctx, int fd, std::vector<AlignedRead> &read_reqs,
|
||||
uint64_t n_retries = 0) {
|
||||
#ifdef DEBUG
|
||||
for (auto &req : read_reqs) {
|
||||
assert(IS_ALIGNED(req.len, 512));
|
||||
// std::cout << "request:"<<req.offset<<":"<<req.len << std::endl;
|
||||
assert(IS_ALIGNED(req.offset, 512));
|
||||
assert(IS_ALIGNED(req.buf, 512));
|
||||
// assert(malloc_usable_size(req.buf) >= req.len);
|
||||
}
|
||||
#endif
|
||||
|
||||
// break-up requests into chunks of size MAX_EVENTS each
|
||||
uint64_t n_iters = ROUND_UP(read_reqs.size(), MAX_EVENTS) / MAX_EVENTS;
|
||||
for (uint64_t iter = 0; iter < n_iters; iter++) {
|
||||
uint64_t n_ops =
|
||||
std::min((uint64_t) read_reqs.size() - (iter * MAX_EVENTS),
|
||||
(uint64_t) MAX_EVENTS);
|
||||
std::vector<iocb_t *> cbs(n_ops, nullptr);
|
||||
std::vector<io_event_t> evts(n_ops);
|
||||
std::vector<struct iocb> cb(n_ops);
|
||||
for (uint64_t j = 0; j < n_ops; j++) {
|
||||
io_prep_pread(cb.data() + j, fd, read_reqs[j + iter * MAX_EVENTS].buf,
|
||||
read_reqs[j + iter * MAX_EVENTS].len,
|
||||
read_reqs[j + iter * MAX_EVENTS].offset);
|
||||
}
|
||||
|
||||
// initialize `cbs` using `cb` array
|
||||
//
|
||||
|
||||
for (uint64_t i = 0; i < n_ops; i++) {
|
||||
cbs[i] = cb.data() + i;
|
||||
}
|
||||
|
||||
uint64_t n_tries = 0;
|
||||
while (n_tries <= n_retries) {
|
||||
// issue reads
|
||||
int64_t ret = io_submit(ctx, (int64_t) n_ops, cbs.data());
|
||||
// if requests didn't get accepted
|
||||
if (ret != (int64_t) n_ops) {
|
||||
std::cerr << "io_submit() failed; returned " << ret
|
||||
<< ", expected=" << n_ops << ", ernno=" << errno << "="
|
||||
<< ::strerror(-ret) << ", try #" << n_tries + 1;
|
||||
std::cout << "ctx: " << ctx << "\n";
|
||||
exit(-1);
|
||||
} else {
|
||||
// wait on io_getevents
|
||||
ret = io_getevents(ctx, (int64_t) n_ops, (int64_t) n_ops, evts.data(),
|
||||
nullptr);
|
||||
// if requests didn't complete
|
||||
if (ret != (int64_t) n_ops) {
|
||||
std::cerr << "io_getevents() failed; returned " << ret
|
||||
<< ", expected=" << n_ops << ", ernno=" << errno << "="
|
||||
<< ::strerror(-ret) << ", try #" << n_tries + 1;
|
||||
exit(-1);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// disabled since req.buf could be an offset into another buf
|
||||
/*
|
||||
for (auto &req : read_reqs) {
|
||||
// corruption check
|
||||
assert(malloc_usable_size(req.buf) >= req.len);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/*
|
||||
for(unsigned i=0;i<64;i++){
|
||||
std::cout << *((unsigned*)read_reqs[0].buf + i) << " ";
|
||||
}
|
||||
std::cout << std::endl;*/
|
||||
}
|
||||
}
|
||||
|
||||
LinuxAlignedFileReader::LinuxAlignedFileReader() {
|
||||
this->file_desc = -1;
|
||||
}
|
||||
|
||||
LinuxAlignedFileReader::~LinuxAlignedFileReader() {
|
||||
int64_t ret;
|
||||
// check to make sure file_desc is closed
|
||||
ret = ::fcntl(this->file_desc, F_GETFD);
|
||||
if (ret == -1) {
|
||||
if (errno != EBADF) {
|
||||
std::cerr << "close() not called" << std::endl;
|
||||
// close file desc
|
||||
ret = ::close(this->file_desc);
|
||||
// error checks
|
||||
if (ret == -1) {
|
||||
std::cerr << "close() failed; returned " << ret << ", errno=" << errno
|
||||
<< ":" << ::strerror(errno) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
io_context_t &LinuxAlignedFileReader::get_ctx() {
|
||||
std::unique_lock<std::mutex> lk(ctx_mut);
|
||||
// perform checks only in DEBUG mode
|
||||
if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) {
|
||||
std::cerr << "bad thread access; returning -1 as io_context_t" << std::endl;
|
||||
return this->bad_ctx;
|
||||
} else {
|
||||
return ctx_map[std::this_thread::get_id()];
|
||||
}
|
||||
}
|
||||
|
||||
void LinuxAlignedFileReader::register_thread() {
|
||||
auto my_id = std::this_thread::get_id();
|
||||
std::unique_lock<std::mutex> lk(ctx_mut);
|
||||
if (ctx_map.find(my_id) != ctx_map.end()) {
|
||||
std::cerr << "multiple calls to register_thread from the same thread"
|
||||
<< std::endl;
|
||||
return;
|
||||
}
|
||||
io_context_t ctx = 0;
|
||||
int ret = io_setup(MAX_EVENTS, &ctx);
|
||||
if (ret != 0) {
|
||||
lk.unlock();
|
||||
assert(errno != EAGAIN);
|
||||
assert(errno != ENOMEM);
|
||||
std::cerr << "io_setup() failed; returned " << ret << ", errno=" << errno
|
||||
<< ":" << ::strerror(errno) << std::endl;
|
||||
} else {
|
||||
std::cerr << "allocating ctx: " << ctx << " to thread-id:" << my_id
|
||||
<< std::endl;
|
||||
ctx_map[my_id] = ctx;
|
||||
}
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
void LinuxAlignedFileReader::deregister_thread() {
|
||||
auto my_id = std::this_thread::get_id();
|
||||
std::unique_lock<std::mutex> lk(ctx_mut);
|
||||
assert(ctx_map.find(my_id) != ctx_map.end());
|
||||
|
||||
lk.unlock();
|
||||
io_context_t ctx = this->get_ctx();
|
||||
io_destroy(ctx);
|
||||
// assert(ret == 0);
|
||||
lk.lock();
|
||||
ctx_map.erase(my_id);
|
||||
std::cerr << "returned ctx from thread-id:" << my_id << std::endl;
|
||||
lk.unlock();
|
||||
}
|
||||
|
||||
void LinuxAlignedFileReader::open(const std::string &fname) {
|
||||
int flags = O_DIRECT | O_RDONLY | O_LARGEFILE;
|
||||
this->file_desc = ::open(fname.c_str(), flags);
|
||||
// error checks
|
||||
assert(this->file_desc != -1);
|
||||
std::cerr << "Opened file : " << fname << std::endl;
|
||||
}
|
||||
|
||||
void LinuxAlignedFileReader::close() {
|
||||
// int64_t ret;
|
||||
|
||||
// check to make sure file_desc is closed
|
||||
::fcntl(this->file_desc, F_GETFD);
|
||||
// assert(ret != -1);
|
||||
|
||||
::close(this->file_desc);
|
||||
// assert(ret != -1);
|
||||
}
|
||||
|
||||
void LinuxAlignedFileReader::read(std::vector<AlignedRead> &read_reqs,
|
||||
io_context_t &ctx, bool async) {
|
||||
assert(this->file_desc != -1);
|
||||
//#pragma omp critical
|
||||
// std::cout << "thread: " << std::this_thread::get_id() << ", crtx: " <<
|
||||
// ctx
|
||||
//<< "\n";
|
||||
execute_io(ctx, this->file_desc, read_reqs);
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
#include "ANNLoggingImpl.hpp"
|
||||
#endif
|
||||
|
||||
#include "logger_impl.h"
|
||||
#include "windows_customizations.h"
|
||||
|
||||
namespace diskann {
|
||||
|
||||
DISKANN_DLLEXPORT ANNStreamBuf coutBuff(stdout);
|
||||
DISKANN_DLLEXPORT ANNStreamBuf cerrBuff(stderr);
|
||||
|
||||
DISKANN_DLLEXPORT std::basic_ostream<char> cout(&coutBuff);
|
||||
DISKANN_DLLEXPORT std::basic_ostream<char> cerr(&cerrBuff);
|
||||
|
||||
ANNStreamBuf::ANNStreamBuf(FILE* fp) {
|
||||
if (fp == nullptr) {
|
||||
throw diskann::ANNException(
|
||||
"File pointer passed to ANNStreamBuf() cannot be null", -1);
|
||||
}
|
||||
if (fp != stdout && fp != stderr) {
|
||||
throw diskann::ANNException(
|
||||
"The custom logger only supports stdout and stderr.", -1);
|
||||
}
|
||||
_fp = fp;
|
||||
_logLevel = (_fp == stdout) ? ANNIndex::LogLevel::LL_Info
|
||||
: ANNIndex::LogLevel::LL_Error;
|
||||
#ifdef EXEC_ENV_OLS
|
||||
_buf = new char[BUFFER_SIZE + 1]; // See comment in the header
|
||||
#else
|
||||
_buf = new char[BUFFER_SIZE]; // See comment in the header
|
||||
#endif
|
||||
|
||||
std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char));
|
||||
setp(_buf, _buf + BUFFER_SIZE);
|
||||
}
|
||||
|
||||
ANNStreamBuf::~ANNStreamBuf() {
|
||||
sync();
|
||||
_fp = nullptr; // we'll not close because we can't.
|
||||
delete[] _buf;
|
||||
}
|
||||
|
||||
int ANNStreamBuf::overflow(int c) {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
if (c != EOF) {
|
||||
*pptr() = (char) c;
|
||||
pbump(1);
|
||||
}
|
||||
flush();
|
||||
return c;
|
||||
}
|
||||
|
||||
int ANNStreamBuf::sync() {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
flush();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ANNStreamBuf::underflow() {
|
||||
throw diskann::ANNException(
|
||||
"Attempt to read on streambuf meant only for writing.", -1);
|
||||
}
|
||||
|
||||
int ANNStreamBuf::flush() {
|
||||
const int num = (int) (pptr() - pbase());
|
||||
logImpl(pbase(), num);
|
||||
pbump(-num);
|
||||
return num;
|
||||
}
|
||||
void ANNStreamBuf::logImpl(char* str, int num) {
|
||||
#ifdef EXEC_ENV_OLS
|
||||
str[num] = '\0'; // Safe. See the c'tor.
|
||||
RandNSGLogging(_logLevel, str);
|
||||
#else
|
||||
fwrite(str, sizeof(char), num, _fp);
|
||||
fflush(_fp);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace diskann
|
|
@ -0,0 +1,464 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <limits>
|
||||
#include <malloc.h>
|
||||
#include <math_utils.h>
|
||||
#include "logger.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace math_utils {
|
||||
|
||||
float calc_distance(float* vec_1, float* vec_2, size_t dim) {
|
||||
float dist = 0;
|
||||
for (size_t j = 0; j < dim; j++) {
|
||||
dist += (vec_1[j] - vec_2[j]) * (vec_1[j] - vec_2[j]);
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
// compute l2-squared norms of data stored in row major num_points * dim,
|
||||
// needs
|
||||
// to be pre-allocated
|
||||
void compute_vecs_l2sq(float* vecs_l2sq, float* data, const size_t num_points,
|
||||
const size_t dim) {
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t n_iter = 0; n_iter < (_s64) num_points; n_iter++) {
|
||||
vecs_l2sq[n_iter] =
|
||||
cblas_snrm2((MKL_INT) dim, (data + (n_iter * dim)), 1);
|
||||
vecs_l2sq[n_iter] *= vecs_l2sq[n_iter];
|
||||
}
|
||||
}
|
||||
|
||||
void rotate_data_randomly(float* data, size_t num_points, size_t dim,
|
||||
float* rot_mat, float*& new_mat,
|
||||
bool transpose_rot) {
|
||||
CBLAS_TRANSPOSE transpose = CblasNoTrans;
|
||||
if (transpose_rot) {
|
||||
diskann::cout << "Transposing rotation matrix.." << std::flush;
|
||||
transpose = CblasTrans;
|
||||
}
|
||||
diskann::cout << "done Rotating data with random matrix.." << std::flush;
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, transpose, (MKL_INT) num_points,
|
||||
(MKL_INT) dim, (MKL_INT) dim, 1.0, data, (MKL_INT) dim, rot_mat,
|
||||
(MKL_INT) dim, 0, new_mat, (MKL_INT) dim);
|
||||
|
||||
diskann::cout << "done." << std::endl;
|
||||
}
|
||||
|
||||
// calculate k closest centers to data of num_points * dim (row major)
|
||||
// centers is num_centers * dim (row major)
|
||||
// data_l2sq has pre-computed squared norms of data
|
||||
// centers_l2sq has pre-computed squared norms of centers
|
||||
// pre-allocated center_index will contain id of nearest center
|
||||
// pre-allocated dist_matrix shound be num_points * num_centers and contain
|
||||
// squared distances
|
||||
// Default value of k is 1
|
||||
|
||||
// Ideally used only by compute_closest_centers
|
||||
void compute_closest_centers_in_block(
|
||||
const float* const data, const size_t num_points, const size_t dim,
|
||||
const float* const centers, const size_t num_centers,
|
||||
const float* const docs_l2sq, const float* const centers_l2sq,
|
||||
uint32_t* center_index, float* const dist_matrix, size_t k) {
|
||||
if (k > num_centers) {
|
||||
diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers
|
||||
<< ")" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
float* ones_a = new float[num_centers];
|
||||
float* ones_b = new float[num_points];
|
||||
|
||||
for (size_t i = 0; i < num_centers; i++) {
|
||||
ones_a[i] = 1.0;
|
||||
}
|
||||
for (size_t i = 0; i < num_points; i++) {
|
||||
ones_b[i] = 1.0;
|
||||
}
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
|
||||
(MKL_INT) num_centers, (MKL_INT) 1, 1.0f, docs_l2sq,
|
||||
(MKL_INT) 1, ones_a, (MKL_INT) 1, 0.0f, dist_matrix,
|
||||
(MKL_INT) num_centers);
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
|
||||
(MKL_INT) num_centers, (MKL_INT) 1, 1.0f, ones_b, (MKL_INT) 1,
|
||||
centers_l2sq, (MKL_INT) 1, 1.0f, dist_matrix,
|
||||
(MKL_INT) num_centers);
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
|
||||
(MKL_INT) num_centers, (MKL_INT) dim, -2.0f, data,
|
||||
(MKL_INT) dim, centers, (MKL_INT) dim, 1.0f, dist_matrix,
|
||||
(MKL_INT) num_centers);
|
||||
|
||||
if (k == 1) {
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t i = 0; i < (_s64) num_points; i++) {
|
||||
float min = std::numeric_limits<float>::max();
|
||||
float* current = dist_matrix + (i * num_centers);
|
||||
for (size_t j = 0; j < num_centers; j++) {
|
||||
if (current[j] < min) {
|
||||
center_index[i] = (uint32_t) j;
|
||||
min = current[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t i = 0; i < (_s64) num_points; i++) {
|
||||
std::priority_queue<PivotContainer> top_k_queue;
|
||||
float* current = dist_matrix + (i * num_centers);
|
||||
for (size_t j = 0; j < num_centers; j++) {
|
||||
PivotContainer this_piv(j, current[j]);
|
||||
top_k_queue.push(this_piv);
|
||||
}
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
PivotContainer this_piv = top_k_queue.top();
|
||||
center_index[i * k + j] = (uint32_t) this_piv.piv_id;
|
||||
top_k_queue.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
delete[] ones_a;
|
||||
delete[] ones_b;
|
||||
}
|
||||
|
||||
// Given data in num_points * new_dim row major
|
||||
// Pivots stored in full_pivot_data as num_centers * new_dim row major
|
||||
// Calculate the k closest pivot for each point and store it in vector
|
||||
// closest_centers_ivf (row major, num_points*k) (which needs to be allocated
|
||||
// outside) Additionally, if inverted index is not null (and pre-allocated),
|
||||
// it
|
||||
// will return inverted index for each center, assuming each of the inverted
|
||||
// indices is an empty vector. Additionally, if pts_norms_squared is not null,
|
||||
// then it will assume that point norms are pre-computed and use those values
|
||||
|
||||
void compute_closest_centers(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers, size_t k,
|
||||
uint32_t* closest_centers_ivf,
|
||||
std::vector<size_t>* inverted_index,
|
||||
float* pts_norms_squared) {
|
||||
if (k > num_centers) {
|
||||
diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers
|
||||
<< ")" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
bool is_norm_given_for_pts = (pts_norms_squared != NULL);
|
||||
|
||||
float* pivs_norms_squared = new float[num_centers];
|
||||
if (!is_norm_given_for_pts)
|
||||
pts_norms_squared = new float[num_points];
|
||||
|
||||
size_t PAR_BLOCK_SIZE = num_points;
|
||||
size_t N_BLOCKS = (num_points % PAR_BLOCK_SIZE) == 0
|
||||
? (num_points / PAR_BLOCK_SIZE)
|
||||
: (num_points / PAR_BLOCK_SIZE) + 1;
|
||||
|
||||
if (!is_norm_given_for_pts)
|
||||
math_utils::compute_vecs_l2sq(pts_norms_squared, data, num_points, dim);
|
||||
math_utils::compute_vecs_l2sq(pivs_norms_squared, pivot_data, num_centers,
|
||||
dim);
|
||||
uint32_t* closest_centers = new uint32_t[PAR_BLOCK_SIZE * k];
|
||||
float* distance_matrix = new float[num_centers * PAR_BLOCK_SIZE];
|
||||
|
||||
for (size_t cur_blk = 0; cur_blk < N_BLOCKS; cur_blk++) {
|
||||
float* data_cur_blk = data + cur_blk * PAR_BLOCK_SIZE * dim;
|
||||
size_t num_pts_blk =
|
||||
std::min(PAR_BLOCK_SIZE, num_points - cur_blk * PAR_BLOCK_SIZE);
|
||||
float* pts_norms_blk = pts_norms_squared + cur_blk * PAR_BLOCK_SIZE;
|
||||
|
||||
math_utils::compute_closest_centers_in_block(
|
||||
data_cur_blk, num_pts_blk, dim, pivot_data, num_centers,
|
||||
pts_norms_blk, pivs_norms_squared, closest_centers, distance_matrix,
|
||||
k);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t j = cur_blk * PAR_BLOCK_SIZE;
|
||||
j <
|
||||
std::min((_s64) num_points, (_s64)((cur_blk + 1) * PAR_BLOCK_SIZE));
|
||||
j++) {
|
||||
for (size_t l = 0; l < k; l++) {
|
||||
size_t this_center_id =
|
||||
closest_centers[(j - cur_blk * PAR_BLOCK_SIZE) * k + l];
|
||||
closest_centers_ivf[j * k + l] = (uint32_t) this_center_id;
|
||||
if (inverted_index != NULL) {
|
||||
#pragma omp critical
|
||||
inverted_index[this_center_id].push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete[] closest_centers;
|
||||
delete[] distance_matrix;
|
||||
delete[] pivs_norms_squared;
|
||||
if (!is_norm_given_for_pts)
|
||||
delete[] pts_norms_squared;
|
||||
}
|
||||
|
||||
// if to_subtract is 1, will subtract nearest center from each row. Else will
|
||||
// add. Output will be in data_load iself.
|
||||
// Nearest centers need to be provided in closst_centers.
|
||||
void process_residuals(float* data_load, size_t num_points, size_t dim,
|
||||
float* cur_pivot_data, size_t num_centers,
|
||||
uint32_t* closest_centers, bool to_subtract) {
|
||||
diskann::cout << "Processing residuals of " << num_points << " points in "
|
||||
<< dim << " dimensions using " << num_centers << " centers "
|
||||
<< std::endl;
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t n_iter = 0; n_iter < (_s64) num_points; n_iter++) {
|
||||
for (size_t d_iter = 0; d_iter < dim; d_iter++) {
|
||||
if (to_subtract == 1)
|
||||
data_load[n_iter * dim + d_iter] =
|
||||
data_load[n_iter * dim + d_iter] -
|
||||
cur_pivot_data[closest_centers[n_iter] * dim + d_iter];
|
||||
else
|
||||
data_load[n_iter * dim + d_iter] =
|
||||
data_load[n_iter * dim + d_iter] +
|
||||
cur_pivot_data[closest_centers[n_iter] * dim + d_iter];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace math_utils
|
||||
|
||||
namespace kmeans {
|
||||
|
||||
// run Lloyds one iteration
|
||||
// Given data in row major num_points * dim, and centers in row major
|
||||
// num_centers * dim And squared lengths of data points, output the closest
|
||||
// center to each data point, update centers, and also return inverted index.
|
||||
// If
|
||||
// closest_centers == NULL, will allocate memory and return. Similarly, if
|
||||
// closest_docs == NULL, will allocate memory and return.
|
||||
|
||||
float lloyds_iter(float* data, size_t num_points, size_t dim, float* centers,
|
||||
size_t num_centers, float* docs_l2sq,
|
||||
std::vector<size_t>* closest_docs,
|
||||
uint32_t*& closest_center) {
|
||||
bool compute_residual = true;
|
||||
// Timer timer;
|
||||
|
||||
if (closest_center == NULL)
|
||||
closest_center = new uint32_t[num_points];
|
||||
if (closest_docs == NULL)
|
||||
closest_docs = new std::vector<size_t>[num_centers];
|
||||
else
|
||||
for (size_t c = 0; c < num_centers; ++c)
|
||||
closest_docs[c].clear();
|
||||
|
||||
math_utils::compute_closest_centers(data, num_points, dim, centers,
|
||||
num_centers, 1, closest_center,
|
||||
closest_docs, docs_l2sq);
|
||||
|
||||
// diskann::cout << "closest centerss calculation done " << std::endl;
|
||||
|
||||
memset(centers, 0, sizeof(float) * (size_t) num_centers * (size_t) dim);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t c = 0; c < (_s64) num_centers; ++c) {
|
||||
float* center = centers + (size_t) c * (size_t) dim;
|
||||
double* cluster_sum = new double[dim];
|
||||
for (size_t i = 0; i < dim; i++)
|
||||
cluster_sum[i] = 0.0;
|
||||
for (size_t i = 0; i < closest_docs[c].size(); i++) {
|
||||
float* current = data + ((closest_docs[c][i]) * dim);
|
||||
for (size_t j = 0; j < dim; j++) {
|
||||
cluster_sum[j] += (double) current[j];
|
||||
}
|
||||
}
|
||||
if (closest_docs[c].size() > 0) {
|
||||
for (size_t i = 0; i < dim; i++)
|
||||
center[i] =
|
||||
(float) (cluster_sum[i] / ((double) closest_docs[c].size()));
|
||||
}
|
||||
delete[] cluster_sum;
|
||||
}
|
||||
|
||||
float residual = 0.0;
|
||||
if (compute_residual) {
|
||||
size_t BUF_PAD = 32;
|
||||
size_t CHUNK_SIZE = 2 * 8192;
|
||||
size_t nchunks =
|
||||
num_points / CHUNK_SIZE + (num_points % CHUNK_SIZE == 0 ? 0 : 1);
|
||||
std::vector<float> residuals(nchunks * BUF_PAD, 0.0);
|
||||
|
||||
#pragma omp parallel for schedule(static, 32)
|
||||
for (int64_t chunk = 0; chunk < (_s64) nchunks; ++chunk)
|
||||
for (size_t d = chunk * CHUNK_SIZE;
|
||||
d < num_points && d < (chunk + 1) * CHUNK_SIZE; ++d)
|
||||
residuals[chunk * BUF_PAD] += math_utils::calc_distance(
|
||||
data + (d * dim),
|
||||
centers + (size_t) closest_center[d] * (size_t) dim, dim);
|
||||
|
||||
for (size_t chunk = 0; chunk < nchunks; ++chunk)
|
||||
residual += residuals[chunk * BUF_PAD];
|
||||
}
|
||||
|
||||
return residual;
|
||||
}
|
||||
|
||||
// Run Lloyds until max_reps or stopping criterion
|
||||
// If you pass NULL for closest_docs and closest_center, it will NOT return
|
||||
// the
|
||||
// results, else it will assume appriate allocation as closest_docs = new
|
||||
// vector<size_t> [num_centers], and closest_center = new size_t[num_points]
|
||||
// Final centers are output in centers as row major num_centers * dim
|
||||
//
|
||||
float run_lloyds(float* data, size_t num_points, size_t dim, float* centers,
|
||||
const size_t num_centers, const size_t max_reps,
|
||||
std::vector<size_t>* closest_docs,
|
||||
uint32_t* closest_center) {
|
||||
float residual = std::numeric_limits<float>::max();
|
||||
bool ret_closest_docs = true;
|
||||
bool ret_closest_center = true;
|
||||
if (closest_docs == NULL) {
|
||||
closest_docs = new std::vector<size_t>[num_centers];
|
||||
ret_closest_docs = false;
|
||||
}
|
||||
if (closest_center == NULL) {
|
||||
closest_center = new uint32_t[num_points];
|
||||
ret_closest_center = false;
|
||||
}
|
||||
|
||||
float* docs_l2sq = new float[num_points];
|
||||
math_utils::compute_vecs_l2sq(docs_l2sq, data, num_points, dim);
|
||||
|
||||
float old_residual;
|
||||
// Timer timer;
|
||||
for (size_t i = 0; i < max_reps; ++i) {
|
||||
old_residual = residual;
|
||||
|
||||
residual = lloyds_iter(data, num_points, dim, centers, num_centers,
|
||||
docs_l2sq, closest_docs, closest_center);
|
||||
|
||||
diskann::cout << "Lloyd's iter " << i
|
||||
<< " dist_sq residual: " << residual << std::endl;
|
||||
|
||||
if (((i != 0) && ((old_residual - residual) / residual) < 0.00001) ||
|
||||
(residual < std::numeric_limits<float>::epsilon())) {
|
||||
diskann::cout << "Residuals unchanged: " << old_residual << " becomes "
|
||||
<< residual << ". Early termination." << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
delete[] docs_l2sq;
|
||||
if (!ret_closest_docs)
|
||||
delete[] closest_docs;
|
||||
if (!ret_closest_center)
|
||||
delete[] closest_center;
|
||||
return residual;
|
||||
}
|
||||
|
||||
// assumes memory allocated for pivot_data as new
|
||||
// float[num_centers*dim]
|
||||
// and select randomly num_centers points as pivots
|
||||
void selecting_pivots(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers) {
|
||||
// pivot_data = new float[num_centers * dim];
|
||||
|
||||
std::vector<size_t> picked;
|
||||
diskann::cout << "Selecting " << num_centers << " pivots from "
|
||||
<< num_points << " points using ";
|
||||
std::random_device rd;
|
||||
auto x = rd();
|
||||
diskann::cout << "random seed " << x << std::endl;
|
||||
std::mt19937 generator(x);
|
||||
std::uniform_int_distribution<size_t> distribution(0, num_points - 1);
|
||||
|
||||
size_t tmp_pivot;
|
||||
for (size_t j = 0; j < num_centers; j++) {
|
||||
tmp_pivot = distribution(generator);
|
||||
if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end())
|
||||
continue;
|
||||
picked.push_back(tmp_pivot);
|
||||
std::memcpy(pivot_data + j * dim, data + tmp_pivot * dim,
|
||||
dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void kmeanspp_selecting_pivots(float* data, size_t num_points, size_t dim,
|
||||
float* pivot_data, size_t num_centers) {
|
||||
if (num_points > 1 << 23) {
|
||||
diskann::cout << "ERROR: n_pts " << num_points
|
||||
<< " currently not supported for k-means++, maximum is "
|
||||
"8388608. Falling back to random pivot "
|
||||
"selection."
|
||||
<< std::endl;
|
||||
selecting_pivots(data, num_points, dim, pivot_data, num_centers);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<size_t> picked;
|
||||
diskann::cout << "Selecting " << num_centers << " pivots from "
|
||||
<< num_points << " points using ";
|
||||
std::random_device rd;
|
||||
auto x = rd();
|
||||
diskann::cout << "random seed " << x << ": " << std::flush;
|
||||
std::mt19937 generator(x);
|
||||
std::uniform_real_distribution<> distribution(0, 1);
|
||||
std::uniform_int_distribution<size_t> int_dist(0, num_points - 1);
|
||||
size_t init_id = int_dist(generator);
|
||||
size_t num_picked = 1;
|
||||
|
||||
picked.push_back(init_id);
|
||||
std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float));
|
||||
|
||||
float* dist = new float[num_points];
|
||||
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t i = 0; i < (_s64) num_points; i++) {
|
||||
dist[i] =
|
||||
math_utils::calc_distance(data + i * dim, data + init_id * dim, dim);
|
||||
}
|
||||
|
||||
double dart_val;
|
||||
size_t tmp_pivot;
|
||||
bool sum_flag = false;
|
||||
|
||||
while (num_picked < num_centers) {
|
||||
dart_val = distribution(generator);
|
||||
|
||||
double sum = 0;
|
||||
for (size_t i = 0; i < num_points; i++) {
|
||||
sum = sum + dist[i];
|
||||
}
|
||||
if (sum == 0)
|
||||
sum_flag = true;
|
||||
|
||||
dart_val *= sum;
|
||||
|
||||
double prefix_sum = 0;
|
||||
for (size_t i = 0; i < (num_points); i++) {
|
||||
tmp_pivot = i;
|
||||
if (dart_val >= prefix_sum && dart_val < prefix_sum + dist[i]) {
|
||||
break;
|
||||
}
|
||||
|
||||
prefix_sum += dist[i];
|
||||
}
|
||||
|
||||
if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end() &&
|
||||
(sum_flag == false))
|
||||
continue;
|
||||
picked.push_back(tmp_pivot);
|
||||
std::memcpy(pivot_data + num_picked * dim, data + tmp_pivot * dim,
|
||||
dim * sizeof(float));
|
||||
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t i = 0; i < (_s64) num_points; i++) {
|
||||
dist[i] = (std::min)(
|
||||
dist[i], math_utils::calc_distance(data + i * dim,
|
||||
data + tmp_pivot * dim, dim));
|
||||
}
|
||||
num_picked++;
|
||||
if (num_picked % 32 == 0)
|
||||
diskann::cout << "." << std::flush;
|
||||
}
|
||||
diskann::cout << "done." << std::endl;
|
||||
delete[] dist;
|
||||
}
|
||||
|
||||
} // namespace kmeans
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "logger.h"
|
||||
#include "memory_mapper.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
using namespace diskann;
|
||||
|
||||
MemoryMapper::MemoryMapper(const std::string& filename)
|
||||
: MemoryMapper(filename.c_str()) {
|
||||
}
|
||||
|
||||
MemoryMapper::MemoryMapper(const char* filename) {
|
||||
#ifndef _WINDOWS
|
||||
_fd = open(filename, O_RDONLY);
|
||||
if (_fd <= 0) {
|
||||
std::cerr << "Inner vertices file not found" << std::endl;
|
||||
return;
|
||||
}
|
||||
struct stat sb;
|
||||
if (fstat(_fd, &sb) != 0) {
|
||||
std::cerr << "Inner vertices file not dound. " << std::endl;
|
||||
return;
|
||||
}
|
||||
_fileSize = sb.st_size;
|
||||
diskann::cout << "File Size: " << _fileSize << std::endl;
|
||||
_buf = (char*) mmap(NULL, _fileSize, PROT_READ, MAP_PRIVATE, _fd, 0);
|
||||
#else
|
||||
_bareFile = CreateFileA(filename, GENERIC_READ | GENERIC_EXECUTE, 0, NULL,
|
||||
OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
if (_bareFile == nullptr) {
|
||||
std::ostringstream message;
|
||||
message << "CreateFileA(" << filename << ") failed with error "
|
||||
<< GetLastError() << std::endl;
|
||||
std::cerr << message.str();
|
||||
throw std::exception(message.str().c_str());
|
||||
}
|
||||
|
||||
_fd = CreateFileMapping(_bareFile, NULL, PAGE_EXECUTE_READ, 0, 0, NULL);
|
||||
if (_fd == nullptr) {
|
||||
std::ostringstream message;
|
||||
message << "CreateFileMapping(" << filename << ") failed with error "
|
||||
<< GetLastError() << std::endl;
|
||||
std::cerr << message.str() << std::endl;
|
||||
throw std::exception(message.str().c_str());
|
||||
}
|
||||
|
||||
_buf = (char*) MapViewOfFile(_fd, FILE_MAP_READ, 0, 0, 0);
|
||||
if (_buf == nullptr) {
|
||||
std::ostringstream message;
|
||||
message << "MapViewOfFile(" << filename
|
||||
<< ") failed with error: " << GetLastError() << std::endl;
|
||||
std::cerr << message.str() << std::endl;
|
||||
throw std::exception(message.str().c_str());
|
||||
}
|
||||
|
||||
LARGE_INTEGER fSize;
|
||||
if (TRUE == GetFileSizeEx(_bareFile, &fSize)) {
|
||||
_fileSize = fSize.QuadPart; // take the 64-bit value
|
||||
diskann::cout << "File Size: " << _fileSize << std::endl;
|
||||
} else {
|
||||
std::cerr << "Failed to get size of file " << filename << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
char* MemoryMapper::getBuf() {
|
||||
return _buf;
|
||||
}
|
||||
|
||||
size_t MemoryMapper::getFileSize() {
|
||||
return _fileSize;
|
||||
}
|
||||
|
||||
MemoryMapper::~MemoryMapper() {
|
||||
#ifndef _WINDOWS
|
||||
if (munmap(_buf, _fileSize) != 0)
|
||||
std::cerr << "ERROR unmapping. CHECK!" << std::endl;
|
||||
close(_fd);
|
||||
#else
|
||||
if (FALSE == UnmapViewOfFile(_buf)) {
|
||||
std::cerr << "Unmap view of file failed. Error: " << GetLastError()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if (FALSE == CloseHandle(_fd)) {
|
||||
std::cerr << "Failed to close memory mapped file. Error: " << GetLastError()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if (FALSE == CloseHandle(_bareFile)) {
|
||||
std::cerr << "Failed to close file: " << _fileName
|
||||
<< " Error: " << GetLastError() << std::endl;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,899 @@
|
|||
#include <math_utils.h>
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "logger.h"
|
||||
#include "exceptions.h"
|
||||
#include "index.h"
|
||||
#include "parameters.h"
|
||||
#include "tsl/robin_set.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
#include <tsl/robin_map.h>
|
||||
|
||||
#include <cassert>
|
||||
#include "memory_mapper.h"
|
||||
#include "partition_and_pq.h"
|
||||
#ifdef _WINDOWS
|
||||
#include <xmmintrin.h>
|
||||
#endif
|
||||
|
||||
#define BLOCK_SIZE 5000000
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const std::string base_file,
|
||||
const std::string output_prefix, double sampling_rate) {
|
||||
_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream base_reader(base_file.c_str(), read_blk_size);
|
||||
std::ofstream sample_writer(std::string(output_prefix + "_data.bin").c_str(),
|
||||
std::ios::binary);
|
||||
std::ofstream sample_id_writer(
|
||||
std::string(output_prefix + "_ids.bin").c_str(), std::ios::binary);
|
||||
|
||||
std::random_device
|
||||
rd; // Will be used to obtain a seed for the random number engine
|
||||
auto x = rd();
|
||||
std::mt19937 generator(
|
||||
x); // Standard mersenne_twister_engine seeded with rd()
|
||||
std::uniform_real_distribution<float> distribution(0, 1);
|
||||
|
||||
size_t npts, nd;
|
||||
uint32_t npts_u32, nd_u32;
|
||||
uint32_t num_sampled_pts_u32 = 0;
|
||||
uint32_t one_const = 1;
|
||||
|
||||
base_reader.read((char *) &npts_u32, sizeof(uint32_t));
|
||||
base_reader.read((char *) &nd_u32, sizeof(uint32_t));
|
||||
diskann::cout << "Loading base " << base_file << ". #points: " << npts_u32
|
||||
<< ". #dim: " << nd_u32 << "." << std::endl;
|
||||
sample_writer.write((char *) &num_sampled_pts_u32, sizeof(uint32_t));
|
||||
sample_writer.write((char *) &nd_u32, sizeof(uint32_t));
|
||||
sample_id_writer.write((char *) &num_sampled_pts_u32, sizeof(uint32_t));
|
||||
sample_id_writer.write((char *) &one_const, sizeof(uint32_t));
|
||||
|
||||
npts = npts_u32;
|
||||
nd = nd_u32;
|
||||
std::unique_ptr<T[]> cur_row = std::make_unique<T[]>(nd);
|
||||
|
||||
for (size_t i = 0; i < npts; i++) {
|
||||
base_reader.read((char *) cur_row.get(), sizeof(T) * nd);
|
||||
float sample = distribution(generator);
|
||||
if (sample < sampling_rate) {
|
||||
sample_writer.write((char *) cur_row.get(), sizeof(T) * nd);
|
||||
uint32_t cur_i_u32 = (_u32) i;
|
||||
sample_id_writer.write((char *) &cur_i_u32, sizeof(uint32_t));
|
||||
num_sampled_pts_u32++;
|
||||
}
|
||||
}
|
||||
sample_writer.seekp(0, std::ios::beg);
|
||||
sample_writer.write((char *) &num_sampled_pts_u32, sizeof(uint32_t));
|
||||
sample_id_writer.seekp(0, std::ios::beg);
|
||||
sample_id_writer.write((char *) &num_sampled_pts_u32, sizeof(uint32_t));
|
||||
sample_writer.close();
|
||||
sample_id_writer.close();
|
||||
diskann::cout << "Wrote " << num_sampled_pts_u32
|
||||
<< " points to sample file: " << output_prefix + "_data.bin"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// streams data from the file, and samples each vector with probability p_val
|
||||
// and returns a matrix of size slice_size* ndims as floating point type.
|
||||
// the slice_size and ndims are set inside the function.
|
||||
|
||||
/***********************************
|
||||
* Reimplement using gen_random_slice(const T* inputdata,...)
|
||||
************************************/
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const std::string data_file, double p_val,
|
||||
float *&sampled_data, size_t &slice_size, size_t &ndims) {
|
||||
size_t npts;
|
||||
uint32_t npts32, ndims32;
|
||||
std::vector<std::vector<float>> sampled_vectors;
|
||||
|
||||
// amount to read in one shot
|
||||
_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
// create cached reader + writer
|
||||
cached_ifstream base_reader(data_file.c_str(), read_blk_size);
|
||||
|
||||
// metadata: npts, ndims
|
||||
base_reader.read((char *) &npts32, sizeof(unsigned));
|
||||
base_reader.read((char *) &ndims32, sizeof(unsigned));
|
||||
npts = npts32;
|
||||
ndims = ndims32;
|
||||
|
||||
std::unique_ptr<T[]> cur_vector_T = std::make_unique<T[]>(ndims);
|
||||
p_val = p_val < 1 ? p_val : 1;
|
||||
|
||||
std::random_device rd; // Will be used to obtain a seed for the random number
|
||||
size_t x = rd();
|
||||
std::mt19937 generator((unsigned) x);
|
||||
std::uniform_real_distribution<float> distribution(0, 1);
|
||||
|
||||
for (size_t i = 0; i < npts; i++) {
|
||||
base_reader.read((char *) cur_vector_T.get(), ndims * sizeof(T));
|
||||
float rnd_val = distribution(generator);
|
||||
if (rnd_val < p_val) {
|
||||
std::vector<float> cur_vector_float;
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
cur_vector_float.push_back(cur_vector_T[d]);
|
||||
sampled_vectors.push_back(cur_vector_float);
|
||||
}
|
||||
}
|
||||
slice_size = sampled_vectors.size();
|
||||
sampled_data = new float[slice_size * ndims];
|
||||
for (size_t i = 0; i < slice_size; i++) {
|
||||
for (size_t j = 0; j < ndims; j++) {
|
||||
sampled_data[i * ndims + j] = sampled_vectors[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// same as above, but samples from the matrix inputdata instead of a file of
|
||||
// npts*ndims to return sampled_data of size slice_size*ndims.
|
||||
template<typename T>
|
||||
void gen_random_slice(const T *inputdata, size_t npts, size_t ndims,
|
||||
double p_val, float *&sampled_data, size_t &slice_size) {
|
||||
std::vector<std::vector<float>> sampled_vectors;
|
||||
const T * cur_vector_T;
|
||||
|
||||
p_val = p_val < 1 ? p_val : 1;
|
||||
|
||||
std::random_device
|
||||
rd; // Will be used to obtain a seed for the random number engine
|
||||
size_t x = rd();
|
||||
std::mt19937 generator(
|
||||
(unsigned) x); // Standard mersenne_twister_engine seeded with rd()
|
||||
std::uniform_real_distribution<float> distribution(0, 1);
|
||||
|
||||
for (size_t i = 0; i < npts; i++) {
|
||||
cur_vector_T = inputdata + ndims * i;
|
||||
float rnd_val = distribution(generator);
|
||||
if (rnd_val < p_val) {
|
||||
std::vector<float> cur_vector_float;
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
cur_vector_float.push_back(cur_vector_T[d]);
|
||||
sampled_vectors.push_back(cur_vector_float);
|
||||
}
|
||||
}
|
||||
slice_size = sampled_vectors.size();
|
||||
sampled_data = new float[slice_size * ndims];
|
||||
for (size_t i = 0; i < slice_size; i++) {
|
||||
for (size_t j = 0; j < ndims; j++) {
|
||||
sampled_data[i * ndims + j] = sampled_vectors[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// given training data in train_data of dimensions num_train * dim, generate PQ
|
||||
// pivots using k-means algorithm to partition the co-ordinates into
|
||||
// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs
|
||||
// k-means in each chunk to compute the PQ pivots and stores in bin format in
|
||||
// file pq_pivots_path as a s num_centers*dim floating point binary file
|
||||
int generate_pq_pivots(const float *passed_train_data, size_t num_train,
|
||||
unsigned dim, unsigned num_centers,
|
||||
unsigned num_pq_chunks, unsigned max_k_means_reps,
|
||||
std::string pq_pivots_path) {
|
||||
if (num_pq_chunks > dim) {
|
||||
diskann::cout << " Error: number of chunks more than dimension"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<float[]> train_data =
|
||||
std::make_unique<float[]>(num_train * dim);
|
||||
std::memcpy(train_data.get(), passed_train_data,
|
||||
num_train * dim * sizeof(float));
|
||||
|
||||
for (uint64_t i = 0; i < num_train; i++) {
|
||||
for (uint64_t j = 0; j < dim; j++) {
|
||||
if (passed_train_data[i * dim + j] != train_data[i * dim + j])
|
||||
diskann::cout << "error in copy" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<float[]> full_pivot_data;
|
||||
|
||||
if (file_exists(pq_pivots_path)) {
|
||||
size_t file_dim, file_num_centers;
|
||||
diskann::load_bin<float>(pq_pivots_path, full_pivot_data, file_num_centers,
|
||||
file_dim);
|
||||
if (file_dim == dim && file_num_centers == num_centers) {
|
||||
diskann::cout << "PQ pivot file exists. Not generating again"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate centroid and center the training data
|
||||
std::unique_ptr<float[]> centroid = std::make_unique<float[]>(dim);
|
||||
for (uint64_t d = 0; d < dim; d++) {
|
||||
centroid[d] = 0;
|
||||
for (uint64_t p = 0; p < num_train; p++) {
|
||||
centroid[d] += train_data[p * dim + d];
|
||||
}
|
||||
centroid[d] /= num_train;
|
||||
}
|
||||
|
||||
// std::memset(centroid, 0 , dim*sizeof(float));
|
||||
|
||||
for (uint64_t d = 0; d < dim; d++) {
|
||||
for (uint64_t p = 0; p < num_train; p++) {
|
||||
train_data[p * dim + d] -= centroid[d];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> rearrangement;
|
||||
std::vector<uint32_t> chunk_offsets;
|
||||
|
||||
size_t low_val = (size_t) std::floor((double) dim / (double) num_pq_chunks);
|
||||
size_t high_val = (size_t) std::ceil((double) dim / (double) num_pq_chunks);
|
||||
size_t max_num_high = dim - (low_val * num_pq_chunks);
|
||||
size_t cur_num_high = 0;
|
||||
size_t cur_bin_threshold = high_val;
|
||||
|
||||
std::vector<std::vector<uint32_t>> bin_to_dims(num_pq_chunks);
|
||||
tsl::robin_map<uint32_t, uint32_t> dim_to_bin;
|
||||
std::vector<float> bin_loads(num_pq_chunks, 0);
|
||||
|
||||
// Process dimensions not inserted by previous loop
|
||||
for (uint32_t d = 0; d < dim; d++) {
|
||||
if (dim_to_bin.find(d) != dim_to_bin.end())
|
||||
continue;
|
||||
auto cur_best = num_pq_chunks + 1;
|
||||
float cur_best_load = std::numeric_limits<float>::max();
|
||||
for (uint32_t b = 0; b < num_pq_chunks; b++) {
|
||||
if (bin_loads[b] < cur_best_load &&
|
||||
bin_to_dims[b].size() < cur_bin_threshold) {
|
||||
cur_best = b;
|
||||
cur_best_load = bin_loads[b];
|
||||
}
|
||||
}
|
||||
diskann::cout << " Pushing " << d << " into bin #: " << cur_best
|
||||
<< std::endl;
|
||||
bin_to_dims[cur_best].push_back(d);
|
||||
if (bin_to_dims[cur_best].size() == high_val) {
|
||||
cur_num_high++;
|
||||
if (cur_num_high == max_num_high)
|
||||
cur_bin_threshold = low_val;
|
||||
}
|
||||
}
|
||||
|
||||
rearrangement.clear();
|
||||
chunk_offsets.clear();
|
||||
chunk_offsets.push_back(0);
|
||||
|
||||
for (uint32_t b = 0; b < num_pq_chunks; b++) {
|
||||
diskann::cout << "[ ";
|
||||
for (auto p : bin_to_dims[b]) {
|
||||
rearrangement.push_back(p);
|
||||
diskann::cout << p << ",";
|
||||
}
|
||||
diskann::cout << "] " << std::endl;
|
||||
if (b > 0)
|
||||
chunk_offsets.push_back(chunk_offsets[b - 1] +
|
||||
(unsigned) bin_to_dims[b - 1].size());
|
||||
}
|
||||
chunk_offsets.push_back(dim);
|
||||
|
||||
diskann::cout << "\nCross-checking rearranged order of coordinates:"
|
||||
<< std::endl;
|
||||
for (auto p : rearrangement)
|
||||
diskann::cout << p << " ";
|
||||
diskann::cout << std::endl;
|
||||
|
||||
full_pivot_data.reset(new float[num_centers * dim]);
|
||||
|
||||
for (size_t i = 0; i < num_pq_chunks; i++) {
|
||||
size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
|
||||
|
||||
if (cur_chunk_size == 0)
|
||||
continue;
|
||||
std::unique_ptr<float[]> cur_pivot_data =
|
||||
std::make_unique<float[]>(num_centers * cur_chunk_size);
|
||||
std::unique_ptr<float[]> cur_data =
|
||||
std::make_unique<float[]>(num_train * cur_chunk_size);
|
||||
std::unique_ptr<uint32_t[]> closest_center =
|
||||
std::make_unique<uint32_t[]>(num_train);
|
||||
|
||||
diskann::cout << "Processing chunk " << i << " with dimensions ["
|
||||
<< chunk_offsets[i] << ", " << chunk_offsets[i + 1] << ")"
|
||||
<< std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t j = 0; j < (_s64) num_train; j++) {
|
||||
std::memcpy(cur_data.get() + j * cur_chunk_size,
|
||||
train_data.get() + j * dim + chunk_offsets[i],
|
||||
cur_chunk_size * sizeof(float));
|
||||
}
|
||||
|
||||
kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size,
|
||||
cur_pivot_data.get(), num_centers);
|
||||
|
||||
kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size,
|
||||
cur_pivot_data.get(), num_centers, max_k_means_reps,
|
||||
NULL, closest_center.get());
|
||||
|
||||
for (uint64_t j = 0; j < num_centers; j++) {
|
||||
std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i],
|
||||
cur_pivot_data.get() + j * cur_chunk_size,
|
||||
cur_chunk_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
diskann::save_bin<float>(pq_pivots_path.c_str(), full_pivot_data.get(),
|
||||
(size_t) num_centers, dim);
|
||||
std::string centroids_path = pq_pivots_path + "_centroid.bin";
|
||||
diskann::save_bin<float>(centroids_path.c_str(), centroid.get(), (size_t) dim,
|
||||
1);
|
||||
std::string rearrangement_path = pq_pivots_path + "_rearrangement_perm.bin";
|
||||
diskann::save_bin<uint32_t>(rearrangement_path.c_str(), rearrangement.data(),
|
||||
rearrangement.size(), 1);
|
||||
std::string chunk_offsets_path = pq_pivots_path + "_chunk_offsets.bin";
|
||||
diskann::save_bin<uint32_t>(chunk_offsets_path.c_str(), chunk_offsets.data(),
|
||||
chunk_offsets.size(), 1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// streams the base file (data_file), and computes the closest centers in each
|
||||
// chunk to generate the compressed data_file and stores it in
|
||||
// pq_compressed_vectors_path.
|
||||
// If the numbber of centers is < 256, it stores as byte vector, else as 4-byte
|
||||
// vector in binary format.
|
||||
template<typename T>
|
||||
int generate_pq_data_from_pivots(const std::string data_file,
|
||||
unsigned num_centers, unsigned num_pq_chunks,
|
||||
std::string pq_pivots_path,
|
||||
std::string pq_compressed_vectors_path) {
|
||||
_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream base_reader(data_file, read_blk_size);
|
||||
_u32 npts32;
|
||||
_u32 basedim32;
|
||||
base_reader.read((char *) &npts32, sizeof(uint32_t));
|
||||
base_reader.read((char *) &basedim32, sizeof(uint32_t));
|
||||
size_t num_points = npts32;
|
||||
size_t dim = basedim32;
|
||||
|
||||
std::unique_ptr<float[]> full_pivot_data;
|
||||
std::unique_ptr<float[]> centroid;
|
||||
std::unique_ptr<uint32_t[]> rearrangement;
|
||||
std::unique_ptr<uint32_t[]> chunk_offsets;
|
||||
|
||||
if (!file_exists(pq_pivots_path)) {
|
||||
diskann::cout << "ERROR: PQ k-means pivot file not found" << std::endl;
|
||||
throw diskann::ANNException("PQ k-means pivot file not found", -1);
|
||||
} else {
|
||||
uint64_t numr, numc;
|
||||
|
||||
std::string centroids_path = pq_pivots_path + "_centroid.bin";
|
||||
diskann::load_bin<float>(centroids_path.c_str(), centroid, numr, numc);
|
||||
|
||||
if (numr != dim || numc != 1) {
|
||||
diskann::cout << "Error reading centroid file." << std::endl;
|
||||
throw diskann::ANNException("Error reading centroid file.", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
std::string rearrangement_path = pq_pivots_path + "_rearrangement_perm.bin";
|
||||
diskann::load_bin<uint32_t>(rearrangement_path.c_str(), rearrangement, numr,
|
||||
numc);
|
||||
if (numr != dim || numc != 1) {
|
||||
diskann::cout << "Error reading rearrangement file." << std::endl;
|
||||
throw diskann::ANNException("Error reading rearrangement file.", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
std::string chunk_offsets_path = pq_pivots_path + "_chunk_offsets.bin";
|
||||
diskann::load_bin<uint32_t>(chunk_offsets_path.c_str(), chunk_offsets, numr,
|
||||
numc);
|
||||
if (numr != (uint64_t) num_pq_chunks + 1 || numc != 1) {
|
||||
diskann::cout << "Error reading chunk offsets file." << std::endl;
|
||||
throw diskann::ANNException("Error reading chunk offsets file.", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
size_t file_num_centers;
|
||||
size_t file_dim;
|
||||
diskann::load_bin<float>(pq_pivots_path, full_pivot_data, file_num_centers,
|
||||
file_dim);
|
||||
|
||||
if (file_num_centers != num_centers) {
|
||||
std::stringstream stream;
|
||||
stream << "ERROR: file number of PQ centers " << file_num_centers
|
||||
<< " does "
|
||||
"not match input argument "
|
||||
<< num_centers << std::endl;
|
||||
diskann::cout << stream.str() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
if (file_dim != dim) {
|
||||
std::stringstream stream;
|
||||
stream << "ERROR: PQ pivot dimension does "
|
||||
"not match base file dimension"
|
||||
<< std::endl;
|
||||
diskann::cout << stream.str() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
diskann::cout << "Loaded PQ pivot information" << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream compressed_file_writer(pq_compressed_vectors_path,
|
||||
std::ios::binary);
|
||||
_u32 num_pq_chunks_u32 = num_pq_chunks;
|
||||
|
||||
compressed_file_writer.write((char *) &num_points, sizeof(uint32_t));
|
||||
compressed_file_writer.write((char *) &num_pq_chunks_u32, sizeof(uint32_t));
|
||||
|
||||
size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
|
||||
std::unique_ptr<_u32[]> block_compressed_base =
|
||||
std::make_unique<_u32[]>(block_size * (_u64) num_pq_chunks);
|
||||
std::memset(block_compressed_base.get(), 0,
|
||||
block_size * (_u64) num_pq_chunks * sizeof(uint32_t));
|
||||
|
||||
std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
|
||||
std::unique_ptr<float[]> block_data_float =
|
||||
std::make_unique<float[]>(block_size * dim);
|
||||
std::unique_ptr<float[]> block_data_tmp =
|
||||
std::make_unique<float[]>(block_size * dim);
|
||||
|
||||
size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
|
||||
|
||||
for (size_t block = 0; block < num_blocks; block++) {
|
||||
size_t start_id = block * block_size;
|
||||
size_t end_id = (std::min)((block + 1) * block_size, num_points);
|
||||
size_t cur_blk_size = end_id - start_id;
|
||||
|
||||
base_reader.read((char *) (block_data_T.get()),
|
||||
sizeof(T) * (cur_blk_size * dim));
|
||||
diskann::convert_types<T, float>(block_data_T.get(), block_data_tmp.get(),
|
||||
cur_blk_size, dim);
|
||||
|
||||
diskann::cout << "Processing points [" << start_id << ", " << end_id
|
||||
<< ").." << std::flush;
|
||||
|
||||
for (uint64_t p = 0; p < cur_blk_size; p++) {
|
||||
for (uint64_t d = 0; d < dim; d++) {
|
||||
block_data_tmp[p * dim + d] -= centroid[d];
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t p = 0; p < cur_blk_size; p++) {
|
||||
for (uint64_t d = 0; d < dim; d++) {
|
||||
block_data_float[p * dim + d] =
|
||||
block_data_tmp[p * dim + rearrangement[d]];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_pq_chunks; i++) {
|
||||
size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
|
||||
if (cur_chunk_size == 0)
|
||||
continue;
|
||||
|
||||
std::unique_ptr<float[]> cur_pivot_data =
|
||||
std::make_unique<float[]>(num_centers * cur_chunk_size);
|
||||
std::unique_ptr<float[]> cur_data =
|
||||
std::make_unique<float[]>(cur_blk_size * cur_chunk_size);
|
||||
std::unique_ptr<uint32_t[]> closest_center =
|
||||
std::make_unique<uint32_t[]>(cur_blk_size);
|
||||
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t j = 0; j < (_s64) cur_blk_size; j++) {
|
||||
for (uint64_t k = 0; k < cur_chunk_size; k++)
|
||||
cur_data[j * cur_chunk_size + k] =
|
||||
block_data_float[j * dim + chunk_offsets[i] + k];
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t j = 0; j < (_s64) num_centers; j++) {
|
||||
std::memcpy(cur_pivot_data.get() + j * cur_chunk_size,
|
||||
full_pivot_data.get() + j * dim + chunk_offsets[i],
|
||||
cur_chunk_size * sizeof(float));
|
||||
}
|
||||
|
||||
math_utils::compute_closest_centers(cur_data.get(), cur_blk_size,
|
||||
cur_chunk_size, cur_pivot_data.get(),
|
||||
num_centers, 1, closest_center.get());
|
||||
|
||||
#pragma omp parallel for schedule(static, 8192)
|
||||
for (int64_t j = 0; j < (_s64) cur_blk_size; j++) {
|
||||
block_compressed_base[j * num_pq_chunks + i] = closest_center[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (num_centers > 256) {
|
||||
compressed_file_writer.write(
|
||||
(char *) (block_compressed_base.get()),
|
||||
cur_blk_size * num_pq_chunks * sizeof(uint32_t));
|
||||
} else {
|
||||
std::unique_ptr<uint8_t[]> pVec =
|
||||
std::make_unique<uint8_t[]>(cur_blk_size * num_pq_chunks);
|
||||
diskann::convert_types<uint32_t, uint8_t>(
|
||||
block_compressed_base.get(), pVec.get(), cur_blk_size, num_pq_chunks);
|
||||
compressed_file_writer.write(
|
||||
(char *) (pVec.get()),
|
||||
cur_blk_size * num_pq_chunks * sizeof(uint8_t));
|
||||
}
|
||||
diskann::cout << ".done." << std::endl;
|
||||
}
|
||||
// Gopal. Splittng nsg_dll into separate DLLs for search and build.
|
||||
// This code should only be available in the "build" DLL.
|
||||
#ifdef DISKANN_BUILD
|
||||
MallocExtension::instance()->ReleaseFreeMemory();
|
||||
#endif
|
||||
compressed_file_writer.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int estimate_cluster_sizes(const std::string data_file, float *pivots,
|
||||
const size_t num_centers, const size_t dim,
|
||||
const size_t k_base,
|
||||
std::vector<size_t> &cluster_sizes) {
|
||||
cluster_sizes.clear();
|
||||
|
||||
size_t num_test, test_dim;
|
||||
float *test_data_float;
|
||||
double sampling_rate = 0.01;
|
||||
|
||||
gen_random_slice<T>(data_file, sampling_rate, test_data_float, num_test,
|
||||
test_dim);
|
||||
|
||||
if (test_dim != dim) {
|
||||
diskann::cout << "Error. dimensions dont match for pivot set and base set"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t *shard_counts = new size_t[num_centers];
|
||||
|
||||
for (size_t i = 0; i < num_centers; i++) {
|
||||
shard_counts[i] = 0;
|
||||
}
|
||||
|
||||
size_t num_points = 0, num_dim = 0;
|
||||
diskann::get_bin_metadata(data_file, num_points, num_dim);
|
||||
size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
|
||||
_u32 * block_closest_centers = new _u32[block_size * k_base];
|
||||
float *block_data_float;
|
||||
|
||||
size_t num_blocks = DIV_ROUND_UP(num_test, block_size);
|
||||
|
||||
for (size_t block = 0; block < num_blocks; block++) {
|
||||
size_t start_id = block * block_size;
|
||||
size_t end_id = (std::min)((block + 1) * block_size, num_test);
|
||||
size_t cur_blk_size = end_id - start_id;
|
||||
|
||||
block_data_float = test_data_float + start_id * test_dim;
|
||||
|
||||
math_utils::compute_closest_centers(block_data_float, cur_blk_size, dim,
|
||||
pivots, num_centers, k_base,
|
||||
block_closest_centers);
|
||||
|
||||
for (size_t p = 0; p < cur_blk_size; p++) {
|
||||
for (size_t p1 = 0; p1 < k_base; p1++) {
|
||||
size_t shard_id = block_closest_centers[p * k_base + p1];
|
||||
shard_counts[shard_id]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
diskann::cout << "Estimated cluster sizes: ";
|
||||
for (size_t i = 0; i < num_centers; i++) {
|
||||
_u32 cur_shard_count = (_u32) shard_counts[i];
|
||||
cluster_sizes.push_back(
|
||||
size_t(((double) cur_shard_count) * (1.0 / sampling_rate)));
|
||||
diskann::cout << cur_shard_count * (1.0 / sampling_rate) << " ";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
delete[] shard_counts;
|
||||
delete[] block_closest_centers;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int shard_data_into_clusters(const std::string data_file, float *pivots,
|
||||
const size_t num_centers, const size_t dim,
|
||||
const size_t k_base, std::string prefix_path) {
|
||||
_u64 read_blk_size = 64 * 1024 * 1024;
|
||||
// _u64 write_blk_size = 64 * 1024 * 1024;
|
||||
// create cached reader + writer
|
||||
cached_ifstream base_reader(data_file, read_blk_size);
|
||||
_u32 npts32;
|
||||
_u32 basedim32;
|
||||
base_reader.read((char *) &npts32, sizeof(uint32_t));
|
||||
base_reader.read((char *) &basedim32, sizeof(uint32_t));
|
||||
size_t num_points = npts32;
|
||||
if (basedim32 != dim) {
|
||||
diskann::cout << "Error. dimensions dont match for train set and base set"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<size_t[]> shard_counts =
|
||||
std::make_unique<size_t[]>(num_centers);
|
||||
std::vector<std::ofstream> shard_data_writer(num_centers);
|
||||
std::vector<std::ofstream> shard_idmap_writer(num_centers);
|
||||
_u32 dummy_size = 0;
|
||||
_u32 const_one = 1;
|
||||
|
||||
for (size_t i = 0; i < num_centers; i++) {
|
||||
std::string data_filename =
|
||||
prefix_path + "_subshard-" + std::to_string(i) + ".bin";
|
||||
std::string idmap_filename =
|
||||
prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin";
|
||||
shard_data_writer[i] =
|
||||
std::ofstream(data_filename.c_str(), std::ios::binary);
|
||||
shard_idmap_writer[i] =
|
||||
std::ofstream(idmap_filename.c_str(), std::ios::binary);
|
||||
shard_data_writer[i].write((char *) &dummy_size, sizeof(uint32_t));
|
||||
shard_data_writer[i].write((char *) &basedim32, sizeof(uint32_t));
|
||||
shard_idmap_writer[i].write((char *) &dummy_size, sizeof(uint32_t));
|
||||
shard_idmap_writer[i].write((char *) &const_one, sizeof(uint32_t));
|
||||
shard_counts[i] = 0;
|
||||
}
|
||||
|
||||
size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
|
||||
std::unique_ptr<_u32[]> block_closest_centers =
|
||||
std::make_unique<_u32[]>(block_size * k_base);
|
||||
std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
|
||||
std::unique_ptr<float[]> block_data_float =
|
||||
std::make_unique<float[]>(block_size * dim);
|
||||
|
||||
size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
|
||||
|
||||
for (size_t block = 0; block < num_blocks; block++) {
|
||||
size_t start_id = block * block_size;
|
||||
size_t end_id = (std::min)((block + 1) * block_size, num_points);
|
||||
size_t cur_blk_size = end_id - start_id;
|
||||
|
||||
base_reader.read((char *) block_data_T.get(),
|
||||
sizeof(T) * (cur_blk_size * dim));
|
||||
diskann::convert_types<T, float>(block_data_T.get(), block_data_float.get(),
|
||||
cur_blk_size, dim);
|
||||
|
||||
math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size,
|
||||
dim, pivots, num_centers, k_base,
|
||||
block_closest_centers.get());
|
||||
|
||||
for (size_t p = 0; p < cur_blk_size; p++) {
|
||||
for (size_t p1 = 0; p1 < k_base; p1++) {
|
||||
size_t shard_id = block_closest_centers[p * k_base + p1];
|
||||
uint32_t original_point_map_id = (uint32_t)(start_id + p);
|
||||
shard_data_writer[shard_id].write(
|
||||
(char *) (block_data_T.get() + p * dim), sizeof(T) * dim);
|
||||
shard_idmap_writer[shard_id].write((char *) &original_point_map_id,
|
||||
sizeof(uint32_t));
|
||||
shard_counts[shard_id]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t total_count = 0;
|
||||
diskann::cout << "Actual shard sizes: " << std::flush;
|
||||
for (size_t i = 0; i < num_centers; i++) {
|
||||
_u32 cur_shard_count = (_u32) shard_counts[i];
|
||||
total_count += cur_shard_count;
|
||||
diskann::cout << cur_shard_count << " ";
|
||||
shard_data_writer[i].seekp(0);
|
||||
shard_data_writer[i].write((char *) &cur_shard_count, sizeof(uint32_t));
|
||||
shard_data_writer[i].close();
|
||||
shard_idmap_writer[i].seekp(0);
|
||||
shard_idmap_writer[i].write((char *) &cur_shard_count, sizeof(uint32_t));
|
||||
shard_idmap_writer[i].close();
|
||||
}
|
||||
|
||||
diskann::cout << "\n Partitioned " << num_points
|
||||
<< " with replication factor " << k_base << " to get "
|
||||
<< total_count << " points across " << num_centers << " shards "
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// partitions a large base file into many shards using k-means hueristic
|
||||
// on a random sample generated using sampling_rate probability. After this, it
|
||||
// assignes each base point to the closest k_base nearest centers and creates
|
||||
// the shards.
|
||||
// The total number of points across all shards will be k_base * num_points.
|
||||
|
||||
template<typename T>
|
||||
int partition(const std::string data_file, const float sampling_rate,
|
||||
size_t num_parts, size_t max_k_means_reps,
|
||||
const std::string prefix_path, size_t k_base) {
|
||||
size_t train_dim;
|
||||
size_t num_train;
|
||||
float *train_data_float;
|
||||
|
||||
gen_random_slice<T>(data_file, sampling_rate, train_data_float, num_train,
|
||||
train_dim);
|
||||
|
||||
float *pivot_data;
|
||||
|
||||
std::string cur_file = std::string(prefix_path);
|
||||
std::string output_file;
|
||||
|
||||
// kmeans_partitioning on training data
|
||||
|
||||
// cur_file = cur_file + "_kmeans_partitioning-" + std::to_string(num_parts);
|
||||
output_file = cur_file + "_centroids.bin";
|
||||
|
||||
pivot_data = new float[num_parts * train_dim];
|
||||
|
||||
// Process Global k-means for kmeans_partitioning Step
|
||||
diskann::cout << "Processing global k-means (kmeans_partitioning Step)"
|
||||
<< std::endl;
|
||||
kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim,
|
||||
pivot_data, num_parts);
|
||||
|
||||
kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data,
|
||||
num_parts, max_k_means_reps, NULL, NULL);
|
||||
|
||||
diskann::cout << "Saving global k-center pivots" << std::endl;
|
||||
diskann::save_bin<float>(output_file.c_str(), pivot_data, (size_t) num_parts,
|
||||
train_dim);
|
||||
|
||||
// now pivots are ready. need to stream base points and assign them to
|
||||
// closest clusters.
|
||||
|
||||
std::vector<size_t> cluster_sizes;
|
||||
estimate_cluster_sizes<T>(data_file, pivot_data, num_parts, train_dim, k_base,
|
||||
cluster_sizes);
|
||||
|
||||
shard_data_into_clusters<T>(data_file, pivot_data, num_parts, train_dim,
|
||||
k_base, prefix_path);
|
||||
delete[] pivot_data;
|
||||
delete[] train_data_float;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int partition_with_ram_budget(const std::string data_file,
|
||||
const double sampling_rate, double ram_budget,
|
||||
size_t graph_degree,
|
||||
const std::string prefix_path, size_t k_base) {
|
||||
size_t train_dim;
|
||||
size_t num_train;
|
||||
float *train_data_float;
|
||||
size_t max_k_means_reps = 20;
|
||||
|
||||
int num_parts = 3;
|
||||
bool fit_in_ram = false;
|
||||
|
||||
gen_random_slice<T>(data_file, sampling_rate, train_data_float, num_train,
|
||||
train_dim);
|
||||
|
||||
float *pivot_data = nullptr;
|
||||
|
||||
std::string cur_file = std::string(prefix_path);
|
||||
std::string output_file;
|
||||
|
||||
// kmeans_partitioning on training data
|
||||
|
||||
// cur_file = cur_file + "_kmeans_partitioning-" + std::to_string(num_parts);
|
||||
output_file = cur_file + "_centroids.bin";
|
||||
|
||||
while (!fit_in_ram) {
|
||||
fit_in_ram = true;
|
||||
|
||||
double max_ram_usage = 0;
|
||||
if (pivot_data != nullptr)
|
||||
delete[] pivot_data;
|
||||
|
||||
pivot_data = new float[num_parts * train_dim];
|
||||
// Process Global k-means for kmeans_partitioning Step
|
||||
diskann::cout << "Processing global k-means (kmeans_partitioning Step)"
|
||||
<< std::endl;
|
||||
kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim,
|
||||
pivot_data, num_parts);
|
||||
|
||||
kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data,
|
||||
num_parts, max_k_means_reps, NULL, NULL);
|
||||
|
||||
// now pivots are ready. need to stream base points and assign them to
|
||||
// closest clusters.
|
||||
|
||||
std::vector<size_t> cluster_sizes;
|
||||
estimate_cluster_sizes<T>(data_file, pivot_data, num_parts, train_dim,
|
||||
k_base, cluster_sizes);
|
||||
|
||||
for (auto &p : cluster_sizes) {
|
||||
double cur_shard_ram_estimate =
|
||||
ESTIMATE_RAM_USAGE(p, train_dim, sizeof(T), graph_degree);
|
||||
|
||||
if (cur_shard_ram_estimate > max_ram_usage)
|
||||
max_ram_usage = cur_shard_ram_estimate;
|
||||
}
|
||||
diskann::cout << "With " << num_parts << " parts, max estimated RAM usage: "
|
||||
<< max_ram_usage / (1024 * 1024 * 1024)
|
||||
<< "GB, budget given is " << ram_budget << std::endl;
|
||||
if (max_ram_usage > 1024 * 1024 * 1024 * ram_budget) {
|
||||
fit_in_ram = false;
|
||||
num_parts++;
|
||||
}
|
||||
}
|
||||
|
||||
diskann::cout << "Saving global k-center pivots" << std::endl;
|
||||
diskann::save_bin<float>(output_file.c_str(), pivot_data, (size_t) num_parts,
|
||||
train_dim);
|
||||
|
||||
shard_data_into_clusters<T>(data_file, pivot_data, num_parts, train_dim,
|
||||
k_base, prefix_path);
|
||||
delete[] pivot_data;
|
||||
delete[] train_data_float;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
// Instantations of supported templates
|
||||
|
||||
template void DISKANN_DLLEXPORT
|
||||
gen_random_slice<int8_t>(const std::string base_file,
|
||||
const std::string output_prefix, double sampling_rate);
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(
|
||||
const std::string base_file, const std::string output_prefix,
|
||||
double sampling_rate);
|
||||
template void DISKANN_DLLEXPORT
|
||||
gen_random_slice<float>(const std::string base_file,
|
||||
const std::string output_prefix, double sampling_rate);
|
||||
|
||||
template void DISKANN_DLLEXPORT
|
||||
gen_random_slice<float>(const float *inputdata, size_t npts, size_t ndims,
|
||||
double p_val, float *&sampled_data, size_t &slice_size);
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(
|
||||
const uint8_t *inputdata, size_t npts, size_t ndims, double p_val,
|
||||
float *&sampled_data, size_t &slice_size);
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<int8_t>(
|
||||
const int8_t *inputdata, size_t npts, size_t ndims, double p_val,
|
||||
float *&sampled_data, size_t &slice_size);
|
||||
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<float>(
|
||||
const std::string data_file, double p_val, float *&sampled_data,
|
||||
size_t &slice_size, size_t &ndims);
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(
|
||||
const std::string data_file, double p_val, float *&sampled_data,
|
||||
size_t &slice_size, size_t &ndims);
|
||||
template void DISKANN_DLLEXPORT gen_random_slice<int8_t>(
|
||||
const std::string data_file, double p_val, float *&sampled_data,
|
||||
size_t &slice_size, size_t &ndims);
|
||||
|
||||
template DISKANN_DLLEXPORT int partition<int8_t>(
|
||||
const std::string data_file, const float sampling_rate, size_t num_centers,
|
||||
size_t max_k_means_reps, const std::string prefix_path, size_t k_base);
|
||||
template DISKANN_DLLEXPORT int partition<uint8_t>(
|
||||
const std::string data_file, const float sampling_rate, size_t num_centers,
|
||||
size_t max_k_means_reps, const std::string prefix_path, size_t k_base);
|
||||
template DISKANN_DLLEXPORT int partition<float>(
|
||||
const std::string data_file, const float sampling_rate, size_t num_centers,
|
||||
size_t max_k_means_reps, const std::string prefix_path, size_t k_base);
|
||||
|
||||
template DISKANN_DLLEXPORT int partition_with_ram_budget<int8_t>(
|
||||
const std::string data_file, const double sampling_rate, double ram_budget,
|
||||
size_t graph_degree, const std::string prefix_path, size_t k_base);
|
||||
template DISKANN_DLLEXPORT int partition_with_ram_budget<uint8_t>(
|
||||
const std::string data_file, const double sampling_rate, double ram_budget,
|
||||
size_t graph_degree, const std::string prefix_path, size_t k_base);
|
||||
template DISKANN_DLLEXPORT int partition_with_ram_budget<float>(
|
||||
const std::string data_file, const double sampling_rate, double ram_budget,
|
||||
size_t graph_degree, const std::string prefix_path, size_t k_base);
|
||||
|
||||
template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<int8_t>(
|
||||
const std::string data_file, unsigned num_centers, unsigned num_pq_chunks,
|
||||
std::string pq_pivots_path, std::string pq_compressed_vectors_path);
|
||||
template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<uint8_t>(
|
||||
const std::string data_file, unsigned num_centers, unsigned num_pq_chunks,
|
||||
std::string pq_pivots_path, std::string pq_compressed_vectors_path);
|
||||
template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<float>(
|
||||
const std::string data_file, unsigned num_centers, unsigned num_pq_chunks,
|
||||
std::string pq_pivots_path, std::string pq_compressed_vectors_path);
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <intrin.h>
|
||||
|
||||
// Taken from:
|
||||
// https://insufficientlycomplicated.wordpress.com/2011/11/07/detecting-intel-advanced-vector-extensions-avx-in-visual-studio/
|
||||
bool cpuHasAvxSupport() {
|
||||
bool avxSupported = false;
|
||||
|
||||
// Checking for AVX requires 3 things:
|
||||
// 1) CPUID indicates that the OS uses XSAVE and XRSTORE
|
||||
// instructions (allowing saving YMM registers on context
|
||||
// switch)
|
||||
// 2) CPUID indicates support for AVX
|
||||
// 3) XGETBV indicates the AVX registers will be saved and
|
||||
// restored on context switch
|
||||
//
|
||||
// Note that XGETBV is only available on 686 or later CPUs, so
|
||||
// the instruction needs to be conditionally run.
|
||||
int cpuInfo[4];
|
||||
__cpuid(cpuInfo, 1);
|
||||
|
||||
bool osUsesXSAVE_XRSTORE = cpuInfo[2] & (1 << 27) || false;
|
||||
bool cpuAVXSuport = cpuInfo[2] & (1 << 28) || false;
|
||||
|
||||
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
|
||||
// Check if the OS will save the YMM registers
|
||||
unsigned long long xcrFeatureMask = _xgetbv(_XCR_XFEATURE_ENABLED_MASK);
|
||||
avxSupported = (xcrFeatureMask & 0x6) || false;
|
||||
}
|
||||
|
||||
return avxSupported;
|
||||
}
|
||||
|
||||
bool cpuHasAvx2Support() {
|
||||
int cpuInfo[4];
|
||||
__cpuid(cpuInfo, 0);
|
||||
int n = cpuInfo[0];
|
||||
if (n >= 7) {
|
||||
__cpuidex(cpuInfo, 7, 0);
|
||||
static int avx2Mask = 0x20;
|
||||
return (cpuInfo[1] & avx2Mask) > 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef _WINDOWS
|
||||
bool AvxSupportedCPU = false;
|
||||
bool Avx2SupportedCPU = true;
|
||||
#else
|
||||
bool AvxSupportedCPU = cpuHasAvxSupport();
|
||||
bool Avx2SupportedCPU = cpuHasAvx2Support();
|
||||
#endif
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
#define SECTOR_LEN 4096
|
||||
|
||||
void WindowsAlignedFileReader::open(const std::string& fname) {
|
||||
m_filename = std::wstring(fname.begin(), fname.end());
|
||||
this->register_thread();
|
||||
}
|
||||
|
||||
void WindowsAlignedFileReader::close() {
|
||||
for (auto& k_v : ctx_map) {
|
||||
IOContext ctx = ctx_map[k_v.first];
|
||||
CloseHandle(ctx.fhandle);
|
||||
}
|
||||
}
|
||||
|
||||
void WindowsAlignedFileReader::register_thread() {
|
||||
std::unique_lock<std::mutex> lk(this->ctx_mut);
|
||||
if (this->ctx_map.find(std::this_thread::get_id()) != ctx_map.end()) {
|
||||
diskann::cout << "Warning:: Duplicate registration for thread_id : "
|
||||
<< std::this_thread::get_id() << std::endl;
|
||||
}
|
||||
|
||||
IOContext ctx;
|
||||
ctx.fhandle = CreateFile(m_filename.c_str(), GENERIC_READ, FILE_SHARE_READ,
|
||||
NULL, OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_READONLY | FILE_FLAG_NO_BUFFERING |
|
||||
FILE_FLAG_OVERLAPPED | FILE_FLAG_RANDOM_ACCESS,
|
||||
NULL);
|
||||
if (ctx.fhandle == INVALID_HANDLE_VALUE) {
|
||||
diskann::cout << "Error opening " << m_filename.c_str()
|
||||
<< " -- error=" << GetLastError() << std::endl;
|
||||
}
|
||||
|
||||
// create IOCompletionPort
|
||||
ctx.iocp = CreateIoCompletionPort(ctx.fhandle, ctx.iocp, 0, 0);
|
||||
|
||||
// create MAX_DEPTH # of reqs
|
||||
for (_u64 i = 0; i < MAX_IO_DEPTH; i++) {
|
||||
OVERLAPPED os;
|
||||
memset(&os, 0, sizeof(OVERLAPPED));
|
||||
// os.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL);
|
||||
ctx.reqs.push_back(os);
|
||||
}
|
||||
this->ctx_map.insert(std::make_pair(std::this_thread::get_id(), ctx));
|
||||
}
|
||||
|
||||
IOContext& WindowsAlignedFileReader::get_ctx() {
|
||||
std::unique_lock<std::mutex> lk(this->ctx_mut);
|
||||
if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) {
|
||||
std::stringstream stream;
|
||||
stream << "unable to find IOContext for thread_id : "
|
||||
<< std::this_thread::get_id() << "\n";
|
||||
throw diskann::ANNException(stream.str(), -2, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
IOContext& ctx = ctx_map[std::this_thread::get_id()];
|
||||
lk.unlock();
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void WindowsAlignedFileReader::read(std::vector<AlignedRead>& read_reqs,
|
||||
IOContext& ctx) {
|
||||
using namespace std::chrono_literals;
|
||||
// execute each request sequentially
|
||||
_u64 n_reqs = read_reqs.size();
|
||||
_u64 n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH;
|
||||
for (_u64 i = 0; i < n_batches; i++) {
|
||||
// reset all OVERLAPPED objects
|
||||
for (auto& os : ctx.reqs) {
|
||||
// HANDLE evt = os.hEvent;
|
||||
memset(&os, 0, sizeof(os));
|
||||
// os.hEvent = evt;
|
||||
|
||||
/*
|
||||
if (ResetEvent(os.hEvent) == 0) {
|
||||
diskann::cerr << "ResetEvent failed" << std::endl;
|
||||
exit(-3);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
// batch start/end
|
||||
_u64 batch_start = MAX_IO_DEPTH * i;
|
||||
_u64 batch_size =
|
||||
std::min((_u64)(n_reqs - batch_start), (_u64) MAX_IO_DEPTH);
|
||||
|
||||
// fill OVERLAPPED and issue them
|
||||
for (_u64 j = 0; j < batch_size; j++) {
|
||||
AlignedRead& req = read_reqs[batch_start + j];
|
||||
OVERLAPPED& os = ctx.reqs[j];
|
||||
|
||||
_u64 offset = req.offset;
|
||||
_u64 nbytes = req.len;
|
||||
char* read_buf = (char*) req.buf;
|
||||
assert(IS_ALIGNED(read_buf, SECTOR_LEN));
|
||||
assert(IS_ALIGNED(offset, SECTOR_LEN));
|
||||
assert(IS_ALIGNED(nbytes, SECTOR_LEN));
|
||||
|
||||
// fill in OVERLAPPED struct
|
||||
os.Offset = offset & 0xffffffff;
|
||||
os.OffsetHigh = (offset >> 32);
|
||||
|
||||
BOOL ret = ReadFile(ctx.fhandle, read_buf, nbytes, NULL, &os);
|
||||
if (ret == FALSE) {
|
||||
auto error = GetLastError();
|
||||
if (error != ERROR_IO_PENDING) {
|
||||
diskann::cerr << "Error queuing IO -- " << error << "\n";
|
||||
}
|
||||
} else {
|
||||
diskann::cerr << "Error queueing IO -- ReadFile returned TRUE"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
DWORD n_read = 0;
|
||||
_u64 n_complete = 0;
|
||||
ULONG_PTR completion_key = 0;
|
||||
OVERLAPPED* lp_os;
|
||||
while (n_complete < batch_size) {
|
||||
if (GetQueuedCompletionStatus(ctx.iocp, &n_read, &completion_key, &lp_os,
|
||||
INFINITE) != 0) {
|
||||
// successfully dequeued a completed I/O
|
||||
n_complete++;
|
||||
} else {
|
||||
// failed to dequeue OR dequeued failed I/O
|
||||
if (lp_os == NULL) {
|
||||
DWORD error = GetLastError();
|
||||
if (error != WAIT_TIMEOUT) {
|
||||
diskann::cerr << "GetQueuedCompletionStatus() failed with error = "
|
||||
<< error << std::endl;
|
||||
throw diskann::ANNException(
|
||||
"GetQueuedCompletionStatus failed with error: ", error,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
// no completion packet dequeued ==> sleep for 5us and try again
|
||||
std::this_thread::sleep_for(5us);
|
||||
} else {
|
||||
// completion packet for failed IO dequeued
|
||||
auto op_idx = lp_os - ctx.reqs.data();
|
||||
std::stringstream stream;
|
||||
stream << "I/O failed , offset: " << read_reqs[op_idx].offset
|
||||
<< "with error code: " << GetLastError() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -1,11 +1,24 @@
|
|||
set(CMAKE_CXX_STANDARD 11)
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
add_executable(test_nsg_index test_nsg_index.cpp)
|
||||
target_link_libraries(test_nsg_index ${PROJECT_NAME})
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
add_executable(test_nsg_search test_nsg_search.cpp)
|
||||
target_link_libraries(test_nsg_search ${PROJECT_NAME})
|
||||
add_executable(build_memory_index build_memory_index.cpp )
|
||||
if(MSVC)
|
||||
target_link_options(build_memory_index PRIVATE /MACHINE:x64 /DEBUG:FULL "/INCLUDE:_tcmalloc")
|
||||
target_link_libraries(build_memory_index debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib ${PROJECT_SOURCE_DIR}/dependencies/windows/tcmalloc/libtcmalloc_minimal.lib)
|
||||
target_link_libraries(build_memory_index optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib ${PROJECT_SOURCE_DIR}/dependencies/windows/tcmalloc/libtcmalloc_minimal.lib)
|
||||
else()
|
||||
target_link_libraries(build_memory_index ${PROJECT_NAME} -ltcmalloc)
|
||||
endif()
|
||||
|
||||
add_executable(test_nsg_optimized_search test_nsg_optimized_search.cpp)
|
||||
target_link_libraries(test_nsg_optimized_search ${PROJECT_NAME})
|
||||
add_executable(search_memory_index search_memory_index.cpp )
|
||||
if(MSVC)
|
||||
target_link_options(search_memory_index PRIVATE /MACHINE:x64 /DEBUG:FULL)
|
||||
target_link_libraries(search_memory_index debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(search_memory_index optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(search_memory_index ${PROJECT_NAME} aio -ltcmalloc)
|
||||
endif()
|
||||
|
||||
# formatter
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
|
||||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#include <index.h>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include "utils.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
template<typename T>
|
||||
int build_in_memory_index(const std::string& data_path, const unsigned R,
|
||||
const unsigned L, const float alpha,
|
||||
const std::string& save_path,
|
||||
const unsigned num_threads) {
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("R", R);
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>(
|
||||
"C", 750); // maximum candidate set size during pruning procedure
|
||||
paras.Set<float>("alpha", alpha);
|
||||
paras.Set<bool>("saturate_graph", 0);
|
||||
paras.Set<unsigned>("num_threads", num_threads);
|
||||
|
||||
diskann::Index<T> index(diskann::L2, data_path.c_str());
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
index.build(paras);
|
||||
std::chrono::duration<double> diff =
|
||||
std::chrono::high_resolution_clock::now() - s;
|
||||
|
||||
std::cout << "Indexing time: " << diff.count() << "\n";
|
||||
index.save(save_path.c_str());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 8) {
|
||||
std::cout << "Usage: " << argv[0]
|
||||
<< " [data_type<int8/uint8/float>] [data_file.bin] "
|
||||
"[output_index_file] "
|
||||
<< "[R] [L] [alpha]"
|
||||
<< " [num_threads_to_use]. See README for more information on "
|
||||
"parameters."
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string save_path(argv[3]);
|
||||
const unsigned R = (unsigned) atoi(argv[4]);
|
||||
const unsigned L = (unsigned) atoi(argv[5]);
|
||||
const float alpha = (float) atof(argv[6]);
|
||||
const unsigned num_threads = (unsigned) atoi(argv[7]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("int8"))
|
||||
build_in_memory_index<int8_t>(data_path, R, L, alpha, save_path,
|
||||
num_threads);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
build_in_memory_index<uint8_t>(data_path, R, L, alpha, save_path,
|
||||
num_threads);
|
||||
else if (std::string(argv[1]) == std::string("float"))
|
||||
build_in_memory_index<float>(data_path, R, L, alpha, save_path,
|
||||
num_threads);
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "aux_utils.h"
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "utils.h"
|
||||
|
||||
template<typename T>
|
||||
int search_memory_index(int argc, char** argv) {
|
||||
T* query = nullptr;
|
||||
unsigned* gt_ids = nullptr;
|
||||
float* gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
std::vector<_u64> Lvec;
|
||||
|
||||
std::string data_file(argv[2]);
|
||||
std::string memory_index_file(argv[3]);
|
||||
std::string query_bin(argv[4]);
|
||||
std::string truthset_bin(argv[5]);
|
||||
_u64 recall_at = std::atoi(argv[6]);
|
||||
std::string result_output_prefix(argv[7]);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
|
||||
for (int ctr = 8; ctr < argc; ctr++) {
|
||||
_u64 curL = std::atoi(argv[ctr]);
|
||||
if (curL >= recall_at)
|
||||
Lvec.push_back(curL);
|
||||
}
|
||||
|
||||
if (Lvec.size() == 0) {
|
||||
std::cout << "No valid Lsearch found. Lsearch must be at least recall_at."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::load_aligned_bin<T>(query_bin, query, query_num, query_dim,
|
||||
query_aligned_dim);
|
||||
|
||||
if (file_exists(truthset_bin)) {
|
||||
diskann::load_truthset(truthset_bin, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num) {
|
||||
std::cout << "Error. Mismatch in number of queries and ground truth data"
|
||||
<< std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
std::cout.precision(2);
|
||||
|
||||
diskann::Index<T> index(diskann::L2, data_file.c_str());
|
||||
index.load(memory_index_file.c_str()); // to load NSG
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
|
||||
diskann::Parameters paras;
|
||||
std::string recall_string = "Recall@" + std::to_string(recall_at);
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << "QPS " << std::setw(18)
|
||||
<< "Mean Latency (ms)" << std::setw(15) << "99.9 Latency"
|
||||
<< std::setw(12) << recall_string << std::endl;
|
||||
std::cout << "==============================================================="
|
||||
"==============="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
|
||||
std::vector<double> latency_stats(query_num, 0);
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) {
|
||||
_u64 L = Lvec[test_id];
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t) query_num; i++) {
|
||||
auto qs = std::chrono::high_resolution_clock::now();
|
||||
index.search(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at);
|
||||
auto qe = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = qe - qs;
|
||||
latency_stats[i] = diff.count() * 1000;
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
|
||||
float qps = (query_num / diff.count());
|
||||
|
||||
float recall = 0;
|
||||
if (calc_recall_flag)
|
||||
recall = diskann::calculate_recall(query_num, gt_ids, gt_dists, gt_dim,
|
||||
query_result_ids[test_id].data(),
|
||||
recall_at, recall_at);
|
||||
|
||||
std::sort(latency_stats.begin(), latency_stats.end());
|
||||
double mean_latency = 0;
|
||||
for (uint64_t q = 0; q < query_num; q++) {
|
||||
mean_latency += latency_stats[q];
|
||||
}
|
||||
mean_latency /= query_num;
|
||||
|
||||
std::cout << std::setw(4) << L << std::setw(12) << qps << std::setw(18)
|
||||
<< (float) mean_latency << std::setw(15)
|
||||
<< (float) latency_stats[(_u64)(0.999 * query_num)]
|
||||
<< std::setw(12) << recall << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Done searching. Now saving results " << std::endl;
|
||||
_u64 test_id = 0;
|
||||
for (auto L : Lvec) {
|
||||
std::string cur_result_path =
|
||||
result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
|
||||
diskann::save_bin<_u32>(cur_result_path, query_result_ids[test_id].data(),
|
||||
query_num, recall_at);
|
||||
test_id++;
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 9) {
|
||||
std::cout
|
||||
<< "Usage: " << argv[0]
|
||||
<< " [index_type<float/int8/uint8>] [data_file.bin] "
|
||||
"[memory_index_path] "
|
||||
"[query_file.bin] [truthset.bin (use \"null\" for none)] "
|
||||
" [K] [result_output_prefix] "
|
||||
" [L1] [L2] etc. See README for more information on parameters. "
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
if (std::string(argv[1]) == std::string("int8"))
|
||||
search_memory_index<int8_t>(argc, argv);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
search_memory_index<uint8_t>(argc, argv);
|
||||
else if (std::string(argv[1]) == std::string("float"))
|
||||
search_memory_index<float>(argc, argv);
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 10) {
|
||||
std::cout << "Correct usage: " << argv[0]
|
||||
<< " data_file L R C alpha num_rounds "
|
||||
<< "save_graph_file #incr_points #frozen_points" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
float* data_load = NULL;
|
||||
size_t num_points, dim, aligned_dim;
|
||||
|
||||
diskann::load_aligned_bin<float>(argv[1], data_load, num_points, dim,
|
||||
aligned_dim);
|
||||
|
||||
unsigned L = (unsigned) atoi(argv[2]);
|
||||
unsigned R = (unsigned) atoi(argv[3]);
|
||||
unsigned C = (unsigned) atoi(argv[4]);
|
||||
float alpha = (float) std::atof(argv[5]);
|
||||
unsigned num_rnds = (unsigned) std::atoi(argv[6]);
|
||||
std::string save_path(argv[7]);
|
||||
unsigned num_incr = (unsigned) atoi(argv[8]);
|
||||
unsigned num_frozen = (unsigned) atoi(argv[9]);
|
||||
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>("R", R);
|
||||
paras.Set<unsigned>("C", C);
|
||||
paras.Set<float>("alpha", alpha);
|
||||
paras.Set<bool>("saturate_graph", false);
|
||||
paras.Set<unsigned>("num_rnds", num_rnds);
|
||||
|
||||
typedef int TagT;
|
||||
diskann::Index<float, TagT> index(diskann::L2, argv[1], num_points,
|
||||
num_points - num_incr, num_frozen, true,
|
||||
true, true);
|
||||
{
|
||||
std::vector<TagT> tags(num_points - num_incr);
|
||||
std::iota(tags.begin(), tags.end(), 0);
|
||||
|
||||
if (argc > 10) {
|
||||
std::string frozen_points_file(argv[10]);
|
||||
index.generate_random_frozen_points(frozen_points_file.c_str());
|
||||
} else
|
||||
index.generate_random_frozen_points();
|
||||
|
||||
diskann::Timer timer;
|
||||
index.build(paras, tags);
|
||||
std::cout << "Index build time: " << timer.elapsed() / 1000 << "ms\n";
|
||||
}
|
||||
|
||||
std::vector<diskann::Neighbor> pool, tmp;
|
||||
tsl::robin_set<unsigned> visited;
|
||||
std::vector<diskann::SimpleNeighbor> cut_graph;
|
||||
index.readjust_data(num_frozen);
|
||||
|
||||
{
|
||||
diskann::Timer timer;
|
||||
for (size_t i = num_points - num_incr; i < num_points; ++i) {
|
||||
index.insert_point(data_load + i * aligned_dim, paras, pool, tmp, visited,
|
||||
cut_graph, i);
|
||||
}
|
||||
std::cout << "Incremental time: " << timer.elapsed() / 1000 << "ms\n";
|
||||
auto save_path_inc = save_path + ".inc";
|
||||
index.save(save_path_inc.c_str());
|
||||
}
|
||||
|
||||
tsl::robin_set<unsigned> delete_list;
|
||||
while (delete_list.size() < num_incr)
|
||||
delete_list.insert(rand() % num_points);
|
||||
std::cout << "Deleting " << delete_list.size() << " elements" << std::endl;
|
||||
|
||||
{
|
||||
diskann::Timer timer;
|
||||
index.enable_delete();
|
||||
for (auto p : delete_list)
|
||||
|
||||
if (index.eager_delete(p, paras) != 0)
|
||||
// if (index.delete_point(p) != 0)
|
||||
std::cerr << "Delete tag " << p << " not found" << std::endl;
|
||||
|
||||
if (index.disable_delete(paras, true) != 0) {
|
||||
std::cerr << "Disable delete failed" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "Delete time: " << timer.elapsed() / 1000 << "ms\n";
|
||||
}
|
||||
|
||||
auto save_path_del = save_path + ".del";
|
||||
index.save(save_path_del.c_str());
|
||||
|
||||
index.readjust_data(num_frozen);
|
||||
{
|
||||
diskann::Timer timer;
|
||||
for (auto p : delete_list) {
|
||||
index.insert_point(data_load + (size_t) p * (size_t) aligned_dim, paras,
|
||||
pool, tmp, visited, cut_graph, p);
|
||||
}
|
||||
std::cout << "Re-incremental time: " << timer.elapsed() / 1000 << "ms\n";
|
||||
}
|
||||
|
||||
auto save_path_reinc = save_path + ".reinc";
|
||||
index.save(save_path_reinc.c_str());
|
||||
|
||||
delete[] data_load;
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#include <efanna2e/index_nsg.h>
|
||||
#include <efanna2e/util.h>
|
||||
|
||||
|
||||
void load_data(char* filename, float*& data, unsigned& num,unsigned& dim){// load data with sift10K pattern
|
||||
std::ifstream in(filename, std::ios::binary);
|
||||
if(!in.is_open()){std::cout<<"open file error"<<std::endl;exit(-1);}
|
||||
in.read((char*)&dim,4);
|
||||
std::cout<<"data dimension: "<<dim<<std::endl;
|
||||
in.seekg(0,std::ios::end);
|
||||
std::ios::pos_type ss = in.tellg();
|
||||
size_t fsize = (size_t)ss;
|
||||
num = (unsigned)(fsize / (dim+1) / 4);
|
||||
data = new float[num * dim * sizeof(float)];
|
||||
|
||||
in.seekg(0,std::ios::beg);
|
||||
for(size_t i = 0; i < num; i++){
|
||||
in.seekg(4,std::ios::cur);
|
||||
in.read((char*)(data+i*dim),dim*4);
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
int main(int argc, char** argv){
|
||||
if(argc!=6){std::cout<< argv[0] <<" data_file nn_graph_path L R save_graph_file"<<std::endl; exit(-1);}
|
||||
float* data_load = NULL;
|
||||
unsigned points_num, dim;
|
||||
load_data(argv[1], data_load, points_num, dim);
|
||||
|
||||
std::string nn_graph_path(argv[2]);
|
||||
unsigned L = (unsigned)atoi(argv[3]);
|
||||
unsigned R = (unsigned)atoi(argv[4]);
|
||||
|
||||
data_load = efanna2e::data_align(data_load, points_num, dim);//one must align the data before build
|
||||
efanna2e::IndexNSG index(dim, points_num, efanna2e::L2, nullptr);
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
efanna2e::Parameters paras;
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>("R", R);
|
||||
paras.Set<std::string>("nn_graph_path", nn_graph_path);
|
||||
|
||||
index.Build(points_num, data_load, paras);
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e-s;
|
||||
std::cout << "indexing time: " << diff.count() << "\n";
|
||||
index.Save(argv[5]);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#include <efanna2e/index_nsg.h>
|
||||
#include <efanna2e/util.h>
|
||||
#include <chrono>
|
||||
|
||||
|
||||
void load_data(char* filename, float*& data, unsigned& num,unsigned& dim){// load data with sift10K pattern
|
||||
std::ifstream in(filename, std::ios::binary);
|
||||
if(!in.is_open()){std::cout<<"open file error"<<std::endl;exit(-1);}
|
||||
in.read((char*)&dim,4);
|
||||
std::cout<<"data dimension: "<<dim<<std::endl;
|
||||
in.seekg(0,std::ios::end);
|
||||
std::ios::pos_type ss = in.tellg();
|
||||
size_t fsize = (size_t)ss;
|
||||
num = (unsigned)(fsize / (dim+1) / 4);
|
||||
data = new float[num * dim * sizeof(float)];
|
||||
|
||||
in.seekg(0,std::ios::beg);
|
||||
for(size_t i = 0; i < num; i++){
|
||||
in.seekg(4,std::ios::cur);
|
||||
in.read((char*)(data+i*dim),dim*4);
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
|
||||
void save_result(char* filename, std::vector<std::vector<unsigned> > &results){
|
||||
std::ofstream out(filename, std::ios::binary | std::ios::out);
|
||||
|
||||
for (unsigned i = 0; i < results.size(); i++) {
|
||||
unsigned GK = (unsigned) results[i].size();
|
||||
out.write((char *) &GK, sizeof(unsigned));
|
||||
out.write((char *) results[i].data(), GK * sizeof(unsigned));
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
int main(int argc, char** argv){
|
||||
if(argc!=7){std::cout<< argv[0] <<" data_file query_file nsg_path search_L search_K result_path"<<std::endl; exit(-1);}
|
||||
float* data_load = NULL;
|
||||
unsigned points_num, dim;
|
||||
load_data(argv[1], data_load, points_num, dim);
|
||||
float* query_load = NULL;
|
||||
unsigned query_num, query_dim;
|
||||
load_data(argv[2], query_load, query_num, query_dim);
|
||||
assert(dim == query_dim);
|
||||
|
||||
unsigned L = (unsigned)atoi(argv[4]);
|
||||
unsigned K = (unsigned)atoi(argv[5]);
|
||||
|
||||
if(L < K){std::cout<< "search_L cannot be smaller than search_K!"<<std::endl; exit(-1);}
|
||||
|
||||
data_load = efanna2e::data_align(data_load, points_num, dim);//one must align the data before build
|
||||
query_load = efanna2e::data_align(query_load, query_num, query_dim);
|
||||
efanna2e::IndexNSG index(dim, points_num, efanna2e::FAST_L2, nullptr);
|
||||
index.Load(argv[3]);
|
||||
index.OptimizeGraph(data_load);
|
||||
|
||||
efanna2e::Parameters paras;
|
||||
paras.Set<unsigned>("L_search", L);
|
||||
paras.Set<unsigned>("P_search", L);
|
||||
|
||||
std::vector<std::vector<unsigned> > res(query_num);
|
||||
for(unsigned i=0; i<query_num; i++)res[i].resize(K);
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
for(unsigned i=0; i<query_num; i++){
|
||||
index.SearchWithOptGraph(query_load + i * dim, K, paras, res[i].data());
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e-s;
|
||||
std::cout << "search time: " << diff.count() << "\n";
|
||||
|
||||
|
||||
save_result(argv[6], res);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
//
|
||||
// Created by 付聪 on 2017/6/21.
|
||||
//
|
||||
|
||||
#include <efanna2e/index_nsg.h>
|
||||
#include <efanna2e/util.h>
|
||||
|
||||
|
||||
void load_data(char* filename, float*& data, unsigned& num,unsigned& dim){// load data with sift10K pattern
|
||||
std::ifstream in(filename, std::ios::binary);
|
||||
if(!in.is_open()){std::cout<<"open file error"<<std::endl;exit(-1);}
|
||||
in.read((char*)&dim,4);
|
||||
std::cout<<"data dimension: "<<dim<<std::endl;
|
||||
in.seekg(0,std::ios::end);
|
||||
std::ios::pos_type ss = in.tellg();
|
||||
size_t fsize = (size_t)ss;
|
||||
num = (unsigned)(fsize / (dim+1) / 4);
|
||||
data = new float[num * dim * sizeof(float)];
|
||||
|
||||
in.seekg(0,std::ios::beg);
|
||||
for(size_t i = 0; i < num; i++){
|
||||
in.seekg(4,std::ios::cur);
|
||||
in.read((char*)(data+i*dim),dim*4);
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
|
||||
void save_result(char* filename, std::vector<std::vector<unsigned> > &results){
|
||||
std::ofstream out(filename, std::ios::binary | std::ios::out);
|
||||
|
||||
for (unsigned i = 0; i < results.size(); i++) {
|
||||
unsigned GK = (unsigned) results[i].size();
|
||||
out.write((char *) &GK, sizeof(unsigned));
|
||||
out.write((char *) results[i].data(), GK * sizeof(unsigned));
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
int main(int argc, char** argv){
|
||||
if(argc!=7){std::cout<< argv[0] <<" data_file query_file nsg_path search_L search_K result_path"<<std::endl; exit(-1);}
|
||||
float* data_load = NULL;
|
||||
unsigned points_num, dim;
|
||||
load_data(argv[1], data_load, points_num, dim);
|
||||
float* query_load = NULL;
|
||||
unsigned query_num, query_dim;
|
||||
load_data(argv[2], query_load, query_num, query_dim);
|
||||
assert(dim == query_dim);
|
||||
|
||||
unsigned L = (unsigned)atoi(argv[4]);
|
||||
unsigned K = (unsigned)atoi(argv[5]);
|
||||
|
||||
if(L < K){std::cout<< "search_L cannot be smaller than search_K!"<<std::endl; exit(-1);}
|
||||
|
||||
|
||||
//data_load = efanna2e::data_align(data_load, points_num, dim);//one must align the data before build
|
||||
//query_load = efanna2e::data_align(query_load, query_num, query_dim);
|
||||
efanna2e::IndexNSG index(dim, points_num, efanna2e::L2, nullptr);
|
||||
index.Load(argv[3]);
|
||||
|
||||
efanna2e::Parameters paras;
|
||||
paras.Set<unsigned>("L_search", L);
|
||||
paras.Set<unsigned>("P_search", L);
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
std::vector<std::vector<unsigned> > res;
|
||||
for(unsigned i=0; i<query_num; i++){
|
||||
std::vector<unsigned> tmp(K);
|
||||
index.Search(query_load + i * dim, data_load, K, paras, tmp.data());
|
||||
res.push_back(tmp);
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e-s;
|
||||
std::cout << "search time: " << diff.count() << "\n";
|
||||
|
||||
|
||||
save_result(argv[6], res);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(fvecs_to_bin PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(fvecs_to_bin debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(fvecs_to_bin optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
endif()
|
||||
|
||||
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(ivecs_to_bin PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(ivecs_to_bin debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(ivecs_to_bin optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
endif()
|
||||
|
||||
add_executable(int8_to_float int8_to_float.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(int8_to_float PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(int8_to_float debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(int8_to_float optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(int8_to_float ${PROJECT_NAME})
|
||||
endif()
|
||||
|
||||
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(uint32_to_uint8 PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(uint32_to_uint8 debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(uint32_to_uint8 optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
|
||||
endif()
|
||||
|
||||
add_executable(gen_random_slice gen_random_slice.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(gen_random_slice PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(gen_random_slice debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(gen_random_slice optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(gen_random_slice ${PROJECT_NAME} -ltcmalloc)
|
||||
endif()
|
||||
|
||||
add_executable(calculate_recall calculate_recall.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(calculate_recall PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(calculate_recall debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(calculate_recall optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(calculate_recall ${PROJECT_NAME} aio -ltcmalloc)
|
||||
endif()
|
||||
|
||||
|
||||
add_executable(compute_groundtruth compute_groundtruth.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(compute_groundtruth PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(compute_groundtruth debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(compute_groundtruth optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(compute_groundtruth ${PROJECT_NAME} aio)
|
||||
endif()
|
||||
|
||||
|
||||
add_executable(generate_pq generate_pq.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(generate_pq PRIVATE /MACHINE:x64 /DEBUG:FULL)
|
||||
target_link_libraries(generate_pq debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/${PROJECT_NAME}.lib)
|
||||
target_link_libraries(generate_pq optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/${PROJECT_NAME}.lib)
|
||||
else()
|
||||
target_link_libraries(generate_pq ${PROJECT_NAME} -ltcmalloc)
|
||||
endif()
|
||||
|
||||
|
||||
# formatter
|
||||
if (NOT MSVC)
|
||||
# add_custom_command(TARGET gen_random_slice PRE_BUILD COMMAND clang-format-4.0 -i ../../../include/*.h ../../../include/dll/*.h ../../../src/*.cpp ../../../tests/*.cpp ../../../src/dll/*.cpp ../../../tests/utils/*.cpp)
|
||||
endif()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "util.h"
|
||||
|
||||
void block_convert(std::ifstream& writr, std::ofstream& readr, float* read_buf,
|
||||
float* write_buf, _u64 npts, _u64 ndims) {
|
||||
writr.write((char*) read_buf,
|
||||
npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
||||
#pragma omp parallel for
|
||||
for (_u64 i = 0; i < npts; i++) {
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1,
|
||||
ndims * sizeof(float));
|
||||
}
|
||||
readr.read((char*) write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream readr(argv[1], std::ios::binary);
|
||||
int npts_s32;
|
||||
int ndims_s32;
|
||||
readr.read((char*) &npts_s32, sizeof(_s32));
|
||||
readr.read((char*) &ndims_s32, sizeof(_s32));
|
||||
size_t npts = npts_s32;
|
||||
size_t ndims = ndims_s32;
|
||||
_u32 ndims_u32 = (_u32) ndims_s32;
|
||||
// _u64 fsize = writr.tellg();
|
||||
readr.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
writr.write((char*) &ndims_u32, sizeof(unsigned));
|
||||
writr.seekg(0, std::ios::beg);
|
||||
_u64 ndims = (_u64) ndims_u32;
|
||||
_u64 npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims
|
||||
<< std::endl;
|
||||
|
||||
_u64 blk_size = 131072;
|
||||
_u64 nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
std::ofstream writr(argv[2], std::ios::binary);
|
||||
float* read_buf = new float[npts * (ndims + 1)];
|
||||
float* write_buf = new float[npts * ndims];
|
||||
for (_u64 i = 0; i < nblks; i++) {
|
||||
_u64 cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writr.close();
|
||||
readr.close();
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "aux_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 4) {
|
||||
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> "
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
unsigned* gold_std = NULL;
|
||||
float* gs_dist = nullptr;
|
||||
unsigned* our_results = NULL;
|
||||
float* or_dist = nullptr;
|
||||
size_t points_num, points_num_gs, points_num_or;
|
||||
size_t dim_gs;
|
||||
size_t dim_or;
|
||||
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
|
||||
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
|
||||
|
||||
if (points_num_gs != points_num_or) {
|
||||
std::cout
|
||||
<< "Error. Number of queries mismatch in ground truth and our results"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
points_num = points_num_gs;
|
||||
|
||||
uint32_t recall_at = std::atoi(argv[3]);
|
||||
|
||||
if ((dim_or < recall_at) || (recall_at > dim_gs)) {
|
||||
std::cout << "ground truth has size " << dim_gs << "; our set has "
|
||||
<< dim_or << " points. Asking for recall " << recall_at
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "Calculating recall@" << recall_at << std::endl;
|
||||
float recall_val = diskann::calculate_recall(
|
||||
points_num, gold_std, gs_dist, dim_gs, our_results, dim_or, recall_at);
|
||||
|
||||
// double avg_recall = (recall*1.0)/(points_num*1.0);
|
||||
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
#include "mkl.h"
|
||||
#include "omp.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
void command_line_help() {
|
||||
std::cerr
|
||||
<< "<exact-kann> <int8/uint8/float> <base bin file> <query bin "
|
||||
"file> <K: # nearest neighbors to compute> <output-truthset-file>"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
T div_round_up(const T numerator, const T denominator) {
|
||||
return (numerator % denominator == 0) ? (numerator / denominator)
|
||||
: 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<int, float>;
|
||||
struct cmpmaxstruct {
|
||||
bool operator()(const pairIF &l, const pairIF &r) {
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS =
|
||||
std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template<class T>
|
||||
T *aligned_malloc(const size_t n, const size_t alignment) {
|
||||
#ifdef _WINDOWS
|
||||
return (T *) _aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a,
|
||||
const std::pair<uint32_t, float> &b) {
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix,
|
||||
const int64_t num_points, const int dim) {
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot(dim, matrix + (ptrdiff_t) d * (ptrdiff_t) dim,
|
||||
1, matrix + (ptrdiff_t) d * (ptrdiff_t) dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(
|
||||
const size_t dim,
|
||||
float * dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL) {
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float) 1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim,
|
||||
(float) -2.0, points, dim, queries, dim, (float) 0.0, dist_matrix,
|
||||
npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1,
|
||||
(float) 1.0, points_l2sq, npoints, ones_vec, nqueries,
|
||||
(float) 1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1,
|
||||
(float) 1.0, ones_vec, npoints, queries_l2sq, nqueries,
|
||||
(float) 1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
int *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
const float *const points, // points in Col major
|
||||
size_t nqueries,
|
||||
const float *const queries) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
// std::cout<<"jere"<<std::endl;
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
// std::cout<<"jere "<<npoints<<" " <<dim << " " << nqueries <<std::endl;
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t) q_batch_size * (size_t) npoints];
|
||||
|
||||
for (_u64 b = 0; b < div_round_up(nqueries, q_batch_size); ++b) {
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e =
|
||||
((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t) q_b * (ptrdiff_t) dim,
|
||||
queries_l2sq + q_b);
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")"
|
||||
<< std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++) {
|
||||
maxPQIFCS point_dist;
|
||||
for (_u64 p = 0; p < k; p++)
|
||||
point_dist.emplace(
|
||||
p, dist_matrix[(ptrdiff_t) p +
|
||||
(ptrdiff_t)(q - q_b) * (ptrdiff_t) npoints]);
|
||||
for (_u64 p = k; p < npoints; p++) {
|
||||
if (point_dist.top().second >
|
||||
dist_matrix[(ptrdiff_t) p +
|
||||
(ptrdiff_t)(q - q_b) * (ptrdiff_t) npoints])
|
||||
point_dist.emplace(
|
||||
p, dist_matrix[(ptrdiff_t) p +
|
||||
(ptrdiff_t)(q - q_b) * (ptrdiff_t) npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t) k; ++l) {
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t) q * (ptrdiff_t) k] =
|
||||
point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) +
|
||||
(ptrdiff_t) q * (ptrdiff_t) k] =
|
||||
point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(
|
||||
dist_closest_points + (ptrdiff_t) q * (ptrdiff_t) k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t) k));
|
||||
/*std::sort(point_dist.begin(), point_dist.end(),
|
||||
[](const auto &l, const auto &r) {return l.second < r.second; });
|
||||
for (int l = 0; l < k; ++l) {
|
||||
closest_points[(ptrdiff_t)l + (ptrdiff_t)q * (ptrdiff_t)k] =
|
||||
point_dist[l].first;
|
||||
dist_closest_points[(ptrdiff_t)l + (ptrdiff_t)q * (ptrdiff_t)k] =
|
||||
point_dist[l].second;
|
||||
}*/
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e
|
||||
<< ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline int get_num_parts(const char *filename) {
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *) &npts_i32, sizeof(int));
|
||||
reader.read((char *) &ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
int num_parts = (npts_i32 % PARTSIZE) == 0
|
||||
? npts_i32 / PARTSIZE
|
||||
: std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts,
|
||||
size_t &ndims, int part_num) {
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *) &npts_i32, sizeof(int));
|
||||
reader.read((char *) &ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t) npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (unsigned) ndims_i32;
|
||||
uint64_t nptsuint64_t = (uint64_t) npts;
|
||||
uint64_t ndimsuint64_t = (uint64_t) ndims;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B"
|
||||
<< std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t),
|
||||
std::ios::beg);
|
||||
// data = new T[nptsuint64_t * ndimsuint64_t];
|
||||
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
|
||||
reader.read((char *) data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
// data = (nptsuint64_t*ndimsuint64_t, ALIGNMENT);
|
||||
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t) nptsuint64_t; i++) {
|
||||
for (int64_t j = 0; j < (int64_t) ndimsuint64_t; j++) {
|
||||
float cur_val_float = (float) data_T[i * ndimsuint64_t + j];
|
||||
std::memcpy((char *) (data + i * ndimsuint64_t + j),
|
||||
(char *) &cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void save_bin(const std::string filename, T *data, size_t npts,
|
||||
size_t ndims) {
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int) npts, ndims_i32 = (int) ndims;
|
||||
writer.write((char *) &npts_i32, sizeof(int));
|
||||
writer.write((char *) &ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B"
|
||||
<< std::endl;
|
||||
|
||||
// data = new T[npts_u64 * ndims_u64];
|
||||
writer.write((char *) data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename,
|
||||
int32_t *data, float *distances,
|
||||
size_t npts, size_t ndims) {
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int) npts, ndims_i32 = (int) ndims;
|
||||
writer.write((char *) &npts_i32, sizeof(int));
|
||||
writer.write((char *) &ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = "
|
||||
<< 2 * npts * ndims * sizeof(unsigned) + 2 * sizeof(int) << "B"
|
||||
<< std::endl;
|
||||
|
||||
// data = new T[npts_u64 * ndims_u64];
|
||||
writer.write((char *) data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *) distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int aux_main(int argv, char **argc) {
|
||||
size_t npoints, nqueries, dim;
|
||||
std::string base_file(argc[2]);
|
||||
std::string query_file(argc[3]);
|
||||
size_t k = atoi(argc[4]);
|
||||
std::string gt_file(argc[5]);
|
||||
|
||||
float *base_data;
|
||||
float *query_data;
|
||||
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results(nqueries);
|
||||
|
||||
int * closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
for (int p = 0; p < num_parts; p++) {
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
int * closest_points_part = new int[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
exact_knn(dim, k, closest_points_part, dist_closest_points_part, npoints,
|
||||
base_data, nqueries, query_data);
|
||||
|
||||
for (_u64 i = 0; i < nqueries; i++) {
|
||||
for (_u64 j = 0; j < k; j++) {
|
||||
results[i].push_back(std::make_pair(
|
||||
(uint32_t)(closest_points_part[i * k + j] + start_id),
|
||||
dist_closest_points_part[i * k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
|
||||
for (_u64 i = 0; i < nqueries; i++) {
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
for (_u64 j = 0; j < k; j++) {
|
||||
closest_points[i * k + j] = (int32_t) cur_res[j].first;
|
||||
dist_closest_points[i * k + j] = cur_res[j].second;
|
||||
}
|
||||
}
|
||||
|
||||
// save_bin<int>(gt_file + std::string("_ids.bin"), closest_points, nqueries,
|
||||
// k);
|
||||
// save_bin<float>(gt_file + std::string("_dist.bin"), dist_closest_points,
|
||||
// nqueries, k);
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points,
|
||||
nqueries, k);
|
||||
diskann::aligned_free(query_data);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc != 6) {
|
||||
command_line_help();
|
||||
return -1;
|
||||
}
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
aux_main<float>(argc, argv);
|
||||
if (std::string(argv[1]) == std::string("int8"))
|
||||
aux_main<int8_t>(argc, argv);
|
||||
if (std::string(argv[1]) == std::string("uint8"))
|
||||
aux_main<uint8_t>(argc, argv);
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream& reader, std::ofstream& writer,
|
||||
float* read_buf, float* write_buf, _u64 npts, _u64 ndims) {
|
||||
reader.read((char*) read_buf,
|
||||
npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
||||
for (_u64 i = 0; i < npts; i++) {
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1,
|
||||
ndims * sizeof(float));
|
||||
}
|
||||
writer.write((char*) write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
std::cout << argv[0] << " input_fvecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
_u64 fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
reader.read((char*) &ndims_u32, sizeof(unsigned));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
_u64 ndims = (_u64) ndims_u32;
|
||||
_u64 npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims
|
||||
<< std::endl;
|
||||
|
||||
_u64 blk_size = 131072;
|
||||
_u64 nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
int npts_s32 = (_s32) npts;
|
||||
int ndims_s32 = (_s32) ndims;
|
||||
writer.write((char*) &npts_s32, sizeof(_s32));
|
||||
writer.write((char*) &ndims_s32, sizeof(_s32));
|
||||
float* read_buf = new float[npts * (ndims + 1)];
|
||||
float* write_buf = new float[npts * ndims];
|
||||
for (_u64 i = 0; i < nblks; i++) {
|
||||
_u64 cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "partition_and_pq.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
template<typename T>
|
||||
int aux_main(int argc, char** argv) {
|
||||
std::string base_file(argv[2]);
|
||||
std::string output_prefix(argv[3]);
|
||||
float sampling_rate = (float) (std::atof(argv[4]));
|
||||
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 5) {
|
||||
std::cout << argv[0] << " data_type [fliat/int8/uint8] base_bin_file "
|
||||
"sample_output_prefix sampling_probability"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float")) {
|
||||
aux_main<float>(argc, argv);
|
||||
} else if (std::string(argv[1]) == std::string("int8")) {
|
||||
aux_main<int8_t>(argc, argv);
|
||||
} else if (std::string(argv[1]) == std::string("uint8")) {
|
||||
aux_main<uint8_t>(argc, argv);
|
||||
} else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "math_utils.h"
|
||||
#include "partition_and_pq.h"
|
||||
|
||||
#define KMEANS_ITERS_FOR_PQ 15
|
||||
|
||||
template<typename T>
|
||||
bool generate_pq(const std::string& data_path,
|
||||
const std::string& index_prefix_path,
|
||||
const size_t num_pq_centers, const size_t num_pq_chunks,
|
||||
const float sampling_rate) {
|
||||
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors_path =
|
||||
index_prefix_path + "_compressed.bin";
|
||||
|
||||
// generates random sample and sets it to train_data and updates train_size
|
||||
size_t train_size, train_dim;
|
||||
float* train_data;
|
||||
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size,
|
||||
train_dim);
|
||||
std::cout << "For computing pivots, loaded sample data of size " << train_size
|
||||
<< std::endl;
|
||||
|
||||
generate_pq_pivots(train_data, train_size, train_dim, num_pq_centers,
|
||||
num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
|
||||
generate_pq_data_from_pivots<T>(data_path, num_pq_centers, num_pq_chunks,
|
||||
pq_pivots_path, pq_compressed_vectors_path);
|
||||
|
||||
delete[] train_data;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 6) {
|
||||
std::cout
|
||||
<< "Usage: \n"
|
||||
<< argv[0]
|
||||
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
|
||||
" <PQ_prefix_path> <target-bytes/data-point> <sampling_rate>"
|
||||
<< std::endl;
|
||||
} else {
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string index_prefix_path(argv[3]);
|
||||
const size_t num_pq_centers = 256;
|
||||
const size_t num_pq_chunks = (size_t) atoi(argv[4]);
|
||||
const float sampling_rate = atof(argv[5]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
generate_pq<float>(data_path, index_prefix_path, num_pq_centers,
|
||||
num_pq_chunks, sampling_rate);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers,
|
||||
num_pq_chunks, sampling_rate);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers,
|
||||
num_pq_chunks, sampling_rate);
|
||||
else
|
||||
std::cout << "Error. wrong file type" << std::endl;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int8_t* input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
|
||||
float* output = new float[npts * nd];
|
||||
diskann::convert_types<int8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream& reader, std::ofstream& writer, _u32* read_buf,
|
||||
_u32* write_buf, _u64 npts, _u64 ndims) {
|
||||
reader.read((char*) read_buf,
|
||||
npts * (ndims * sizeof(_u32) + sizeof(unsigned)));
|
||||
for (_u64 i = 0; i < npts; i++) {
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1,
|
||||
ndims * sizeof(_u32));
|
||||
}
|
||||
writer.write((char*) write_buf, npts * ndims * sizeof(_u32));
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
_u64 fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
reader.read((char*) &ndims_u32, sizeof(unsigned));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
_u64 ndims = (_u64) ndims_u32;
|
||||
_u64 npts = fsize / ((ndims + 1) * sizeof(_u32));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims
|
||||
<< std::endl;
|
||||
|
||||
_u64 blk_size = 131072;
|
||||
_u64 nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
int npts_s32 = (_s32) npts;
|
||||
int ndims_s32 = (_s32) ndims;
|
||||
writer.write((char*) &npts_s32, sizeof(_s32));
|
||||
writer.write((char*) &ndims_s32, sizeof(_s32));
|
||||
_u32* read_buf = new _u32[npts * (ndims + 1)];
|
||||
_u32* write_buf = new _u32[npts * ndims];
|
||||
for (_u64 i = 0; i < nblks; i++) {
|
||||
_u64 cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "aux_utils.h"
|
||||
#include "cached_io.h"
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc != 9) {
|
||||
std::cout
|
||||
<< argv[0]
|
||||
<< " vamana_index_prefix[1] vamana_index_suffix[2] idmaps_prefix[3] "
|
||||
"idmaps_suffix[4] n_shards[5] max_degree[6] output_vamana_path[7] "
|
||||
"output_medoids_path[8]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string vamana_prefix(argv[1]);
|
||||
std::string vamana_suffix(argv[2]);
|
||||
std::string idmaps_prefix(argv[3]);
|
||||
std::string idmaps_suffix(argv[4]);
|
||||
_u64 nshards = (_u64) std::atoi(argv[5]);
|
||||
_u32 max_degree = (_u64) std::atoi(argv[6]);
|
||||
std::string output_index(argv[7]);
|
||||
std::string output_medoids(argv[8]);
|
||||
|
||||
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix,
|
||||
idmaps_suffix, nshards, max_degree, output_index,
|
||||
output_medoids);
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
uint32_t* input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<uint32_t>(argv[1], input, npts, nd);
|
||||
uint8_t* output = new uint8_t[npts * nd];
|
||||
diskann::convert_types<uint32_t, uint8_t>(input, output, npts, nd);
|
||||
diskann::save_bin<uint8_t>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|